Skip to content

Commit

Permalink
Merge pull request #3550 from agriyakhetarpal/bump-jax-jaxlib-versions
Browse files Browse the repository at this point in the history
Upgrade to newest versions of `jax` + `jaxlib` and add Windows support for JAX Solver
  • Loading branch information
Saransh-cpp authored Dec 8, 2023
2 parents 93c265b + f41be98 commit 22d1229
Show file tree
Hide file tree
Showing 8 changed files with 74 additions and 44 deletions.
27 changes: 1 addition & 26 deletions .github/workflows/docker.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,7 @@ jobs:
echo "tag=all" >> "$GITHUB_OUTPUT"
fi
- name: Build and push Docker image to Docker Hub (no solvers)
if: matrix.build-args == 'No solvers'
- name: Build and push Docker image to Docker Hub (${{ matrix.build-args }})
uses: docker/build-push-action@v5
with:
context: .
Expand All @@ -58,29 +57,5 @@ jobs:
push: true
platforms: linux/amd64, linux/arm64

- name: Build and push Docker image to Docker Hub (with ODES and IDAKLU solvers)
if: matrix.build-args == 'ODES' || matrix.build-args == 'IDAKLU'
uses: docker/build-push-action@v5
with:
context: .
file: scripts/Dockerfile
tags: pybamm/pybamm:${{ steps.tags.outputs.tag }}
push: true
build-args: ${{ matrix.build-args }}=true
platforms: linux/amd64, linux/arm64

- name: Build and push Docker image to Docker Hub (with ALL and JAX solvers)
if: matrix.build-args == 'ALL' || matrix.build-args == 'JAX'
uses: docker/build-push-action@v5
with:
context: .
file: scripts/Dockerfile
tags: pybamm/pybamm:${{ steps.tags.outputs.tag }}
push: true
build-args: ${{ matrix.build-args }}=true
# exclude arm64 for JAX and ALL builds for now, see
# https://github.com/google/jax/issues/13608
platforms: linux/amd64

- name: List built image(s)
run: docker images
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
- Fixed bug in calculation of theoretical energy that made it very slow ([#3506](https://github.com/pybamm-team/PyBaMM/pull/3506))
- The irreversible plating model now increments `f"{Domain} dead lithium concentration [mol.m-3]"`, not `f"{Domain} lithium plating concentration [mol.m-3]"` as it did previously. ([#3485](https://github.com/pybamm-team/PyBaMM/pull/3485))

## Optimizations

- Updated `jax` and `jaxlib` to the latest available versions and added Windows (Python 3.9+) support for the Jax solver ([#3550](https://github.com/pybamm-team/PyBaMM/pull/3550))

## Breaking changes

- Dropped support for the `[jax]` extra, i.e., the Jax solver when running on Python 3.8. The Jax solver is now available on Python 3.9 and above ([#3550](https://github.com/pybamm-team/PyBaMM/pull/3550))

# [v23.9](https://github.com/pybamm-team/PyBaMM/tree/v23.9) - 2023-10-31

## Features
Expand Down
5 changes: 4 additions & 1 deletion docs/source/user_guide/installation/GNU-linux.rst
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,10 @@ Optional - JaxSolver
~~~~~~~~~~~~~~~~~~~~

Users can install ``jax`` and ``jaxlib`` to use the Jax solver.
Currently, only GNU/Linux and macOS are supported.

.. note::

The Jax solver is not supported on Python 3.8. It is supported on Python 3.9, 3.10, and 3.11.

.. code:: bash
Expand Down
6 changes: 3 additions & 3 deletions docs/source/user_guide/installation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,13 @@ Dependency Minimum Version p
Jax dependencies
^^^^^^^^^^^^^^^^^

Installable with ``pip install "pybamm[jax]"``
Installable with ``pip install "pybamm[jax]"``, currently supported on Python 3.9-3.11.

========================================================================= ================== ================== =======================
Dependency Minimum Version pip extra Notes
========================================================================= ================== ================== =======================
`JAX <https://jax.readthedocs.io/en/latest/notebooks/quickstart.html>`__ 0.4.8 jax For JAX solvers
`jaxlib <https://pypi.org/project/jaxlib/>`__ 0.4.7 jax Support library for JAX
`JAX <https://jax.readthedocs.io/en/latest/notebooks/quickstart.html>`__ 0.4.20 jax For the JAX solver
`jaxlib <https://pypi.org/project/jaxlib/>`__ 0.4.20 jax Support library for JAX
========================================================================= ================== ================== =======================

.. _install.odes_dependencies:
Expand Down
15 changes: 15 additions & 0 deletions docs/source/user_guide/installation/windows.rst
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,21 @@ installed automatically when you install PyBaMM using ``pip``.
For an introduction to virtual environments, see
(https://realpython.com/python-virtual-environments-a-primer/).

Optional - JaxSolver
~~~~~~~~~~~~~~~~~~~~

Users can install ``jax`` and ``jaxlib`` to use the Jax solver.

.. note::

The Jax solver is not supported on Python 3.8. It is supported on Python 3.9, 3.10, and 3.11.

.. code:: bash
pip install "pybamm[jax]"
The ``pip install "pybamm[jax]"`` command automatically downloads and installs ``pybamm`` and the compatible versions of ``jax`` and ``jaxlib`` on your system. (``pybamm_install_jax`` is deprecated.)

Uninstall PyBaMM
----------------

Expand Down
47 changes: 38 additions & 9 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,12 @@ def run_coverage(session):
set_environment_variables(PYBAMM_ENV, session=session)
session.install("coverage", silent=False)
if sys.platform != "win32":
session.install("-e", ".[all,odes,jax]", silent=False)
session.install("-e", ".[all,jax,odes]", silent=False)
else:
session.install("-e", ".[all]", silent=False)
if sys.version_info < (3, 9):
session.install("-e", ".[all]", silent=False)
else:
session.install("-e", ".[all,jax]", silent=False)
session.run("coverage", "run", "run-tests.py", "--nosub")
session.run("coverage", "combine")
session.run("coverage", "xml")
Expand All @@ -74,9 +77,12 @@ def run_integration(session):
"""Run the integration tests."""
set_environment_variables(PYBAMM_ENV, session=session)
if sys.platform != "win32":
session.install("-e", ".[all,odes,jax]", silent=False)
session.install("-e", ".[all,jax,odes]", silent=False)
else:
session.install("-e", ".[all]", silent=False)
if sys.version_info < (3, 9):
session.install("-e", ".[all]", silent=False)
else:
session.install("-e", ".[all,jax]", silent=False)
session.run("python", "run-tests.py", "--integration")


Expand All @@ -92,9 +98,12 @@ def run_unit(session):
"""Run the unit tests."""
set_environment_variables(PYBAMM_ENV, session=session)
if sys.platform != "win32":
session.install("-e", ".[all,odes,jax]", silent=False)
session.install("-e", ".[all,jax,odes]", silent=False)
else:
session.install("-e", ".[all]", silent=False)
if sys.version_info < (3, 9):
session.install("-e", ".[all]", silent=False)
else:
session.install("-e", ".[all,jax]", silent=False)
session.run("python", "run-tests.py", "--unit")


Expand Down Expand Up @@ -144,17 +153,37 @@ def set_dev(session):
external=True,
)
else:
session.run(python, "-m", "pip", "install", "-e", ".[all,dev]", external=True)
if sys.version_info < (3, 9):
session.run(
python,
"-m",
"pip",
"install",
".[all,dev]",
external=True,
)
else:
session.run(
python,
"-m",
"pip",
"install",
".[all,dev,jax]",
external=True,
)


@nox.session(name="tests")
def run_tests(session):
"""Run the unit tests and integration tests sequentially."""
set_environment_variables(PYBAMM_ENV, session=session)
if sys.platform != "win32":
session.install("-e", ".[all,odes,jax]", silent=False)
session.install("-e", ".[all,jax,odes]", silent=False)
else:
session.install("-e", ".[all]", silent=False)
if sys.version_info < (3, 9):
session.install("-e", ".[all]", silent=False)
else:
session.install("-e", ".[all,jax]", silent=False)
session.run("python", "run-tests.py", "--all")


Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = [
"setuptools>=64",
"wheel",
# On Windows, use the CasADi vcpkg registry and CMake bundled from MSVC
"casadi>=3.6.0; platform_system!='Windows'",
"casadi>=3.6.3; platform_system!='Windows'",
"cmake; platform_system!='Windows'",
]
build-backend = "setuptools.build_meta"
Expand Down Expand Up @@ -110,13 +110,13 @@ dev = [
"nbmake",
]
# Reading CSV files
pandas = [
pandas = [
"pandas>=1.5.0",
]
# For the Jax solver. Note: these must be kept in sync with the versions defined in pybamm/util.py.
jax = [
"jax>=0.4,<=0.5",
"jaxlib>=0.4,<=0.5",
"jax==0.4.20; python_version >= '3.9'",
"jaxlib==0.4.20; python_version >= '3.9'",
]
# For the scikits.odes solver
odes = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def test_evaluator_jax(self):
expr = pybamm.exp(a * b)
evaluator = pybamm.EvaluatorJax(expr)
result = evaluator(t=None, y=np.array([[2], [3]]))
self.assertEqual(result, np.exp(6))
np.testing.assert_array_almost_equal(result, np.exp(6), decimal=15)

# test a constant expression
expr = pybamm.Scalar(2) * pybamm.Scalar(3)
Expand Down

0 comments on commit 22d1229

Please sign in to comment.