Skip to content

Commit

Permalink
fixed missing perturbation arg for ensemble in jaxmd
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Jan 20, 2024
1 parent 4159f6c commit 41884c3
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions apax/md/nvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@


def create_energy_fn(model, params, numbers, n_models):
def ensemble(params, R, Z, neighbor, box, offsets):
vmodel = jax.vmap(model, (0, None, None, None, None, None), 0)
energies = vmodel(params, R, Z, neighbor, box, offsets)
def ensemble(params, R, Z, neighbor, box, offsets, perturbation=None):
vmodel = jax.vmap(model, (0, None, None, None, None, None, None), 0)
energies = vmodel(params, R, Z, neighbor, box, offsets, perturbation)
energy = jnp.mean(energies)

return energy
Expand Down

0 comments on commit 41884c3

Please sign in to comment.