Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More flexible debugging of triton IRs #223

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 36 additions & 14 deletions jax_triton/triton_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,42 +143,56 @@ def aval_size_bytes(aval):


def ptx_get_kernel_name(module) -> str:
return tc.get_kernel_name(module, pattern='// .globl')
return tc.get_kernel_name(module, pattern="// .globl")


def _maybe_dump(
dump: Callable[[str, str], None] | bool, ir_name: str, ir_body: Any
):
"""Do the right thing w.r.t. logging the IR.

Args:
dump: if it is a callable use it to log the IR. If it's bool, its value
decieds whether to do nothing or dump the IR in stdout.
ir_name: the name of the IR.
ir_body: the text of the ir. Must implement `__str__()`.
"""

if callable(dump):
dump(ir_name, str(ir_body))
elif dump:
print(ir_body)


def compile_ttir_to_ptx_inplace(
ttir,
device: int = 0,
num_warps: int = 4,
num_stages: Optional[int] = None,
dump: bool = False,
dump: Callable[[str, str], None] | bool = False,
) -> Tuple[str, str, int, int]:
compute_capability = triton_kernel_call_lib.get_compute_capability(device)
if num_stages is None:
num_stages = 3 if compute_capability >= 75 else 2
if dump:
print(ttir)
_maybe_dump(dump, "ttir", ttir)
try:
ttir = tc.optimize_ttir(ttir, compute_capability)
ttgir = tc.ttir_to_ttgir(ttir, num_warps)
ttgir = tc.optimize_ttgir(ttgir, num_stages, compute_capability)
except RuntimeError as e:
ttir.dump()
raise ValueError("TTIR->TTGIR pass failed!") from e
if dump:
print(ttgir)
_maybe_dump(dump, "ttgir", ttgir)
extern_libs = {}
try:
llir = tc.ttgir_to_llir(ttgir, extern_libs, compute_capability)
except RuntimeError as e:
ttgir.dump()
raise ValueError("TTGIR->LLIR pass failed!") from e
shared_mem_bytes = _triton.get_shared_memory_size(ttgir)
if dump:
print(llir)
_maybe_dump(dump, "llir", llir)
ptx = tc.llir_to_ptx(llir, compute_capability)
if dump:
print(ptx)
_maybe_dump(dump, "ptx", ptx)
name = ptx_get_kernel_name(ptx)
return ptx, name, shared_mem_bytes, compute_capability

Expand All @@ -194,7 +208,7 @@ def get_or_create_triton_kernel(
num_warps,
num_stages,
metaparams,
dump: bool,
dump: Callable[[str, str], None] | bool,
) -> Tuple[triton_kernel_call_lib.TritonKernel, Any]:
signature = dict(enumerate(arg_dtypes))
# TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers
Expand Down Expand Up @@ -229,8 +243,14 @@ def get_or_create_triton_kernel(
# general.
device = 0
arch = triton_kernel_call_lib.get_compute_capability(device)
_maybe_dump(dump, "py", fn.src)
module = code_gen.ast_to_ttir(
fn, signature, specialization, constants, debug=dump, arch=arch
fn,
signature,
specialization,
constants,
debug=dump is not None,
arch=arch,
)
ttir = str(module) # `module`` is compiled in-place, so copy TTIR here.
ptx, kernel_name, shared_mem_bytes, compute_capability = (
Expand Down Expand Up @@ -454,7 +474,7 @@ def triton_call(
zeroed_outputs: Union[
Sequence[int], Callable[[Dict[str, Any]], Sequence[int]]
] = (),
debug: bool = False,
debug: Callable[[str, str], None] | bool = False,
serialized_metadata: bytes = b"",
**metaparams: Any,
) -> Any:
Expand Down Expand Up @@ -529,7 +549,9 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray:
indices, for outputs that should be zeroed before the kernel is launched.
num_warps: The number of warps used to execute the Triton kernel.
num_stages: The number of stages emitted by the Triton compiler.
debug: Prints out intermediate IRs if True for debugging purposes.
debug: Passes the IRs to a callable `debug(ir_name, ir_body)` for debugging.
It could also be a bool where `True` means dump to stdout and `False` don't
dump anything.
serialized_metadata: Arbitrary metadata that will be added into the
serialized kernel call.
**metaparams: Additional keyword arguments that will be provided to a `grid`
Expand Down
35 changes: 35 additions & 0 deletions tests/triton_call_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,41 @@ def test_specialization(self):
# specialize" leaving `stride_{bn,cn}`.
self.assertEqual(specialization.equal_to_1, (8, 10))

def test_debug_callable(self):
emitted_irs = dict()

m, n, k = 128, 128, 128
x, y = create_random_inputs([m, k], [k, n])

def intercept_ir(ir_name, ir_body):
ir_body = str(ir_body)
self.assertNotIn(
ir_name, emitted_irs, f"Attempted to overwrite {ir_name}"
)
self.assertNotEmpty(ir_body, f"IR '{ir_name}' was empty.")
emitted_irs[ir_name] = ir_body

block_size_m, block_size_n, block_size_k = 128, 128, 32
_ = matmul(
x,
y,
debug=intercept_ir,
BLOCK_SIZE_M=block_size_m,
BLOCK_SIZE_N=block_size_n,
BLOCK_SIZE_K=block_size_k,
K_EXACTLY_DIVISIBLE_BY_BLOCK=k % block_size_k == 0,
)

for ir_name in ["py", "ttir", "ttgir", "llir", "ptx"]:
self.assertIn(ir_name, emitted_irs, f"IR '{ir_name}' was not recorded.")
self.assertNotEmpty(emitted_irs[ir_name], f"IR '{ir_name}' was empty.")

self.assertStartsWith(
emitted_irs["py"],
"def",
"Python code is emitted as the strigification of a different object",
)


if __name__ == "__main__":
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"
Expand Down