Strange behavior of jax.lax.scan
#24926
-
I am trying to compute a sum as follows: # Function definitions (m, dy, dx, E, T_rev, mu are all fixed numbers/arrays)
ys = jnp.array(list(itertools.product(*((m + 1) * [np.arange(dy)]))))
@jax.jit
def z(y):
@jax.vmap
def f(x):
e = 0
x_ = x
for yi in y:
e += jnp.log(E[yi] @ x_)
x_ = T_rev @ x_
return jnp.exp(e) * (w[*y] - v @ x)**2
return mu @ f(jnp.eye(dx))
# Sum
s = 0.
for y in ys:
s += z(y) In Jax, this can be implemented as zs = jax.vmap(z)(ys)
s = zs.sum() However, c, zs_ = jax.lax.scan(lambda c, y: (c + z(y), z(y)), 0., ys) Here, c_, = jax.lax.scan(lambda c, y: (c + z, None), 0., zs) returns the correct result ( |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
Okay, it seems that the error is indeed due to the summation of many small floating point numbers. Replacing the naive addition by a Kahan addition fixes the problem. I can only assume that the last example I show above ( |
Beta Was this translation helpful? Give feedback.
Okay, it seems that the error is indeed due to the summation of many small floating point numbers. Replacing the naive addition by a Kahan addition fixes the problem. I can only assume that the last example I show above (
jax.lax.scan(lambda c, y: (c + z, None), 0., zs)
) works because Jax does some clever compilation that turns it into a non-rolling sum.