Optimizing XLA Compilation when using a tuple of functions as an argument in a scan function #24283
Unanswered
christophedessers
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello,
I’m encountering a significant increase in XLA compilation time when using a tuple of functions as an argument in a
scan
function which uses a switch to select some function to apply.Specifically, the issue arises when I scale the system by duplicating these functions. If I understand correctly, creating N functions would lead JAX to compile N functions. However, in my case, I’m only working with two distinct functions (
fct1
,fct2
) which are partial and both of them only have 1 value as a "curried" parameter. Therefore, although I replicate these functions many times, I hoped that JAX would see there are actually only two 2 unique functions, not N.Here is the simplified code :
Here’s when the problem arises: As the number of functions duplicates increases, the XLA compilation time for the
scan
function grows rapidly (for example, with 5000 replicas, compilation takes about 2 minutes). I suspect JAX doesn't understand that there are 2 unique functions.My questions:
Here are the results of the “xla_dump_to” flag containing the pre-optimization HLO files :
dump_Simple_test.zip
I’m new to JAX, so I might be missing some key optimization strategies. Any insights or suggestions would be greatly appreciated!
Thank you for your time.
Beta Was this translation helpful? Give feedback.
All reactions