JAX vs. TF MLPerf Benchmark #4488
-
I recently came across these results: https://cloud.google.com/blog/products/ai-machine-learning/google-breaks-ai-performance-records-in-mlperf-with-worlds-fastest-training-supercomputer and I was wondering why the runtimes are different between JAX and TensorFlow if they both use XLA under the hood. I tried searching for documentation detailing the differences in the way they use XLA but came up short and was wondering if there was a fundamental difference between how JAX JIT-compiles programs to XLA versus how TensorFlow does it. In addition to this, I looked at certain XLA dumps of the same methods in JAX (with XLA) vs. TF (with XLA) and they appear to be fundamentally different. I was hoping to get some more insight into this. |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
While JAX and TensorFlow both use XLA as their compiler on TPUs, there are many reasons why similar models implemented in JAX and TensorFlow might not end up with exactly the same XLA HLO: differences between the TF-XLA bridge and JAX-XLA translations, differences between layer implementations in TensorFlow/Keras and JAX neural network libraries like Flax, etc. But in the MLPerf submissions, we made an effort to produce very similar XLA HLO from both TensorFlow and JAX, so that XLA optimizations could apply equally to both submissions. Instead, end-to-end timing differences between TF and JAX tended to come from a few different sources:
|
Beta Was this translation helpful? Give feedback.
While JAX and TensorFlow both use XLA as their compiler on TPUs, there are many reasons why similar models implemented in JAX and TensorFlow might not end up with exactly the same XLA HLO: differences between the TF-XLA bridge and JAX-XLA translations, differences between layer implementations in TensorFlow/Keras and JAX neural network libraries like Flax, etc.
But in the MLPerf submissions, we made an effort to produce very similar XLA HLO from both TensorFlow and JAX, so that XLA optimizations could apply equally to both submissions. Instead, end-to-end timing differences between TF and JAX tended to come from a few different sources: