diff --git a/jax_triton/triton_lib.py b/jax_triton/triton_lib.py index 8e678d8f..55ab5296 100644 --- a/jax_triton/triton_lib.py +++ b/jax_triton/triton_lib.py @@ -143,7 +143,25 @@ def aval_size_bytes(aval): def ptx_get_kernel_name(module) -> str: - return tc.get_kernel_name(module, pattern='// .globl') + return tc.get_kernel_name(module, pattern="// .globl") + + +def _maybe_dump( + dump: Callable[[str, str], None] | bool, ir_name: str, ir_body: Any +): + """Do the right thing w.r.t. logging the IR. + + Args: + dump: if it is a callable use it to log the IR. If it's bool, its value + decieds whether to do nothing or dump the IR in stdout. + ir_name: the name of the IR. + ir_body: the text of the ir. Must implement `__str__()`. + """ + + if callable(dump): + dump(ir_name, str(ir_body)) + elif dump: + print(ir_body) def compile_ttir_to_ptx_inplace( @@ -151,13 +169,12 @@ def compile_ttir_to_ptx_inplace( device: int = 0, num_warps: int = 4, num_stages: Optional[int] = None, - dump: bool = False, + dump: Callable[[str, str], None] | bool = False, ) -> Tuple[str, str, int, int]: compute_capability = triton_kernel_call_lib.get_compute_capability(device) if num_stages is None: num_stages = 3 if compute_capability >= 75 else 2 - if dump: - print(ttir) + _maybe_dump(dump, "ttir", ttir) try: ttir = tc.optimize_ttir(ttir, compute_capability) ttgir = tc.ttir_to_ttgir(ttir, num_warps) @@ -165,8 +182,7 @@ def compile_ttir_to_ptx_inplace( except RuntimeError as e: ttir.dump() raise ValueError("TTIR->TTGIR pass failed!") from e - if dump: - print(ttgir) + _maybe_dump(dump, "ttgir", ttgir) extern_libs = {} try: llir = tc.ttgir_to_llir(ttgir, extern_libs, compute_capability) @@ -174,11 +190,9 @@ def compile_ttir_to_ptx_inplace( ttgir.dump() raise ValueError("TTGIR->LLIR pass failed!") from e shared_mem_bytes = _triton.get_shared_memory_size(ttgir) - if dump: - print(llir) + _maybe_dump(dump, "llir", llir) ptx = tc.llir_to_ptx(llir, compute_capability) - if dump: - print(ptx) + _maybe_dump(dump, "ptx", ptx) name = ptx_get_kernel_name(ptx) return ptx, name, shared_mem_bytes, compute_capability @@ -194,7 +208,7 @@ def get_or_create_triton_kernel( num_warps, num_stages, metaparams, - dump: bool, + dump: Callable[[str, str], None] | bool, ) -> Tuple[triton_kernel_call_lib.TritonKernel, Any]: signature = dict(enumerate(arg_dtypes)) # TODO(sharadmv,zhangqiaorjc): handle differently aligned pointers @@ -229,8 +243,14 @@ def get_or_create_triton_kernel( # general. device = 0 arch = triton_kernel_call_lib.get_compute_capability(device) + _maybe_dump(dump, "py", fn.src) module = code_gen.ast_to_ttir( - fn, signature, specialization, constants, debug=dump, arch=arch + fn, + signature, + specialization, + constants, + debug=dump is not None, + arch=arch, ) ttir = str(module) # `module`` is compiled in-place, so copy TTIR here. ptx, kernel_name, shared_mem_bytes, compute_capability = ( @@ -454,7 +474,7 @@ def triton_call( zeroed_outputs: Union[ Sequence[int], Callable[[Dict[str, Any]], Sequence[int]] ] = (), - debug: bool = False, + debug: Callable[[str, str], None] | bool = False, serialized_metadata: bytes = b"", **metaparams: Any, ) -> Any: @@ -529,7 +549,9 @@ def add(x: jnp.ndarray, y: jnp.ndarray) -> jnp.ndarray: indices, for outputs that should be zeroed before the kernel is launched. num_warps: The number of warps used to execute the Triton kernel. num_stages: The number of stages emitted by the Triton compiler. - debug: Prints out intermediate IRs if True for debugging purposes. + debug: Passes the IRs to a callable `debug(ir_name, ir_body)` for debugging. + It could also be a bool where `True` means dump to stdout and `False` don't + dump anything. serialized_metadata: Arbitrary metadata that will be added into the serialized kernel call. **metaparams: Additional keyword arguments that will be provided to a `grid` diff --git a/tests/triton_call_test.py b/tests/triton_call_test.py index 71a59806..bbb3492d 100644 --- a/tests/triton_call_test.py +++ b/tests/triton_call_test.py @@ -549,6 +549,41 @@ def test_specialization(self): # specialize" leaving `stride_{bn,cn}`. self.assertEqual(specialization.equal_to_1, (8, 10)) + def test_debug_callable(self): + emitted_irs = dict() + + m, n, k = 128, 128, 128 + x, y = create_random_inputs([m, k], [k, n]) + + def intercept_ir(ir_name, ir_body): + ir_body = str(ir_body) + self.assertNotIn( + ir_name, emitted_irs, f"Attempted to overwrite {ir_name}" + ) + self.assertNotEmpty(ir_body, f"IR '{ir_name}' was empty.") + emitted_irs[ir_name] = ir_body + + block_size_m, block_size_n, block_size_k = 128, 128, 32 + _ = matmul( + x, + y, + debug=intercept_ir, + BLOCK_SIZE_M=block_size_m, + BLOCK_SIZE_N=block_size_n, + BLOCK_SIZE_K=block_size_k, + K_EXACTLY_DIVISIBLE_BY_BLOCK=k % block_size_k == 0, + ) + + for ir_name in ["py", "ttir", "ttgir", "llir", "ptx"]: + self.assertIn(ir_name, emitted_irs, f"IR '{ir_name}' was not recorded.") + self.assertNotEmpty(emitted_irs[ir_name], f"IR '{ir_name}' was empty.") + + self.assertStartsWith( + emitted_irs["py"], + "def", + "Python code is emitted as the strigification of a different object", + ) + if __name__ == "__main__": os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.5"