You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Dear all,
I'd like to do auto differentiation for some function f(g(x)), where f can be implemented with pure jax.numpy functions but g is some complicated function implemented in C. The jvp rule for g is also known and is implemented in C. Is it possible to add a custom primitive for g in JAX? I saw in the tutorial where it states that the jvp rule has to be implemented in a JAX-tracible way. How can I achieve that?
Thanks.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Dear all,
I'd like to do auto differentiation for some function
f(g(x))
, wheref
can be implemented with purejax.numpy
functions butg
is some complicated function implemented in C. Thejvp
rule forg
is also known and is implemented in C. Is it possible to add a custom primitive forg
in JAX? I saw in the tutorial where it states that thejvp
rule has to be implemented in a JAX-tracible way. How can I achieve that?Thanks.
Beta Was this translation helpful? Give feedback.
All reactions