Set GPU memory allocation via Python (programatically) #6102
Unanswered
fabiannagel
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi! I've been benchmarking JAX-MD and noticed some strange runtime variance that I have tracked down to JAX GPU memory allocation behavior. For that reason, I want to run JAX in various VRAM allocation modes. I know it works in a terminal with
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
but I'd prefer to do it programatically via Python.Updating the OS environment variables is not the issue but I noticed that the JAX module needs to be reloaded in order to take notice of it. Applying the sledgehammer via
importlib.reload()
causes some issues with pickling, as that's the way I save my benchmark results.Therefore my question: Is there a "nice" and simple way to get JAX to reload XLA environment variables without doing a full-blown re-import of the entire module? Thanks for your help :)
Beta Was this translation helpful? Give feedback.
All reactions