From 9d51dc7356a34bf51b020b30855b2343be8051f9 Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Mon, 2 Dec 2024 07:48:46 -0800 Subject: [PATCH 1/2] Add tests using kwargs --- tests/test_sparsejac.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/tests/test_sparsejac.py b/tests/test_sparsejac.py index 3bfbb70..30e4104 100644 --- a/tests/test_sparsejac.py +++ b/tests/test_sparsejac.py @@ -181,6 +181,19 @@ def fn(x): onp.testing.assert_array_equal(expected_jac, result_jac.todense()) onp.testing.assert_array_equal(expected_aux, result_aux) + def test_kwargs(self): + def fn(x, y): + return y * jnp.convolve(x, jnp.asarray([1.0, -2.0, 1.0]), mode="same") ** 2 + + x = jax.random.uniform(jax.random.PRNGKey(0), shape=(_SIZE,)) + i, j = jnp.meshgrid(jnp.arange(_SIZE), jnp.arange(_SIZE), indexing="ij") + sparsity = (i == j) | ((i - 1) == j) | ((i + 1) == j) + sparsity = jsparse.BCOO.fromdense(sparsity) + + result_jac = sparsejac.jacrev(fn, sparsity)(x, y=1) + expected_jac = jax.jacrev(fn)(x, y=1) + onp.testing.assert_array_equal(expected_jac, result_jac.todense()) + class JacfwdTest(unittest.TestCase): def test_sparsity_shape_validation(self): @@ -350,6 +363,19 @@ def fn(x): onp.testing.assert_array_equal(expected_jac, result_jac.todense()) onp.testing.assert_array_equal(expected_aux, result_aux) + def test_kwargs(self): + def fn(x, y): + return y * jnp.convolve(x, jnp.asarray([1.0, -2.0, 1.0]), mode="same") ** 2 + + x = jax.random.uniform(jax.random.PRNGKey(0), shape=(_SIZE,)) + i, j = jnp.meshgrid(jnp.arange(_SIZE), jnp.arange(_SIZE), indexing="ij") + sparsity = (i == j) | ((i - 1) == j) | ((i + 1) == j) + sparsity = jsparse.BCOO.fromdense(sparsity) + + result_jac = sparsejac.jacfwd(fn, sparsity)(x, y=1) + expected_jac = jax.jacfwd(fn)(x, y=1) + onp.testing.assert_array_equal(expected_jac, result_jac.todense()) + class ConnectivityFromSparsityTest(unittest.TestCase): def test_output_connectivity_matches_expected(self): From a46818778db5226b91e06ef534ef4a379601f06d Mon Sep 17 00:00:00 2001 From: Martin Schubert Date: Mon, 2 Dec 2024 07:49:13 -0800 Subject: [PATCH 2/2] Version updated from v0.1.3 to v0.2.0 --- .bumpversion.toml | 2 +- README.md | 2 +- pyproject.toml | 2 +- src/sparsejac/__init__.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.bumpversion.toml b/.bumpversion.toml index 07e0132..afa052e 100644 --- a/.bumpversion.toml +++ b/.bumpversion.toml @@ -1,5 +1,5 @@ [tool.bumpversion] -current_version = "v0.1.3" +current_version = "v0.2.0" commit = true commit_args = "--no-verify" tag = true diff --git a/README.md b/README.md index 5a41960..9a2754a 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # sparsejac: Efficient sparse Jacobians using Jax -`v0.1.3` +`v0.2.0` Sparse Jacobians are frequently encountered in the simulation of physical systems. Jax tranformations `jacfwd` and `jacrev` make it easy to compute dense Jacobians, but these are wasteful when the Jacobian is sparse. `sparsejac` provides a function to more efficiently compute the Jacobian if its sparsity is known. It makes use of the recently-introduced `jax.experimental.sparse` module. diff --git a/pyproject.toml b/pyproject.toml index 68bded4..f1db1e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "sparsejac" -version = "v0.1.3" +version = "v0.2.0" description = "Efficient forward- and reverse-mode sparse Jacobians using Jax." keywords = ["jax", "jacobian", "sparse"] readme = "README.md" diff --git a/src/sparsejac/__init__.py b/src/sparsejac/__init__.py index c14030b..a48de17 100644 --- a/src/sparsejac/__init__.py +++ b/src/sparsejac/__init__.py @@ -1,6 +1,6 @@ """sparsejac - Efficient forward- and reverse-mode sparse Jacobians using Jax.""" -__version__ = "v0.1.3" +__version__ = "v0.2.0" __author__ = "Martin Schubert " from sparsejac.sparsejac import jacfwd as jacfwd