-
I am not sure if that's is a bug or just my lack of experience. I have the following code import jax
import jax.numpy as jnp
from functools import partial
@partial(jnp.vectorize, signature='(n)->(m)')
def to_homogeneous(v):
return jnp.concatenate([v, jnp.array([1.])])
# dims (before vmapping): P: [4, 4] xyz: [3] -> [2]
def xyz_to_xy(P, xyz):
xyzw_world = to_homogeneous(xyz)
xyzw_camera = P @ xyzw_world
xy_camera = xyzw_camera[:2] / xyzw_camera[2]
return xy_camera
vc_xyz_to_xy = jnp.vectorize(xyz_to_xy, signature='(aa),(b)->(c)')
P = jnp.eye(4)
xyz = jnp.ones(3)
print('unbatched', xyz_to_xy(P, xyz))
print('batched', vc_xyz_to_xy(P, xyz)) # fails and I find that applying WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
unbatched [1. 1.]
Traceback (most recent call last):
File "projection.py", line 22, in <module>
print('batched', vc_xyz_to_xy(P, xyz)) # fails
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/numpy/vectorize.py", line 304, in wrapped
return vectorized_func(*vec_args)
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/numpy/vectorize.py", line 135, in wrapped
out = func(*args)
File "projection.py", line 13, in xyz_to_xy
xy_camera = xyzw_camera[:2] / xyzw_camera[2]
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 4384, in _rewriting_take
return _gather(arr, treedef, static_idx, dynamic_idx)
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 4391, in _gather
indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 4475, in _index_to_gather
idx = _canonicalize_tuple_index(len(x_shape), idx)
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 4748, in _canonicalize_tuple_index
raise IndexError(msg.format(len_without_none, arr_ndim))
jax._src.traceback_util.FilteredStackTrace: IndexError: Too many indices for array: 1 non-None/Ellipsis indices for dim 0.
The stack trace above excludes JAX-internal frames.
The following is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "projection.py", line 22, in <module>
print('batched', vc_xyz_to_xy(P, xyz)) # fails
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/numpy/vectorize.py", line 304, in wrapped
return vectorized_func(*vec_args)
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 139, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/api.py", line 1237, in batched_fun
out_flat = batching.batch(
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/linear_util.py", line 166, in call_wrapped
ans = self.f(*args, **dict(self.params, **kwargs))
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/numpy/vectorize.py", line 135, in wrapped
out = func(*args)
File "projection.py", line 13, in xyz_to_xy
xy_camera = xyzw_camera[:2] / xyzw_camera[2]
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/core.py", line 552, in __getitem__
def __getitem__(self, idx): return self.aval._getitem(self, idx)
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 4384, in _rewriting_take
return _gather(arr, treedef, static_idx, dynamic_idx)
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 4391, in _gather
indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 4475, in _index_to_gather
idx = _canonicalize_tuple_index(len(x_shape), idx)
File "/home/jatentaki/Storage/jatentaki/miniconda3/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py", line 4748, in _canonicalize_tuple_index
raise IndexError(msg.format(len_without_none, arr_ndim))
IndexError: Too many indices for array: 1 non-None/Ellipsis indices for dim 0. Am I doing something wrong or is this a bug in |
Beta Was this translation helpful? Give feedback.
Answered by
shoyer
Apr 20, 2021
Replies: 1 comment 1 reply
-
You're missing a comma in the With that change, your code seems to run correctly. |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
shoyer
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
You're missing a comma in the
signature
string invectorize
. It should besignature='(a,a),(b)->(c)'
.With that change, your code seems to run correctly.