Replies: 1 comment
-
I think that the Take a look at those docs to see if you can get what you need from that! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
I'm writing a custom JVP for a non-differentiable function. But the gradient can be estimated in a computationally expensive way for each argument. So I would like to restrict the gradient computation for only those components which require it. Is there a recommended way to do this?
Code example:
I appreciate the for loop is best formatted in terms of a lax scan/foriloop: I'm just trying to illustrate with some pseudo code to get the question across clearly.
My question is: what suitable for should
x_i_gradient_needed
take in the above code to avoid the computation when its a gradient with respect tox[i]
is not needed? Is it sufficient for it to be something like likex_dot[i] != 0.0
? Or is there a better jax-esque way of doing this?Thanks for any help!
Beta Was this translation helpful? Give feedback.
All reactions