Skip to content

Commit

Permalink
Merge branch 'dev' into batch_progress
Browse files Browse the repository at this point in the history
  • Loading branch information
Moritz Schaefer committed Apr 3, 2024
2 parents 1e75155 + 8f6faa7 commit 5a76439
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 36 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ tmp/
.traj
.h5
events.out.*
*.schema.json

# Translations
*.mo
Expand Down
32 changes: 29 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ See the [Jax installation instructions](https://github.com/google/jax#installati

In order to train a model, you need to run

```python
```bash
apax train config.yaml
```

We offer some input file templates to get new users started as quickly as possible.
Simply run the following commands and add the appropriate entries in the marked fields

```python
```bash
apax template train # use --full for a template with all input options
```

Expand All @@ -79,7 +79,7 @@ The documentation can convenienty be accessed by running `apax docs`.
There are two ways in which `apax` models can be used for molecular dynamics out of the box.
High performance NVT simulations using JaxMD can be started with the CLI by running

```python
```bash
apax md config.yaml md_config.yaml
```

Expand All @@ -88,6 +88,32 @@ A template command for MD input files is provided as well.
The second way is to use the ASE calculator provided in `apax.md`.


## Input File Auto-Completion

use the following command to generate JSON schemata for training and validation files:

```bash
apax schema
```

If you are using VSCode, you can utilize them to lint and autocomplete your input files by including them in `.vscode/settings.json`

```json
{
"yaml.schemas": {

"/absolute/path/to/apaxtrain.schema.json": [
"train.yaml"
]
,
"/absolute/path/to/apaxmd.schema.json": [
"md.yaml"
]
}
}
```


## Authors
- Moritz René Schäfer
- Nico Segreto
Expand Down
18 changes: 18 additions & 0 deletions apax/cli/apax_app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib.metadata
import importlib.resources as pkg_resources
import json
import sys
from pathlib import Path

Expand Down Expand Up @@ -93,6 +94,23 @@ def docs():
typer.launch("https://apax.readthedocs.io/en/latest/")


@app.command()
def schema():
"""
Generating JSON schemata for autocompletion of train/md inputs in VSCode.
"""
console.print("Generating JSON schema")
from apax.config import Config, MDConfig

train_schema = Config.model_json_schema()
md_schema = MDConfig.model_json_schema()
with open("./apaxtrain.schema.json", "w") as f:
f.write(json.dumps(train_schema, indent=2))

with open("./apaxmd.schema.json", "w") as f:
f.write(json.dumps(md_schema, indent=2))


@validate_app.command("train")
def validate_train_config(
config_path: Path = typer.Argument(
Expand Down
4 changes: 3 additions & 1 deletion apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,10 @@ class LossConfig(BaseModel, extra="forbid"):
"""

name: str
loss_type: str = "structures"
loss_type: str = "mse"
weight: NonNegativeFloat = 1.0
atoms_exponent: NonNegativeFloat = 1
parameters: dict = {}


class CallbackConfig(BaseModel, frozen=True, extra="forbid"):
Expand Down
59 changes: 34 additions & 25 deletions apax/train/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


def weighted_squared_error(
label: jnp.array, prediction: jnp.array, divisor: float = 1.0
label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {}
) -> jnp.array:
"""
Squared error function that allows weighting of
Expand All @@ -17,8 +17,23 @@ def weighted_squared_error(
return (label - prediction) ** 2 / divisor


def weighted_huber_loss(
label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {}
) -> jnp.array:
"""
Huber loss function that allows weighting of
individual contributions by the number of atoms in the system.
"""
if "delta" not in parameters.keys():
raise KeyError("Huber loss function requires 'delta' parameter")
delta = parameters["delta"]
diff = jnp.abs(label - prediction)
loss = jnp.where(diff > delta, delta * (diff - 0.5 * delta), 0.5 * diff**2)
return loss / divisor


def force_angle_loss(
label: jnp.array, prediction: jnp.array, divisor: float = 1.0
label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {}
) -> jnp.array:
"""
Consine similarity loss function. Contributions are summed in `Loss`.
Expand All @@ -28,7 +43,7 @@ def force_angle_loss(


def force_angle_div_force_label(
label: jnp.array, prediction: jnp.array, divisor: float = 1.0
label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {}
):
"""
Consine similarity loss function weighted by the norm of the force labels.
Expand All @@ -41,7 +56,7 @@ def force_angle_div_force_label(


def force_angle_exponential_weight(
label: jnp.array, prediction: jnp.array, divisor: float = 1.0
label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {}
) -> jnp.array:
"""
Consine similarity loss function exponentially scaled by the norm of the force labels.
Expand All @@ -52,17 +67,16 @@ def force_angle_exponential_weight(
return (1.0 - dotp) * jnp.exp(-F_0_norm) / divisor


def stress_tril(label, prediction, divisor=1.0):
def stress_tril(label, prediction, divisor=1.0, parameters: dict = {}):
idxs = jnp.tril_indices(3)
label_tril = label[:, idxs[0], idxs[1]]
prediction_tril = prediction[:, idxs[0], idxs[1]]
return (label_tril - prediction_tril) ** 2 / divisor


loss_functions = {
"molecules": weighted_squared_error,
"structures": weighted_squared_error,
"vibrations": weighted_squared_error,
"mse": weighted_squared_error,
"huber": weighted_huber_loss,
"cosine_sim": force_angle_loss,
"cosine_sim_div_magnitude": force_angle_div_force_label,
"cosine_sim_exp_magnitude": force_angle_exponential_weight,
Expand All @@ -80,6 +94,8 @@ class Loss:
name: str
loss_type: str
weight: float = 1.0
atoms_exponent: float = 1.0
parameters: dict = dataclasses.field(default_factory=lambda: {})

def __post_init__(self):
if self.loss_type not in loss_functions.keys():
Expand All @@ -94,25 +110,18 @@ def __post_init__(self):
def __call__(self, inputs: dict, prediction: dict, label: dict) -> float:
# TODO we may want to insert an additional `mask` argument for this method
divisor = self.determine_divisor(inputs["n_atoms"])
loss = self.loss_fn(label[self.name], prediction[self.name], divisor=divisor)
return self.weight * jnp.sum(jnp.mean(loss, axis=0))
batch_losses = self.loss_fn(
label[self.name], prediction[self.name], divisor, self.parameters
)
loss = self.weight * jnp.sum(jnp.mean(batch_losses, axis=0))
return loss

def determine_divisor(self, n_atoms: jnp.array) -> jnp.array:
divisor_id = self.name + "_" + self.loss_type
divisor_dict = {
"energy_structures": n_atoms**2,
"energy_vibrations": n_atoms,
"forces_structures": einops.repeat(n_atoms, "batch -> batch 1 1"),
"forces_cosine_sim": einops.repeat(n_atoms, "batch -> batch 1 1"),
"cosine_sim_div_magnitude": einops.repeat(n_atoms, "batch -> batch 1 1"),
"forces_cosine_sim_exp_magnitude": einops.repeat(
n_atoms, "batch -> batch 1 1"
),
"stress_structures": einops.repeat(n_atoms**2, "batch -> batch 1 1"),
"stress_tril": einops.repeat(n_atoms**2, "batch -> batch 1 1"),
"stress_vibrations": einops.repeat(n_atoms, "batch -> batch 1 1"),
}
divisor = divisor_dict.get(divisor_id, jnp.array(1.0))
# shape: batch
divisor = n_atoms**self.atoms_exponent

if self.name in ["forces", "stress"]:
divisor = einops.repeat(divisor, "batch -> batch 1 1")

return divisor

Expand Down
10 changes: 4 additions & 6 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,10 @@ def fit(
epoch_loss["val_loss"] /= val_steps_per_epoch
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])

epoch_metrics.update(
{
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
}
)
epoch_metrics.update({
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
})

epoch_metrics.update({**epoch_loss})

Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/train/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def test_force_angle_loss():

def test_force_loss():
name = "forces"
loss_type = "structures"
loss_type = "mse"
weight = 1
inputs = {
"n_atoms": jnp.array([2]),
Expand Down

0 comments on commit 5a76439

Please sign in to comment.