Implementation plans for pallas.dynamic_slice
and scatter_reduce
ops
#25281
Unanswered
olivier-peltre
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
When trying to execute a pallas kernel that calls
pl.dslice(start, size)
on any accelerator (GPU/TPU) I getNotImplementedError
(jaxlib == 0.4.34)pl.dslice
to come soon?Use case
scatter_reduce
ops or any suggestions to go forward?I noticed that
scatter_add
scales very bad, and never managed to haveindices_are_sorted=True
to produce a significant difference (in previous attempts, I think I gotindices_are_sorted=False
in the compiled jaxpr even when passing it as keyword).I will now try comparing with torch + rusty1s/pytorch_scatter to get an idea of the gains I could possibly hope for.
N.B. I'm looking for an efficient way to aggregate values based on a static index array, though I understand there are many constraints that may prevent very efficient dynamic scatter-reduce ops in XLA.
MW Example
Beta Was this translation helpful? Give feedback.
All reactions