-
I noticed that with Jax's lax.while_loop construct, the translated Jaxpr doesn't really do any loop invariant code motion (LICM). I was wondering in general does this impact the quality of the JIT compiled code? I know that XLA has LICM passes, but not sure if it's guaranteed to be triggered. This is very common for example in gradient optimization, where some part of the gradient function can be precomputed to some intermediate result, and the optimizer loop can just use that intermediate result repeatedly. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
My bad. What I found was wrong. Jax does perform LICM in some way (common computation of the gradient and the likelihood function is only done before entering the loop). Though I don't know how it's achieved, maybe through partial evaluation? |
Beta Was this translation helpful? Give feedback.
My bad. What I found was wrong. Jax does perform LICM in some way (common computation of the gradient and the likelihood function is only done before entering the loop). Though I don't know how it's achieved, maybe through partial evaluation?