Skip to content

Jax equivalent of tf.gather_nd #6119

Answered by hawkinsp
AlexiaJM asked this question in Q&A
Mar 18, 2021 · 1 comments · 2 replies
Discussion options

You must be logged in to vote

Without batch dimensions, x[tuple(jnp.moveaxis(indices, -1, 0))] should work.

If you need batch dimensions, you can use vmap. (It's slightly more awkward with many batch dimensions.)

Replies: 1 comment 2 replies

Comment options

You must be logged in to vote
2 replies
@AlexiaJM
Comment options

@AlexiaJM
Comment options

Answer selected by AlexiaJM
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants