Skip to content

Commit

Permalink
Merge branch 'dev' into moredocs-nico
Browse files Browse the repository at this point in the history
  • Loading branch information
Tetracarbonylnickel committed Apr 8, 2024
2 parents 56eba3b + 0b59ae2 commit 261b1ed
Show file tree
Hide file tree
Showing 42 changed files with 4,266 additions and 1,658 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Install package
run: |
poetry --version
poetry install
poetry install --all-extras
- name: Unit Tests
run: |
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ tmp/
.traj
.h5
events.out.*
*.schema.json

# Translations
*.mo
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/psf/black
rev: 24.1.1
rev: 24.3.0
hooks:
- id: black
exclude: ^apax/utils/jax_md_reduced/
Expand Down
32 changes: 29 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ See the [Jax installation instructions](https://github.com/google/jax#installati

In order to train a model, you need to run

```python
```bash
apax train config.yaml
```

We offer some input file templates to get new users started as quickly as possible.
Simply run the following commands and add the appropriate entries in the marked fields

```python
```bash
apax template train # use --full for a template with all input options
```

Expand All @@ -63,7 +63,7 @@ The documentation can convenienty be accessed by running `apax docs`.
There are two ways in which `apax` models can be used for molecular dynamics out of the box.
High performance NVT simulations using JaxMD can be started with the CLI by running

```python
```bash
apax md config.yaml md_config.yaml
```

Expand All @@ -72,6 +72,32 @@ A template command for MD input files is provided as well.
The second way is to use the ASE calculator provided in `apax.md`.


## Input File Auto-Completion

use the following command to generate JSON schemata for training and validation files:

```bash
apax schema
```

If you are using VSCode, you can utilize them to lint and autocomplete your input files by including them in `.vscode/settings.json`

```json
{
"yaml.schemas": {

"/absolute/path/to/apaxtrain.schema.json": [
"train.yaml"
]
,
"/absolute/path/to/apaxmd.schema.json": [
"md.yaml"
]
}
}
```


## Authors
- Moritz René Schäfer
- Nico Segreto
Expand Down
1 change: 1 addition & 0 deletions apax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
jax.config.update("jax_enable_x64", True)
from apax.utils.helpers import setup_ase

Expand Down
28 changes: 11 additions & 17 deletions apax/bal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
from tqdm import trange

from apax.bal import feature_maps, kernel, selection, transforms
from apax.data.input_pipeline import AtomisticDataset
from apax.data.input_pipeline import OTFInMemoryDataset
from apax.model.builder import ModelBuilder
from apax.model.gmnn import EnergyModel
from apax.train.checkpoints import (
canonicalize_energy_model_parameters,
check_for_ensemble,
restore_parameters,
)
from apax.train.run import initialize_dataset


def create_feature_fn(
Expand Down Expand Up @@ -61,18 +60,8 @@ def create_feature_fn(
return feature_fn


def compute_features(
feature_fn: feature_maps.FeatureMap, dataset: AtomisticDataset
) -> np.ndarray:
"""Compute the features of a dataset.
Attributes
----------
feature_fn:
Function to compute the features with.
dataset:
Dataset to compute the features for.
"""
def compute_features(feature_fn, dataset: OTFInMemoryDataset):
"""Compute the features of a dataset."""
features = []
n_data = dataset.n_data
ds = dataset.batch()
Expand Down Expand Up @@ -139,10 +128,15 @@ def kernel_selection(
is_ensemble = n_models > 1

n_train = len(train_atoms)
dataset = initialize_dataset(
config, train_atoms + pool_atoms, read_labels=False, calc_stats=False
dataset = OTFInMemoryDataset(
train_atoms + pool_atoms,
cutoff=config.model.r_max,
bs=processing_batch_size,
n_epochs=1,
ignore_labels=True,
pos_unit=config.data.pos_unit,
energy_unit=config.data.energy_unit,
)
dataset.set_batch_size(processing_batch_size)

_, init_box = dataset.init_input()

Expand Down
18 changes: 18 additions & 0 deletions apax/cli/apax_app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib.metadata
import importlib.resources as pkg_resources
import json
import sys
from pathlib import Path

Expand Down Expand Up @@ -93,6 +94,23 @@ def docs():
typer.launch("https://apax.readthedocs.io/en/latest/")


@app.command()
def schema():
"""
Generating JSON schemata for autocompletion of train/md inputs in VSCode.
"""
console.print("Generating JSON schema")
from apax.config import Config, MDConfig

train_schema = Config.model_json_schema()
md_schema = MDConfig.model_json_schema()
with open("./apaxtrain.schema.json", "w") as f:
f.write(json.dumps(train_schema, indent=2))

with open("./apaxmd.schema.json", "w") as f:
f.write(json.dumps(md_schema, indent=2))


@validate_app.command("train")
def validate_train_config(
config_path: Path = typer.Argument(
Expand Down
1 change: 0 additions & 1 deletion apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,3 @@ checkpoints:

progress_bar:
disable_epoch_pbar: false
disable_nl_pbar: false
15 changes: 15 additions & 0 deletions apax/config/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import logging
import os
from collections.abc import MutableMapping
from typing import Union

import yaml
Expand Down Expand Up @@ -28,3 +29,17 @@ def parse_config(config: Union[str, os.PathLike, dict], mode: str = "train") ->
config = MDConfig.model_validate(config)

return config


def flatten(dictionary, parent_key="", separator="_"):
"""https://stackoverflow.com/questions/6027558/
flatten-nested-dictionaries-compressing-keys
"""
items = []
for key, value in dictionary.items():
new_key = parent_key + separator + key if parent_key else key
if isinstance(value, MutableMapping):
items.extend(flatten(value, new_key, separator=separator).items())
else:
items.append((new_key, value))
return dict(items)
57 changes: 48 additions & 9 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import logging
import os
from pathlib import Path
from typing import List, Literal, Optional
from typing import List, Literal, Optional, Union

import yaml
from pydantic import (
BaseModel,
ConfigDict,
Field,
NonNegativeFloat,
PositiveFloat,
PositiveInt,
create_model,
model_validator,
)
from typing_extensions import Annotated

from apax.data.statistics import scale_method_list, shift_method_list

Expand Down Expand Up @@ -50,6 +52,7 @@ class DataConfig(BaseModel, extra="forbid"):

directory: str
experiment: str
ds_type: Literal["cached", "otf"] = "cached"
data_path: Optional[str] = None
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
Expand Down Expand Up @@ -228,20 +231,53 @@ class LossConfig(BaseModel, extra="forbid"):
"""

name: str
loss_type: str = "structures"
loss_type: str = "mse"
weight: NonNegativeFloat = 1.0
atoms_exponent: NonNegativeFloat = 1
parameters: dict = {}


class CallbackConfig(BaseModel, frozen=True, extra="forbid"):
class CSVCallback(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the training callbacks.
Configuration of the CSVCallback.
Parameters
----------
name: Keyword of the callback used. Currently we implement "csv" and "tensorboard".
name: Keyword of the callback used..
"""

name: str
name: Literal["csv"]


class TBCallback(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the TensorBoard callback.
Parameters
----------
name: Keyword of the callback used..
"""

name: Literal["tensorboard"]


class MLFlowCallback(BaseModel, frozen=True, extra="forbid"):
"""
Configuration of the MLFlow callback.
Parameters
----------
name: Keyword of the callback used.
experiment: Path to the MLFlow experiment, e.g. /Users/<user>/<my_experiment>
"""

name: Literal["mlflow"]
experiment: str


CallBack = Annotated[
Union[CSVCallback, TBCallback, MLFlowCallback], Field(discriminator="name")
]


class TrainProgressbarConfig(BaseModel, extra="forbid"):
Expand All @@ -251,11 +287,11 @@ class TrainProgressbarConfig(BaseModel, extra="forbid"):
Parameters
----------
disable_epoch_pbar: Set to True to disable the epoch progress bar.
disable_nl_pbar: Set to True to disable the NL precomputation progress bar.
disable_batch_pbar: Set to True to disable the batch progress bar.
"""

disable_epoch_pbar: bool = False
disable_nl_pbar: bool = False
disable_batch_pbar: bool = True


class CheckpointConfig(BaseModel, extra="forbid"):
Expand Down Expand Up @@ -295,20 +331,23 @@ class Config(BaseModel, frozen=True, extra="forbid"):
callbacks: List of :class: `callback` <config.CallbackConfig> configurations.
progress_bar: Progressbar configuration.
checkpoints: Checkpoint configuration.
data_parallel: Automatically uses all available GPUs for data parallel training.
Set to false to force single device training.
"""

n_epochs: PositiveInt
patience: Optional[PositiveInt] = None
seed: int = 1
n_models: int = 1
n_jitted_steps: int = 1
data_parallel: int = True

data: DataConfig
model: ModelConfig = ModelConfig()
metrics: List[MetricsConfig] = []
loss: List[LossConfig]
optimizer: OptimizerConfig = OptimizerConfig()
callbacks: List[CallbackConfig] = [CallbackConfig(name="csv")]
callbacks: List[CallBack] = [CSVCallback(name="csv")]
progress_bar: TrainProgressbarConfig = TrainProgressbarConfig()
checkpoints: CheckpointConfig = CheckpointConfig()

Expand Down
50 changes: 0 additions & 50 deletions apax/data/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

import numpy as np

from apax.data.input_pipeline import AtomisticDataset, process_inputs
from apax.data.statistics import compute_scale_shift_parameters
from apax.utils.convert import atoms_to_labels
from apax.utils.data import load_data, split_atoms, split_idxs

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -36,50 +33,3 @@ def load_data_files(data_config):
raise ValueError("input data path/paths not defined")

return train_atoms_list, val_atoms_list


def initialize_dataset(
config,
atoms_list,
read_labels: bool = True,
calc_stats: bool = True,
):
if calc_stats and not read_labels:
raise ValueError(
"Cannot calculate scale/shift parameters without reading labels."
)
inputs = process_inputs(
atoms_list,
r_max=config.model.r_max,
disable_pbar=config.progress_bar.disable_nl_pbar,
pos_unit=config.data.pos_unit,
)
labels = atoms_to_labels(
atoms_list,
additional_properties_info=config.data.additional_properties_info,
read_labels=read_labels,
pos_unit=config.data.pos_unit,
energy_unit=config.data.energy_unit,
)

if calc_stats:
ds_stats = compute_scale_shift_parameters(
inputs,
labels,
config.data.shift_method,
config.data.scale_method,
config.data.shift_options,
config.data.scale_options,
)

dataset = AtomisticDataset(
inputs,
config.n_epochs,
labels=labels,
buffer_size=config.data.shuffle_buffer_size,
)

if calc_stats:
return dataset, ds_stats
else:
return dataset
Loading

0 comments on commit 261b1ed

Please sign in to comment.