From 513858743b1d02404e067327030aa318feb9795a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 3 Apr 2024 10:39:55 +0200 Subject: [PATCH] linting --- apax/nn/impl/activation.py | 3 ++- apax/nn/impl/basis.py | 3 ++- apax/nn/impl/gaussian_moment_descriptor.py | 3 +-- apax/nn/jax/layers/descriptor/__init__.py | 4 ++- apax/nn/jax/layers/descriptor/basis.py | 4 ++- .../descriptor/gaussian_moment_descriptor.py | 5 ++-- apax/nn/jax/layers/readout.py | 4 +-- apax/nn/jax/model/gmnn.py | 14 +++++----- apax/nn/torch/layers/descriptor/__init__.py | 4 ++- apax/nn/torch/layers/descriptor/basis.py | 26 +++++++++++++++---- .../descriptor/gaussian_moment_descriptor.py | 16 +++++++++--- apax/nn/torch/layers/ntk_linear.py | 3 +-- apax/nn/torch/layers/readout.py | 7 +++-- apax/nn/torch/layers/scaling.py | 10 ++++--- apax/ops.py | 4 ++- apax/train/trainer.py | 10 ++++--- 16 files changed, 82 insertions(+), 38 deletions(-) diff --git a/apax/nn/impl/activation.py b/apax/nn/impl/activation.py index 6ea7fa17..fac989c8 100644 --- a/apax/nn/impl/activation.py +++ b/apax/nn/impl/activation.py @@ -1,9 +1,10 @@ from apax import ops + def swish(x): out = 1.6765324703310907 * ops.swish(x) return out def inverse_softplus(x): - return ops.log(ops.exp(x) - 1.0) \ No newline at end of file + return ops.log(ops.exp(x) - 1.0) diff --git a/apax/nn/impl/basis.py b/apax/nn/impl/basis.py index efb05129..4645ff60 100644 --- a/apax/nn/impl/basis.py +++ b/apax/nn/impl/basis.py @@ -2,6 +2,7 @@ import numpy as np from apax import ops + def gaussian_basis_impl(dr, shifts, betta, rad_norm): dr = einops.repeat(dr, "neighbors -> neighbors 1") # 1 x n_basis, neighbors x 1 -> neighbors x n_basis @@ -34,4 +35,4 @@ def radial_basis_impl(basis, Z_i, Z_j, embeddings, embed_norm): radial_function = einops.einsum( species_pair_coeffs, basis, "nbrs radial basis, nbrs basis -> nbrs radial" ) - return radial_function \ No newline at end of file + return radial_function diff --git a/apax/nn/impl/gaussian_moment_descriptor.py b/apax/nn/impl/gaussian_moment_descriptor.py index ea71de3d..5135caf8 100644 --- a/apax/nn/impl/gaussian_moment_descriptor.py +++ b/apax/nn/impl/gaussian_moment_descriptor.py @@ -4,7 +4,6 @@ def gaussian_moment_impl(moments, triang_idxs_2d, triang_idxs_3d, n_contr): - contr_0 = moments[0] contr_1 = ops.einsum("ari, asi -> rsa", moments[1], moments[1]) contr_2 = ops.einsum("arij, asij -> rsa", moments[2], moments[2]) @@ -49,5 +48,5 @@ def gaussian_moment_impl(moments, triang_idxs_2d, triang_idxs_3d, n_contr): ] # gaussian_moments shape: n_atoms x n_features - gaussian_moments = ops.concatenate(gaussian_moments[: n_contr], axis=-1) + gaussian_moments = ops.concatenate(gaussian_moments[:n_contr], axis=-1) return gaussian_moments diff --git a/apax/nn/jax/layers/descriptor/__init__.py b/apax/nn/jax/layers/descriptor/__init__.py index c04eb1a8..94b6cdd9 100644 --- a/apax/nn/jax/layers/descriptor/__init__.py +++ b/apax/nn/jax/layers/descriptor/__init__.py @@ -1,3 +1,5 @@ -from apax.nn.jax.layers.descriptor.gaussian_moment_descriptor import GaussianMomentDescriptor +from apax.nn.jax.layers.descriptor.gaussian_moment_descriptor import ( + GaussianMomentDescriptor, +) __all__ = ["GaussianMomentDescriptor"] diff --git a/apax/nn/jax/layers/descriptor/basis.py b/apax/nn/jax/layers/descriptor/basis.py index fa92a43b..b5122f6e 100644 --- a/apax/nn/jax/layers/descriptor/basis.py +++ b/apax/nn/jax/layers/descriptor/basis.py @@ -72,7 +72,9 @@ def __call__(self, dr, Z_i, Z_j): # basis shape: neighbors x n_basis basis = self.basis_fn(dr) - radial_function = radial_basis_impl(basis, Z_i, Z_j, self.embeddings, self.embed_norm) + radial_function = radial_basis_impl( + basis, Z_i, Z_j, self.embeddings, self.embed_norm + ) cutoff = cosine_cutoff(dr, self.r_max) radial_function = radial_function * cutoff diff --git a/apax/nn/jax/layers/descriptor/gaussian_moment_descriptor.py b/apax/nn/jax/layers/descriptor/gaussian_moment_descriptor.py index 85c97074..b7d0bf72 100644 --- a/apax/nn/jax/layers/descriptor/gaussian_moment_descriptor.py +++ b/apax/nn/jax/layers/descriptor/gaussian_moment_descriptor.py @@ -13,7 +13,6 @@ from apax.utils.jax_md_reduced import space - class GaussianMomentDescriptor(nn.Module): radial_fn: nn.Module = RadialFunction() n_contr: int = 8 @@ -52,7 +51,9 @@ def __call__(self, dr_vec, Z, neighbor_idxs): radial_function = mask_by_neighbor(radial_function, neighbor_idxs) moments = geometric_moments(radial_function, dn, idx_j, n_atoms) - gaussian_moments = gaussian_moment_impl(moments, self.triang_idxs_2d, self.triang_idxs_3d, self.n_contr) + gaussian_moments = gaussian_moment_impl( + moments, self.triang_idxs_2d, self.triang_idxs_3d, self.n_contr + ) # # gaussian_moments shape: n_atoms x n_features # gaussian_moments = jnp.concatenate(gaussian_moments[: self.n_contr], axis=-1) diff --git a/apax/nn/jax/layers/readout.py b/apax/nn/jax/layers/readout.py index 796a4cc7..b162cb77 100644 --- a/apax/nn/jax/layers/readout.py +++ b/apax/nn/jax/layers/readout.py @@ -19,8 +19,8 @@ def setup(self): dense = [] for ii, n_hidden in enumerate(units): layer = NTKLinear( - n_hidden, b_init=self.b_init, dtype=self.dtype, name=f"dense_{ii}" - ) + n_hidden, b_init=self.b_init, dtype=self.dtype, name=f"dense_{ii}" + ) dense.append(layer) if ii < len(units) - 1: dense.append(self.activation_fn) diff --git a/apax/nn/jax/model/gmnn.py b/apax/nn/jax/model/gmnn.py index 70c057fe..0ea66c30 100644 --- a/apax/nn/jax/model/gmnn.py +++ b/apax/nn/jax/model/gmnn.py @@ -8,12 +8,14 @@ import numpy as np from jax import Array, vmap -from apax.layers.descriptor.gaussian_moment_descriptor import GaussianMomentDescriptor -from apax.layers.empirical import EmpiricalEnergyTerm -from apax.layers.masking import mask_by_atom -from apax.layers.properties import stress_times_vol -from apax.layers.readout import AtomisticReadout -from apax.layers.scaling import PerElementScaleShift +from apax.nn.jax.layers.descriptor.gaussian_moment_descriptor import ( + GaussianMomentDescriptor, +) +from apax.nn.jax.layers.empirical import EmpiricalEnergyTerm +from apax.nn.jax.layers.masking import mask_by_atom +from apax.nn.jax.layers.properties import stress_times_vol +from apax.nn.jax.layers.readout import AtomisticReadout +from apax.nn.jax.layers.scaling import PerElementScaleShift from apax.utils.jax_md_reduced import partition, space from apax.utils.math import fp64_sum diff --git a/apax/nn/torch/layers/descriptor/__init__.py b/apax/nn/torch/layers/descriptor/__init__.py index 41c7c1fd..599e133c 100644 --- a/apax/nn/torch/layers/descriptor/__init__.py +++ b/apax/nn/torch/layers/descriptor/__init__.py @@ -1,3 +1,5 @@ -from apax.nn.torch.layers.descriptor.gaussian_moment_descriptor import GaussianMomentDescriptor +from apax.nn.torch.layers.descriptor.gaussian_moment_descriptor import ( + GaussianMomentDescriptor, +) __all__ = ["GaussianMomentDescriptor"] diff --git a/apax/nn/torch/layers/descriptor/basis.py b/apax/nn/torch/layers/descriptor/basis.py index 47ec012e..1d2f72f0 100644 --- a/apax/nn/torch/layers/descriptor/basis.py +++ b/apax/nn/torch/layers/descriptor/basis.py @@ -10,7 +10,13 @@ class GaussianBasis(nn.Module): - def __init__(self, n_basis: int = 7, r_min: float = 0.5, r_max: float = 6.0, dtype: Any = torch.float32) -> None: + def __init__( + self, + n_basis: int = 7, + r_min: float = 0.5, + r_max: float = 6.0, + dtype: Any = torch.float32, + ) -> None: super().__init__() self.n_basis = n_basis self.r_min = r_min @@ -28,12 +34,20 @@ def __init__(self, n_basis: int = 7, r_min: float = 0.5, r_max: float = 6.0, dty self.shifts = torch.tensor(shifts, dtype=self.dtype) def forward(self, dr: torch.Tensor) -> torch.Tensor: - basis = gaussian_basis_impl(dr.type(self.dtype), self.shifts, self.betta, self.rad_norm) + basis = gaussian_basis_impl( + dr.type(self.dtype), self.shifts, self.betta, self.rad_norm + ) return basis class RadialFunction(nn.Module): - def __init__(self, n_radial: int = 5, basis_fn: nn.Module = GaussianBasis(), n_species: int = 119, dtype: Any = torch.float32) -> None: + def __init__( + self, + n_radial: int = 5, + basis_fn: nn.Module = GaussianBasis(), + n_species: int = 119, + dtype: Any = torch.float32, + ) -> None: super().__init__() self.n_radial = n_radial self.basis_fn = basis_fn @@ -42,7 +56,7 @@ def __init__(self, n_radial: int = 5, basis_fn: nn.Module = GaussianBasis(), n_s self.r_max = self.basis_fn.r_max norm = 1.0 / np.sqrt(self.basis_fn.n_basis) - self.embed_norm = torch.Tensor(norm , dtype=self.dtype) + self.embed_norm = torch.Tensor(norm, dtype=self.dtype) self.embeddings = None if self.emb_init is not None: self.embeddings = nn.Parameter() @@ -54,7 +68,9 @@ def forward(self, dr, Z_i, Z_j): # basis shape: neighbors x n_basis basis = self.basis_fn(dr) - radial_function = radial_basis_impl(basis, Z_i, Z_j, self.embeddings, self.embed_norm) + radial_function = radial_basis_impl( + basis, Z_i, Z_j, self.embeddings, self.embed_norm + ) cutoff = cosine_cutoff(dr, self.r_max) radial_function = radial_function * cutoff diff --git a/apax/nn/torch/layers/descriptor/gaussian_moment_descriptor.py b/apax/nn/torch/layers/descriptor/gaussian_moment_descriptor.py index 84226920..010a658f 100644 --- a/apax/nn/torch/layers/descriptor/gaussian_moment_descriptor.py +++ b/apax/nn/torch/layers/descriptor/gaussian_moment_descriptor.py @@ -17,8 +17,12 @@ def distance(dR): class GaussianMomentDescriptor(nn.Module): - - def __init__(self, radial_fn: nn.Module = RadialFunction(), n_contr: int = 8, dtype: Any = torch.float32): + def __init__( + self, + radial_fn: nn.Module = RadialFunction(), + n_contr: int = 8, + dtype: Any = torch.float32, + ): super().__init__() self.radial_fn = radial_fn self.n_contr = n_contr @@ -32,7 +36,9 @@ def __init__(self, radial_fn: nn.Module = RadialFunction(), n_contr: int = 8, dt self.triang_idxs_2d = torch.tensor(tril_2d_indices(self.n_radial)) self.triang_idxs_3d = torch.tensor(tril_3d_indices(self.n_radial)) - def forward(self, dr_vec: torch.Tensor, Z: torch.Tensor, neighbor_idxs: torch.Tensor) -> torch.Tensor: + def forward( + self, dr_vec: torch.Tensor, Z: torch.Tensor, neighbor_idxs: torch.Tensor + ) -> torch.Tensor: dr_vec = dr_vec.type(self.dtype) # Z shape n_atoms n_atoms = Z.shape[0] @@ -53,7 +59,9 @@ def forward(self, dr_vec: torch.Tensor, Z: torch.Tensor, neighbor_idxs: torch.Te radial_function = self.radial_fn(dr, Z_i, Z_j) moments = geometric_moments(radial_function, dn, idx_j, n_atoms) - gaussian_moments = gaussian_moment_impl(moments, self.triang_idxs_2d, self.triang_idxs_3d, self.n_contr) + gaussian_moments = gaussian_moment_impl( + moments, self.triang_idxs_2d, self.triang_idxs_3d, self.n_contr + ) # # gaussian_moments shape: n_atoms x n_features # gaussian_moments = jnp.concatenate(gaussian_moments[: self.n_contr], axis=-1) diff --git a/apax/nn/torch/layers/ntk_linear.py b/apax/nn/torch/layers/ntk_linear.py index b4ff53c3..64e57521 100644 --- a/apax/nn/torch/layers/ntk_linear.py +++ b/apax/nn/torch/layers/ntk_linear.py @@ -4,7 +4,6 @@ class NTKLinear(nn.Module): - def __init__(self, units) -> None: super().__init__() @@ -17,4 +16,4 @@ def __init__(self, units) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: weight_factor = torch.sqrt(1.0 / x.shape[0]) out = F.linear(x, weight_factor * self.w, self.bias_factor * self.b) - return out \ No newline at end of file + return out diff --git a/apax/nn/torch/layers/readout.py b/apax/nn/torch/layers/readout.py index 2482213f..63dbbef3 100644 --- a/apax/nn/torch/layers/readout.py +++ b/apax/nn/torch/layers/readout.py @@ -6,9 +6,12 @@ from apax.nn.impl.activation import swish from apax.nn.torch.layers.ntk_linear import NTKLinear + class AtomisticReadout(nn.Module): - def __init__(self, units: List[int]= [512, 512], activation_fn: Callable = swish) -> None: - super().__init__() + def __init__( + self, units: List[int] = [512, 512], activation_fn: Callable = swish + ) -> None: + super().__init__() units = [u for u in self.units] + [1] dense = [] diff --git a/apax/nn/torch/layers/scaling.py b/apax/nn/torch/layers/scaling.py index 4e622dd9..9490708b 100644 --- a/apax/nn/torch/layers/scaling.py +++ b/apax/nn/torch/layers/scaling.py @@ -5,8 +5,13 @@ class PerElementScaleShift(nn.Module): - - def __init__(self, scale: Union[torch.Tensor, float] = 1.0, shift : Union[torch.Tensor, float] = 0.0, n_species: int = 119, dtype=torch.float32) -> None: + def __init__( + self, + scale: Union[torch.Tensor, float] = 1.0, + shift: Union[torch.Tensor, float] = 0.0, + n_species: int = 119, + dtype=torch.float32, + ) -> None: super().__init__() self.n_species = n_species @@ -14,7 +19,6 @@ def __init__(self, scale: Union[torch.Tensor, float] = 1.0, shift : Union[torch. self.shift_param = nn.Parameter(shift) self.dtype = dtype - def forward(self, x: torch.Tensor, Z: torch.Tensor) -> torch.Tensor: # x shape: n_atoms x 1 # Z shape: n_atoms diff --git a/apax/ops.py b/apax/ops.py index fcdbe129..ea5f01e4 100644 --- a/apax/ops.py +++ b/apax/ops.py @@ -1,12 +1,13 @@ import jax import torch + def swish(x): if isinstance(x, jax.Array): return jax.nn.swish(x) elif isinstance(x, torch.Tensor): return torch.nn.functional.silu(x) - + def exp(x): if isinstance(x, jax.Array): @@ -56,6 +57,7 @@ def sum(x, axis, keepdims=False): elif isinstance(x, torch.Tensor): return torch.sum(x, dim=axis, keepdim=keepdims) + def concatenate(x, axis): if isinstance(x, jax.Array): return jax.numpy.concatenate(x, axis=axis) diff --git a/apax/train/trainer.py b/apax/train/trainer.py index f0e99ef5..5f38a3cd 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -107,10 +107,12 @@ def fit( epoch_loss["val_loss"] /= val_steps_per_epoch epoch_loss["val_loss"] = float(epoch_loss["val_loss"]) - epoch_metrics.update({ - f"val_{key}": float(val) - for key, val in val_batch_metrics.compute().items() - }) + epoch_metrics.update( + { + f"val_{key}": float(val) + for key, val in val_batch_metrics.compute().items() + } + ) epoch_metrics.update({**epoch_loss})