-
Notifications
You must be signed in to change notification settings - Fork 82
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Is there a way to know how much memory is requried for a task? #423
Comments
Hi @feiyang-k , please check https://jax.readthedocs.io/en/latest/gpu_memory_allocation.html on how to change the pre-allocation JAX does. In short, you can do import os # before importing anything jax
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"
import jax
import jax.numpy as jnp
... |
Thanks @michalk8 ! I tried it and it works exactly as I wished! By the way, I'm using it jupyter notebook and the GPU memory recycling seems not fully working. Each time an OT problem is computed, the GPU memory will not be released. More interestingly, if the computing successfully completed, it is ok to use the memory to compute the next problem. But if a problem went into error, then the allocated memory seems "dead". The available memory for the next OT problem will be the remaining memory, which could be much smaller. Thus, I would need to restart the Jupyter Notebook kernel every time I went into any error with ott. Is this a known issue? Also, it seems I'm never able to interrupt the block running OT problems. It will never respond. I will need restart to jupyter notebook kernel whenever a task seems will never finish in a reasonable time. Is this as expected? Thanks again! |
According to the docs,
I will go and investigate this behavior.
I'm not 100% sure, but would say yes, as the code runs on device and the interrupt will happen when execution is given to host (will check if this statement is true). Maybe adding a printing callback (see this tutorial) will allow for easier interruption of an execution. |
Hi,
jax
seems to reserve all the gpu memory at import. So we cannot see how much memory is used exactly by the ott package from the nvidia panels. Right now, if some problem runs into memory issues, the only thing we can do is to reduce the problem size until the error disappears. Is there a more direct way to know how much memory is required for a target task?Thanks!
The text was updated successfully, but these errors were encountered: