Skip to content

Commit

Permalink
sketch of batch eval
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Jan 20, 2024
1 parent 93cb39d commit a37d4d1
Showing 1 changed file with 29 additions and 0 deletions.
29 changes: 29 additions & 0 deletions apax/md/ase_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from ase.calculators.calculator import Calculator, all_changes
from jax_md import partition, quantity, space
from matscipy.neighbours import neighbour_list
from tqdm import trange
from apax.data.initialization import RawDataset, initialize_dataset

from apax.model import ModelBuilder
from apax.train.checkpoints import check_for_ensemble, restore_parameters
Expand Down Expand Up @@ -127,6 +129,7 @@ def __init__(
self.neighbor_fn = None
self.neighbors = None
self.offsets = None
self.model = None

def initialize(self, atoms):
box = jnp.asarray(atoms.cell.array, dtype=jnp.float64)
Expand All @@ -146,6 +149,7 @@ def initialize(self, atoms):
for transformation in self.transformations:
model = transformation.apply(model, self.n_models)

self.model = model
self.step = get_step_fn(model, atoms, self.neigbor_from_jax)
self.neighbor_fn = neighbor_fn

Expand Down Expand Up @@ -215,6 +219,31 @@ def calculate(self, atoms, properties=["energy"], system_changes=all_changes):
self.results = {k: np.array(v, dtype=np.float64) for k, v in results.items()}
self.results["energy"] = self.results["energy"].item()

def batch_eval(self, data, batch_size=64, silent=False):
if self.model is None:
self.initialize(data[0])
dataset = initialize_dataset(self.model_config, RawDataset(atoms_list=data), calc_stats=False)
dataset.set_batch_size(batch_size)

features = []
n_data = dataset.n_data
ds = dataset.batch()
batched_model = jax.jit(jax.vmap(self.model,))

pbar = trange(n_data, desc="Computing features", ncols=100, leave=True, disable=silent)
for i, (inputs, _) in enumerate(ds):
results = batched_model(inputs)
unpadded_results = unpad_results(results, inputs)
for j in range(batch_size):
data[i].calc = SinglepointCalculator(atoms=data[i], results={})
pbar.update(batch_size)
pbar.close()







def neighbor_calculable_with_jax(box, r_max):
if np.all(box < 1e-6):
Expand Down

0 comments on commit a37d4d1

Please sign in to comment.