diff --git a/jax_triton/pallas/__init__.py b/jax_triton/pallas/__init__.py index df581891..6ddbc645 100644 --- a/jax_triton/pallas/__init__.py +++ b/jax_triton/pallas/__init__.py @@ -36,8 +36,5 @@ from jax_triton.pallas.primitives import swap from jax_triton.utils import cdiv -try: - from jax_triton.pallas import triton_ir_lowering - del triton_ir_lowering -except (ImportError, ModuleNotFoundError): - pass +from jax_triton.pallas import triton_ir_lowering +del triton_ir_lowering diff --git a/jax_triton/pallas/triton_ir_lowering.py b/jax_triton/pallas/triton_ir_lowering.py index bbca5b54..fbd94db8 100644 --- a/jax_triton/pallas/triton_ir_lowering.py +++ b/jax_triton/pallas/triton_ir_lowering.py @@ -44,7 +44,7 @@ import triton import triton.language as tl from triton.language import ir as tl_ir -import triton.libtriton.triton as _triton +import triton._C.libtriton.triton as _triton from jax_triton import triton_lib from jax_triton import triton_kernel_call_lib