Skip to content

Commit

Permalink
Merge pull request #248 from apax-hub/cache_data
Browse files Browse the repository at this point in the history
Cache dataset
  • Loading branch information
M-R-Schaefer authored Apr 3, 2024
2 parents 10470b1 + 2f5ec10 commit 8f6faa7
Show file tree
Hide file tree
Showing 14 changed files with 225 additions and 73 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ tmp/
.traj
.h5
events.out.*
*.schema.json

# Translations
*.mo
Expand Down
32 changes: 29 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,14 @@ See the [Jax installation instructions](https://github.com/google/jax#installati

In order to train a model, you need to run

```python
```bash
apax train config.yaml
```

We offer some input file templates to get new users started as quickly as possible.
Simply run the following commands and add the appropriate entries in the marked fields

```python
```bash
apax template train # use --full for a template with all input options
```

Expand All @@ -79,7 +79,7 @@ The documentation can convenienty be accessed by running `apax docs`.
There are two ways in which `apax` models can be used for molecular dynamics out of the box.
High performance NVT simulations using JaxMD can be started with the CLI by running

```python
```bash
apax md config.yaml md_config.yaml
```

Expand All @@ -88,6 +88,32 @@ A template command for MD input files is provided as well.
The second way is to use the ASE calculator provided in `apax.md`.


## Input File Auto-Completion

use the following command to generate JSON schemata for training and validation files:

```bash
apax schema
```

If you are using VSCode, you can utilize them to lint and autocomplete your input files by including them in `.vscode/settings.json`

```json
{
"yaml.schemas": {

"/absolute/path/to/apaxtrain.schema.json": [
"train.yaml"
]
,
"/absolute/path/to/apaxmd.schema.json": [
"md.yaml"
]
}
}
```


## Authors
- Moritz René Schäfer
- Nico Segreto
Expand Down
1 change: 1 addition & 0 deletions apax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
jax.config.update("jax_enable_x64", True)
from apax.utils.helpers import setup_ase

Expand Down
6 changes: 3 additions & 3 deletions apax/bal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import trange

from apax.bal import feature_maps, kernel, selection, transforms
from apax.data.input_pipeline import InMemoryDataset
from apax.data.input_pipeline import OTFInMemoryDataset
from apax.model.builder import ModelBuilder
from apax.model.gmnn import EnergyModel
from apax.train.checkpoints import (
Expand Down Expand Up @@ -46,7 +46,7 @@ def create_feature_fn(
return feature_fn


def compute_features(feature_fn, dataset: InMemoryDataset):
def compute_features(feature_fn, dataset: OTFInMemoryDataset):
"""Compute the features of a dataset."""
features = []
n_data = dataset.n_data
Expand Down Expand Up @@ -85,7 +85,7 @@ def kernel_selection(
is_ensemble = n_models > 1

n_train = len(train_atoms)
dataset = InMemoryDataset(
dataset = OTFInMemoryDataset(
train_atoms + pool_atoms,
cutoff=config.model.r_max,
bs=processing_batch_size,
Expand Down
18 changes: 18 additions & 0 deletions apax/cli/apax_app.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import importlib.metadata
import importlib.resources as pkg_resources
import json
import sys
from pathlib import Path

Expand Down Expand Up @@ -93,6 +94,23 @@ def docs():
typer.launch("https://apax.readthedocs.io/en/latest/")


@app.command()
def schema():
"""
Generating JSON schemata for autocompletion of train/md inputs in VSCode.
"""
console.print("Generating JSON schema")
from apax.config import Config, MDConfig

train_schema = Config.model_json_schema()
md_schema = MDConfig.model_json_schema()
with open("./apaxtrain.schema.json", "w") as f:
f.write(json.dumps(train_schema, indent=2))

with open("./apaxmd.schema.json", "w") as f:
f.write(json.dumps(md_schema, indent=2))


@validate_app.command("train")
def validate_train_config(
config_path: Path = typer.Argument(
Expand Down
5 changes: 4 additions & 1 deletion apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class DataConfig(BaseModel, extra="forbid"):

directory: str
experiment: str
ds_type: Literal["cached", "otf"] = "cached"
data_path: Optional[str] = None
train_data_path: Optional[str] = None
val_data_path: Optional[str] = None
Expand Down Expand Up @@ -228,8 +229,10 @@ class LossConfig(BaseModel, extra="forbid"):
"""

name: str
loss_type: str = "structures"
loss_type: str = "mse"
weight: NonNegativeFloat = 1.0
atoms_exponent: NonNegativeFloat = 1
parameters: dict = {}


class CallbackConfig(BaseModel, frozen=True, extra="forbid"):
Expand Down
108 changes: 93 additions & 15 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import uuid
from collections import deque
from pathlib import Path
from random import shuffle
from typing import Dict, Iterator

Expand Down Expand Up @@ -44,6 +46,7 @@ def __init__(
n_jit_steps=1,
pre_shuffle=False,
ignore_labels=False,
cache_path=".",
) -> None:
if pre_shuffle:
shuffle(atoms)
Expand All @@ -68,6 +71,7 @@ def __init__(
self.buffer = deque()
self.batch_size = self.validate_batch_size(bs)
self.n_jit_steps = n_jit_steps
self.file = Path(cache_path) / str(uuid.uuid4())

self.enqueue(min(self.buffer_size, self.n_data))

Expand All @@ -85,7 +89,6 @@ def validate_batch_size(self, batch_size: int) -> int:
f"requested batch size {batch_size} is larger than the number of data"
f" points {self.n_data}. Setting batch size = {self.n_data}"
)
print("Warning: " + msg)
log.warning(msg)
batch_size = self.n_data
return batch_size
Expand Down Expand Up @@ -125,20 +128,6 @@ def enqueue(self, num_elements):
self.buffer.append(data)
self.count += 1

def __iter__(self):
epoch = 0
while epoch < self.n_epochs or len(self.buffer) > 0:
yield self.buffer.popleft()

space = self.buffer_size - len(self.buffer)
if self.count + space > self.n_data:
space = self.n_data - self.count

if self.count >= self.n_data and epoch < self.n_epochs:
epoch += 1
self.count = 0
self.enqueue(space)

def make_signature(self) -> tf.TensorSpec:
input_signature = {}
input_signature["n_atoms"] = tf.TensorSpec((), dtype=tf.int16, name="n_atoms")
Expand Down Expand Up @@ -189,6 +178,89 @@ def init_input(self) -> Dict[str, np.ndarray]:
inputs = jax.tree_map(lambda x: jnp.array(x), inputs)
return inputs, np.array(box)

def __iter__(self):
raise NotImplementedError

def shuffle_and_batch(self):
raise NotImplementedError

def batch(self) -> Iterator[jax.Array]:
raise NotImplementedError

def cleanup(self):
pass


class CachedInMemoryDataset(InMemoryDataset):
def __iter__(self):
while self.count < self.n_data or len(self.buffer) > 0:
yield self.buffer.popleft()

space = self.buffer_size - len(self.buffer)
if self.count + space > self.n_data:
space = self.n_data - self.count
self.enqueue(space)

def shuffle_and_batch(self):
"""Shuffles and batches the inputs/labels. This function prepares the
inputs and labels for the whole training and prefetches the data.
Returns
-------
ds :
Iterator that returns inputs and labels of one batch in each step.
"""
ds = (
tf.data.Dataset.from_generator(
lambda: self, output_signature=self.make_signature()
)
.cache(self.file.as_posix())
.repeat(self.n_epochs)
)

ds = ds.shuffle(
buffer_size=self.buffer_size, reshuffle_each_iteration=True
).batch(batch_size=self.batch_size)
if self.n_jit_steps > 1:
ds = ds.batch(batch_size=self.n_jit_steps)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
return ds

def batch(self) -> Iterator[jax.Array]:
ds = (
tf.data.Dataset.from_generator(
lambda: self, output_signature=self.make_signature()
)
.cache(self.file.as_posix())
.repeat(self.n_epochs)
)
ds = ds.batch(batch_size=self.batch_size)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
return ds

def cleanup(self):
for p in self.file.parent.glob(f"{self.file.name}.data*"):
p.unlink()

index_file = self.file.parent / f"{self.file.name}.index"
index_file.unlink()


class OTFInMemoryDataset(InMemoryDataset):
def __iter__(self):
epoch = 0
while epoch < self.n_epochs or len(self.buffer) > 0:
yield self.buffer.popleft()

space = self.buffer_size - len(self.buffer)
if self.count + space > self.n_data:
space = self.n_data - self.count

if self.count >= self.n_data and epoch < self.n_epochs:
epoch += 1
self.count = 0
self.enqueue(space)

def shuffle_and_batch(self):
"""Shuffles and batches the inputs/labels. This function prepares the
inputs and labels for the whole training and prefetches the data.
Expand Down Expand Up @@ -217,3 +289,9 @@ def batch(self) -> Iterator[jax.Array]:
ds = ds.batch(batch_size=self.batch_size)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
return ds


dataset_dict = {
"cached": CachedInMemoryDataset,
"otf": OTFInMemoryDataset,
}
4 changes: 2 additions & 2 deletions apax/md/ase_calc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from matscipy.neighbours import neighbour_list
from tqdm import trange

from apax.data.input_pipeline import InMemoryDataset
from apax.data.input_pipeline import OTFInMemoryDataset
from apax.model import ModelBuilder
from apax.train.checkpoints import check_for_ensemble, restore_parameters
from apax.utils.jax_md_reduced import partition, quantity, space
Expand Down Expand Up @@ -256,7 +256,7 @@ def batch_eval(
"""
if self.model is None:
self.initialize(atoms_list[0])
dataset = InMemoryDataset(
dataset = OTFInMemoryDataset(
atoms_list,
self.model_config.model.r_max,
batch_size,
Expand Down
4 changes: 2 additions & 2 deletions apax/train/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import trange

from apax.config import parse_config
from apax.data.input_pipeline import InMemoryDataset
from apax.data.input_pipeline import OTFInMemoryDataset
from apax.model import ModelBuilder
from apax.train.callbacks import initialize_callbacks
from apax.train.checkpoints import restore_single_parameters
Expand Down Expand Up @@ -122,7 +122,7 @@ def eval_model(config_path, n_test=-1, log_file="eval.log", log_level="error"):
Metrics = initialize_metrics(config.metrics)

atoms_list = load_test_data(config, model_version_path, eval_path, n_test)
test_ds = InMemoryDataset(
test_ds = OTFInMemoryDataset(
atoms_list, config.model.r_max, config.data.valid_batch_size
)

Expand Down
Loading

0 comments on commit 8f6faa7

Please sign in to comment.