-
Please see below simple code for testing vmap with branch together, it shows below as below, please help to check. Thanks! Traceback (most recent call last):
File "jax_branch_simple.py", line 14, in <module>
y = f(x)
File "/.conda/envs/exp/lib/python3.6/site-packages/jax/api.py", line 1232, in batched_fun
).call_wrapped(*args_flat)
File "/.conda/envs/exp/lib/python3.6/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "jax_branch_simple.py", line 6, in branch
if x > 0:
File "/.conda/envs/exp/lib/python3.6/site-packages/jax/core.py", line 529, in __bool__
def __bool__(self): return self.aval._bool(self)
File "/.conda/envs/exp/lib/python3.6/site-packages/jax/core.py", line 957, in error
raise ConcretizationTypeError(arg, fname_context)
jax._src.errors.ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(bool[])>with<BatchTrace(level=1/0)>
with val = DeviceArray([False, True, False, False, False, False, False, False,
True, True], dtype=bool)
batch_dim = 0
The problem arose with the `bool` function.
(https://jax.readthedocs.io/en/latest/errors.html#jax.errors.ConcretizationTypeError) from jax import random, vmap
from functools import partial
def branch(x):
#return x > 0 #working if without branch
if x > 0:
return x
else:
return -x
n = 10
x = random.normal(random.PRNGKey(0), (n,))
f = partial(vmap, in_axes=(0))(branch)
y = f(x)
print(x)
print(y) |
Beta Was this translation helpful? Give feedback.
Answered by
bbfrog
Apr 6, 2021
Replies: 1 comment 2 replies
-
If changes the branch function to use lax.cond, it works with vmap. Thanks! def branch(x): |
Beta Was this translation helpful? Give feedback.
2 replies
Answer selected by
bbfrog
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
If changes the branch function to use lax.cond, it works with vmap. Thanks!
def branch(x):
return lax.cond(x>0, lambda x:x, lambda x:-x, x)