Skip to content

Commit

Permalink
first working torch model
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Apr 9, 2024
1 parent 06b965b commit 3a634fa
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 20 deletions.
6 changes: 5 additions & 1 deletion apax/nn/torch/layers/descriptor/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,12 @@ def __init__(
self.embed_norm = torch.tensor(norm, dtype=self.dtype)
self.embeddings = None
if self.emb_init is not None:
self.embeddings = nn.Parameter()
self.n_radial = n_radial
emb = torch.rand((self.n_species,
self.n_species,
self.n_radial,
self.basis_fn.n_basis))
self.embeddings = nn.Parameter(emb)
else:
self.n_radial = self.basis_fn.n_basis

Expand Down
9 changes: 5 additions & 4 deletions apax/nn/torch/layers/ntk_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,17 @@


class NTKLinear(nn.Module):
def __init__(self, units) -> None:
def __init__(self, units_in, units_out) -> None:
super().__init__()

self.bias_factor = 0.1
# self.weight_factor = torch.sqrt(1.0 / dim_in)

self.w = nn.Parameter()
self.b = nn.Parameter()
self.w = nn.Parameter(torch.rand((units_out, units_in)))
self.b = nn.Parameter(torch.rand((units_out)))
self.one = torch.tensor(1.0)

def forward(self, x: torch.Tensor) -> torch.Tensor:
weight_factor = torch.sqrt(1.0 / x.shape[0])
weight_factor = torch.sqrt(self.one / x.size(0))
out = F.linear(x, weight_factor * self.w, self.bias_factor * self.b)
return out
9 changes: 5 additions & 4 deletions apax/nn/torch/layers/readout.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@ def __init__(
) -> None:
super().__init__()

units = [u for u in units] + [1]
units = [360] + [u for u in units] + [1]
dense = []
for ii, n_hidden in enumerate(units):
dense.append(NTKLinear(n_hidden))
if ii < len(units) - 1:
for ii in range(len(units)-1):
units_in, units_out = units[ii], units[ii+1]
dense.append(NTKLinear(units_in, units_out))
if ii < len(units) - 2:
dense.append(activation_fn())
self.sequential = nn.Sequential(*dense)

Expand Down
8 changes: 6 additions & 2 deletions apax/nn/torch/layers/scaling.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn as nn
import numpy as np

from typing import Any, Union

Expand All @@ -15,8 +16,11 @@ def __init__(
super().__init__()
self.n_species = n_species

self.scale_param = nn.Parameter(torch.tensor(scale))
self.shift_param = nn.Parameter(torch.tensor(shift))
scale = np.repeat(scale, n_species)
shift = np.repeat(shift, n_species)

self.scale_param = nn.Parameter(torch.from_numpy(scale))
self.shift_param = nn.Parameter(torch.from_numpy(shift))
self.dtype = dtype

def forward(self, x: torch.Tensor, Z: torch.Tensor) -> torch.Tensor:
Expand Down
17 changes: 11 additions & 6 deletions apax/nn/torch/model/gmnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
):
super().__init__()
self.descriptor = descriptor
self.readout = torch.vmap(readout)
self.readout = readout # readout??
self.scale_shift = scale_shift

def forward(
Expand All @@ -30,13 +30,15 @@ def forward(
idx: torch.tensor,
) -> torch.tensor:
gm = self.descriptor(dr_vec, Z, idx)
h = self.readout(gm)
# print(gm.size())
h = self.readout(gm).squeeze()
# print(h.size())
output = self.scale_shift(h, Z)

return output


def free_displacement(Ri, Rj, box, perturbation):
def free_displacement(Ri, Rj):
return Ri - Rj


Expand Down Expand Up @@ -125,11 +127,11 @@ def forward(
box: torch.Tensor,
offsets: torch.Tensor,
):
R.requires_grad = True
R.requires_grad_(True)
requires_grad = [R]
if self.calc_stress:
eps = torch.zeros((3, 3), torch.float64)
eps.requires_grad = True
eps.requires_grad_(True)
eps_sym = 0.5 * (eps + eps.T)
identity = torch.eye(3, dtype=torch.float64)
perturbation = identity + eps_sym
Expand All @@ -138,10 +140,13 @@ def forward(
perturbation = None

energy = self.energy_model(R, Z, neighbor, box, offsets, perturbation)
# print(energy)
# quit()

grads = autograd.grad(
energy, requires_grad, grad_outputs=torch.ones_like(energy), create_graph=True
energy, requires_grad, grad_outputs=torch.ones_like(energy), create_graph=False, retain_graph=False, allow_unused=True,
)
print(grads)

neg_forces = grads[0]
forces = -neg_forces
Expand Down
8 changes: 6 additions & 2 deletions apax/ops.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import jax
import torch
import torch_scatter


def swish(x):
Expand Down Expand Up @@ -80,15 +81,18 @@ def cast(x, dtype):


def einsum(pattern, *operands, **kwargs):
# print([type(o) for o in operands])
# quit()
if isinstance(operands[0], jax.Array):
return jax.numpy.einsum(pattern, *operands, **kwargs)
elif isinstance(operands[0], torch.Tensor):
# opt_einsum_fx
return torch.einsum(pattern, *operands, **kwargs)


def segment_sum(x, segment_ids, num_segments=None):
if isinstance(x, jax.Array):
return jax.ops.segment_sum(x, segment_ids, num_segments)
elif isinstance(x, torch.Tensor):
# TODO pytorch scatter
return None
out = torch_scatter.scatter(x, segment_ids, dim=0, reduce="sum")
return out
2 changes: 1 addition & 1 deletion apax/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from apax.data.initialization import load_data_files
from apax.data.input_pipeline import dataset_dict
from apax.data.statistics import compute_scale_shift_parameters
from apax.model import ModelBuilder
from apax.nn.jax.model import ModelBuilder
from apax.optimizer import get_opt
from apax.train.callbacks import initialize_callbacks
from apax.train.checkpoints import create_params, create_train_state
Expand Down

0 comments on commit 3a634fa

Please sign in to comment.