Skip to content

Commit

Permalink
fixed jax nl always allocating in cartesian coords and repeated alloc…
Browse files Browse the repository at this point in the history
…ations
  • Loading branch information
M-R-Schaefer committed Feb 1, 2024
1 parent 8cdc07d commit 9b4cab3
Showing 1 changed file with 12 additions and 14 deletions.
26 changes: 12 additions & 14 deletions apax/md/ase_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,14 +150,23 @@ def initialize(self, atoms):
self.neighbor_fn = neighbor_fn

if self.neigbor_from_jax:
positions = jnp.asarray(atoms.positions, dtype=jnp.float64)
self.neighbors = self.neighbor_fn.allocate(positions)
if np.any(atoms.get_cell().lengths() > 1e-6):
positions = jnp.asarray(atoms.positions, dtype=jnp.float64)
box = atoms.cell.array.T
inv_box = jnp.linalg.inv(box)
positions = space.transform(inv_box, positions) # frac coords
self.neighbors = self.neighbor_fn.allocate(positions, box=box)
else:
neighbor = neighbor.allocate(positions)
else:
idxs_i = neighbour_list("i", atoms, self.r_max)
self.padded_length = int(len(idxs_i) * self.padding_factor)

def set_neighbours_and_offsets(self, atoms, box):
idxs_i, idxs_j, offsets = neighbour_list("ijS", atoms, self.r_max)

if len(idxs_i) > self.padded_length:
print("neighbor list overflowed, reallocating.")
print("neighbor list overflowed, extending.")
self.padded_length = int(len(idxs_i) * self.padding_factor)
self.initialize(atoms)

Expand All @@ -178,12 +187,6 @@ def calculate(self, atoms, properties=["energy"], system_changes=all_changes):
if self.step is None:
self.initialize(atoms)

if self.neigbor_from_jax:
self.neighbors = self.neighbor_fn.allocate(positions)
else:
idxs_i = neighbour_list("i", atoms, self.r_max)
self.padded_length = int(len(idxs_i) * self.padding_factor)

elif "numbers" in system_changes:
self.initialize(atoms)

Expand All @@ -202,8 +205,6 @@ def calculate(self, atoms, properties=["energy"], system_changes=all_changes):
if self.neighbors.did_buffer_overflow:
print("neighbor list overflowed, reallocating.")
self.initialize(atoms)
self.neighbors = self.neighbor_fn.allocate(positions)

results, self.neighbors = self.step(positions, self.neighbors, box)

else:
Expand Down Expand Up @@ -263,14 +264,11 @@ def step_fn(positions, neighbor, box):
return results, neighbor

else:

@jax.jit
def step_fn(positions, neighbor, box, offsets):
results = model(positions, Z, neighbor, box, offsets)

if "stress" in results.keys():
results = process_stress(results, box)

return results

return step_fn

0 comments on commit 9b4cab3

Please sign in to comment.