Skip to content

Commit

Permalink
implemented atomistic torch model
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Apr 3, 2024
1 parent a1fa924 commit 8938c7f
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions apax/nn/torch/model/gmnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F

from apax.nn.torch.layers.descriptor import GaussianMomentDescriptor
from apax.nn.torch.layers.readout import AtomisticReadout
from apax.nn.torch.layers.scaling import PerElementScaleShift


class AtomisticModel(nn.Module):
def __init__(
self,
descriptor: nn.Module = GaussianMomentDescriptor(),
readout: nn.Module = AtomisticReadout(),
scale_shift: nn.Module = PerElementScaleShift(),
):
super().__init__()
self.descriptor = descriptor
self.readout = torch.vmap(readout)
self.scale_shift = scale_shift

def forward(
self,
dr_vec: torch.tensor,
Z: torch.tensor,
idx: torch.tensor,
) -> torch.tensor:
gm = self.descriptor(dr_vec, Z, idx)
h = self.readout(gm)
output = self.scale_shift(h, Z)

return output

0 comments on commit 8938c7f

Please sign in to comment.