Skip to content

Commit

Permalink
fixed various semantic errors
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Apr 8, 2024
1 parent 57f1b2e commit 06b965b
Show file tree
Hide file tree
Showing 7 changed files with 28 additions and 14 deletions.
4 changes: 2 additions & 2 deletions apax/nn/torch/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from apax.nn.torch.layers import activation, descriptor, ntk_linear, scaling
from apax.nn.torch.layers import descriptor, ntk_linear, scaling

__all__ = ["descriptor", "activation", "ntk_linear", "scaling"]
__all__ = ["descriptor", "ntk_linear", "scaling"]
13 changes: 13 additions & 0 deletions apax/nn/torch/layers/activation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import torch.nn as nn
from apax.nn.impl.activation import swish


class SwishT(nn.Module):
def __init__(
self
) -> None:
super().__init__()

def forward(self, x):
h = swish(x)
return h
3 changes: 2 additions & 1 deletion apax/nn/torch/layers/descriptor/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ def __init__(
self.embeddings = None
if self.emb_init is not None:
self.embeddings = nn.Parameter()
self.n_radial = n_radial
else:
self._n_radial = self.basis_fn.n_basis
self.n_radial = self.basis_fn.n_basis

def forward(self, dr, Z_i, Z_j):
dr = dr.type(self.dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
self.dtype = dtype

self.r_max = self.radial_fn.r_max
self.n_radial = self.radial_fn._n_radial
self.n_radial = self.radial_fn.n_radial

self.distance = torch.vmap(distance, 0, 0)

Expand Down
10 changes: 5 additions & 5 deletions apax/nn/torch/layers/readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,23 @@
import torch.nn.functional as F
from typing import Any, Callable, List

from apax.nn.impl.activation import swish
from apax.nn.torch.layers.activation import SwishT
from apax.nn.torch.layers.ntk_linear import NTKLinear


class AtomisticReadoutT(nn.Module):
def __init__(
self, units: List[int] = [512, 512], activation_fn: Callable = swish
self, units: List[int] = [512, 512], activation_fn: Callable = SwishT
) -> None:
super().__init__()

units = [u for u in self.units] + [1]
units = [u for u in units] + [1]
dense = []
for ii, n_hidden in enumerate(units):
dense.append(NTKLinear(n_hidden))
if ii < len(units) - 1:
dense.append(activation_fn)
self.sequential = nn.Sequential(dense)
dense.append(activation_fn())
self.sequential = nn.Sequential(*dense)

def forward(self, x):
h = self.sequential(x)
Expand Down
4 changes: 2 additions & 2 deletions apax/nn/torch/layers/scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def __init__(
super().__init__()
self.n_species = n_species

self.scale_param = nn.Parameter(scale)
self.shift_param = nn.Parameter(shift)
self.scale_param = nn.Parameter(torch.tensor(scale))
self.shift_param = nn.Parameter(torch.tensor(shift))
self.dtype = dtype

def forward(self, x: torch.Tensor, Z: torch.Tensor) -> torch.Tensor:
Expand Down
6 changes: 3 additions & 3 deletions apax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,10 @@ def sum(x, axis, keepdims=False):
return torch.sum(x, dim=axis, keepdim=keepdims)


def concatenate(x, axis):
if isinstance(x, jax.Array):
def concatenate(x: list, axis):
if isinstance(x[0], jax.Array):
return jax.numpy.concatenate(x, axis=axis)
elif isinstance(x, torch.Tensor):
elif isinstance(x[0], torch.Tensor):
return torch.concatenate(x, dim=axis)


Expand Down

0 comments on commit 06b965b

Please sign in to comment.