Skip to content

How to correctly bind a primitive that returns a Pytree? #16301

Answered by mattjj
justindomke asked this question in Q&A
Discussion options

You must be logged in to vote

Justin, I'm a huge fan of your work, and your prescient blog!

Primitive.bind can't return pytrees. It can only return a jaxtype single result (if multiple_results=False) or a (flat) sequence of jaxtype results (whenmultiple_results=True). By "jaxtype" I essentially mean array. (As for the multiple_results flag existing at all, be not afraid! That's exactly what you're meant to do with it. Someday we'll do the cleanup so that multiple_results is always true for all Primitives. But it hasn't been a big priority. So for now all the jaxpr interpreters have to if/else switch on multiple_results when they bind a primitive.)

We keep Primitive.bind and indeed all JAX internals from having to know…

Replies: 3 comments 4 replies

Comment options

You must be logged in to vote
0 replies
Answer selected by justindomke
Comment options

You must be logged in to vote
2 replies
@patrick-kidger
Comment options

@femtomc
Comment options

Comment options

You must be logged in to vote
2 replies
@jakevdp
Comment options

@femtomc
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
6 participants