From 4dd0082d9734ef921179a6dddc49f64abd23ce9a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 27 Mar 2024 10:24:49 +0100 Subject: [PATCH 01/12] switched to cached dataset --- apax/data/input_pipeline.py | 201 ++++++++++++++++++++++++++++++++++-- 1 file changed, 192 insertions(+), 9 deletions(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index d6547cfb..260d70a0 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -2,6 +2,7 @@ from collections import deque from random import shuffle from typing import Dict, Iterator +import uuid import jax import jax.numpy as jnp @@ -33,6 +34,192 @@ def find_largest_system(inputs: dict[str, np.ndarray], r_max) -> tuple[int]: return max_atoms, max_nbrs +# class InMemoryDataset: +# def __init__( +# self, +# atoms, +# cutoff, +# bs, +# n_epochs, +# buffer_size=1000, +# n_jit_steps=1, +# pre_shuffle=False, +# ignore_labels=False, +# ) -> None: +# if pre_shuffle: +# shuffle(atoms) +# self.sample_atoms = atoms[0] +# self.inputs = atoms_to_inputs(atoms) + +# self.n_epochs = n_epochs +# self.buffer_size = buffer_size + +# max_atoms, max_nbrs = find_largest_system(self.inputs, cutoff) +# self.max_atoms = max_atoms +# self.max_nbrs = max_nbrs + +# if atoms[0].calc and not ignore_labels: +# self.labels = atoms_to_labels(atoms) +# else: +# self.labels = None + +# self.n_data = len(atoms) +# self.count = 0 +# self.cutoff = cutoff +# self.buffer = deque() +# self.batch_size = self.validate_batch_size(bs) +# self.n_jit_steps = n_jit_steps + +# self.enqueue(min(self.buffer_size, self.n_data)) + +# def steps_per_epoch(self) -> int: +# """Returns the number of steps per epoch dependent on the number of data and the +# batch size. Steps per epoch are calculated in a way that all epochs have the same +# number of steps, and all batches have the same length. To do so, some training +# data are dropped in each epoch. +# """ +# return self.n_data // self.batch_size // self.n_jit_steps + +# def validate_batch_size(self, batch_size: int) -> int: +# if batch_size > self.n_data: +# msg = ( +# f"requested batch size {batch_size} is larger than the number of data" +# f" points {self.n_data}. Setting batch size = {self.n_data}" +# ) +# print("Warning: " + msg) +# log.warning(msg) +# batch_size = self.n_data +# return batch_size + +# def prepare_data(self, i): +# inputs = {k: v[i] for k, v in self.inputs.items()} +# idx, offsets = compute_nl(inputs["positions"], inputs["box"], self.cutoff) +# inputs["idx"], inputs["offsets"] = pad_nl(idx, offsets, self.max_nbrs) + +# zeros_to_add = self.max_atoms - inputs["numbers"].shape[0] +# inputs["positions"] = np.pad( +# inputs["positions"], ((0, zeros_to_add), (0, 0)), "constant" +# ) +# inputs["numbers"] = np.pad( +# inputs["numbers"], (0, zeros_to_add), "constant" +# ).astype(np.int16) +# inputs["n_atoms"] = np.pad( +# inputs["n_atoms"], (0, zeros_to_add), "constant" +# ).astype(np.int16) + +# if not self.labels: +# return inputs + +# labels = {k: v[i] for k, v in self.labels.items()} +# if "forces" in labels: +# labels["forces"] = np.pad( +# labels["forces"], ((0, zeros_to_add), (0, 0)), "constant" +# ) + +# inputs = {k: tf.constant(v) for k, v in inputs.items()} +# labels = {k: tf.constant(v) for k, v in labels.items()} +# return (inputs, labels) + +# def enqueue(self, num_elements): +# for _ in range(num_elements): +# data = self.prepare_data(self.count) +# self.buffer.append(data) +# self.count += 1 + +# def __iter__(self): +# epoch = 0 +# while epoch < self.n_epochs or len(self.buffer) > 0: +# yield self.buffer.popleft() + +# space = self.buffer_size - len(self.buffer) +# if self.count + space > self.n_data: +# space = self.n_data - self.count + +# if self.count >= self.n_data and epoch < self.n_epochs: +# epoch += 1 +# self.count = 0 +# self.enqueue(space) + +# def make_signature(self) -> tf.TensorSpec: +# input_signature = {} +# input_signature["n_atoms"] = tf.TensorSpec((), dtype=tf.int16, name="n_atoms") +# input_signature["numbers"] = tf.TensorSpec( +# (self.max_atoms,), dtype=tf.int16, name="numbers" +# ) +# input_signature["positions"] = tf.TensorSpec( +# (self.max_atoms, 3), dtype=tf.float64, name="positions" +# ) +# input_signature["box"] = tf.TensorSpec((3, 3), dtype=tf.float64, name="box") +# input_signature["idx"] = tf.TensorSpec( +# (2, self.max_nbrs), dtype=tf.int16, name="idx" +# ) +# input_signature["offsets"] = tf.TensorSpec( +# (self.max_nbrs, 3), dtype=tf.float64, name="offsets" +# ) + +# if not self.labels: +# return input_signature + +# label_signature = {} +# if "energy" in self.labels.keys(): +# label_signature["energy"] = tf.TensorSpec((), dtype=tf.float64, name="energy") +# if "forces" in self.labels.keys(): +# label_signature["forces"] = tf.TensorSpec( +# (self.max_atoms, 3), dtype=tf.float64, name="forces" +# ) +# if "stress" in self.labels.keys(): +# label_signature["stress"] = tf.TensorSpec( +# (3, 3), dtype=tf.float64, name="stress" +# ) +# signature = (input_signature, label_signature) +# return signature + +# def init_input(self) -> Dict[str, np.ndarray]: +# """Returns first batch of inputs and labels to init the model.""" +# positions = self.sample_atoms.positions +# box = self.sample_atoms.cell.array +# idx, offsets = compute_nl(positions, box, self.cutoff) +# inputs = ( +# positions, +# self.sample_atoms.numbers, +# idx, +# box, +# offsets, +# ) + +# inputs = jax.tree_map(lambda x: jnp.array(x), inputs) +# return inputs, np.array(box) + +# def shuffle_and_batch(self): +# """Shuffles and batches the inputs/labels. This function prepares the +# inputs and labels for the whole training and prefetches the data. + +# Returns +# ------- +# ds : +# Iterator that returns inputs and labels of one batch in each step. +# """ +# ds = tf.data.Dataset.from_generator( +# lambda: self, output_signature=self.make_signature() +# ) + +# ds = ds.shuffle( +# buffer_size=self.buffer_size, reshuffle_each_iteration=True +# ).batch(batch_size=self.batch_size) +# if self.n_jit_steps > 1: +# ds = ds.batch(batch_size=self.n_jit_steps) +# ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) +# return ds + +# def batch(self) -> Iterator[jax.Array]: +# ds = tf.data.Dataset.from_generator( +# lambda: self, output_signature=self.make_signature() +# ) +# ds = ds.batch(batch_size=self.batch_size) +# ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) +# return ds + + class InMemoryDataset: def __init__( self, @@ -68,6 +255,7 @@ def __init__( self.buffer = deque() self.batch_size = self.validate_batch_size(bs) self.n_jit_steps = n_jit_steps + self.name = str(uuid.uuid4()) self.enqueue(min(self.buffer_size, self.n_data)) @@ -126,17 +314,12 @@ def enqueue(self, num_elements): self.count += 1 def __iter__(self): - epoch = 0 - while epoch < self.n_epochs or len(self.buffer) > 0: + while self.count < self.n_data or len(self.buffer) > 0: yield self.buffer.popleft() space = self.buffer_size - len(self.buffer) if self.count + space > self.n_data: space = self.n_data - self.count - - if self.count >= self.n_data and epoch < self.n_epochs: - epoch += 1 - self.count = 0 self.enqueue(space) def make_signature(self) -> tf.TensorSpec: @@ -200,7 +383,7 @@ def shuffle_and_batch(self): """ ds = tf.data.Dataset.from_generator( lambda: self, output_signature=self.make_signature() - ) + ).cache(self.name).repeat(self.n_epochs) ds = ds.shuffle( buffer_size=self.buffer_size, reshuffle_each_iteration=True @@ -213,7 +396,7 @@ def shuffle_and_batch(self): def batch(self) -> Iterator[jax.Array]: ds = tf.data.Dataset.from_generator( lambda: self, output_signature=self.make_signature() - ) + ).cache(self.name).repeat(self.n_epochs) ds = ds.batch(batch_size=self.batch_size) ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) - return ds + return ds \ No newline at end of file From 83a0d3436516e2424b5df014a45012299a04c8d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 27 Mar 2024 18:37:24 +0100 Subject: [PATCH 02/12] implemented automatic handling of cache files. --- apax/data/input_pipeline.py | 18 ++++++++++++++---- apax/train/run.py | 3 ++- apax/train/trainer.py | 5 +++++ 3 files changed, 21 insertions(+), 5 deletions(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 260d70a0..1d93dc5e 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -3,6 +3,7 @@ from random import shuffle from typing import Dict, Iterator import uuid +from pathlib import Path import jax import jax.numpy as jnp @@ -231,6 +232,7 @@ def __init__( n_jit_steps=1, pre_shuffle=False, ignore_labels=False, + cache_path = "." ) -> None: if pre_shuffle: shuffle(atoms) @@ -255,7 +257,7 @@ def __init__( self.buffer = deque() self.batch_size = self.validate_batch_size(bs) self.n_jit_steps = n_jit_steps - self.name = str(uuid.uuid4()) + self.file = Path(cache_path) / str(uuid.uuid4()) self.enqueue(min(self.buffer_size, self.n_data)) @@ -383,7 +385,7 @@ def shuffle_and_batch(self): """ ds = tf.data.Dataset.from_generator( lambda: self, output_signature=self.make_signature() - ).cache(self.name).repeat(self.n_epochs) + ).cache(self.file.as_posix()).repeat(self.n_epochs) ds = ds.shuffle( buffer_size=self.buffer_size, reshuffle_each_iteration=True @@ -396,7 +398,15 @@ def shuffle_and_batch(self): def batch(self) -> Iterator[jax.Array]: ds = tf.data.Dataset.from_generator( lambda: self, output_signature=self.make_signature() - ).cache(self.name).repeat(self.n_epochs) + ).cache(self.file.as_posix()).repeat(self.n_epochs) ds = ds.batch(batch_size=self.batch_size) ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) - return ds \ No newline at end of file + return ds + + def cleanup(self): + for p in self.file.parent.glob(f"{self.file.name}.data*"): + p.unlink() + + index_file = self.file.parent / f"{self.file.name}.index" + index_file.unlink() + \ No newline at end of file diff --git a/apax/train/run.py b/apax/train/run.py index f1dbeb39..6606348e 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -76,9 +76,10 @@ def run(user_config, log_level="error"): config.data.shuffle_buffer_size, config.n_jitted_steps, pre_shuffle=True, + cache_path=config.data.model_version_path, ) val_ds = InMemoryDataset( - val_raw_ds, config.model.r_max, config.data.valid_batch_size, config.n_epochs + val_raw_ds, config.model.r_max, config.data.valid_batch_size, config.n_epochs, cache_path=config.data.model_version_path, ) ds_stats = compute_scale_shift_parameters( train_ds.inputs, diff --git a/apax/train/trainer.py b/apax/train/trainer.py index f0e99ef5..7575d23e 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -1,5 +1,6 @@ import functools import logging +from pathlib import Path import time from functools import partial from typing import Callable, Optional @@ -142,6 +143,10 @@ def fit( epoch_pbar.close() callbacks.on_train_end() + train_ds.cleanup() + if val_ds: + val_ds.cleanup() + def global_norm(updates) -> jnp.ndarray: """Returns the l2 norm of the input. From ab55328361f0645e28210d755e1393aeee3d4ad2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 27 Mar 2024 19:35:20 +0100 Subject: [PATCH 03/12] moved dataset initialization to separate function --- apax/train/run.py | 37 +++++++++++++++++++++---------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/apax/train/run.py b/apax/train/run.py index 6606348e..0dee0f08 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -50,22 +50,7 @@ def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection: loss_funcs.append(Loss(**loss.model_dump())) return LossCollection(loss_funcs) - -def run(user_config, log_level="error"): - config = parse_config(user_config) - - seed_py_np_tf(config.seed) - rng_key = jax.random.PRNGKey(config.seed) - - log.info("Initializing directories") - config.data.model_version_path.mkdir(parents=True, exist_ok=True) - setup_logging(config.data.model_version_path / "train.log", log_level) - config.dump_config(config.data.model_version_path) - - callbacks = initialize_callbacks(config.callbacks, config.data.model_version_path) - loss_fn = initialize_loss_fn(config.loss) - Metrics = initialize_metrics(config.metrics) - +def initialize_datasets(config): train_raw_ds, val_raw_ds = load_data_files(config.data) train_ds = InMemoryDataset( @@ -89,6 +74,26 @@ def run(user_config, log_level="error"): config.data.shift_options, config.data.scale_options, ) + return train_ds, val_ds, ds_stats + + + +def run(user_config, log_level="error"): + config = parse_config(user_config) + + seed_py_np_tf(config.seed) + rng_key = jax.random.PRNGKey(config.seed) + + log.info("Initializing directories") + config.data.model_version_path.mkdir(parents=True, exist_ok=True) + setup_logging(config.data.model_version_path / "train.log", log_level) + config.dump_config(config.data.model_version_path) + + callbacks = initialize_callbacks(config.callbacks, config.data.model_version_path) + loss_fn = initialize_loss_fn(config.loss) + Metrics = initialize_metrics(config.metrics) + + train_ds, val_ds, ds_stats = initialize_datasets(config) log.info("Initializing Model") sample_input, init_box = train_ds.init_input() From ec3b2cb47d4503963c7f3c8dfa325f8c92bacef5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 29 Mar 2024 12:55:14 +0100 Subject: [PATCH 04/12] convert atomic numbers to int16 --- apax/utils/convert.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/utils/convert.py b/apax/utils/convert.py index 166535ad..b6bfd76e 100644 --- a/apax/utils/convert.py +++ b/apax/utils/convert.py @@ -89,7 +89,7 @@ def atoms_to_inputs( frac_pos = space.transform(inv_box, pos) inputs["positions"].append(np.array(frac_pos)) - inputs["numbers"].append(atoms.numbers) + inputs["numbers"].append(atoms.numbers.astype(np.int16)) inputs["n_atoms"].append(len(atoms)) inputs = prune_dict(inputs) From 79a9e69e703bf72b3a279907df444f94db11a98f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 29 Mar 2024 12:57:54 +0100 Subject: [PATCH 05/12] implemented OTFInMemoryDataset. Usage for single epoch tasks --- apax/bal/api.py | 6 +- apax/config/train_config.py | 1 + apax/data/input_pipeline.py | 278 +++++++++++------------------------- apax/md/ase_calc.py | 4 +- apax/train/eval.py | 4 +- apax/train/run.py | 12 +- 6 files changed, 95 insertions(+), 210 deletions(-) diff --git a/apax/bal/api.py b/apax/bal/api.py index 4a50463d..f891c5c9 100644 --- a/apax/bal/api.py +++ b/apax/bal/api.py @@ -8,7 +8,7 @@ from tqdm import trange from apax.bal import feature_maps, kernel, selection, transforms -from apax.data.input_pipeline import InMemoryDataset +from apax.data.input_pipeline import OTFInMemoryDataset from apax.model.builder import ModelBuilder from apax.model.gmnn import EnergyModel from apax.train.checkpoints import ( @@ -46,7 +46,7 @@ def create_feature_fn( return feature_fn -def compute_features(feature_fn, dataset: InMemoryDataset): +def compute_features(feature_fn, dataset: OTFInMemoryDataset): """Compute the features of a dataset.""" features = [] n_data = dataset.n_data @@ -85,7 +85,7 @@ def kernel_selection( is_ensemble = n_models > 1 n_train = len(train_atoms) - dataset = InMemoryDataset( + dataset = OTFInMemoryDataset( train_atoms + pool_atoms, cutoff=config.model.r_max, bs=processing_batch_size, diff --git a/apax/config/train_config.py b/apax/config/train_config.py index 350849d0..5733dec8 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -50,6 +50,7 @@ class DataConfig(BaseModel, extra="forbid"): directory: str experiment: str + ds_type: Literal["cached", "otf"] = "cached" data_path: Optional[str] = None train_data_path: Optional[str] = None val_data_path: Optional[str] = None diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 1d93dc5e..1a014c55 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -35,192 +35,6 @@ def find_largest_system(inputs: dict[str, np.ndarray], r_max) -> tuple[int]: return max_atoms, max_nbrs -# class InMemoryDataset: -# def __init__( -# self, -# atoms, -# cutoff, -# bs, -# n_epochs, -# buffer_size=1000, -# n_jit_steps=1, -# pre_shuffle=False, -# ignore_labels=False, -# ) -> None: -# if pre_shuffle: -# shuffle(atoms) -# self.sample_atoms = atoms[0] -# self.inputs = atoms_to_inputs(atoms) - -# self.n_epochs = n_epochs -# self.buffer_size = buffer_size - -# max_atoms, max_nbrs = find_largest_system(self.inputs, cutoff) -# self.max_atoms = max_atoms -# self.max_nbrs = max_nbrs - -# if atoms[0].calc and not ignore_labels: -# self.labels = atoms_to_labels(atoms) -# else: -# self.labels = None - -# self.n_data = len(atoms) -# self.count = 0 -# self.cutoff = cutoff -# self.buffer = deque() -# self.batch_size = self.validate_batch_size(bs) -# self.n_jit_steps = n_jit_steps - -# self.enqueue(min(self.buffer_size, self.n_data)) - -# def steps_per_epoch(self) -> int: -# """Returns the number of steps per epoch dependent on the number of data and the -# batch size. Steps per epoch are calculated in a way that all epochs have the same -# number of steps, and all batches have the same length. To do so, some training -# data are dropped in each epoch. -# """ -# return self.n_data // self.batch_size // self.n_jit_steps - -# def validate_batch_size(self, batch_size: int) -> int: -# if batch_size > self.n_data: -# msg = ( -# f"requested batch size {batch_size} is larger than the number of data" -# f" points {self.n_data}. Setting batch size = {self.n_data}" -# ) -# print("Warning: " + msg) -# log.warning(msg) -# batch_size = self.n_data -# return batch_size - -# def prepare_data(self, i): -# inputs = {k: v[i] for k, v in self.inputs.items()} -# idx, offsets = compute_nl(inputs["positions"], inputs["box"], self.cutoff) -# inputs["idx"], inputs["offsets"] = pad_nl(idx, offsets, self.max_nbrs) - -# zeros_to_add = self.max_atoms - inputs["numbers"].shape[0] -# inputs["positions"] = np.pad( -# inputs["positions"], ((0, zeros_to_add), (0, 0)), "constant" -# ) -# inputs["numbers"] = np.pad( -# inputs["numbers"], (0, zeros_to_add), "constant" -# ).astype(np.int16) -# inputs["n_atoms"] = np.pad( -# inputs["n_atoms"], (0, zeros_to_add), "constant" -# ).astype(np.int16) - -# if not self.labels: -# return inputs - -# labels = {k: v[i] for k, v in self.labels.items()} -# if "forces" in labels: -# labels["forces"] = np.pad( -# labels["forces"], ((0, zeros_to_add), (0, 0)), "constant" -# ) - -# inputs = {k: tf.constant(v) for k, v in inputs.items()} -# labels = {k: tf.constant(v) for k, v in labels.items()} -# return (inputs, labels) - -# def enqueue(self, num_elements): -# for _ in range(num_elements): -# data = self.prepare_data(self.count) -# self.buffer.append(data) -# self.count += 1 - -# def __iter__(self): -# epoch = 0 -# while epoch < self.n_epochs or len(self.buffer) > 0: -# yield self.buffer.popleft() - -# space = self.buffer_size - len(self.buffer) -# if self.count + space > self.n_data: -# space = self.n_data - self.count - -# if self.count >= self.n_data and epoch < self.n_epochs: -# epoch += 1 -# self.count = 0 -# self.enqueue(space) - -# def make_signature(self) -> tf.TensorSpec: -# input_signature = {} -# input_signature["n_atoms"] = tf.TensorSpec((), dtype=tf.int16, name="n_atoms") -# input_signature["numbers"] = tf.TensorSpec( -# (self.max_atoms,), dtype=tf.int16, name="numbers" -# ) -# input_signature["positions"] = tf.TensorSpec( -# (self.max_atoms, 3), dtype=tf.float64, name="positions" -# ) -# input_signature["box"] = tf.TensorSpec((3, 3), dtype=tf.float64, name="box") -# input_signature["idx"] = tf.TensorSpec( -# (2, self.max_nbrs), dtype=tf.int16, name="idx" -# ) -# input_signature["offsets"] = tf.TensorSpec( -# (self.max_nbrs, 3), dtype=tf.float64, name="offsets" -# ) - -# if not self.labels: -# return input_signature - -# label_signature = {} -# if "energy" in self.labels.keys(): -# label_signature["energy"] = tf.TensorSpec((), dtype=tf.float64, name="energy") -# if "forces" in self.labels.keys(): -# label_signature["forces"] = tf.TensorSpec( -# (self.max_atoms, 3), dtype=tf.float64, name="forces" -# ) -# if "stress" in self.labels.keys(): -# label_signature["stress"] = tf.TensorSpec( -# (3, 3), dtype=tf.float64, name="stress" -# ) -# signature = (input_signature, label_signature) -# return signature - -# def init_input(self) -> Dict[str, np.ndarray]: -# """Returns first batch of inputs and labels to init the model.""" -# positions = self.sample_atoms.positions -# box = self.sample_atoms.cell.array -# idx, offsets = compute_nl(positions, box, self.cutoff) -# inputs = ( -# positions, -# self.sample_atoms.numbers, -# idx, -# box, -# offsets, -# ) - -# inputs = jax.tree_map(lambda x: jnp.array(x), inputs) -# return inputs, np.array(box) - -# def shuffle_and_batch(self): -# """Shuffles and batches the inputs/labels. This function prepares the -# inputs and labels for the whole training and prefetches the data. - -# Returns -# ------- -# ds : -# Iterator that returns inputs and labels of one batch in each step. -# """ -# ds = tf.data.Dataset.from_generator( -# lambda: self, output_signature=self.make_signature() -# ) - -# ds = ds.shuffle( -# buffer_size=self.buffer_size, reshuffle_each_iteration=True -# ).batch(batch_size=self.batch_size) -# if self.n_jit_steps > 1: -# ds = ds.batch(batch_size=self.n_jit_steps) -# ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) -# return ds - -# def batch(self) -> Iterator[jax.Array]: -# ds = tf.data.Dataset.from_generator( -# lambda: self, output_signature=self.make_signature() -# ) -# ds = ds.batch(batch_size=self.batch_size) -# ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) -# return ds - - class InMemoryDataset: def __init__( self, @@ -268,7 +82,7 @@ def steps_per_epoch(self) -> int: data are dropped in each epoch. """ return self.n_data // self.batch_size // self.n_jit_steps - + def validate_batch_size(self, batch_size: int) -> int: if batch_size > self.n_data: msg = ( @@ -279,7 +93,7 @@ def validate_batch_size(self, batch_size: int) -> int: log.warning(msg) batch_size = self.n_data return batch_size - + def prepare_data(self, i): inputs = {k: v[i] for k, v in self.inputs.items()} idx, offsets = compute_nl(inputs["positions"], inputs["box"], self.cutoff) @@ -315,15 +129,6 @@ def enqueue(self, num_elements): self.buffer.append(data) self.count += 1 - def __iter__(self): - while self.count < self.n_data or len(self.buffer) > 0: - yield self.buffer.popleft() - - space = self.buffer_size - len(self.buffer) - if self.count + space > self.n_data: - space = self.n_data - self.count - self.enqueue(space) - def make_signature(self) -> tf.TensorSpec: input_signature = {} input_signature["n_atoms"] = tf.TensorSpec((), dtype=tf.int16, name="n_atoms") @@ -373,6 +178,32 @@ def init_input(self) -> Dict[str, np.ndarray]: inputs = jax.tree_map(lambda x: jnp.array(x), inputs) return inputs, np.array(box) + + def __iter__(self): + raise NotImplementedError + + def shuffle_and_batch(self): + raise NotImplementedError + + def batch(self) -> Iterator[jax.Array]: + raise NotImplementedError + + def cleanup(self): + pass + + +class CachedInMemoryDataset(InMemoryDataset): + + def __iter__(self): + while self.count < self.n_data or len(self.buffer) > 0: + yield self.buffer.popleft() + + space = self.buffer_size - len(self.buffer) + if self.count + space > self.n_data: + space = self.n_data - self.count + self.enqueue(space) + + def shuffle_and_batch(self): """Shuffles and batches the inputs/labels. This function prepares the @@ -409,4 +240,55 @@ def cleanup(self): index_file = self.file.parent / f"{self.file.name}.index" index_file.unlink() - \ No newline at end of file + + +class OTFInMemoryDataset(InMemoryDataset): + + def __iter__(self): + epoch = 0 + while epoch < self.n_epochs or len(self.buffer) > 0: + yield self.buffer.popleft() + + space = self.buffer_size - len(self.buffer) + if self.count + space > self.n_data: + space = self.n_data - self.count + + if self.count >= self.n_data and epoch < self.n_epochs: + epoch += 1 + self.count = 0 + self.enqueue(space) + + def shuffle_and_batch(self): + """Shuffles and batches the inputs/labels. This function prepares the + inputs and labels for the whole training and prefetches the data. + + Returns + ------- + ds : + Iterator that returns inputs and labels of one batch in each step. + """ + ds = tf.data.Dataset.from_generator( + lambda: self, output_signature=self.make_signature() + ) + + ds = ds.shuffle( + buffer_size=self.buffer_size, reshuffle_each_iteration=True + ).batch(batch_size=self.batch_size) + if self.n_jit_steps > 1: + ds = ds.batch(batch_size=self.n_jit_steps) + ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) + return ds + + def batch(self) -> Iterator[jax.Array]: + ds = tf.data.Dataset.from_generator( + lambda: self, output_signature=self.make_signature() + ) + ds = ds.batch(batch_size=self.batch_size) + ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) + return ds + + +dataset_dict = { + "cached": CachedInMemoryDataset, + "otf": OTFInMemoryDataset, +} \ No newline at end of file diff --git a/apax/md/ase_calc.py b/apax/md/ase_calc.py index 73f27658..69ef95f6 100644 --- a/apax/md/ase_calc.py +++ b/apax/md/ase_calc.py @@ -11,7 +11,7 @@ from matscipy.neighbours import neighbour_list from tqdm import trange -from apax.data.input_pipeline import InMemoryDataset +from apax.data.input_pipeline import OTFInMemoryDataset from apax.model import ModelBuilder from apax.train.checkpoints import check_for_ensemble, restore_parameters from apax.utils.jax_md_reduced import partition, quantity, space @@ -256,7 +256,7 @@ def batch_eval( """ if self.model is None: self.initialize(atoms_list[0]) - dataset = InMemoryDataset( + dataset = OTFInMemoryDataset( atoms_list, self.model_config.model.r_max, batch_size, diff --git a/apax/train/eval.py b/apax/train/eval.py index b7d69700..f15a6da3 100644 --- a/apax/train/eval.py +++ b/apax/train/eval.py @@ -8,7 +8,7 @@ from tqdm import trange from apax.config import parse_config -from apax.data.input_pipeline import InMemoryDataset +from apax.data.input_pipeline import OTFInMemoryDataset from apax.model import ModelBuilder from apax.train.callbacks import initialize_callbacks from apax.train.checkpoints import restore_single_parameters @@ -122,7 +122,7 @@ def eval_model(config_path, n_test=-1, log_file="eval.log", log_level="error"): Metrics = initialize_metrics(config.metrics) atoms_list = load_test_data(config, model_version_path, eval_path, n_test) - test_ds = InMemoryDataset( + test_ds = OTFInMemoryDataset( atoms_list, config.model.r_max, config.data.valid_batch_size ) diff --git a/apax/train/run.py b/apax/train/run.py index 0dee0f08..cb7fb60a 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -4,9 +4,9 @@ import jax -from apax.config import LossConfig, parse_config +from apax.config import LossConfig, parse_config, Config from apax.data.initialization import load_data_files -from apax.data.input_pipeline import InMemoryDataset +from apax.data.input_pipeline import dataset_dict from apax.data.statistics import compute_scale_shift_parameters from apax.model import ModelBuilder from apax.optimizer import get_opt @@ -50,10 +50,12 @@ def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection: loss_funcs.append(Loss(**loss.model_dump())) return LossCollection(loss_funcs) -def initialize_datasets(config): +def initialize_datasets(config: Config): train_raw_ds, val_raw_ds = load_data_files(config.data) - train_ds = InMemoryDataset( + Dataset = dataset_dict[config.data.ds_type] + + train_ds = Dataset( train_raw_ds, config.model.r_max, config.data.batch_size, @@ -63,7 +65,7 @@ def initialize_datasets(config): pre_shuffle=True, cache_path=config.data.model_version_path, ) - val_ds = InMemoryDataset( + val_ds = Dataset( val_raw_ds, config.model.r_max, config.data.valid_batch_size, config.n_epochs, cache_path=config.data.model_version_path, ) ds_stats = compute_scale_shift_parameters( From 90c9f6918719170ac1a22827f2d2d8ae79367659 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 29 Mar 2024 13:19:35 +0100 Subject: [PATCH 06/12] remove douple prinitng of BS warning --- apax/data/input_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 1a014c55..07cee803 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -89,7 +89,6 @@ def validate_batch_size(self, batch_size: int) -> int: f"requested batch size {batch_size} is larger than the number of data" f" points {self.n_data}. Setting batch size = {self.n_data}" ) - print("Warning: " + msg) log.warning(msg) batch_size = self.n_data return batch_size From 07c86b55b417e4c59bdb33289f024f66eb6338c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 29 Mar 2024 13:20:01 +0100 Subject: [PATCH 07/12] filter erroneos TF warning --- apax/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/apax/__init__.py b/apax/__init__.py index 7438bf4a..bc72f84c 100644 --- a/apax/__init__.py +++ b/apax/__init__.py @@ -3,6 +3,7 @@ import jax os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' jax.config.update("jax_enable_x64", True) from apax.utils.helpers import setup_ase From 731592c7456a723dfa36afc95eae1f5f0d643417 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Fri, 29 Mar 2024 13:22:12 +0100 Subject: [PATCH 08/12] linting --- apax/__init__.py | 2 +- apax/data/input_pipeline.py | 42 ++++++++++++++++++++----------------- apax/train/run.py | 10 ++++++--- apax/train/trainer.py | 11 +++++----- 4 files changed, 37 insertions(+), 28 deletions(-) diff --git a/apax/__init__.py b/apax/__init__.py index bc72f84c..de4b9e6e 100644 --- a/apax/__init__.py +++ b/apax/__init__.py @@ -3,7 +3,7 @@ import jax os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" -os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" jax.config.update("jax_enable_x64", True) from apax.utils.helpers import setup_ase diff --git a/apax/data/input_pipeline.py b/apax/data/input_pipeline.py index 07cee803..082415a0 100644 --- a/apax/data/input_pipeline.py +++ b/apax/data/input_pipeline.py @@ -1,9 +1,9 @@ import logging +import uuid from collections import deque +from pathlib import Path from random import shuffle from typing import Dict, Iterator -import uuid -from pathlib import Path import jax import jax.numpy as jnp @@ -46,7 +46,7 @@ def __init__( n_jit_steps=1, pre_shuffle=False, ignore_labels=False, - cache_path = "." + cache_path=".", ) -> None: if pre_shuffle: shuffle(atoms) @@ -82,7 +82,7 @@ def steps_per_epoch(self) -> int: data are dropped in each epoch. """ return self.n_data // self.batch_size // self.n_jit_steps - + def validate_batch_size(self, batch_size: int) -> int: if batch_size > self.n_data: msg = ( @@ -92,7 +92,7 @@ def validate_batch_size(self, batch_size: int) -> int: log.warning(msg) batch_size = self.n_data return batch_size - + def prepare_data(self, i): inputs = {k: v[i] for k, v in self.inputs.items()} idx, offsets = compute_nl(inputs["positions"], inputs["box"], self.cutoff) @@ -177,7 +177,7 @@ def init_input(self) -> Dict[str, np.ndarray]: inputs = jax.tree_map(lambda x: jnp.array(x), inputs) return inputs, np.array(box) - + def __iter__(self): raise NotImplementedError @@ -192,7 +192,6 @@ def cleanup(self): class CachedInMemoryDataset(InMemoryDataset): - def __iter__(self): while self.count < self.n_data or len(self.buffer) > 0: yield self.buffer.popleft() @@ -202,8 +201,6 @@ def __iter__(self): space = self.n_data - self.count self.enqueue(space) - - def shuffle_and_batch(self): """Shuffles and batches the inputs/labels. This function prepares the inputs and labels for the whole training and prefetches the data. @@ -213,9 +210,13 @@ def shuffle_and_batch(self): ds : Iterator that returns inputs and labels of one batch in each step. """ - ds = tf.data.Dataset.from_generator( - lambda: self, output_signature=self.make_signature() - ).cache(self.file.as_posix()).repeat(self.n_epochs) + ds = ( + tf.data.Dataset.from_generator( + lambda: self, output_signature=self.make_signature() + ) + .cache(self.file.as_posix()) + .repeat(self.n_epochs) + ) ds = ds.shuffle( buffer_size=self.buffer_size, reshuffle_each_iteration=True @@ -226,13 +227,17 @@ def shuffle_and_batch(self): return ds def batch(self) -> Iterator[jax.Array]: - ds = tf.data.Dataset.from_generator( - lambda: self, output_signature=self.make_signature() - ).cache(self.file.as_posix()).repeat(self.n_epochs) + ds = ( + tf.data.Dataset.from_generator( + lambda: self, output_signature=self.make_signature() + ) + .cache(self.file.as_posix()) + .repeat(self.n_epochs) + ) ds = ds.batch(batch_size=self.batch_size) ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) return ds - + def cleanup(self): for p in self.file.parent.glob(f"{self.file.name}.data*"): p.unlink() @@ -242,7 +247,6 @@ def cleanup(self): class OTFInMemoryDataset(InMemoryDataset): - def __iter__(self): epoch = 0 while epoch < self.n_epochs or len(self.buffer) > 0: @@ -285,9 +289,9 @@ def batch(self) -> Iterator[jax.Array]: ds = ds.batch(batch_size=self.batch_size) ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2) return ds - + dataset_dict = { "cached": CachedInMemoryDataset, "otf": OTFInMemoryDataset, -} \ No newline at end of file +} diff --git a/apax/train/run.py b/apax/train/run.py index cb7fb60a..fe408ab6 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -4,7 +4,7 @@ import jax -from apax.config import LossConfig, parse_config, Config +from apax.config import Config, LossConfig, parse_config from apax.data.initialization import load_data_files from apax.data.input_pipeline import dataset_dict from apax.data.statistics import compute_scale_shift_parameters @@ -50,6 +50,7 @@ def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection: loss_funcs.append(Loss(**loss.model_dump())) return LossCollection(loss_funcs) + def initialize_datasets(config: Config): train_raw_ds, val_raw_ds = load_data_files(config.data) @@ -66,7 +67,11 @@ def initialize_datasets(config: Config): cache_path=config.data.model_version_path, ) val_ds = Dataset( - val_raw_ds, config.model.r_max, config.data.valid_batch_size, config.n_epochs, cache_path=config.data.model_version_path, + val_raw_ds, + config.model.r_max, + config.data.valid_batch_size, + config.n_epochs, + cache_path=config.data.model_version_path, ) ds_stats = compute_scale_shift_parameters( train_ds.inputs, @@ -79,7 +84,6 @@ def initialize_datasets(config: Config): return train_ds, val_ds, ds_stats - def run(user_config, log_level="error"): config = parse_config(user_config) diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 7575d23e..6d6bc0f0 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -1,6 +1,5 @@ import functools import logging -from pathlib import Path import time from functools import partial from typing import Callable, Optional @@ -108,10 +107,12 @@ def fit( epoch_loss["val_loss"] /= val_steps_per_epoch epoch_loss["val_loss"] = float(epoch_loss["val_loss"]) - epoch_metrics.update({ - f"val_{key}": float(val) - for key, val in val_batch_metrics.compute().items() - }) + epoch_metrics.update( + { + f"val_{key}": float(val) + for key, val in val_batch_metrics.compute().items() + } + ) epoch_metrics.update({**epoch_loss}) From af320db5f4e7048b9b88aa31e447874d20f61747 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 2 Apr 2024 11:01:19 +0200 Subject: [PATCH 09/12] implemented huber loss --- apax/config/train_config.py | 4 +- apax/train/loss.py | 59 +++++++++++++++++------------ tests/unit_tests/train/test_loss.py | 2 +- 3 files changed, 38 insertions(+), 27 deletions(-) diff --git a/apax/config/train_config.py b/apax/config/train_config.py index 5733dec8..3cd5f57c 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -229,8 +229,10 @@ class LossConfig(BaseModel, extra="forbid"): """ name: str - loss_type: str = "structures" + loss_type: str = "mse" weight: NonNegativeFloat = 1.0 + atoms_exponent: NonNegativeFloat = 1 + parameters: dict = {} class CallbackConfig(BaseModel, frozen=True, extra="forbid"): diff --git a/apax/train/loss.py b/apax/train/loss.py index cd28786e..c2e575d3 100644 --- a/apax/train/loss.py +++ b/apax/train/loss.py @@ -8,7 +8,7 @@ def weighted_squared_error( - label: jnp.array, prediction: jnp.array, divisor: float = 1.0 + label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {} ) -> jnp.array: """ Squared error function that allows weighting of @@ -17,8 +17,23 @@ def weighted_squared_error( return (label - prediction) ** 2 / divisor +def weighted_huber_loss( + label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {} +) -> jnp.array: + """ + Huber loss function that allows weighting of + individual contributions by the number of atoms in the system. + """ + if "delta" not in parameters.keys(): + raise KeyError("Huber loss function requires 'delta' parameter") + delta = parameters["delta"] + diff = jnp.abs(label - prediction) + loss = jnp.where(diff > delta, delta * (diff - 0.5 * delta), 0.5 * diff**2) + return loss / divisor + + def force_angle_loss( - label: jnp.array, prediction: jnp.array, divisor: float = 1.0 + label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {} ) -> jnp.array: """ Consine similarity loss function. Contributions are summed in `Loss`. @@ -28,7 +43,7 @@ def force_angle_loss( def force_angle_div_force_label( - label: jnp.array, prediction: jnp.array, divisor: float = 1.0 + label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {} ): """ Consine similarity loss function weighted by the norm of the force labels. @@ -41,7 +56,7 @@ def force_angle_div_force_label( def force_angle_exponential_weight( - label: jnp.array, prediction: jnp.array, divisor: float = 1.0 + label: jnp.array, prediction: jnp.array, divisor: float = 1.0, parameters: dict = {} ) -> jnp.array: """ Consine similarity loss function exponentially scaled by the norm of the force labels. @@ -52,7 +67,7 @@ def force_angle_exponential_weight( return (1.0 - dotp) * jnp.exp(-F_0_norm) / divisor -def stress_tril(label, prediction, divisor=1.0): +def stress_tril(label, prediction, divisor=1.0, parameters: dict = {}): idxs = jnp.tril_indices(3) label_tril = label[:, idxs[0], idxs[1]] prediction_tril = prediction[:, idxs[0], idxs[1]] @@ -60,9 +75,8 @@ def stress_tril(label, prediction, divisor=1.0): loss_functions = { - "molecules": weighted_squared_error, - "structures": weighted_squared_error, - "vibrations": weighted_squared_error, + "mse": weighted_squared_error, + "huber": weighted_huber_loss, "cosine_sim": force_angle_loss, "cosine_sim_div_magnitude": force_angle_div_force_label, "cosine_sim_exp_magnitude": force_angle_exponential_weight, @@ -80,6 +94,8 @@ class Loss: name: str loss_type: str weight: float = 1.0 + atoms_exponent: float = 1.0 + parameters: dict = dataclasses.field(default_factory=lambda: {}) def __post_init__(self): if self.loss_type not in loss_functions.keys(): @@ -94,25 +110,18 @@ def __post_init__(self): def __call__(self, inputs: dict, prediction: dict, label: dict) -> float: # TODO we may want to insert an additional `mask` argument for this method divisor = self.determine_divisor(inputs["n_atoms"]) - loss = self.loss_fn(label[self.name], prediction[self.name], divisor=divisor) - return self.weight * jnp.sum(jnp.mean(loss, axis=0)) + batch_losses = self.loss_fn( + label[self.name], prediction[self.name], divisor, self.parameters + ) + loss = self.weight * jnp.sum(jnp.mean(batch_losses, axis=0)) + return loss def determine_divisor(self, n_atoms: jnp.array) -> jnp.array: - divisor_id = self.name + "_" + self.loss_type - divisor_dict = { - "energy_structures": n_atoms**2, - "energy_vibrations": n_atoms, - "forces_structures": einops.repeat(n_atoms, "batch -> batch 1 1"), - "forces_cosine_sim": einops.repeat(n_atoms, "batch -> batch 1 1"), - "cosine_sim_div_magnitude": einops.repeat(n_atoms, "batch -> batch 1 1"), - "forces_cosine_sim_exp_magnitude": einops.repeat( - n_atoms, "batch -> batch 1 1" - ), - "stress_structures": einops.repeat(n_atoms**2, "batch -> batch 1 1"), - "stress_tril": einops.repeat(n_atoms**2, "batch -> batch 1 1"), - "stress_vibrations": einops.repeat(n_atoms, "batch -> batch 1 1"), - } - divisor = divisor_dict.get(divisor_id, jnp.array(1.0)) + # shape: batch + divisor = n_atoms**self.atoms_exponent + + if self.name in ["forces", "stress"]: + divisor = einops.repeat(divisor, "batch -> batch 1 1") return divisor diff --git a/tests/unit_tests/train/test_loss.py b/tests/unit_tests/train/test_loss.py index 26c18bc1..b39aac4c 100644 --- a/tests/unit_tests/train/test_loss.py +++ b/tests/unit_tests/train/test_loss.py @@ -68,7 +68,7 @@ def test_force_angle_loss(): def test_force_loss(): name = "forces" - loss_type = "structures" + loss_type = "mse" weight = 1 inputs = { "n_atoms": jnp.array([2]), From dd8f601d21c7266a2bd96752f28fe7435a105018 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Apr 2024 09:09:06 +0000 Subject: [PATCH 10/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/train/trainer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 6d6bc0f0..8c040a3f 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -107,12 +107,10 @@ def fit( epoch_loss["val_loss"] /= val_steps_per_epoch epoch_loss["val_loss"] = float(epoch_loss["val_loss"]) - epoch_metrics.update( - { - f"val_{key}": float(val) - for key, val in val_batch_metrics.compute().items() - } - ) + epoch_metrics.update({ + f"val_{key}": float(val) + for key, val in val_batch_metrics.compute().items() + }) epoch_metrics.update({**epoch_loss}) From 69becc1af251559eec871064c2b315ac65d88c4e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Tue, 2 Apr 2024 11:27:06 +0200 Subject: [PATCH 11/12] implemented yaml autocompletion via PyDantic json schemata --- .gitignore | 1 + README.md | 32 +++++++++++++++++++++++++++++--- apax/cli/apax_app.py | 18 ++++++++++++++++++ 3 files changed, 48 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index e6981eda..041b225d 100644 --- a/.gitignore +++ b/.gitignore @@ -59,6 +59,7 @@ tmp/ .traj .h5 events.out.* +*.schema.json # Translations *.mo diff --git a/README.md b/README.md index 16d78cca..dc4b76c8 100644 --- a/README.md +++ b/README.md @@ -60,14 +60,14 @@ See the [Jax installation instructions](https://github.com/google/jax#installati In order to train a model, you need to run -```python +```bash apax train config.yaml ``` We offer some input file templates to get new users started as quickly as possible. Simply run the following commands and add the appropriate entries in the marked fields -```python +```bash apax template train # use --full for a template with all input options ``` @@ -79,7 +79,7 @@ The documentation can convenienty be accessed by running `apax docs`. There are two ways in which `apax` models can be used for molecular dynamics out of the box. High performance NVT simulations using JaxMD can be started with the CLI by running -```python +```bash apax md config.yaml md_config.yaml ``` @@ -88,6 +88,32 @@ A template command for MD input files is provided as well. The second way is to use the ASE calculator provided in `apax.md`. +## Input File Auto-Completion + +use the following command to generate JSON schemata for training and validation files: + +```bash +apax schema +``` + +If you are using VSCode, you can utilize them to lint and autocomplete your input files by including them in `.vscode/settings.json` + +```json +{ + "yaml.schemas": { + + "/absolute/path/to/apaxtrain.schema.json": [ + "train.yaml" + ] + , + "/absolute/path/to/apaxmd.schema.json": [ + "md.yaml" + ] + } +} +``` + + ## Authors - Moritz René Schäfer - Nico Segreto diff --git a/apax/cli/apax_app.py b/apax/cli/apax_app.py index 60321662..cd09ea7d 100644 --- a/apax/cli/apax_app.py +++ b/apax/cli/apax_app.py @@ -1,5 +1,6 @@ import importlib.metadata import importlib.resources as pkg_resources +import json import sys from pathlib import Path @@ -93,6 +94,23 @@ def docs(): typer.launch("https://apax.readthedocs.io/en/latest/") +@app.command() +def schema(): + """ + Generating JSON schemata for autocompletion of train/md inputs in VSCode. + """ + console.print("Generating JSON schema") + from apax.config import Config, MDConfig + + train_schema = Config.model_json_schema() + md_schema = MDConfig.model_json_schema() + with open("./apaxtrain.schema.json", "w") as f: + f.write(json.dumps(train_schema, indent=2)) + + with open("./apaxmd.schema.json", "w") as f: + f.write(json.dumps(md_schema, indent=2)) + + @validate_app.command("train") def validate_train_config( config_path: Path = typer.Argument( From d5cafc6c4ec0703c2df4faabaebb3fd232139320 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Apr 2024 09:29:03 +0000 Subject: [PATCH 12/12] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index dc4b76c8..c3ffc31f 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,7 @@ If you are using VSCode, you can utilize them to lint and autocomplete your inpu ```json { "yaml.schemas": { - + "/absolute/path/to/apaxtrain.schema.json": [ "train.yaml" ]