How to correctly bind a primitive that returns a Pytree? #16301
-
Hello all, I'm trying to create some new primitives. It's probably not helpful to discuss the full details of the application, but ultimately, the problem I'm struggling with is how to create new JAX primitives that return multiple results or, more generically, a Pytree of results. Here's a simple example that shows what I'm trying to do. (Of course it's pointless to define new primitives like this, but I'm trying to create the simplest possible example.) Following the tutorial at https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html I wrote this code: from jax import core
from jax._src import abstract_arrays
import jax
def make_prim(name,fun):
p = core.Primitive(name)
def prim(key,*args):
return p.bind(key,*args)
p.def_impl(fun)
def abstract_eval(*args):
out_shape = jax.eval_shape(fun,*args)
out_abstract = jax.tree_util.tree_map(lambda shape : abstract_arrays.ShapedArray(shape.shape,shape.dtype), out_shape)
return out_abstract
p.def_abstract_eval(abstract_eval)
return prim
add = make_prim("add", lambda a,b:a+b)
print(f'{add(2.0,3.0)}') # ok
print(f'{jax.make_jaxpr(add)(2.0,3.0)}') # ok
mul = make_prim("mul", lambda a,b:a*b)
print(f'{mul(2.0,3.0)}') # ok
print(f'{jax.make_jaxpr(mul)(2.0,3.0)}') # ok
add_and_mul = make_prim("add_and_mul", lambda a,b:(a+b,a*b))
print(f'{add_and_mul(2.0,3.0)}') # ok
print(f'{jax.make_jaxpr(add_and_mul)(2.0,3.0)}') # fails with "TypeError: <class 'tuple'>" All this runs except for the last line, which fails with the error listed above. I think the problem is that Primitives are supposed to, by default, return single arguments. So, for the above case only, I was seemingly able to fix things through the following: def make_prim_mult(name,fun):
p = core.Primitive(name)
p.multiple_results=True # <------- eeeeeek
def prim(key,*args):
return p.bind(key,*args)
p.def_impl(fun)
def abstract_eval(*args):
out_shape = jax.eval_shape(fun,*args)
out_abstract = jax.tree_util.tree_map(lambda shape : abstract_arrays.ShapedArray(shape.shape,shape.dtype), out_shape)
return out_abstract
p.def_abstract_eval(abstract_eval)
return prim
add_and_mul = make_prim_mult("add_and_mul", lambda a,b:(a+b,a*b))
print(f'{add_and_mul(2.0,3.0)}') # ok
print(f'{jax.make_jaxpr(add_and_mul)(2.0,3.0)=}') # ok This works, but the assignment to p.multiple_results is a bit worrying. But even if we wanted to accept that, what if I wanted to create a primitive like below,that results a full Pytree? add_and_mul_and_div = make_prim_mult("add_and_mul", lambda a,b:(a+b,(a*b,a/b)))
print(f'{add_and_mul_and_div(2.0,3.0)}') # ok
print(f'{jax.make_jaxpr(add_and_mul_and_div)(2.0,3.0)=}') # fails with "TypeError: <class 'tuple'>" Does anyone know, is it possible to use some other incantation to call bind and make this work? Or is there a different intended solution? Thanks in advance! |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 4 replies
-
Justin, I'm a huge fan of your work, and your prescient blog!
We keep So you should handle any pytrees in the Python callable which calls def prim(key, *args):
# get the output tree structure
out_shape = jax.eval_shape(fun, *args)
out_tree = jax.tree_util.tree_structure(out_shape)
# bind the primitive, which returns flat results
out_flat = p.bind(key, *args)
# unflatten results
return jax.tree_util.tree_unflatten(out_tree, out_flat) You'd additionally need to change the impl and the abstract eval rule to return a flat list of results, rather than pytree-structured results like If This is sort of a "local" improvement to your code, but depending on what you really want to do there may be alternative ways to organize things, e.g. to avoid redundant work flattening and unflattening. If you want to say more about what you're doing, I'm happy to listen if it might help! |
Beta Was this translation helpful? Give feedback.
-
Hi Matthew, Thanks so much for the reply (and the kind words!) The idea that primitives only deal with lists makes perfect sense (except for the function that does the binding), and has helped a bunch of things fall into place in terms of understanding how JAX works and why it's designed like it is. I implemented a version of the above toy code that (I think) implements all the flattening and unflattening. It works fine in testing, but if you could, I'd be grateful if you could let me know if this looks like the right approach to you. from jax import core
from jax._src import abstract_arrays
import jax
def make_prim(name,fun):
p = core.Primitive(name)
p.multiple_results=True
def prim(*args):
# flatten inputs
args_flat, args_tree = jax.tree_util.tree_flatten(args)
# get output shape
out_shape = jax.eval_shape(fun, *args)
out_tree = jax.tree_util.tree_structure(out_shape)
# call flat function
out_flat = p.bind(*args_flat, args_tree=args_tree)
# unflatten outputs
out_pytree = jax.tree_util.tree_unflatten(out_tree, out_flat)
return out_pytree
def impl(*args_flat,args_tree):
# unflatten inputs
args = jax.tree_util.tree_unflatten(args_tree, args_flat)
# call original function to get pytree output
out_pytree = fun(*args)
# flatten pytree of outputs
out_flat, out_tree = jax.tree_util.tree_flatten(out_pytree)
return out_flat
p.def_impl(impl)
def abstract_eval(*args_flat,args_tree):
# unflatten inputs
args = jax.tree_util.tree_unflatten(args_tree, args_flat)
# get pytree of shapes
out_shape_pytree = jax.eval_shape(fun,*args)
# flatten pytree of shapes
out_shape_flat, out_shape_tree = jax.tree_util.tree_flatten(out_shape_pytree)
# transform into abstract arrays
out_abstract_flat = list(map(lambda shape : abstract_arrays.ShapedArray(shape.shape,shape.dtype), out_shape_flat))
return out_abstract_flat
p.def_abstract_eval(abstract_eval)
return prim
# simple function mapping a scalar to a scalar
def fun(a):
return 2*a
a = 2.0
print('simple function')
print(f'{fun(a)=}')
prim = make_prim('fun',fun)
print(f'{prim(a)=}')
print(f'{jax.make_jaxpr(prim)(a)=}')
# "complex" function mapping multiple pytrees to multiple pytrees
def fun(a,b):
return (a[0] + a[1][0]*a[1][1], b, (b+1.0, (b*2.0, b*3.0)))
a = (1.0,(2.0,3.0))
b = 4.0
print('complex simple function')
print(f'{fun(a,b)=}')
prim = make_prim('fun',fun)
print(f'{prim(a,b)=}')
print(f'{jax.make_jaxpr(prim)(a,b)=}') I really appreciate the offer to talk more about what I'm doing! Long-term, my goal is to create a library for specifying probabilistic models in JAX. Sort of similar to {pymc4, numpyro, tfp, oryx}, but targeting folks developing approximate inference algorithms, rather than end users. I'd like to do this by creating program transformations analogous to |
Beta Was this translation helpful? Give feedback.
-
Hi, I'm also looking at the possibility of binding a primitive with Pytree inputs and outputs. |
Beta Was this translation helpful? Give feedback.
Justin, I'm a huge fan of your work, and your prescient blog!
Primitive.bind
can't return pytrees. It can only return a jaxtype single result (ifmultiple_results=False
) or a (flat) sequence of jaxtype results (whenmultiple_results=True
). By "jaxtype" I essentially mean array. (As for themultiple_results
flag existing at all, be not afraid! That's exactly what you're meant to do with it. Someday we'll do the cleanup so thatmultiple_results
is always true for all Primitives. But it hasn't been a big priority. So for now all the jaxpr interpreters have to if/else switch onmultiple_results
when they bind a primitive.)We keep
Primitive.bind
and indeed all JAX internals from having to know…