diff --git a/apax/nn/torch/layers/descriptor/basis.py b/apax/nn/torch/layers/descriptor/basis.py index b33f8dbc..33dc316a 100644 --- a/apax/nn/torch/layers/descriptor/basis.py +++ b/apax/nn/torch/layers/descriptor/basis.py @@ -61,8 +61,12 @@ def __init__( self.embed_norm = torch.tensor(norm, dtype=self.dtype) self.embeddings = None if self.emb_init is not None: - self.embeddings = nn.Parameter() self.n_radial = n_radial + emb = torch.rand((self.n_species, + self.n_species, + self.n_radial, + self.basis_fn.n_basis)) + self.embeddings = nn.Parameter(emb) else: self.n_radial = self.basis_fn.n_basis diff --git a/apax/nn/torch/layers/ntk_linear.py b/apax/nn/torch/layers/ntk_linear.py index 64e57521..d36acc3c 100644 --- a/apax/nn/torch/layers/ntk_linear.py +++ b/apax/nn/torch/layers/ntk_linear.py @@ -4,16 +4,17 @@ class NTKLinear(nn.Module): - def __init__(self, units) -> None: + def __init__(self, units_in, units_out) -> None: super().__init__() self.bias_factor = 0.1 # self.weight_factor = torch.sqrt(1.0 / dim_in) - self.w = nn.Parameter() - self.b = nn.Parameter() + self.w = nn.Parameter(torch.rand((units_out, units_in))) + self.b = nn.Parameter(torch.rand((units_out))) + self.one = torch.tensor(1.0) def forward(self, x: torch.Tensor) -> torch.Tensor: - weight_factor = torch.sqrt(1.0 / x.shape[0]) + weight_factor = torch.sqrt(self.one / x.size(0)) out = F.linear(x, weight_factor * self.w, self.bias_factor * self.b) return out diff --git a/apax/nn/torch/layers/readout.py b/apax/nn/torch/layers/readout.py index a798f17d..7ea0ba46 100644 --- a/apax/nn/torch/layers/readout.py +++ b/apax/nn/torch/layers/readout.py @@ -13,11 +13,12 @@ def __init__( ) -> None: super().__init__() - units = [u for u in units] + [1] + units = [360] + [u for u in units] + [1] dense = [] - for ii, n_hidden in enumerate(units): - dense.append(NTKLinear(n_hidden)) - if ii < len(units) - 1: + for ii in range(len(units)-1): + units_in, units_out = units[ii], units[ii+1] + dense.append(NTKLinear(units_in, units_out)) + if ii < len(units) - 2: dense.append(activation_fn()) self.sequential = nn.Sequential(*dense) diff --git a/apax/nn/torch/layers/scaling.py b/apax/nn/torch/layers/scaling.py index 869f8b7d..03591274 100644 --- a/apax/nn/torch/layers/scaling.py +++ b/apax/nn/torch/layers/scaling.py @@ -1,5 +1,6 @@ import torch import torch.nn as nn +import numpy as np from typing import Any, Union @@ -15,8 +16,11 @@ def __init__( super().__init__() self.n_species = n_species - self.scale_param = nn.Parameter(torch.tensor(scale)) - self.shift_param = nn.Parameter(torch.tensor(shift)) + scale = np.repeat(scale, n_species) + shift = np.repeat(shift, n_species) + + self.scale_param = nn.Parameter(torch.from_numpy(scale)) + self.shift_param = nn.Parameter(torch.from_numpy(shift)) self.dtype = dtype def forward(self, x: torch.Tensor, Z: torch.Tensor) -> torch.Tensor: diff --git a/apax/nn/torch/model/gmnn.py b/apax/nn/torch/model/gmnn.py index c845d95e..790c5fcc 100644 --- a/apax/nn/torch/model/gmnn.py +++ b/apax/nn/torch/model/gmnn.py @@ -20,7 +20,7 @@ def __init__( ): super().__init__() self.descriptor = descriptor - self.readout = torch.vmap(readout) + self.readout = readout # readout?? self.scale_shift = scale_shift def forward( @@ -30,13 +30,15 @@ def forward( idx: torch.tensor, ) -> torch.tensor: gm = self.descriptor(dr_vec, Z, idx) - h = self.readout(gm) + # print(gm.size()) + h = self.readout(gm).squeeze() + # print(h.size()) output = self.scale_shift(h, Z) return output -def free_displacement(Ri, Rj, box, perturbation): +def free_displacement(Ri, Rj): return Ri - Rj @@ -125,11 +127,11 @@ def forward( box: torch.Tensor, offsets: torch.Tensor, ): - R.requires_grad = True + R.requires_grad_(True) requires_grad = [R] if self.calc_stress: eps = torch.zeros((3, 3), torch.float64) - eps.requires_grad = True + eps.requires_grad_(True) eps_sym = 0.5 * (eps + eps.T) identity = torch.eye(3, dtype=torch.float64) perturbation = identity + eps_sym @@ -138,10 +140,13 @@ def forward( perturbation = None energy = self.energy_model(R, Z, neighbor, box, offsets, perturbation) + # print(energy) + # quit() grads = autograd.grad( - energy, requires_grad, grad_outputs=torch.ones_like(energy), create_graph=True + energy, requires_grad, grad_outputs=torch.ones_like(energy), create_graph=False, retain_graph=False, allow_unused=True, ) + print(grads) neg_forces = grads[0] forces = -neg_forces diff --git a/apax/ops.py b/apax/ops.py index 2ce7c747..0f0b8118 100644 --- a/apax/ops.py +++ b/apax/ops.py @@ -1,5 +1,6 @@ import jax import torch +import torch_scatter def swish(x): @@ -80,9 +81,12 @@ def cast(x, dtype): def einsum(pattern, *operands, **kwargs): + # print([type(o) for o in operands]) + # quit() if isinstance(operands[0], jax.Array): return jax.numpy.einsum(pattern, *operands, **kwargs) elif isinstance(operands[0], torch.Tensor): + # opt_einsum_fx return torch.einsum(pattern, *operands, **kwargs) @@ -90,5 +94,5 @@ def segment_sum(x, segment_ids, num_segments=None): if isinstance(x, jax.Array): return jax.ops.segment_sum(x, segment_ids, num_segments) elif isinstance(x, torch.Tensor): - # TODO pytorch scatter - return None + out = torch_scatter.scatter(x, segment_ids, dim=0, reduce="sum") + return out diff --git a/apax/train/run.py b/apax/train/run.py index cca77a95..6527adf0 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -8,7 +8,7 @@ from apax.data.initialization import load_data_files from apax.data.input_pipeline import dataset_dict from apax.data.statistics import compute_scale_shift_parameters -from apax.model import ModelBuilder +from apax.nn.jax.model import ModelBuilder from apax.optimizer import get_opt from apax.train.callbacks import initialize_callbacks from apax.train.checkpoints import create_params, create_train_state