Why aren't jitted function a JAX object? #6664
Unanswered
AdrienCorenflos
asked this question in
Ideas
Replies: 1 comment 3 replies
-
Does this do what you want? from functools import partial
from jax import jit
@jit
def f(x):
return x
@partial(jit, static_argnums=0)
def g(h, y):
return h(y)
g(f, 5.) |
Beta Was this translation helpful? Give feedback.
3 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
There are a number of cases when I have wanted to pass a jitted function as an argument to another one, but this is not possible (yet), contrarily to say tensorflow. Des this have something to do with possible closure problems/gradient tracing?
Would it possibly considered as a future improvement or is it a no go?
To fix ideas, I'm talking of doing something like this:
Beta Was this translation helpful? Give feedback.
All reactions