Replies: 1 comment 2 replies
-
Not really an answer but there's the unreleased grain library: https://github.com/google/grain |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi,
I am wondering about how I should deal with data loading and processing in JAX in a local single GPU setup?
So far I have been doing the preprocessing (e.g. normalization) on CPU, and then applied random cropping and random flips on minibatches on the GPU in my train_step function.
However, my approach has some issues:
So far I have not been able to run the Tensorflow and JAX part of the code on the GPU at the same time. Could this have something to do with how VRAM is reserved by tensoflow and JAX?
So my main question is how to do both preprocessing and data augmentation on GPU in an elegant way?
Note: My question is closely related to this unanswered question from 2022: #13339
Beta Was this translation helpful? Give feedback.
All reactions