Debugging performance discrepancy between PyTorch and JAX variants of NVDiffrast #21
+29
−30
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
The following notes are modified from the related Notion card
Benchmarking script: b3d/test/test_renderer_fps.py
Before
Output of
python test/test_renderer_fps.py
:First change: Using
lax.scan
instead of for loopThis should let us get rid of some overhead from XLA…
(Note:
lax.while_loop
should achieve similar effect)Related:
jax.lax.scan
andjax.lax.while_loop
Second change: Removing unnecessary
cudaStreamSynchronize(stream)
Disclaimer: I’m not certain about this change, since I’m new to CUDA programming.
It looks like we’re calling
cudaStreamSynchronize(stream)
a lot in the definition of the JAX rasterize wrapper code (e.g.jax_rasterize_gl.cpp
). However, except for debugging, we probably don’t want to block CPU until the stream has finished execution?By deleting
cudaStreamSynchronize(stream)
from the C++ implementations, we can see another performance bump on the JAX rasterizer:Note 1: It seems like removing all Stream synchronization from b3d version of the renderer can result in nondeterministic CUDA error. I haven’t take a super close look into the b3d-version of the renderer (aka
JAX
) to find out what are removable, so the numbers are not included above. Though even when it doesn't error out, we don't see the same performance boost:Note 2: After removing the unnecessary
cudaStreamSynchronize
call, the output of JAX NVDiffrast is still the same as the PyTorch version::) I'm just pushing the code here so people can give it a try. Even though we're only tweaking the rasterization operator here, this can give us some ideas about how to improve the performance on the overall rendering pipeline. @nishadgothoskar