-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Abnormal Memory Demand #7
Comments
I am using JAX version 0.4.37 installed via pip. |
related to
|
Hi @ben0i0d I agree with you that it is unresonable, but sadly it is working "as expected". dist_fn = jax.vmap(jax.vmap(tsnex.tsne.euclidean_distance, in_axes=(0, None)), in_axes=(None, 0))
D = dist_fn(x,x) with some napkin calculations: assuming single floating point (32bits), to compute D we need to create a (2500, 2500, 784) matrix, which should give ~20Gb, then we add D and we get the >30Gb. The reason sklearn can get away with it is because, other than all the approximations method that I need to include #4 #5 , they rely on scipy distance calculations (when using Now, what can we do... The first simple solution could be to check the dimension of x, and if above a threshold k, split D to compute every k rows (that would save some memory). Ideally, I think more optimize distance calculations should be implemented, but that takes time. If you feel to contribute I'd gladly review the above "quick" solution, otherwise I'll give it a go soon. Thanks for pointing this out and let's try to solve it! |
Thank you for your reply. I also noticed related to the distance computation: If I have time later, I would be interested in learning more about the optimization methods you mentioned (#4 #5) and trying some code. Since I am working on an ROCm platform and almost all tSNE acceleration implementations are written for CUDA, I also tried the tSNE-Torch implementation. Unfortunately, it took 2 minutes and 13 seconds on this dataset, whereas this JAX implementation is extremely fast, taking only 0.1 seconds (so I truly appreciate your work). |
When using t-SNE on a dataset of shape (2500, 784) on the CPU, I found that the memory usage peaked at around 34.5GB, which made it impossible to run the program on my local computer.
This is quite unreasonable since the dataset is small. Furthermore, I did not observe such behavior with sklearn or pcax; they showed almost no memory demand.
I have uploaded my code and process to https://github.com/ben0i0d/tsne-test.
If you have time, please take a look.
The text was updated successfully, but these errors were encountered: