Skip to content

Commit

Permalink
sketch of energy model
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Apr 3, 2024
1 parent 5138587 commit 015aee8
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions apax/nn/torch/model/gmnn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Any
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -31,3 +33,64 @@ def forward(
output = self.scale_shift(h, Z)

return output


class EnergyModel(nn.Module):
def __init__(
self,
atomistic_model: AtomisticModel = AtomisticModel(),
# corrections: list[EmpiricalEnergyTerm] = field(default_factory=lambda: []),
init_box: np.array = np.array([0.0, 0.0, 0.0]),
inference_disp_fn: Any = None,
):
super().__init__()
self.atomistic_model = atomistic_model
# self.corrections = corrections
self.init_box = init_box
self.inference_disp_fn = inference_disp_fn

if np.all(self.init_box < 1e-6):
# gas phase training and predicting
displacement_fn = space.free()[0]
self.displacement = space.map_bond(displacement_fn)
elif self.inference_disp_fn is None:
# for training on periodic systems
self.displacement = vmap(disp_fn, (0, 0, None, None), 0)
else:
mappable_displacement_fn = get_disp_fn(self.inference_disp_fn)
self.displacement = vmap(mappable_displacement_fn, (0, 0, None, None), 0)

def forward(
self,
R: torch.Tensor,
Z: torch.Tensor,
idx: torch.Tensor,
box,
offsets,
perturbation=None,
):
# Distances
idx_i, idx_j = idx[0], idx[1]

# R shape n_atoms x 3
R = R.type(torch.float64)
Ri = R[idx_i]
Rj = R[idx_j]

# dr_vec shape: neighbors x 3
if np.all(self.init_box < 1e-6):
dr_vec = self.displacement(Rj, Ri)
else:
dr_vec = self.displacement(Rj, Ri, perturbation, box)
dr_vec += offsets

# Model Core
atomic_energies = self.atomistic_model(dr_vec, Z, idx)
total_energy = fp64_sum(atomic_energies)

# Corrections
# for correction in self.corrections:
# energy_correction = correction(dr_vec, Z, idx)
# total_energy = total_energy + energy_correction

return total_energy

0 comments on commit 015aee8

Please sign in to comment.