diff --git a/.gitignore b/.gitignore index e6981eda..041b225d 100644 --- a/.gitignore +++ b/.gitignore @@ -59,6 +59,7 @@ tmp/ .traj .h5 events.out.* +*.schema.json # Translations *.mo diff --git a/README.md b/README.md index 16d78cca..c3ffc31f 100644 --- a/README.md +++ b/README.md @@ -60,14 +60,14 @@ See the [Jax installation instructions](https://github.com/google/jax#installati In order to train a model, you need to run -```python +```bash apax train config.yaml ``` We offer some input file templates to get new users started as quickly as possible. Simply run the following commands and add the appropriate entries in the marked fields -```python +```bash apax template train # use --full for a template with all input options ``` @@ -79,7 +79,7 @@ The documentation can convenienty be accessed by running `apax docs`. There are two ways in which `apax` models can be used for molecular dynamics out of the box. High performance NVT simulations using JaxMD can be started with the CLI by running -```python +```bash apax md config.yaml md_config.yaml ``` @@ -88,6 +88,32 @@ A template command for MD input files is provided as well. The second way is to use the ASE calculator provided in `apax.md`. +## Input File Auto-Completion + +use the following command to generate JSON schemata for training and validation files: + +```bash +apax schema +``` + +If you are using VSCode, you can utilize them to lint and autocomplete your input files by including them in `.vscode/settings.json` + +```json +{ + "yaml.schemas": { + + "/absolute/path/to/apaxtrain.schema.json": [ + "train.yaml" + ] + , + "/absolute/path/to/apaxmd.schema.json": [ + "md.yaml" + ] + } +} +``` + + ## Authors - Moritz René Schäfer - Nico Segreto diff --git a/apax/__init__.py b/apax/__init__.py index 7438bf4a..de4b9e6e 100644 --- a/apax/__init__.py +++ b/apax/__init__.py @@ -3,6 +3,7 @@ import jax os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" jax.config.update("jax_enable_x64", True) from apax.utils.helpers import setup_ase diff --git a/apax/bal/api.py b/apax/bal/api.py index 4a50463d..f891c5c9 100644 --- a/apax/bal/api.py +++ b/apax/bal/api.py @@ -8,7 +8,7 @@ from tqdm import trange from apax.bal import feature_maps, kernel, selection, transforms -from apax.data.input_pipeline import InMemoryDataset +from apax.data.input_pipeline import OTFInMemoryDataset from apax.model.builder import ModelBuilder from apax.model.gmnn import EnergyModel from apax.train.checkpoints import ( @@ -46,7 +46,7 @@ def create_feature_fn( return feature_fn -def compute_features(feature_fn, dataset: InMemoryDataset): +def compute_features(feature_fn, dataset: OTFInMemoryDataset): """Compute the features of a dataset.""" features = [] n_data = dataset.n_data @@ -85,7 +85,7 @@ def kernel_selection( is_ensemble = n_models > 1 n_train = len(train_atoms) - dataset = InMemoryDataset( + dataset = OTFInMemoryDataset( train_atoms + pool_atoms, cutoff=config.model.r_max, bs=processing_batch_size, diff --git a/apax/cli/apax_app.py b/apax/cli/apax_app.py index 60321662..cd09ea7d 100644 --- a/apax/cli/apax_app.py +++ b/apax/cli/apax_app.py @@ -1,5 +1,6 @@ import importlib.metadata import importlib.resources as pkg_resources +import json import sys from pathlib import Path @@ -93,6 +94,23 @@ def docs(): typer.launch("https://apax.readthedocs.io/en/latest/") +@app.command() +def schema(): + """ + Generating JSON schemata for autocompletion of train/md inputs in VSCode. + """ + console.print("Generating JSON schema") + from apax.config import Config, MDConfig + + train_schema = Config.model_json_schema() + md_schema = MDConfig.model_json_schema() + with open("./apaxtrain.schema.json", "w") as f: + f.write(json.dumps(train_schema, indent=2)) + + with open("./apaxmd.schema.json", "w") as f: + f.write(json.dumps(md_schema, indent=2)) + + @validate_app.command("train") def validate_train_config( config_path: Path = typer.Argument( diff --git a/apax/config/train_config.py b/apax/config/train_config.py index 350849d0..3cd5f57c 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -50,6 +50,7 @@ class DataConfig(BaseModel, extra="forbid"): directory: str experiment: str + ds_type: Literal["cached", "otf"] = "cached" data_path: Optional[str] = None train_data_path: Optional[str] = None val_data_path: Optional[str] = None @@ -228,8 +229,10 @@ class LossConfig(BaseModel, extra="forbid"): """ name: str - loss_type: str = "structures" + loss_type: str = "mse" weight: NonNegativeFloat = 1.0 + atoms_exponent: NonNegativeFloat = 1 + parameters: dict = {} class CallbackConfig(BaseModel, frozen=True, extra="forbid"): diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index d6547cfb..082415a0 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -1,5 +1,7 @@ import logging +import uuid from collections import deque +from pathlib import Path from random import shuffle from typing import Dict, Iterator @@ -44,6 +46,7 @@ def __init__( n_jit_steps=1, pre_shuffle=False, ignore_labels=False, + cache_path=".", ) -> None: if pre_shuffle: shuffle(atoms) @@ -68,6 +71,7 @@ def __init__( self.buffer = deque() self.batch_size = self.validate_batch_size(bs) self.n_jit_steps = n_jit_steps + self.file = Path(cache_path) / str(uuid.uuid4()) self.enqueue(min(self.buffer_size, self.n_data)) @@ -85,7 +89,6 @@ def validate_batch_size(self, batch_size: int) -> int: f"requested batch size {batch_size} is larger than the number of data" f" points {self.n_data}. Setting batch size = {self.n_data}" ) - print("Warning: " + msg) log.warning(msg) batch_size = self.n_data return batch_size @@ -125,20 +128,6 @@ def enqueue(self, num_elements): self.buffer.append(data) self.count += 1 - def __iter__(self): - epoch = 0 - while epoch < self.n_epochs or len(self.buffer) > 0: - yield self.buffer.popleft() - - space = self.buffer_size - len(self.buffer) - if self.count + space > self.n_data: - space = self.n_data - self.count - - if self.count >= self.n_data and epoch < self.n_epochs: - epoch += 1 - self.count = 0 - self.enqueue(space) - def make_signature(self) -> tf.TensorSpec: input_signature = {} input_signature["n_atoms"] = tf.TensorSpec((), dtype=tf.int16, name="n_atoms") @@ -189,6 +178,89 @@ def init_input(self) -> Dict[str, np.ndarray]: inputs = jax.tree_map(lambda x: jnp.array(x), inputs) return inputs, np.array(box) + def __iter__(self): + raise NotImplementedError + + def shuffle_and_batch(self): + raise NotImplementedError + + def batch(self) -> Iterator[jax.Array]: + raise NotImplementedError + + def cleanup(self): + pass + + +class CachedInMemoryDataset(InMemoryDataset): + def __iter__(self): + while self.count < self.n_data or len(self.buffer) > 0: + yield self.buffer.popleft() + + space = self.buffer_size - len(self.buffer) + if self.count + space > self.n_data: + space = self.n_data - self.count + self.enqueue(space) + + def shuffle_and_batch(self): + """Shuffles and batches the inputs/labels. This function prepares the + inputs and labels for the whole training and prefetches the data. + + Returns + ------- + ds : + Iterator that returns inputs and labels of one batch in each step. + """ + ds = ( + tf.data.Dataset.from_generator( + lambda: self, output_signature=self.make_signature() + ) + .cache(self.file.as_posix()) + .repeat(self.n_epochs) + ) + + ds = ds.shuffle( + buffer_size=self.buffer_size, reshuffle_each_iteration=True + ).batch(batch_size=self.batch_size) + if self.n_jit_steps > 1: + ds = ds.batch(batch_size=self.n_jit_steps) + ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) + return ds + + def batch(self) -> Iterator[jax.Array]: + ds = ( + tf.data.Dataset.from_generator( + lambda: self, output_signature=self.make_signature() + ) + .cache(self.file.as_posix()) + .repeat(self.n_epochs) + ) + ds = ds.batch(batch_size=self.batch_size) + ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) + return ds + + def cleanup(self): + for p in self.file.parent.glob(f"{self.file.name}.data*"): + p.unlink() + + index_file = self.file.parent / f"{self.file.name}.index" + index_file.unlink() + + +class OTFInMemoryDataset(InMemoryDataset): + def __iter__(self): + epoch = 0 + while epoch < self.n_epochs or len(self.buffer) > 0: + yield self.buffer.popleft() + + space = self.buffer_size - len(self.buffer) + if self.count + space > self.n_data: + space = self.n_data - self.count + + if self.count >= self.n_data and epoch < self.n_epochs: + epoch += 1 + self.count = 0 + self.enqueue(space) + def shuffle_and_batch(self): """Shuffles and batches the inputs/labels. This function prepares the inputs and labels for the whole training and prefetches the data. @@ -217,3 +289,9 @@ def batch(self) -> Iterator[jax.Array]: ds = ds.batch(batch_size=self.batch_size) ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) return ds + + +dataset_dict = { + "cached": CachedInMemoryDataset, + "otf": OTFInMemoryDataset, +} diff --git a/apax/md/ase_calc.py b/apax/md/ase_calc.py index 73f27658..69ef95f6 100644 --- a/apax/md/ase_calc.py +++ b/apax/md/ase_calc.py @@ -11,7 +11,7 @@ from matscipy.neighbours import neighbour_list from tqdm import trange -from apax.data.input_pipeline import InMemoryDataset +from apax.data.input_pipeline import OTFInMemoryDataset from apax.model import ModelBuilder from apax.train.checkpoints import check_for_ensemble, restore_parameters from apax.utils.jax_md_reduced import partition, quantity, space @@ -256,7 +256,7 @@ def batch_eval( """ if self.model is None: self.initialize(atoms_list[0]) - dataset = InMemoryDataset( + dataset = OTFInMemoryDataset( atoms_list, self.model_config.model.r_max, batch_size, diff --git a/apax/train/eval.py b/apax/train/eval.py index b7d69700..f15a6da3 100644 --- a/apax/train/eval.py +++ b/apax/train/eval.py @@ -8,7 +8,7 @@ from tqdm import trange from apax.config import parse_config -from apax.data.input_pipeline import InMemoryDataset +from apax.data.input_pipeline import OTFInMemoryDataset from apax.model import ModelBuilder from apax.train.callbacks import initialize_callbacks from apax.train.checkpoints import restore_single_parameters @@ -122,7 +122,7 @@ def eval_model(config_path, n_test=-1, log_file="eval.log", log_level="error"): Metrics = initialize_metrics(config.metrics) atoms_list = load_test_data(config, model_version_path, eval_path, n_test) - test_ds = InMemoryDataset( + test_ds = OTFInMemoryDataset( atoms_list, config.model.r_max, config.data.valid_batch_size ) diff --git a/apax/train/loss.py b/apax/train/loss.py index cd28786e..c2e575d3 100644 --- a/apax/train/loss.py +++ b/apax/train/loss.py @@ -8,7 +8,7 @@ def weighted_squared_error( - label: jnp.array, prediction: jnp.array, divisor: float = 1.0 + label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {} ) -> jnp.array: """ Squared error function that allows weighting of @@ -17,8 +17,23 @@ def weighted_squared_error( return (label - prediction) ** 2 / divisor +def weighted_huber_loss( + label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {} +) -> jnp.array: + """ + Huber loss function that allows weighting of + individual contributions by the number of atoms in the system. + """ + if "delta" not in parameters.keys(): + raise KeyError("Huber loss function requires 'delta' parameter") + delta = parameters["delta"] + diff = jnp.abs(label - prediction) + loss = jnp.where(diff > delta, delta * (diff - 0.5 * delta), 0.5 * diff**2) + return loss / divisor + + def force_angle_loss( - label: jnp.array, prediction: jnp.array, divisor: float = 1.0 + label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {} ) -> jnp.array: """ Consine similarity loss function. Contributions are summed in `Loss`. @@ -28,7 +43,7 @@ def force_angle_loss( def force_angle_div_force_label( - label: jnp.array, prediction: jnp.array, divisor: float = 1.0 + label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {} ): """ Consine similarity loss function weighted by the norm of the force labels. @@ -41,7 +56,7 @@ def force_angle_div_force_label( def force_angle_exponential_weight( - label: jnp.array, prediction: jnp.array, divisor: float = 1.0 + label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {} ) -> jnp.array: """ Consine similarity loss function exponentially scaled by the norm of the force labels. @@ -52,7 +67,7 @@ def force_angle_exponential_weight( return (1.0 - dotp) * jnp.exp(-F_0_norm) / divisor -def stress_tril(label, prediction, divisor=1.0): +def stress_tril(label, prediction, divisor=1.0, parameters: dict = {}): idxs = jnp.tril_indices(3) label_tril = label[:, idxs[0], idxs[1]] prediction_tril = prediction[:, idxs[0], idxs[1]] @@ -60,9 +75,8 @@ def stress_tril(label, prediction, divisor=1.0): loss_functions = { - "molecules": weighted_squared_error, - "structures": weighted_squared_error, - "vibrations": weighted_squared_error, + "mse": weighted_squared_error, + "huber": weighted_huber_loss, "cosine_sim": force_angle_loss, "cosine_sim_div_magnitude": force_angle_div_force_label, "cosine_sim_exp_magnitude": force_angle_exponential_weight, @@ -80,6 +94,8 @@ class Loss: name: str loss_type: str weight: float = 1.0 + atoms_exponent: float = 1.0 + parameters: dict = dataclasses.field(default_factory=lambda: {}) def __post_init__(self): if self.loss_type not in loss_functions.keys(): @@ -94,25 +110,18 @@ def __post_init__(self): def __call__(self, inputs: dict, prediction: dict, label: dict) -> float: # TODO we may want to insert an additional `mask` argument for this method divisor = self.determine_divisor(inputs["n_atoms"]) - loss = self.loss_fn(label[self.name], prediction[self.name], divisor=divisor) - return self.weight * jnp.sum(jnp.mean(loss, axis=0)) + batch_losses = self.loss_fn( + label[self.name], prediction[self.name], divisor, self.parameters + ) + loss = self.weight * jnp.sum(jnp.mean(batch_losses, axis=0)) + return loss def determine_divisor(self, n_atoms: jnp.array) -> jnp.array: - divisor_id = self.name + "_" + self.loss_type - divisor_dict = { - "energy_structures": n_atoms**2, - "energy_vibrations": n_atoms, - "forces_structures": einops.repeat(n_atoms, "batch -> batch 1 1"), - "forces_cosine_sim": einops.repeat(n_atoms, "batch -> batch 1 1"), - "cosine_sim_div_magnitude": einops.repeat(n_atoms, "batch -> batch 1 1"), - "forces_cosine_sim_exp_magnitude": einops.repeat( - n_atoms, "batch -> batch 1 1" - ), - "stress_structures": einops.repeat(n_atoms**2, "batch -> batch 1 1"), - "stress_tril": einops.repeat(n_atoms**2, "batch -> batch 1 1"), - "stress_vibrations": einops.repeat(n_atoms, "batch -> batch 1 1"), - } - divisor = divisor_dict.get(divisor_id, jnp.array(1.0)) + # shape: batch + divisor = n_atoms**self.atoms_exponent + + if self.name in ["forces", "stress"]: + divisor = einops.repeat(divisor, "batch -> batch 1 1") return divisor diff --git a/apax/train/run.py b/apax/train/run.py index f1dbeb39..fe408ab6 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -4,9 +4,9 @@ import jax -from apax.config import LossConfig, parse_config +from apax.config import Config, LossConfig, parse_config from apax.data.initialization import load_data_files -from apax.data.input_pipeline import InMemoryDataset +from apax.data.input_pipeline import dataset_dict from apax.data.statistics import compute_scale_shift_parameters from apax.model import ModelBuilder from apax.optimizer import get_opt @@ -51,24 +51,12 @@ def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection: return LossCollection(loss_funcs) -def run(user_config, log_level="error"): - config = parse_config(user_config) - - seed_py_np_tf(config.seed) - rng_key = jax.random.PRNGKey(config.seed) - - log.info("Initializing directories") - config.data.model_version_path.mkdir(parents=True, exist_ok=True) - setup_logging(config.data.model_version_path / "train.log", log_level) - config.dump_config(config.data.model_version_path) - - callbacks = initialize_callbacks(config.callbacks, config.data.model_version_path) - loss_fn = initialize_loss_fn(config.loss) - Metrics = initialize_metrics(config.metrics) - +def initialize_datasets(config: Config): train_raw_ds, val_raw_ds = load_data_files(config.data) - train_ds = InMemoryDataset( + Dataset = dataset_dict[config.data.ds_type] + + train_ds = Dataset( train_raw_ds, config.model.r_max, config.data.batch_size, @@ -76,9 +64,14 @@ def run(user_config, log_level="error"): config.data.shuffle_buffer_size, config.n_jitted_steps, pre_shuffle=True, + cache_path=config.data.model_version_path, ) - val_ds = InMemoryDataset( - val_raw_ds, config.model.r_max, config.data.valid_batch_size, config.n_epochs + val_ds = Dataset( + val_raw_ds, + config.model.r_max, + config.data.valid_batch_size, + config.n_epochs, + cache_path=config.data.model_version_path, ) ds_stats = compute_scale_shift_parameters( train_ds.inputs, @@ -88,6 +81,25 @@ def run(user_config, log_level="error"): config.data.shift_options, config.data.scale_options, ) + return train_ds, val_ds, ds_stats + + +def run(user_config, log_level="error"): + config = parse_config(user_config) + + seed_py_np_tf(config.seed) + rng_key = jax.random.PRNGKey(config.seed) + + log.info("Initializing directories") + config.data.model_version_path.mkdir(parents=True, exist_ok=True) + setup_logging(config.data.model_version_path / "train.log", log_level) + config.dump_config(config.data.model_version_path) + + callbacks = initialize_callbacks(config.callbacks, config.data.model_version_path) + loss_fn = initialize_loss_fn(config.loss) + Metrics = initialize_metrics(config.metrics) + + train_ds, val_ds, ds_stats = initialize_datasets(config) log.info("Initializing Model") sample_input, init_box = train_ds.init_input() diff --git a/apax/train/trainer.py b/apax/train/trainer.py index f0e99ef5..8c040a3f 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -142,6 +142,10 @@ def fit( epoch_pbar.close() callbacks.on_train_end() + train_ds.cleanup() + if val_ds: + val_ds.cleanup() + def global_norm(updates) -> jnp.ndarray: """Returns the l2 norm of the input. diff --git a/apax/utils/convert.py b/apax/utils/convert.py index 166535ad..b6bfd76e 100644 --- a/apax/utils/convert.py +++ b/apax/utils/convert.py @@ -89,7 +89,7 @@ def atoms_to_inputs( frac_pos = space.transform(inv_box, pos) inputs["positions"].append(np.array(frac_pos)) - inputs["numbers"].append(atoms.numbers) + inputs["numbers"].append(atoms.numbers.astype(np.int16)) inputs["n_atoms"].append(len(atoms)) inputs = prune_dict(inputs) diff --git a/tests/unit_tests/train/test_loss.py b/tests/unit_tests/train/test_loss.py index 26c18bc1..b39aac4c 100644 --- a/tests/unit_tests/train/test_loss.py +++ b/tests/unit_tests/train/test_loss.py @@ -68,7 +68,7 @@ def test_force_angle_loss(): def test_force_loss(): name = "forces" - loss_type = "structures" + loss_type = "mse" weight = 1 inputs = { "n_atoms": jnp.array([2]),