From 9b4cab3b9349562adf358afe6500313786cf70de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Thu, 1 Feb 2024 14:59:15 +0100 Subject: [PATCH] fixed jax nl always allocating in cartesian coords and repeated allocations --- apax/md/ase_calc.py | 26 ++++++++++++-------------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/apax/md/ase_calc.py b/apax/md/ase_calc.py index 6422b5a5..8d7d2a65 100644 --- a/apax/md/ase_calc.py +++ b/apax/md/ase_calc.py @@ -150,14 +150,23 @@ def initialize(self, atoms): self.neighbor_fn = neighbor_fn if self.neigbor_from_jax: - positions = jnp.asarray(atoms.positions, dtype=jnp.float64) - self.neighbors = self.neighbor_fn.allocate(positions) + if np.any(atoms.get_cell().lengths() > 1e-6): + positions = jnp.asarray(atoms.positions, dtype=jnp.float64) + box = atoms.cell.array.T + inv_box = jnp.linalg.inv(box) + positions = space.transform(inv_box, positions) # frac coords + self.neighbors = self.neighbor_fn.allocate(positions, box=box) + else: + neighbor = neighbor.allocate(positions) + else: + idxs_i = neighbour_list("i", atoms, self.r_max) + self.padded_length = int(len(idxs_i) * self.padding_factor) def set_neighbours_and_offsets(self, atoms, box): idxs_i, idxs_j, offsets = neighbour_list("ijS", atoms, self.r_max) if len(idxs_i) > self.padded_length: - print("neighbor list overflowed, reallocating.") + print("neighbor list overflowed, extending.") self.padded_length = int(len(idxs_i) * self.padding_factor) self.initialize(atoms) @@ -178,12 +187,6 @@ def calculate(self, atoms, properties=["energy"], system_changes=all_changes): if self.step is None: self.initialize(atoms) - if self.neigbor_from_jax: - self.neighbors = self.neighbor_fn.allocate(positions) - else: - idxs_i = neighbour_list("i", atoms, self.r_max) - self.padded_length = int(len(idxs_i) * self.padding_factor) - elif "numbers" in system_changes: self.initialize(atoms) @@ -202,8 +205,6 @@ def calculate(self, atoms, properties=["energy"], system_changes=all_changes): if self.neighbors.did_buffer_overflow: print("neighbor list overflowed, reallocating.") self.initialize(atoms) - self.neighbors = self.neighbor_fn.allocate(positions) - results, self.neighbors = self.step(positions, self.neighbors, box) else: @@ -263,14 +264,11 @@ def step_fn(positions, neighbor, box): return results, neighbor else: - @jax.jit def step_fn(positions, neighbor, box, offsets): results = model(positions, Z, neighbor, box, offsets) - if "stress" in results.keys(): results = process_stress(results, box) - return results return step_fn