-
How to implement tf.gather_nd in Jax? This is similar to #3658 and the answer by @shoyer seemed simple enough. |
Beta Was this translation helpful? Give feedback.
Answered by
hawkinsp
Mar 18, 2021
Replies: 1 comment 2 replies
-
Without batch dimensions, If you need batch dimensions, you can use |
Beta Was this translation helpful? Give feedback.
2 replies
Answer selected by
AlexiaJM
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Without batch dimensions,
x[tuple(jnp.moveaxis(indices, -1, 0))]
should work.If you need batch dimensions, you can use
vmap
. (It's slightly more awkward with many batch dimensions.)