Some issues with omnistaging #4810
Replies: 1 comment 6 replies
-
I believe outside omnistaging your code will silently do the wrong thing, because it depends on global state. Here is your original code, executed without omnistaging: from jax import config
config.disable_omnistaging()
import jax, jax.numpy as jp, jax.random as rax
init_rng = rax.PRNGKey(1)
params = dict()
def fn(h, x):
global init_rng
init_rng, rng = rax.split(init_rng)
w = params.setdefault("w", rax.uniform(rng, shape=[]))
return jp.tanh(h + w * x), h
jax.lax.scan(fn, 0., jp.arange(10))
print(params["w"])
# 0.12568676
print(init_rng)
# [2441914641 1384938218] And here is a snippet that performs identical logic (but that works with omnistaging): init_rng = rax.PRNGKey(1)
params = dict()
init_rng, rng = rax.split(init_rng)
w = params.setdefault("w", rax.uniform(rng, shape=[]))
def fn(h, x):
return jp.tanh(h + w * x), h
jax.lax.scan(fn, 0., jp.arange(10))
print(params["w"])
# 0.12568676
print(init_rng)
# [2441914641 1384938218] The A benefit of the omnistaging approach is that it tells you, loudly, that your code is not working as you intended, rather than silently giving you unintended results. |
Beta Was this translation helpful? Give feedback.
-
So I just updated JAX and discovered omnistaging by way of it breaking my code (which is fine). After learning more about it and figuring out what it would take to make it work for me, I decided to disable it. The pros of omnistaging, as I understand them, include trace time efficiency and code simplification. I'm desperately in favor of those, but I also find the cons hard to swallow: it breaks valid use cases with no simple workaround, and it fundamentally alters the way JAX works by introducing global mutable state.
Concretely, in my code I often use this sort of pattern to discover and initialize parameters on the fly in an initial dummy pass through the model:
With omnistaging, this doesn't work because it leaks a tracer. In fact this is pretty much one of the examples of what not to do from the omnistaging note. That note made me believe that anything that omnistage breaks was a bug to begin with, but upon reflection I don't think it's nearly so clear-cut. Omnistaging just isn't a straightforward change that only breaks already-broken code. It's a qualitative change that breaks working code by breaking the expected functional semantics. It has real implications for real things built on JAX.
How would I fix my code above? Supposedly I would use numpy instead of JAX. But then I would have to rewrite a bunch of code in numpy using numpy's random system which has an entirely different interface. Or worse, I would have to write code that calls either numpy or jax depending on context. I'm having trouble understanding why omnistaging means I can no longer just call JAX functions with constant values and get constant results.
More generally, omnistaging marks a departure from the functional style that makes JAX simple (if not always easy). Now some innocuous function like
jax.numpy.zeros
can mean one of two very different things depending on some mutable global state. Observe:This brings back bad memories of the way Theano and Tensorflow handled symbolic control flow behind the scenes -- push some manager onto a global stack and let it take control of every operation created under its reign. This broke on-the-fly parameter construction in the same way, which is why Tensorflow has initialization infrastructure all over the place even though initialization isn't really conceptually a special operation. Theano left its users to deal with it, somehow. This is not a road anyone should wish to go down.
JAX has been a breath of fresh air with its functional style... breaking that sets an ominous precedent. Is omnistaging here to stay? Will it always be optional? Are there further departures from functional in the pipeline?
Beta Was this translation helpful? Give feedback.
All reactions