Detect when a new JAX frame starts #6009
Unanswered
dionhaefner
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
For
mpi4jax
, our API would get a lot cleaner if we could keep track of local state within eachjit
block to inject tokens into our custom calls. But for this I need to detect when a new frame starts to clear the local state.I can make this work with something like this:
Then we can use it like this:
Is there any way to detect that we are now in a different frame as the one that
current_token
originated from that doesn't rely on catchingUnexpectedTracerError
?Beta Was this translation helpful? Give feedback.
All reactions