diff --git a/iree/jax/builtins.py b/iree/jax/builtins.py index 0f6f07f..a91b525 100644 --- a/iree/jax/builtins.py +++ b/iree/jax/builtins.py @@ -40,7 +40,7 @@ class jit_kernel(tracing.CallableIntrinsic): def __init__(self, wrapped_f, *, wrap_with_jit: bool = True): self.wrapped_f = wrapped_f - self.jit_f = jax.jit(self.wrapped_f, backend="iree") if wrap_with_jit else self.wrapped_f + self.jit_f = jax.jit(self.wrapped_f, backend="cpu") if wrap_with_jit else self.wrapped_f def __repr__(self): return f""