From d4603195d9422059566cb0500b1487737acbb749 Mon Sep 17 00:00:00 2001 From: Sharad Vikram Date: Wed, 16 Nov 2022 22:22:45 +0000 Subject: [PATCH] Fix dot to use preferred_element_type --- jax_triton/pallas/lowering.py | 36 ++++++++++++--------- jax_triton/pallas/primitives.py | 6 +++- jax_triton/triton_call.py | 55 ++++++++++++++++++++------------- tests/pallas_test.py | 8 ++--- 4 files changed, 63 insertions(+), 42 deletions(-) diff --git a/jax_triton/pallas/lowering.py b/jax_triton/pallas/lowering.py index 2268c4b8..8d0992e4 100644 --- a/jax_triton/pallas/lowering.py +++ b/jax_triton/pallas/lowering.py @@ -33,7 +33,6 @@ from jax._src.state import primitives as sp from jax._src.state import discharge from jax._src.state import ShapedArrayRef -from jax_triton.triton_call import get_triton_python_ir import jax.numpy as jnp import triton import triton.language as tl @@ -42,6 +41,8 @@ import triton._C.libtriton.triton as _triton import jax_triton as jt +from jax_triton.triton_call import get_triton_python_ir +from jax_triton.triton_call import get_triton_element_type from jax_triton.pallas import primitives map, unsafe_map = util.safe_map, map @@ -247,13 +248,8 @@ def _convert_element_type_lowering_rule(ctx: TritonLoweringRuleContext, a, *, new_dtype, weak_type): if new_dtype == ctx.avals_in[0].dtype: return a - if new_dtype == jnp.float32: - new_dtype = tl.float32 - elif new_dtype == jnp.float16: - new_dtype = tl.float16 - elif new_dtype == jnp.bfloat16: - new_dtype = tl.bfloat16 - return tl.semantic.cast(a, new_dtype, ctx.builder) + triton_eltype = get_triton_element_type(ctx.avals_in[0].dtype) + return tl.semantic.cast(a, triton_eltype, ctx.builder) triton_lowering_rules[jax.lax.convert_element_type_p] = _convert_element_type_lowering_rule def max_lowering_rule(ctx: TritonLoweringRuleContext, a, b): @@ -312,6 +308,7 @@ def _offset_ptr(ptr, idx: primitives.NDIndexer, shape, builder, is_scalar): other_shape = indexer_shape[len(idx.int_indexer_shape):] bcast_indices = [] other_shape_idx = 0 + dest_shape = map(tl.constexpr, idx.get_indexer_shape()) for stride, index, dim_size, is_sc in zip(strides, indices, shape, is_scalar): if isinstance(index, primitives.Slice): index_size = index.size @@ -344,12 +341,10 @@ def _offset_ptr(ptr, idx: primitives.NDIndexer, shape, builder, is_scalar): index = tl.broadcast_to(index, desired_shape, _builder=builder) else: index = tl.reshape(index, desired_shape, _builder=builder) + if dest_shape != index.shape: + index = tl.broadcast_to(index, dest_shape, _builder=builder) stride_size = tl.core._to_tensor(int(stride), builder) bcast_indices.append(index.__mul__(stride_size, _builder=builder)) - dest_shape = map(tl.constexpr, idx.get_indexer_shape()) - bcast_indices = [ - tl.broadcast_to(index, dest_shape, _builder=builder) if dest_shape != index.shape - else index for index in bcast_indices] for bcast_idx in bcast_indices: ptr = ptr.__add__(bcast_idx, _builder=builder) return ptr @@ -466,15 +461,26 @@ def _addupdate_lowering_rule(ctx: TritonLoweringRuleContext, ptr, value, def _dot_general_lowering(ctx: TritonLoweringRuleContext, a, b, *, dimension_numbers, precision, preferred_element_type): + if preferred_element_type is None: + preferred_element_type = ctx.avals_out[0].dtype contract_dims, batch_dims = dimension_numbers - assert batch_dims == ((), ()) + if batch_dims != ((), ()): + raise NotImplementedError("`batch_dims` currently unsupported.") + if len(contract_dims[0]) != 1 or len(contract_dims[1]) != 1: + raise NotImplementedError("Multiple contraction dimensions currently unsupported.") a_contract_dim, = contract_dims[0] b_contract_dim, = contract_dims[1] trans_a = a_contract_dim == 0 trans_b = b_contract_dim == 1 allow_tf32 = precision == lax.Precision.HIGH or precision == lax.Precision.DEFAULT - return tl.dot(a, b, _builder=ctx.builder, trans_a=trans_a, trans_b=trans_b, - allow_tf32=allow_tf32) + out = tl.dot(a, b, _builder=ctx.builder, trans_a=trans_a, trans_b=trans_b, + allow_tf32=allow_tf32) + out_eltype = get_triton_element_type(preferred_element_type) + if out_eltype != out.dtype: + # `tl.dot` by default outputs f32 accumulation. We cast it to the dtype JAX + # wants. + out = tl.semantic.cast(out, out_eltype, ctx.builder) + return out triton_lowering_rules[jax.lax.dot_general_p] = _dot_general_lowering def _reduce_lowering(triton_op, ctx: TritonLoweringRuleContext, a, *, axes): diff --git a/jax_triton/pallas/primitives.py b/jax_triton/pallas/primitives.py index 6393b724..555b0813 100644 --- a/jax_triton/pallas/primitives.py +++ b/jax_triton/pallas/primitives.py @@ -380,7 +380,11 @@ def store(x_ref, idx, val, *, mask=None, eviction_policy="") -> None: def dot(a, b, trans_a=False, trans_b=False, allow_tf32=True): rhs_contract_dim = int(trans_b) lhs_contract_dim = int(not trans_a) + # `pl.dot`, like `tl.dot` does accumulation in f32. + preferred_element_type = None + if jnp.issubdtype(a.dtype, jnp.floating): + preferred_element_type = jnp.dtype("float32") return jax.lax.dot_general( a, b, dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())), precision=lax.Precision.HIGH if allow_tf32 else lax.Precision.HIGHEST, - preferred_element_type=None) + preferred_element_type=preferred_element_type) diff --git a/jax_triton/triton_call.py b/jax_triton/triton_call.py index 199057da..dd071507 100644 --- a/jax_triton/triton_call.py +++ b/jax_triton/triton_call.py @@ -57,30 +57,41 @@ triton_type_mappings = {} -def get_triton_type(obj: Any) -> str: - type_map = { - jnp.dtype("bfloat16"): "bf16", - jnp.dtype("float64"): "fp64", - jnp.dtype("float32"): "fp32", - jnp.dtype("float16"): "fp16", - # Triton has 'fp8' as well which Jax doesn't support yet. - - jnp.dtype("int64"): "i64", - jnp.dtype("int32"): "i32", - jnp.dtype("int16"): "i16", - jnp.dtype("int8"): "i8", - - jnp.dtype("uint64"): "u64", - jnp.dtype("uint32"): "u32", - jnp.dtype("uint16"): "u16", - jnp.dtype("uint8"): "u8", - - # Triton defines a 'B' type, which is an alias for both i1 and bool. - jnp.dtype("bool"): "B", - } +_element_type_map = { + jnp.dtype("bfloat16"): (tl.bfloat16, "bf16"), + jnp.dtype("float64"): (tl.float64, "fp64"), + jnp.dtype("float32"): (tl.float32, "fp32"), + jnp.dtype("float16"): (tl.float16, "fp16"), + # Triton has 'fp8' as well which Jax doesn't support yet. + + jnp.dtype("int64"): (tl.int64, "i64"), + jnp.dtype("int32"): (tl.int32, "i32"), + jnp.dtype("int16"): (tl.int16, "i16"), + jnp.dtype("int8"): (tl.int8, "i8"), + + jnp.dtype("uint64"): (tl.uint64, "u64"), + jnp.dtype("uint32"): (tl.uint32, "u32"), + jnp.dtype("uint16"): (tl.uint16, "u16"), + jnp.dtype("uint8"): (tl.uint8, "u8"), + + # Triton defines a 'B' type, which is an alias for both i1 and bool. + jnp.dtype("bool"): (tl.int32, "B"), +} +def get_triton_element_type(dtype: jnp.dtype) -> tl.dtype: + if dtype not in _element_type_map: + raise NotImplementedError(f"Unknown dtype: {dtype}") + return _element_type_map[dtype][0] + +def get_triton_element_type_as_str(dtype: jnp.dtype) -> str: + if dtype not in _element_type_map: + raise NotImplementedError(f"Unknown dtype: {dtype}") + return _element_type_map[dtype][1] + +def get_triton_type(obj: Any) -> str: if isinstance(obj, (jax.core.ShapedArray, state.ShapedArrayRef)): - return f"*{type_map[obj.dtype]}" + eltype = get_triton_element_type_as_str(obj.dtype) + return f"*{eltype}" if isinstance(obj, tl.constexpr): obj = obj.value if isinstance(obj, int): diff --git a/tests/pallas_test.py b/tests/pallas_test.py index 59ccef79..9a75b201 100644 --- a/tests/pallas_test.py +++ b/tests/pallas_test.py @@ -144,7 +144,7 @@ def body(i, acc_ref): jax.lax.broadcast_in_dim(idx_k, (bk, bn), (0,)), jax.lax.broadcast_in_dim(idx_n, (bk, bn), (1,))) x_block, y_block = x_ref[x_idx], y_ref[y_idx] - out = jnp.dot(x_block, y_block) + out = pl.dot(x_block, y_block) acc_ref[:, :] += out acc = for_loop(k // bk, body, acc).astype(o_ref.dtype) o_idx = ( @@ -157,7 +157,7 @@ def body(i, acc_ref): x = random.normal(k1, (m, k), dtype=dtype) y = random.normal(k2, (k, n), dtype=dtype) out, expected = matmul(x, y), jnp.matmul(x, y) - np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05) + np.testing.assert_allclose(out, expected, atol=0.03, rtol=0.03) @parameterized.named_parameters(*( dict(testcase_name=f"{size}_{dtype}", size=size, dtype=dtype) @@ -177,13 +177,13 @@ def test_dot(self, size, dtype): def dot(x_ref, y_ref, o_ref): x = x_ref[:, :] y = y_ref[:, :] - o_ref[:, :] = pl.dot(x, y) + o_ref[:, :] = pl.dot(x, y).astype(o_ref.dtype) k1, k2 = random.split(random.PRNGKey(0)) x = random.normal(k1, (size, size), dtype=dtype) y = random.normal(k2, (size, size), dtype=dtype) out, expected = dot(x, y), jnp.dot(x, y) - np.testing.assert_allclose(out, expected, atol=0.05, rtol=0.05) + np.testing.assert_allclose(out, expected, atol=0.02, rtol=0.02) @parameterized.named_parameters(*( dict(testcase_name=f"{batch_size}_{size}_{block_size}_{dtype}",