Skip to content

Commit

Permalink
Merge branch 'dev' into moredocs
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Apr 10, 2024
2 parents 0be92aa + 0b59ae2 commit 095cfe4
Show file tree
Hide file tree
Showing 25 changed files with 2,358 additions and 159 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Install package
run: |
poetry --version
poetry install
poetry install --all-extras
- name: Unit Tests
run: |
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- id: trailing-whitespace

- repo: https://github.com/psf/black
rev: 24.1.1
rev: 24.3.0
hooks:
- id: black
exclude: ^apax/utils/jax_md_reduced/
Expand Down
3 changes: 3 additions & 0 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,13 +331,16 @@ class Config(BaseModel, frozen=True, extra="forbid"):
callbacks: List of :class: `callback` <config.CallbackConfig> configurations.
progress_bar: Progressbar configuration.
checkpoints: Checkpoint configuration.
data_parallel: Automatically uses all available GPUs for data parallel training.
Set to false to force single device training.
"""

n_epochs: PositiveInt
patience: Optional[PositiveInt] = None
seed: int = 1
n_models: int = 1
n_jitted_steps: int = 1
data_parallel: int = True

data: DataConfig
model: ModelConfig = ModelConfig()
Expand Down
51 changes: 28 additions & 23 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ def pad_nl(idx, offsets, max_neighbors):
return idx, offsets


def find_largest_system(inputs: dict[str, np.ndarray], r_max) -> tuple[int]:
def find_largest_system(inputs, r_max) -> tuple[int]:
positions, boxes = inputs["positions"], inputs["box"]
max_atoms = np.max(inputs["n_atoms"])

max_nbrs = 0
for position, box in zip(inputs["positions"], inputs["box"]):
neighbor_idxs, _ = compute_nl(position, box, r_max)
for pos, box in zip(positions, boxes):
neighbor_idxs, _ = compute_nl(pos, box, r_max)
n_neighbors = neighbor_idxs.shape[1]
max_nbrs = max(max_nbrs, n_neighbors)

Expand All @@ -38,7 +39,7 @@ def find_largest_system(inputs: dict[str, np.ndarray], r_max) -> tuple[int]:
class InMemoryDataset:
def __init__(
self,
atoms,
atoms_list,
cutoff,
bs,
n_epochs,
Expand All @@ -54,21 +55,20 @@ def __init__(
self.cutoff = cutoff
self.n_jit_steps = n_jit_steps
self.buffer_size = buffer_size
self.n_data = len(atoms)
self.n_data = len(atoms_list)
self.batch_size = self.validate_batch_size(bs)
self.pos_unit = pos_unit

if pre_shuffle:
shuffle(atoms)
self.sample_atoms = atoms[0]
self.inputs = atoms_to_inputs(atoms, self.pos_unit)
shuffle(atoms_list)
self.sample_atoms = atoms_list[0]
self.inputs = atoms_to_inputs(atoms_list, pos_unit)

max_atoms, max_nbrs = find_largest_system(self.inputs, self.cutoff)
self.max_atoms = max_atoms
self.max_nbrs = max_nbrs

if atoms[0].calc and not ignore_labels:
self.labels = atoms_to_labels(atoms, self.pos_unit, energy_unit)
if atoms_list[0].calc and not ignore_labels:
self.labels = atoms_to_labels(atoms_list, pos_unit, energy_unit)
else:
self.labels = None

Expand Down Expand Up @@ -108,9 +108,6 @@ def prepare_data(self, i):
inputs["numbers"] = np.pad(
inputs["numbers"], (0, zeros_to_add), "constant"
).astype(np.int16)
inputs["n_atoms"] = np.pad(
inputs["n_atoms"], (0, zeros_to_add), "constant"
).astype(np.int16)

if not self.labels:
return inputs
Expand All @@ -120,7 +117,6 @@ def prepare_data(self, i):
labels["forces"] = np.pad(
labels["forces"], ((0, zeros_to_add), (0, 0)), "constant"
)

inputs = {k: tf.constant(v) for k, v in inputs.items()}
labels = {k: tf.constant(v) for k, v in labels.items()}
return (inputs, labels)
Expand Down Expand Up @@ -169,6 +165,7 @@ def init_input(self) -> Dict[str, np.ndarray]:
"""Returns first batch of inputs and labels to init the model."""
positions = self.sample_atoms.positions * unit_dict[self.pos_unit]
box = self.sample_atoms.cell.array * unit_dict[self.pos_unit]
# For an input sample, it does not matter whether pos is fractional or cartesian
idx, offsets = compute_nl(positions, box, self.cutoff)
inputs = (
positions,
Expand Down Expand Up @@ -204,7 +201,7 @@ def __iter__(self):
space = self.n_data - self.count
self.enqueue(space)

def shuffle_and_batch(self):
def shuffle_and_batch(self, sharding=None):
"""Shuffles and batches the inputs/labels. This function prepares the
inputs and labels for the whole training and prefetches the data.
Expand All @@ -226,10 +223,12 @@ def shuffle_and_batch(self):
).batch(batch_size=self.batch_size)
if self.n_jit_steps > 1:
ds = ds.batch(batch_size=self.n_jit_steps)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
ds = prefetch_to_single_device(
ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1
)
return ds

def batch(self) -> Iterator[jax.Array]:
def batch(self, sharding=None) -> Iterator[jax.Array]:
ds = (
tf.data.Dataset.from_generator(
lambda: self, output_signature=self.make_signature()
Expand All @@ -238,7 +237,9 @@ def batch(self) -> Iterator[jax.Array]:
.repeat(self.n_epochs)
)
ds = ds.batch(batch_size=self.batch_size)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
ds = prefetch_to_single_device(
ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1
)
return ds

def cleanup(self):
Expand All @@ -265,7 +266,7 @@ def __iter__(self):
self.enqueue(space)
outer_count += 1

def shuffle_and_batch(self):
def shuffle_and_batch(self, sharding=None):
"""Shuffles and batches the inputs/labels. This function prepares the
inputs and labels for the whole training and prefetches the data.
Expand All @@ -283,15 +284,19 @@ def shuffle_and_batch(self):
).batch(batch_size=self.batch_size)
if self.n_jit_steps > 1:
ds = ds.batch(batch_size=self.n_jit_steps)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
ds = prefetch_to_single_device(
ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1
)
return ds

def batch(self) -> Iterator[jax.Array]:
def batch(self, sharding=None) -> Iterator[jax.Array]:
ds = tf.data.Dataset.from_generator(
lambda: self, output_signature=self.make_signature()
)
ds = ds.batch(batch_size=self.batch_size)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
ds = prefetch_to_single_device(
ds.as_numpy_iterator(), 2, sharding, n_step_jit=self.n_jit_steps > 1
)
return ds


Expand Down
46 changes: 32 additions & 14 deletions apax/data/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,33 @@
log = logging.getLogger(__name__)


def compute_nl(position, box, r_max):
def compute_nl(positions, box, r_max):
"""Computes the NL for a single structure.
For periodic systems, positions are assumed to be in
fractional coordinates.
"""
if np.all(box < 1e-6):
cell, cell_origin = get_shrink_wrapped_cell(position)
box, box_origin = get_shrink_wrapped_cell(positions)
idxs_i, idxs_j = neighbour_list(
"ij",
positions=position,
positions=positions,
cutoff=r_max,
cell=cell,
cell_origin=cell_origin,
cell=box,
cell_origin=box_origin,
pbc=[False, False, False],
)

neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16)

n_neighbors = neighbor_idxs.shape[1]
offsets = np.full([n_neighbors, 3], 0)

else:
positions = positions @ box
idxs_i, idxs_j, offsets = neighbour_list(
"ijS",
positions=position,
cutoff=r_max,
cell=box,
"ijS", positions=positions, cutoff=r_max, cell=box, pbc=[True, True, True]
)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int32)
neighbor_idxs = np.array([idxs_i, idxs_j], dtype=np.int16)
offsets = np.matmul(offsets, box)
return neighbor_idxs, offsets

Expand All @@ -53,16 +55,32 @@ def get_shrink_wrapped_cell(positions):
return cell, cell_origin


def prefetch_to_single_device(iterator, size: int):
def prefetch_to_single_device(iterator, size: int, sharding=None, n_step_jit=False):
"""
inspired by
https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device
except it does not shard the data.
"""
queue = collections.deque()

def _prefetch(x):
return jnp.asarray(x)
if sharding:
n_devices = len(sharding._devices)
slice_start = 1
shape = [n_devices]
if n_step_jit:
# replicate over multi-batch axis
# data shape: njit x bs x ...
slice_start = 2
shape.insert(0, 1)

def _prefetch(x: jax.Array):
if sharding:
remaining_axes = [1] * len(x.shape[slice_start:])
final_shape = tuple(shape + remaining_axes)
x = jax.device_put(x, sharding.reshape(final_shape))
else:
x = jnp.asarray(x)
return x

def enqueue(n):
for data in itertools.islice(iterator, n):
Expand Down
4 changes: 4 additions & 0 deletions apax/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .md import ApaxJaxMD
from .model import Apax, ApaxEnsemble

__all__ = ["Apax", "ApaxEnsemble", "ApaxJaxMD"]
87 changes: 87 additions & 0 deletions apax/nodes/md.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import functools
import logging
import pathlib
import typing

import ase.io
import h5py
import yaml
import znh5md
import zntrack.utils

from apax.md.nvt import run_md

from .model import Apax
from .utils import check_duplicate_keys

log = logging.getLogger(__name__)


class ApaxJaxMD(zntrack.Node):
"""Class to run a more performant JaxMD simulation with a apax Model.
Attributes
----------
data: list[ase.Atoms]
MD starting structure
data_id: int, default=-1
index of the configuration from the data list to use
model: ApaxModel
model to use for the simulation
repeat: float
number of repeats
config: str
path to the MD simulation parameter file
"""

data: list[ase.Atoms] = zntrack.deps()
data_id: int = zntrack.params(-1)

model: Apax = zntrack.deps()
repeat = zntrack.params(None)

config: str = zntrack.params_path(None)

sim_dir: pathlib.Path = zntrack.outs_path(zntrack.nwd / "md")
init_struc_dir: pathlib.Path = zntrack.outs_path(
zntrack.nwd / "initial_structure.extxyz"
)

_parameter: dict = None

def _post_load_(self) -> None:
self._handle_parameter_file()

def _handle_parameter_file(self):
with self.state.use_tmp_path():
self._parameter = yaml.safe_load(pathlib.Path(self.config).read_text())

custom_parameters = {
"sim_dir": self.sim_dir.as_posix(),
"initial_structure": self.init_struc_dir.as_posix(),
}
check_duplicate_keys(custom_parameters, self._parameter, log)
self._parameter.update(custom_parameters)

def run(self):
"""Primary method to run which executes all steps of the model training"""

atoms = self.data[self.data_id]
if self.repeat is not None:
atoms = atoms.repeat(self.repeat)
ase.io.write(self.init_struc_dir.as_posix(), atoms)

run_md(self.model._parameter, self._parameter)

@functools.cached_property
def atoms(self) -> typing.List[ase.Atoms]:
def file_handle(filename):
file = self.state.fs.open(filename, "rb")
return h5py.File(file)

return znh5md.ASEH5MD(
self.sim_dir / "md.h5",
format_handler=functools.partial(
znh5md.FormatHandler, file_handle=file_handle
),
).get_atoms_list()
Loading

0 comments on commit 095cfe4

Please sign in to comment.