From 9db6953dfd28a5bba61067f40d48ce74cdefda36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 26 Apr 2024 09:42:45 +0200 Subject: [PATCH] integration tests for torch layers --- tests/integration_tests/torch/__init__.py | 0 tests/integration_tests/torch/layers.py | 108 ++++++++++++++++++++++ tests/integration_tests/torch/models.py | 13 +++ 3 files changed, 121 insertions(+) create mode 100644 tests/integration_tests/torch/__init__.py create mode 100644 tests/integration_tests/torch/layers.py create mode 100644 tests/integration_tests/torch/models.py diff --git a/tests/integration_tests/torch/__init__.py b/tests/integration_tests/torch/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration_tests/torch/layers.py b/tests/integration_tests/torch/layers.py new file mode 100644 index 00000000..7eb04aaf --- /dev/null +++ b/tests/integration_tests/torch/layers.py @@ -0,0 +1,108 @@ +import jax +from apax.layers.descriptor.basis_functions import GaussianBasis, RadialFunction +from apax.layers.ntk_linear import NTKLinear +from apax.layers.readout import AtomisticReadout +from apax.layers.scaling import PerElementScaleShift +from apax.nn.torch.layers.descriptor.basis import GaussianBasisT, RadialFunctionT +from apax.nn.torch.layers.ntk_linear import NTKLinearT +from apax.nn.torch.layers.readout import AtomisticReadoutT +import jax.numpy as jnp +import numpy as np +import torch + +from apax.nn.torch.layers.scaling import PerElementScaleShiftT + + +def test_i_torch_gaussian_basis(): + linj = GaussianBasis(16) + + inputj = jnp.array(np.random.randn(8)) + + rng_key = jax.random.PRNGKey(0) + params = linj.init(rng_key, inputj) + + inputt = torch.from_numpy(np.asarray(inputj, dtype=np.float32)) + lint = GaussianBasisT(params=params["params"]) + + outj = linj.apply(params, inputj) + outt = lint(inputt) + + assert np.allclose(outj, outt) + assert outj.dtype == outt.dtype + + +def test_i_torch_radial_basis(): + linj = RadialFunction(16) + + inputj = jnp.array(np.random.randn(8)) + + rng_key = jax.random.PRNGKey(0) + params = linj.init(rng_key, inputj) + + inputt = torch.from_numpy(np.asarray(inputj, dtype=np.float32)) + lint = RadialFunctionT(params=params["params"]) + + outj = linj.apply(params, inputj) + outt = lint(inputt) + + assert np.allclose(outj, outt) + assert outj.dtype == outt.dtype + + + +def test_i_torch_ntk_linear(): + linj = NTKLinear(16) + + inputj = jnp.array(np.random.randn(8)) + + rng_key = jax.random.PRNGKey(0) + params = linj.init(rng_key, inputj) + + inputt = torch.from_numpy(np.asarray(inputj, dtype=np.float32)) + lint = NTKLinearT(params=params["params"]) + + outj = linj.apply(params, inputj) + outt = lint(inputt) + + assert np.allclose(outj, outt) + assert outj.dtype == outt.dtype + + +def test_i_torch_readout(): + + linj = AtomisticReadout(16) + + inputj = jnp.array(np.random.randn(8)) + + + rng_key = jax.random.PRNGKey(0) + params = linj.init(rng_key, inputj) + + inputt = torch.from_numpy(np.asarray(inputj, dtype=np.float32)) + lint = AtomisticReadoutT(params=params["params"]) + + outj = linj.apply(params, inputj) + outt = lint(inputt) + + assert np.allclose(outj, outt) + assert outj.dtype == outt.dtype + + +def test_i_torch_scaling(): + + linj = PerElementScaleShift(16) + + inputj = jnp.array(np.random.randn(8)) + + rng_key = jax.random.PRNGKey(0) + params = linj.init(rng_key, inputj) + + inputt = torch.from_numpy(np.asarray(inputj, dtype=np.float32)) + lint = PerElementScaleShiftT(params=params["params"]) + + outj = linj.apply(params, inputj) + outt = lint(inputt) + + assert np.allclose(outj, outt) + assert outj.dtype == outt.dtype + diff --git a/tests/integration_tests/torch/models.py b/tests/integration_tests/torch/models.py new file mode 100644 index 00000000..67739bb1 --- /dev/null +++ b/tests/integration_tests/torch/models.py @@ -0,0 +1,13 @@ +import jax +from apax.layers.descriptor.basis_functions import GaussianBasis, RadialFunction +from apax.layers.ntk_linear import NTKLinear +from apax.layers.readout import AtomisticReadout +from apax.layers.scaling import PerElementScaleShift +from apax.nn.torch.layers.descriptor.basis import GaussianBasisT, RadialFunctionT +from apax.nn.torch.layers.ntk_linear import NTKLinearT +from apax.nn.torch.layers.readout import AtomisticReadoutT +import jax.numpy as jnp +import numpy as np +import torch + +from apax.nn.torch.layers.scaling import PerElementScaleShiftT