Skip to content

Commit

Permalink
Merge pull request #188 from apax-hub/val_bs_fix
Browse files Browse the repository at this point in the history
Val bs fix and distance computation refactor
  • Loading branch information
M-R-Schaefer authored Oct 31, 2023
2 parents c6a6fe3 + 655db44 commit f632d22
Show file tree
Hide file tree
Showing 18 changed files with 373 additions and 406 deletions.
9 changes: 5 additions & 4 deletions apax/bal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from tqdm import trange

from apax.bal import feature_maps, kernel, selection, transforms
from apax.data.input_pipeline import TFPipeline
from apax.data.input_pipeline import AtomisticDataset
from apax.model.builder import ModelBuilder
from apax.model.gmnn import EnergyModel
from apax.train.checkpoints import restore_parameters
Expand Down Expand Up @@ -43,11 +43,11 @@ def create_feature_fn(
return feature_fn


def compute_features(feature_fn, dataset: TFPipeline, processing_batch_size: int):
def compute_features(feature_fn, dataset: AtomisticDataset):
"""Compute the features of a dataset."""
features = []
n_data = dataset.n_data
ds = dataset.batch(processing_batch_size)
ds = dataset.batch()

pbar = trange(n_data, desc="Computing features", ncols=100, leave=True)
for i, (inputs, _) in enumerate(ds):
Expand Down Expand Up @@ -83,6 +83,7 @@ def kernel_selection(

n_train = len(train_atoms)
dataset = initialize_dataset(config, RawDataset(atoms_list=train_atoms + pool_atoms))
dataset.set_batch_size(processing_batch_size)

init_box = dataset.init_input()["box"][0]

Expand All @@ -92,7 +93,7 @@ def kernel_selection(
feature_fn = create_feature_fn(
model, params, base_feature_map, feature_transforms, is_ensemble
)
g = compute_features(feature_fn, dataset, processing_batch_size)
g = compute_features(feature_fn, dataset)
km = kernel.KernelMatrix(g, n_train)
new_indices = selection_fn(km, selection_batch_size)

Expand Down
2 changes: 0 additions & 2 deletions apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -79,5 +79,3 @@ checkpoints:
progress_bar:
disable_epoch_pbar: false
disable_nl_pbar: false

maximize_l2_cache: true
2 changes: 0 additions & 2 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ class Config(BaseModel, frozen=True, extra="forbid"):
callbacks: List of :class: `callback` <config.CallbackConfig> configurations.
progress_bar: Progressbar configuration.
checkpoints: Checkpoint configuration.
maximize_l2_cache: Whether or not to maximize GPU L2 cache.
"""

n_epochs: PositiveInt
Expand All @@ -301,7 +300,6 @@ class Config(BaseModel, frozen=True, extra="forbid"):
callbacks: List[CallbackConfig] = [CallbackConfig(name="csv")]
progress_bar: TrainProgressbarConfig = TrainProgressbarConfig()
checkpoints: CheckpointConfig = CheckpointConfig()
maximize_l2_cache: bool = False

def dump_config(self, save_path):
"""
Expand Down
86 changes: 86 additions & 0 deletions apax/data/initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import dataclasses
import logging
import os
from typing import Optional

import numpy as np
from ase import Atoms

from apax.data.input_pipeline import AtomisticDataset, create_dict_dataset
from apax.data.statistics import compute_scale_shift_parameters
from apax.utils.data import load_data, split_atoms, split_idxs, split_label

log = logging.getLogger(__name__)


@dataclasses.dataclass
class RawDataset:
atoms_list: list[Atoms]
additional_labels: Optional[dict] = None


def load_data_files(data_config, model_version_path):
log.info("Running Input Pipeline")
if data_config.data_path is not None:
log.info(f"Read data file {data_config.data_path}")
atoms_list, label_dict = load_data(data_config.data_path)

train_idxs, val_idxs = split_idxs(
atoms_list, data_config.n_train, data_config.n_valid
)
train_atoms_list, val_atoms_list = split_atoms(atoms_list, train_idxs, val_idxs)
train_label_dict, val_label_dict = split_label(label_dict, train_idxs, val_idxs)

np.savez(
os.path.join(model_version_path, "train_val_idxs"),
train_idxs=train_idxs,
val_idxs=val_idxs,
)

elif data_config.train_data_path and data_config.val_data_path is not None:
log.info(f"Read training data file {data_config.train_data_path}")
log.info(f"Read validation data file {data_config.val_data_path}")
train_atoms_list, train_label_dict = load_data(data_config.train_data_path)
val_atoms_list, val_label_dict = load_data(data_config.val_data_path)
else:
raise ValueError("input data path/paths not defined")

train_raw_ds = RawDataset(
atoms_list=train_atoms_list, additional_labels=train_label_dict
)
val_raw_ds = RawDataset(atoms_list=val_atoms_list, additional_labels=val_label_dict)

return train_raw_ds, val_raw_ds


def initialize_dataset(config, raw_ds, calc_stats: bool = True):
inputs, labels = create_dict_dataset(
raw_ds.atoms_list,
r_max=config.model.r_max,
external_labels=raw_ds.additional_labels,
disable_pbar=config.progress_bar.disable_nl_pbar,
pos_unit=config.data.pos_unit,
energy_unit=config.data.energy_unit,
)

if calc_stats:
ds_stats = compute_scale_shift_parameters(
inputs,
labels,
config.data.shift_method,
config.data.scale_method,
config.data.shift_options,
config.data.scale_options,
)

dataset = AtomisticDataset(
inputs,
labels,
config.n_epochs,
buffer_size=config.data.shuffle_buffer_size,
)

if calc_stats:
return dataset, ds_stats
else:
return dataset
86 changes: 51 additions & 35 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
from typing import Dict, Iterator

import jax
import numpy as np
import tensorflow as tf

Expand All @@ -9,15 +11,15 @@
log = logging.getLogger(__name__)


def find_largest_system(inputs):
def find_largest_system(inputs: dict[str, np.ndarray]) -> tuple[int]:
max_atoms = np.max(inputs["fixed"]["n_atoms"])
nbr_shapes = [idx.shape[1] for idx in inputs["ragged"]["idx"]]
max_nbrs = np.max(nbr_shapes)
return max_atoms, max_nbrs


class PadToSpecificSize:
def __init__(self, max_atoms=None, max_nbrs=None) -> None:
def __init__(self, max_atoms: int, max_nbrs: int) -> None:
"""Function is padding all input and label dicts that values are of type ragged
to largest element in the batch. Afterward, the distinction between ragged
and fixed inputs/labels is not needed and all inputs/labels are updated to
Expand Down Expand Up @@ -95,7 +97,7 @@ def create_dict_dataset(
disable_pbar=False,
pos_unit: str = "Ang",
energy_unit: str = "eV",
) -> None:
) -> tuple[dict]:
inputs, labels = atoms_to_arrays(atoms_list, pos_unit, energy_unit)

if external_labels:
Expand All @@ -115,15 +117,41 @@ def create_dict_dataset(
return inputs, labels


class TFPipeline:
def dataset_from_dicts(
inputs: Dict[str, np.ndarray], labels: Dict[str, np.ndarray]
) -> tf.data.Dataset:
# tf.RaggedTensors should be created from `tf.ragged.stack`
# instead of `tf.ragged.constant` for performance reasons.
# See https://github.com/tensorflow/tensorflow/issues/47853
for key, val in inputs["ragged"].items():
inputs["ragged"][key] = tf.ragged.stack(val)
for key, val in inputs["fixed"].items():
inputs["fixed"][key] = tf.constant(val)

for key, val in labels["ragged"].items():
labels["ragged"][key] = tf.ragged.stack(val)
for key, val in labels["fixed"].items():
labels["fixed"][key] = tf.constant(val)

ds = tf.data.Dataset.from_tensor_slices(
(
inputs["ragged"],
inputs["fixed"],
labels["ragged"],
labels["fixed"],
)
)
return ds


class AtomisticDataset:
"""Class processes inputs/labels and makes them accessible for training."""

def __init__(
self,
inputs,
labels,
n_epoch: int,
batch_size: int,
buffer_size: int = 1000,
) -> None:
"""Processes inputs/labels and makes them accessible for training.
Expand All @@ -144,7 +172,7 @@ def __init__(
value.
"""
self.n_epoch = n_epoch
self.batch_size = batch_size
self.batch_size = None
self.buffer_size = buffer_size

max_atoms, max_nbrs = find_largest_system(inputs)
Expand All @@ -153,6 +181,16 @@ def __init__(

self.n_data = len(inputs["fixed"]["n_atoms"])

self.ds = dataset_from_dicts(inputs, labels)

def set_batch_size(self, batch_size: int):
self.batch_size = self.validate_batch_size(batch_size)

def _check_batch_size(self):
if self.batch_size is None:
raise ValueError("Dataset Batch Size has not been set yet")

def validate_batch_size(self, batch_size: int) -> int:
if batch_size > self.n_data:
msg = (
f"requested batch size {batch_size} is larger than the number of data"
Expand All @@ -161,29 +199,7 @@ def __init__(
print("Warning: " + msg)
log.warning(msg)
batch_size = self.n_data
self.batch_size = batch_size

# tf.RaggedTensors should be created from `tf.ragged.stack`
# instead of `tf.ragged.constant` for performance reasons.
# See https://github.com/tensorflow/tensorflow/issues/47853
for key, val in inputs["ragged"].items():
inputs["ragged"][key] = tf.ragged.stack(val)
for key, val in inputs["fixed"].items():
inputs["fixed"][key] = tf.constant(val)

for key, val in labels["ragged"].items():
labels["ragged"][key] = tf.ragged.stack(val)
for key, val in labels["fixed"].items():
labels["fixed"][key] = tf.constant(val)

self.ds = tf.data.Dataset.from_tensor_slices(
(
inputs["ragged"],
inputs["fixed"],
labels["ragged"],
labels["fixed"],
)
)
return batch_size

def steps_per_epoch(self) -> int:
"""Returns the number of steps per epoch dependent on the number of data and the
Expand All @@ -193,7 +209,7 @@ def steps_per_epoch(self) -> int:
"""
return self.n_data // self.batch_size

def init_input(self):
def init_input(self) -> Dict[str, np.ndarray]:
"""Returns first batch of inputs and labels to init the model."""
inputs, _ = next(
self.ds.batch(1)
Expand All @@ -203,7 +219,7 @@ def init_input(self):
)
return inputs

def shuffle_and_batch(self):
def shuffle_and_batch(self) -> Iterator[jax.Array]:
"""Shuffles, batches, and pads the inputs/labels. This function prepares the
inputs and labels for the whole training and prefetches the data.
Expand All @@ -212,6 +228,7 @@ def shuffle_and_batch(self):
shuffled_ds :
Iterator that returns inputs and labels of one batch in each step.
"""
self._check_batch_size()
shuffled_ds = (
self.ds.shuffle(buffer_size=self.buffer_size)
.repeat(self.n_epoch)
Expand All @@ -222,10 +239,9 @@ def shuffle_and_batch(self):
shuffled_ds = prefetch_to_single_device(shuffled_ds.as_numpy_iterator(), 2)
return shuffled_ds

def batch(self, batch_size):
# TODO: the batch size here overrides self.batch_size
# we should find a better abstraction
ds = self.ds.batch(batch_size=batch_size).map(
def batch(self) -> Iterator[jax.Array]:
self._check_batch_size()
ds = self.ds.batch(batch_size=self.batch_size).map(
PadToSpecificSize(self.max_atoms, self.max_nbrs)
)

Expand Down
7 changes: 4 additions & 3 deletions apax/data/preprocessing.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import collections
import itertools
import logging
from typing import Callable

import jax
import jax.numpy as jnp
import numpy as np
from ase import Atoms
from jax_md import partition, space
from matscipy.neighbours import neighbour_list
from tqdm import trange
Expand All @@ -14,7 +16,7 @@
log = logging.getLogger(__name__)


def initialize_nbr_fn(atoms, cutoff):
def initialize_nbr_fn(atoms: Atoms, cutoff: float) -> Callable:
neighbor_fn = None
default_box = 100
box = jnp.asarray(atoms.cell.array)
Expand All @@ -36,7 +38,6 @@ def initialize_nbr_fn(atoms, cutoff):

@jax.jit
def extract_nl(neighbors, position):
# vmapped neighborlist probably only useful for larger structures
neighbors = neighbors.update(position)
return neighbors

Expand Down Expand Up @@ -115,7 +116,7 @@ def get_shrink_wrapped_cell(positions):
return cell, cell_origin


def prefetch_to_single_device(iterator, size):
def prefetch_to_single_device(iterator, size: int):
"""
inspired by
https://flax.readthedocs.io/en/latest/_modules/flax/jax_utils.html#prefetch_to_device
Expand Down
4 changes: 1 addition & 3 deletions apax/data/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
class DatasetStats:
elemental_shift: np.array = None
elemental_scale: float = None
n_atoms: int = 0
n_species: int = 119
displacement_fn = None


class PerElementRegressionShift:
Expand All @@ -36,7 +34,7 @@ def compute(inputs, labels, shift_options) -> np.ndarray:
n_atoms_total = np.sum(system_sizes)

mean_energy = ds_energy / n_atoms_total
n_species = 119 # max([max(n) for n in numbers]) + 1
n_species = 119 # for simplicity, we assume any element could be in the dataset
X = np.zeros(shape=(energies.shape[0], n_species))
y = np.zeros(energies.shape[0])

Expand Down
Loading

0 comments on commit f632d22

Please sign in to comment.