Skip to content

Does JAX evaluate the false-branch of lax.cond? #6142

Answered by jakevdp
helange23 asked this question in Q&A
Discussion options

You must be logged in to vote

JAX traces both functions (so side-effects like print will always execute) but will in general only execute the traced code if it is required to compute the result. You can see this in practice using some quick benchmarks:

import numpy as np
import jax.numpy as jnp
from jax import lax

def fast_func(x):
  return x.sum()

def slow_func(x):
  return jnp.linalg.svd(x)[1].sum()

x = jnp.array(np.random.rand(1000, 1000))

%timeit fast_func(x).block_until_ready()
# 1 loop, best of 5: 2.14 ms per loop

%timeit slow_func(x).block_until_ready()
# 1 loop, best of 5: 522 ms per loop

%timeit lax.cond(True, slow_func, fast_func, x).block_until_ready()
# 1 loop, best of 5: 530 ms per loop

%timeit lax.c…

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by helange23
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants