-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Good First Issue][JAX FE]: Support jax.lax.scatter operation for JAX #26573
Comments
.take Hi, I would like to work on this issue if its ok. Thanks! |
Thank you for looking into this issue! Please let us know if you have any questions or require any help. |
I want to work on this issue could you please assign me. @rkazants |
Sorry I haven't gotten around to this. I will be finishing this soon! |
hey @aymuos15 - do you need any support or you will have no time to finish the task? |
.take Hi @rkazants , I would like to work on this issue. |
Thank you for looking into this issue! Please let us know if you have any questions or require any help. |
Hi @rkazants , I have a query. jax.lax.scatter operation is very versatile. Scatter operations supported by openvino are only able to perform subset of the jax scatter operation. For the past few weeks, i have been trying to come up with a way of implementing full functionality of jax scatter operation using openvino opset. Only subset of jax scatter functionality can be straight forwardly implemented. I have the solution for implementing the full functionality but the code is quite complex. So, the question is am i required to implement the full functionality or am i allowed to indicate that certain functionality of jax scatter operation is not supported in openvino. If you wish i can explain you my solution in detail. |
Context
OpenVINO component responsible for support of JAX/Flax models is called as JAX Frontend (JAX FE). JAX FE converts a JAX/Flax model represented by
ClosedJAXpr
graph object with operations from jax.lax opset to OpenVINO IR containing operations from OpenVINO opset.In order to infer JAX/Flax models containing jax.lax.scatter operation by OpenVINO, JAX FE needs to be extended with this operation support.
What needs to be done?
For jax.lax.scatter operation support, you need to implement the corresponding loader into JAX FE op directory and to register it into the dictionary of Loaders. One loader is responsible for conversion (or decomposition) of one type of JAX operation.
Here is an example of loader implementation for jax.lax.reshape operation:
In this example,
translate_reshape
expressesjax.lax.reshape
using OpenVINO opset. Sincejax.lax.reshape
performs transposition and tensor reshaping according to JAX documentation, the resulted decomposition contains OpenVINOTranspose
andReshape
operations. ForTranspose
andReshape
nodes, this conversion parses constant parametersdimensions
to permute input tensor andnew_size
that is the target shape of the result.Once you are done with implementation of the translator, you need to implement the corresponding layer tests
test_scatter.py
and put it into layer_tests/jax_tests directory. Example how to run some layer test:Example Pull Requests
Resources
Contact points
Ticket
No response
The text was updated successfully, but these errors were encountered: