Skip to content

Commit

Permalink
integration tests for torch layers
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Apr 26, 2024
1 parent 7f77443 commit 9db6953
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 0 deletions.
Empty file.
108 changes: 108 additions & 0 deletions tests/integration_tests/torch/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
import jax
from apax.layers.descriptor.basis_functions import GaussianBasis, RadialFunction
from apax.layers.ntk_linear import NTKLinear
from apax.layers.readout import AtomisticReadout
from apax.layers.scaling import PerElementScaleShift
from apax.nn.torch.layers.descriptor.basis import GaussianBasisT, RadialFunctionT
from apax.nn.torch.layers.ntk_linear import NTKLinearT
from apax.nn.torch.layers.readout import AtomisticReadoutT
import jax.numpy as jnp
import numpy as np
import torch

from apax.nn.torch.layers.scaling import PerElementScaleShiftT


def test_i_torch_gaussian_basis():
linj = GaussianBasis(16)

inputj = jnp.array(np.random.randn(8))

rng_key = jax.random.PRNGKey(0)
params = linj.init(rng_key, inputj)

inputt = torch.from_numpy(np.asarray(inputj, dtype=np.float32))
lint = GaussianBasisT(params=params["params"])

outj = linj.apply(params, inputj)
outt = lint(inputt)

assert np.allclose(outj, outt)
assert outj.dtype == outt.dtype


def test_i_torch_radial_basis():
linj = RadialFunction(16)

inputj = jnp.array(np.random.randn(8))

rng_key = jax.random.PRNGKey(0)
params = linj.init(rng_key, inputj)

inputt = torch.from_numpy(np.asarray(inputj, dtype=np.float32))
lint = RadialFunctionT(params=params["params"])

outj = linj.apply(params, inputj)
outt = lint(inputt)

assert np.allclose(outj, outt)
assert outj.dtype == outt.dtype



def test_i_torch_ntk_linear():
linj = NTKLinear(16)

inputj = jnp.array(np.random.randn(8))

rng_key = jax.random.PRNGKey(0)
params = linj.init(rng_key, inputj)

inputt = torch.from_numpy(np.asarray(inputj, dtype=np.float32))
lint = NTKLinearT(params=params["params"])

outj = linj.apply(params, inputj)
outt = lint(inputt)

assert np.allclose(outj, outt)
assert outj.dtype == outt.dtype


def test_i_torch_readout():

linj = AtomisticReadout(16)

inputj = jnp.array(np.random.randn(8))


rng_key = jax.random.PRNGKey(0)
params = linj.init(rng_key, inputj)

inputt = torch.from_numpy(np.asarray(inputj, dtype=np.float32))
lint = AtomisticReadoutT(params=params["params"])

outj = linj.apply(params, inputj)
outt = lint(inputt)

assert np.allclose(outj, outt)
assert outj.dtype == outt.dtype


def test_i_torch_scaling():

linj = PerElementScaleShift(16)

inputj = jnp.array(np.random.randn(8))

rng_key = jax.random.PRNGKey(0)
params = linj.init(rng_key, inputj)

inputt = torch.from_numpy(np.asarray(inputj, dtype=np.float32))
lint = PerElementScaleShiftT(params=params["params"])

outj = linj.apply(params, inputj)
outt = lint(inputt)

assert np.allclose(outj, outt)
assert outj.dtype == outt.dtype

13 changes: 13 additions & 0 deletions tests/integration_tests/torch/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import jax
from apax.layers.descriptor.basis_functions import GaussianBasis, RadialFunction
from apax.layers.ntk_linear import NTKLinear
from apax.layers.readout import AtomisticReadout
from apax.layers.scaling import PerElementScaleShift
from apax.nn.torch.layers.descriptor.basis import GaussianBasisT, RadialFunctionT
from apax.nn.torch.layers.ntk_linear import NTKLinearT
from apax.nn.torch.layers.readout import AtomisticReadoutT
import jax.numpy as jnp
import numpy as np
import torch

from apax.nn.torch.layers.scaling import PerElementScaleShiftT

0 comments on commit 9db6953

Please sign in to comment.