-
I have the following problem: I have two different function that take different amounts of time to run. I want to dispatch either function based on the value x takes. Basically, what I am trying to achieve is: if x > 0:
return slow_func(x)
else:
return fast_func(x) In JAX, this looks like: jax.lax.cond(x>0, slow_func, fast_func, operand=x) The problem is that this does not increase performance because it seems like JAX evaluates both branches but only returns the value of the true branch, like this:
Is there a way to achieve this dispatching? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
JAX traces both functions (so side-effects like import numpy as np
import jax.numpy as jnp
from jax import lax
def fast_func(x):
return x.sum()
def slow_func(x):
return jnp.linalg.svd(x)[1].sum()
x = jnp.array(np.random.rand(1000, 1000))
%timeit fast_func(x).block_until_ready()
# 1 loop, best of 5: 2.14 ms per loop
%timeit slow_func(x).block_until_ready()
# 1 loop, best of 5: 522 ms per loop
%timeit lax.cond(True, slow_func, fast_func, x).block_until_ready()
# 1 loop, best of 5: 530 ms per loop
%timeit lax.cond(False, slow_func, fast_func, x).block_until_ready()
# 100 loops, best of 5: 3.94 ms per loop |
Beta Was this translation helpful? Give feedback.
-
I am wrong. Everything works how it should. |
Beta Was this translation helpful? Give feedback.
JAX traces both functions (so side-effects like
print
will always execute) but will in general only execute the traced code if it is required to compute the result. You can see this in practice using some quick benchmarks: