Skip to content

Commit

Permalink
Merge branch 'dev' into epoch_jit
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer authored Jan 16, 2024
2 parents 008c3a8 + 2e6ec7a commit 0266e44
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions apax/train/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,21 @@ def force_angle_exponential_weight(
return (1.0 - dotp) * jnp.exp(-F_0_norm) / divisor


def stress_tril(label, prediction, divisor=1.0):
idxs = jnp.tril_indices(3)
label_tril = label[:, idxs[0], idxs[1]]
prediction_tril = prediction[:, idxs[0], idxs[1]]
return (label_tril - prediction_tril) ** 2 / divisor


loss_functions = {
"molecules": weighted_squared_error,
"structures": weighted_squared_error,
"vibrations": weighted_squared_error,
"cosine_sim": force_angle_loss,
"cosine_sim_div_magnitude": force_angle_div_force_label,
"cosine_sim_exp_magnitude": force_angle_exponential_weight,
"tril": stress_tril,
}


Expand Down Expand Up @@ -101,6 +109,7 @@ def determine_divisor(self, n_atoms: jnp.array) -> jnp.array:
n_atoms, "batch -> batch 1 1"
),
"stress_structures": einops.repeat(n_atoms**2, "batch -> batch 1 1"),
"stress_tril": einops.repeat(n_atoms**2, "batch -> batch 1 1"),
"stress_vibrations": einops.repeat(n_atoms, "batch -> batch 1 1"),
}
divisor = divisor_dict.get(divisor_id, jnp.array(1.0))
Expand Down

0 comments on commit 0266e44

Please sign in to comment.