Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Apr 3, 2024
1 parent 8938c7f commit 5138587
Show file tree
Hide file tree
Showing 16 changed files with 82 additions and 38 deletions.
3 changes: 2 additions & 1 deletion apax/nn/impl/activation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from apax import ops


def swish(x):
out = 1.6765324703310907 * ops.swish(x)
return out


def inverse_softplus(x):
return ops.log(ops.exp(x) - 1.0)
return ops.log(ops.exp(x) - 1.0)
3 changes: 2 additions & 1 deletion apax/nn/impl/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
from apax import ops


def gaussian_basis_impl(dr, shifts, betta, rad_norm):
dr = einops.repeat(dr, "neighbors -> neighbors 1")
# 1 x n_basis, neighbors x 1 -> neighbors x n_basis
Expand Down Expand Up @@ -34,4 +35,4 @@ def radial_basis_impl(basis, Z_i, Z_j, embeddings, embed_norm):
radial_function = einops.einsum(
species_pair_coeffs, basis, "nbrs radial basis, nbrs basis -> nbrs radial"
)
return radial_function
return radial_function
3 changes: 1 addition & 2 deletions apax/nn/impl/gaussian_moment_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


def gaussian_moment_impl(moments, triang_idxs_2d, triang_idxs_3d, n_contr):

contr_0 = moments[0]
contr_1 = ops.einsum("ari, asi -> rsa", moments[1], moments[1])
contr_2 = ops.einsum("arij, asij -> rsa", moments[2], moments[2])
Expand Down Expand Up @@ -49,5 +48,5 @@ def gaussian_moment_impl(moments, triang_idxs_2d, triang_idxs_3d, n_contr):
]

# gaussian_moments shape: n_atoms x n_features
gaussian_moments = ops.concatenate(gaussian_moments[: n_contr], axis=-1)
gaussian_moments = ops.concatenate(gaussian_moments[:n_contr], axis=-1)
return gaussian_moments
4 changes: 3 additions & 1 deletion apax/nn/jax/layers/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from apax.nn.jax.layers.descriptor.gaussian_moment_descriptor import GaussianMomentDescriptor
from apax.nn.jax.layers.descriptor.gaussian_moment_descriptor import (
GaussianMomentDescriptor,
)

__all__ = ["GaussianMomentDescriptor"]
4 changes: 3 additions & 1 deletion apax/nn/jax/layers/descriptor/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,9 @@ def __call__(self, dr, Z_i, Z_j):
# basis shape: neighbors x n_basis
basis = self.basis_fn(dr)

radial_function = radial_basis_impl(basis, Z_i, Z_j, self.embeddings, self.embed_norm)
radial_function = radial_basis_impl(
basis, Z_i, Z_j, self.embeddings, self.embed_norm
)
cutoff = cosine_cutoff(dr, self.r_max)
radial_function = radial_function * cutoff

Expand Down
5 changes: 3 additions & 2 deletions apax/nn/jax/layers/descriptor/gaussian_moment_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from apax.utils.jax_md_reduced import space



class GaussianMomentDescriptor(nn.Module):
radial_fn: nn.Module = RadialFunction()
n_contr: int = 8
Expand Down Expand Up @@ -52,7 +51,9 @@ def __call__(self, dr_vec, Z, neighbor_idxs):
radial_function = mask_by_neighbor(radial_function, neighbor_idxs)

moments = geometric_moments(radial_function, dn, idx_j, n_atoms)
gaussian_moments = gaussian_moment_impl(moments, self.triang_idxs_2d, self.triang_idxs_3d, self.n_contr)
gaussian_moments = gaussian_moment_impl(
moments, self.triang_idxs_2d, self.triang_idxs_3d, self.n_contr
)

# # gaussian_moments shape: n_atoms x n_features
# gaussian_moments = jnp.concatenate(gaussian_moments[: self.n_contr], axis=-1)
Expand Down
4 changes: 2 additions & 2 deletions apax/nn/jax/layers/readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ def setup(self):
dense = []
for ii, n_hidden in enumerate(units):
layer = NTKLinear(
n_hidden, b_init=self.b_init, dtype=self.dtype, name=f"dense_{ii}"
)
n_hidden, b_init=self.b_init, dtype=self.dtype, name=f"dense_{ii}"
)
dense.append(layer)
if ii < len(units) - 1:
dense.append(self.activation_fn)
Expand Down
14 changes: 8 additions & 6 deletions apax/nn/jax/model/gmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
import numpy as np
from jax import Array, vmap

from apax.layers.descriptor.gaussian_moment_descriptor import GaussianMomentDescriptor
from apax.layers.empirical import EmpiricalEnergyTerm
from apax.layers.masking import mask_by_atom
from apax.layers.properties import stress_times_vol
from apax.layers.readout import AtomisticReadout
from apax.layers.scaling import PerElementScaleShift
from apax.nn.jax.layers.descriptor.gaussian_moment_descriptor import (
GaussianMomentDescriptor,
)
from apax.nn.jax.layers.empirical import EmpiricalEnergyTerm
from apax.nn.jax.layers.masking import mask_by_atom
from apax.nn.jax.layers.properties import stress_times_vol
from apax.nn.jax.layers.readout import AtomisticReadout
from apax.nn.jax.layers.scaling import PerElementScaleShift
from apax.utils.jax_md_reduced import partition, space
from apax.utils.math import fp64_sum

Expand Down
4 changes: 3 additions & 1 deletion apax/nn/torch/layers/descriptor/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from apax.nn.torch.layers.descriptor.gaussian_moment_descriptor import GaussianMomentDescriptor
from apax.nn.torch.layers.descriptor.gaussian_moment_descriptor import (
GaussianMomentDescriptor,
)

__all__ = ["GaussianMomentDescriptor"]
26 changes: 21 additions & 5 deletions apax/nn/torch/layers/descriptor/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,13 @@


class GaussianBasis(nn.Module):
def __init__(self, n_basis: int = 7, r_min: float = 0.5, r_max: float = 6.0, dtype: Any = torch.float32) -> None:
def __init__(
self,
n_basis: int = 7,
r_min: float = 0.5,
r_max: float = 6.0,
dtype: Any = torch.float32,
) -> None:
super().__init__()
self.n_basis = n_basis
self.r_min = r_min
Expand All @@ -28,12 +34,20 @@ def __init__(self, n_basis: int = 7, r_min: float = 0.5, r_max: float = 6.0, dty
self.shifts = torch.tensor(shifts, dtype=self.dtype)

def forward(self, dr: torch.Tensor) -> torch.Tensor:
basis = gaussian_basis_impl(dr.type(self.dtype), self.shifts, self.betta, self.rad_norm)
basis = gaussian_basis_impl(
dr.type(self.dtype), self.shifts, self.betta, self.rad_norm
)
return basis


class RadialFunction(nn.Module):
def __init__(self, n_radial: int = 5, basis_fn: nn.Module = GaussianBasis(), n_species: int = 119, dtype: Any = torch.float32) -> None:
def __init__(
self,
n_radial: int = 5,
basis_fn: nn.Module = GaussianBasis(),
n_species: int = 119,
dtype: Any = torch.float32,
) -> None:
super().__init__()
self.n_radial = n_radial
self.basis_fn = basis_fn
Expand All @@ -42,7 +56,7 @@ def __init__(self, n_radial: int = 5, basis_fn: nn.Module = GaussianBasis(), n_s

self.r_max = self.basis_fn.r_max
norm = 1.0 / np.sqrt(self.basis_fn.n_basis)
self.embed_norm = torch.Tensor(norm , dtype=self.dtype)
self.embed_norm = torch.Tensor(norm, dtype=self.dtype)
self.embeddings = None
if self.emb_init is not None:
self.embeddings = nn.Parameter()
Expand All @@ -54,7 +68,9 @@ def forward(self, dr, Z_i, Z_j):
# basis shape: neighbors x n_basis
basis = self.basis_fn(dr)

radial_function = radial_basis_impl(basis, Z_i, Z_j, self.embeddings, self.embed_norm)
radial_function = radial_basis_impl(
basis, Z_i, Z_j, self.embeddings, self.embed_norm
)
cutoff = cosine_cutoff(dr, self.r_max)
radial_function = radial_function * cutoff

Expand Down
16 changes: 12 additions & 4 deletions apax/nn/torch/layers/descriptor/gaussian_moment_descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,12 @@ def distance(dR):


class GaussianMomentDescriptor(nn.Module):

def __init__(self, radial_fn: nn.Module = RadialFunction(), n_contr: int = 8, dtype: Any = torch.float32):
def __init__(
self,
radial_fn: nn.Module = RadialFunction(),
n_contr: int = 8,
dtype: Any = torch.float32,
):
super().__init__()
self.radial_fn = radial_fn
self.n_contr = n_contr
Expand All @@ -32,7 +36,9 @@ def __init__(self, radial_fn: nn.Module = RadialFunction(), n_contr: int = 8, dt
self.triang_idxs_2d = torch.tensor(tril_2d_indices(self.n_radial))
self.triang_idxs_3d = torch.tensor(tril_3d_indices(self.n_radial))

def forward(self, dr_vec: torch.Tensor, Z: torch.Tensor, neighbor_idxs: torch.Tensor) -> torch.Tensor:
def forward(
self, dr_vec: torch.Tensor, Z: torch.Tensor, neighbor_idxs: torch.Tensor
) -> torch.Tensor:
dr_vec = dr_vec.type(self.dtype)
# Z shape n_atoms
n_atoms = Z.shape[0]
Expand All @@ -53,7 +59,9 @@ def forward(self, dr_vec: torch.Tensor, Z: torch.Tensor, neighbor_idxs: torch.Te
radial_function = self.radial_fn(dr, Z_i, Z_j)

moments = geometric_moments(radial_function, dn, idx_j, n_atoms)
gaussian_moments = gaussian_moment_impl(moments, self.triang_idxs_2d, self.triang_idxs_3d, self.n_contr)
gaussian_moments = gaussian_moment_impl(
moments, self.triang_idxs_2d, self.triang_idxs_3d, self.n_contr
)

# # gaussian_moments shape: n_atoms x n_features
# gaussian_moments = jnp.concatenate(gaussian_moments[: self.n_contr], axis=-1)
Expand Down
3 changes: 1 addition & 2 deletions apax/nn/torch/layers/ntk_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class NTKLinear(nn.Module):

def __init__(self, units) -> None:
super().__init__()

Expand All @@ -17,4 +16,4 @@ def __init__(self, units) -> None:
def forward(self, x: torch.Tensor) -> torch.Tensor:
weight_factor = torch.sqrt(1.0 / x.shape[0])
out = F.linear(x, weight_factor * self.w, self.bias_factor * self.b)
return out
return out
7 changes: 5 additions & 2 deletions apax/nn/torch/layers/readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
from apax.nn.impl.activation import swish
from apax.nn.torch.layers.ntk_linear import NTKLinear


class AtomisticReadout(nn.Module):
def __init__(self, units: List[int]= [512, 512], activation_fn: Callable = swish) -> None:
super().__init__()
def __init__(
self, units: List[int] = [512, 512], activation_fn: Callable = swish
) -> None:
super().__init__()

units = [u for u in self.units] + [1]
dense = []
Expand Down
10 changes: 7 additions & 3 deletions apax/nn/torch/layers/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,20 @@


class PerElementScaleShift(nn.Module):

def __init__(self, scale: Union[torch.Tensor, float] = 1.0, shift : Union[torch.Tensor, float] = 0.0, n_species: int = 119, dtype=torch.float32) -> None:
def __init__(
self,
scale: Union[torch.Tensor, float] = 1.0,
shift: Union[torch.Tensor, float] = 0.0,
n_species: int = 119,
dtype=torch.float32,
) -> None:
super().__init__()
self.n_species = n_species

self.scale_param = nn.Parameter(scale)
self.shift_param = nn.Parameter(shift)
self.dtype = dtype


def forward(self, x: torch.Tensor, Z: torch.Tensor) -> torch.Tensor:
# x shape: n_atoms x 1
# Z shape: n_atoms
Expand Down
4 changes: 3 additions & 1 deletion apax/ops.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import jax
import torch


def swish(x):
if isinstance(x, jax.Array):
return jax.nn.swish(x)
elif isinstance(x, torch.Tensor):
return torch.nn.functional.silu(x)


def exp(x):
if isinstance(x, jax.Array):
Expand Down Expand Up @@ -56,6 +57,7 @@ def sum(x, axis, keepdims=False):
elif isinstance(x, torch.Tensor):
return torch.sum(x, dim=axis, keepdim=keepdims)


def concatenate(x, axis):
if isinstance(x, jax.Array):
return jax.numpy.concatenate(x, axis=axis)
Expand Down
10 changes: 6 additions & 4 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,12 @@ def fit(
epoch_loss["val_loss"] /= val_steps_per_epoch
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])

epoch_metrics.update({
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
})
epoch_metrics.update(
{
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
}
)

epoch_metrics.update({**epoch_loss})

Expand Down

0 comments on commit 5138587

Please sign in to comment.