-
Consider the following code snippet: from jax import value_and_grad, numpy as jnp
def bad_func(x):
return (x**2 + jnp.inf)**2
def good_func(x):
return jnp.where(x > 0, x, bad_func(x))
print(value_and_grad(good_func)(.1)) Here we created a function That's very unfortunate. Any suggestions on how to workaround this issue? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
There is some information on this topic in the Frequently Asked Questions: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where |
Beta Was this translation helpful? Give feedback.
There is some information on this topic in the Frequently Asked Questions: https://jax.readthedocs.io/en/latest/faq.html#gradients-contain-nan-where-using-where