Skip to content

JAX vs. TF MLPerf Benchmark #4488

Discussion options

You must be logged in to vote

While JAX and TensorFlow both use XLA as their compiler on TPUs, there are many reasons why similar models implemented in JAX and TensorFlow might not end up with exactly the same XLA HLO: differences between the TF-XLA bridge and JAX-XLA translations, differences between layer implementations in TensorFlow/Keras and JAX neural network libraries like Flax, etc.

But in the MLPerf submissions, we made an effort to produce very similar XLA HLO from both TensorFlow and JAX, so that XLA optimizations could apply equally to both submissions. Instead, end-to-end timing differences between TF and JAX tended to come from a few different sources:

  • differences in startup overhead (while MLPerf doesn…

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
0 replies
Comment options

You must be logged in to vote
4 replies
@n2cholas
Comment options

@pranavsubramani
Comment options

@jekbradbury
Comment options

@pranavsubramani
Comment options

Answer selected by mattjj
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
4 participants
Converted from issue

This discussion was converted from issue #4488 on October 08, 2020 00:23.