Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Good First Issue][JAX FE]: Support jax.lax.scatter operation for JAX #26573

Open
rkazants opened this issue Sep 12, 2024 · 8 comments · May be fixed by #28357
Open

[Good First Issue][JAX FE]: Support jax.lax.scatter operation for JAX #26573

rkazants opened this issue Sep 12, 2024 · 8 comments · May be fixed by #28357
Assignees
Labels
category: JAX FE OpenVINO JAX FrontEnd good first issue Good for newcomers no_stale Do not mark as stale

Comments

@rkazants
Copy link
Member

Context

OpenVINO component responsible for support of JAX/Flax models is called as JAX Frontend (JAX FE). JAX FE converts a JAX/Flax model represented by ClosedJAXpr graph object with operations from jax.lax opset to OpenVINO IR containing operations from OpenVINO opset.

In order to infer JAX/Flax models containing jax.lax.scatter operation by OpenVINO, JAX FE needs to be extended with this operation support.

What needs to be done?

For jax.lax.scatter operation support, you need to implement the corresponding loader into JAX FE op directory and to register it into the dictionary of Loaders. One loader is responsible for conversion (or decomposition) of one type of JAX operation.

Here is an example of loader implementation for jax.lax.reshape operation:

OutputVector translate_reshape(const NodeContext& context) {
    num_inputs_check(context, 1, 1);
    Output<Node> input = context.get_input(0);
    auto new_sizes = context.const_named_param<std::vector<int64_t>>("new_sizes");
    if (context.has_param("dimensions")) {
        auto dimensions = context.const_named_param<std::vector<int64_t>>("dimensions");
        // transpose the input first.
        auto permutation_node = std::make_shared<v0::Constant>(element::i64, Shape{dimensions.size()}, dimensions);
        input = std::make_shared<v1::Transpose>(input, permutation_node);
    }

    auto new_shape_node = std::make_shared<v0::Constant>(element::i64, Shape{new_sizes.size()}, new_sizes);
    Output<Node> res = std::make_shared<v1::Reshape>(input, new_shape_node, false);
    return {res};
};

In this example, translate_reshape expresses jax.lax.reshape using OpenVINO opset. Since jax.lax.reshape performs transposition and tensor reshaping according to JAX documentation, the resulted decomposition contains OpenVINO Transpose and Reshape operations. For Transpose and Reshape nodes, this conversion parses constant parameters dimensions to permute input tensor and new_size that is the target shape of the result.

Once you are done with implementation of the translator, you need to implement the corresponding layer tests test_scatter.py and put it into layer_tests/jax_tests directory. Example how to run some layer test:

export TEST_DEVICE=CPU
export JAX_TRACE_MODE=JAXPR
export 
cd openvino/tests/layer_tests/jax_tests
pytest test_reshape.py

Example Pull Requests

Resources

Contact points

  • @openvinotoolkit/openvino-jax-frontend-maintainers
  • @rkazants in GitHub and Discord

Ticket

No response

@rkazants rkazants added good first issue Good for newcomers no_stale Do not mark as stale category: JAX FE OpenVINO JAX FrontEnd labels Sep 12, 2024
@github-project-automation github-project-automation bot moved this to Contributors Needed in Good first issues Sep 12, 2024
@aymuos15
Copy link

.take

Hi, I would like to work on this issue if its ok. Thanks!

Copy link
Contributor

Thank you for looking into this issue! Please let us know if you have any questions or require any help.

@rkazants rkazants moved this from Contributors Needed to Assigned in Good first issues Sep 14, 2024
@Aman-patel-13
Copy link

I want to work on this issue could you please assign me. @rkazants

@aymuos15
Copy link

Sorry I haven't gotten around to this. I will be finishing this soon!

@mlukasze
Copy link
Contributor

mlukasze commented Nov 6, 2024

hey @aymuos15 - do you need any support or you will have no time to finish the task?

@rkazants rkazants moved this from Assigned to Contributors Needed in Good first issues Dec 9, 2024
@sumhaj
Copy link
Contributor

sumhaj commented Dec 15, 2024

.take

Hi @rkazants , I would like to work on this issue.

Copy link
Contributor

Thank you for looking into this issue! Please let us know if you have any questions or require any help.

@mlukasze mlukasze moved this from Contributors Needed to Assigned in Good first issues Dec 16, 2024
@sumhaj
Copy link
Contributor

sumhaj commented Jan 4, 2025

Hi @rkazants ,

I have a query.

jax.lax.scatter operation is very versatile. Scatter operations supported by openvino are only able to perform subset of the jax scatter operation. For the past few weeks, i have been trying to come up with a way of implementing full functionality of jax scatter operation using openvino opset.

Only subset of jax scatter functionality can be straight forwardly implemented. I have the solution for implementing the full functionality but the code is quite complex. So, the question is am i required to implement the full functionality or am i allowed to indicate that certain functionality of jax scatter operation is not supported in openvino.

If you wish i can explain you my solution in detail.

@sumhaj sumhaj linked a pull request Jan 9, 2025 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: JAX FE OpenVINO JAX FrontEnd good first issue Good for newcomers no_stale Do not mark as stale
Projects
Status: Assigned
Development

Successfully merging a pull request may close this issue.

5 participants