You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
JAX works in single precision by default, and won't even let you create double precision arrays unless an environment variable is set (JAX_ENABLE_X64=True) or a special command is run when jax is imported (config.update("jax_enable_x64", True)). In order to test double precision, these commands are used various places in the tests.
Unfortunately, enabling double precision also makes it the default for new arrays, creating situations where tests have different behavior when run on their own versus in the whole suite (because the config is "sticky" and setting it in one test affects others).
All of this may change in a future JAX release (jax-ml/jax#8178), but for now, I propose running all tests with JAX_ENABLE_X64=True JAX_DEFAULT_DTYPE_BITS=32 and removing any config.updates in test files.
The text was updated successfully, but these errors were encountered:
While there has been no further discussion in jax-ml/jax#8178 for more than a year, it appears that deprecation of the X64 flag is still being considered. Instead of simply running all the tests with X64 enabled and X32 default, perhaps we should configure that state in scico/__init__.py so that it applies across the code base?
JAX works in single precision by default, and won't even let you create double precision arrays unless an environment variable is set (
JAX_ENABLE_X64=True
) or a special command is run when jax is imported (config.update("jax_enable_x64", True)
). In order to test double precision, these commands are used various places in the tests.Unfortunately, enabling double precision also makes it the default for new arrays, creating situations where tests have different behavior when run on their own versus in the whole suite (because the config is "sticky" and setting it in one test affects others).
All of this may change in a future JAX release (jax-ml/jax#8178), but for now, I propose running all tests with
JAX_ENABLE_X64=True JAX_DEFAULT_DTYPE_BITS=32
and removing anyconfig.update
s in test files.The text was updated successfully, but these errors were encountered: