Skip to content

Commit

Permalink
Merge pull request #245 from apax-hub/otf_nl
Browse files Browse the repository at this point in the history
Input Pipeline rework
  • Loading branch information
M-R-Schaefer authored Mar 26, 2024
2 parents 2fd404c + 5811501 commit 10470b1
Show file tree
Hide file tree
Showing 13 changed files with 364 additions and 589 deletions.
14 changes: 8 additions & 6 deletions apax/bal/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
from tqdm import trange

from apax.bal import feature_maps, kernel, selection, transforms
from apax.data.input_pipeline import AtomisticDataset
from apax.data.input_pipeline import InMemoryDataset
from apax.model.builder import ModelBuilder
from apax.model.gmnn import EnergyModel
from apax.train.checkpoints import (
canonicalize_energy_model_parameters,
check_for_ensemble,
restore_parameters,
)
from apax.train.run import initialize_dataset


def create_feature_fn(
Expand Down Expand Up @@ -47,7 +46,7 @@ def create_feature_fn(
return feature_fn


def compute_features(feature_fn, dataset: AtomisticDataset):
def compute_features(feature_fn, dataset: InMemoryDataset):
"""Compute the features of a dataset."""
features = []
n_data = dataset.n_data
Expand Down Expand Up @@ -86,10 +85,13 @@ def kernel_selection(
is_ensemble = n_models > 1

n_train = len(train_atoms)
dataset = initialize_dataset(
config, train_atoms + pool_atoms, read_labels=False, calc_stats=False
dataset = InMemoryDataset(
train_atoms + pool_atoms,
cutoff=config.model.r_max,
bs=processing_batch_size,
n_epochs=1,
ignore_labels=True,
)
dataset.set_batch_size(processing_batch_size)

_, init_box = dataset.init_input()

Expand Down
50 changes: 0 additions & 50 deletions apax/data/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,6 @@

import numpy as np

from apax.data.input_pipeline import AtomisticDataset, process_inputs
from apax.data.statistics import compute_scale_shift_parameters
from apax.utils.convert import atoms_to_labels
from apax.utils.data import load_data, split_atoms, split_idxs

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -36,50 +33,3 @@ def load_data_files(data_config):
raise ValueError("input data path/paths not defined")

return train_atoms_list, val_atoms_list


def initialize_dataset(
config,
atoms_list,
read_labels: bool = True,
calc_stats: bool = True,
):
if calc_stats and not read_labels:
raise ValueError(
"Cannot calculate scale/shift parameters without reading labels."
)
inputs = process_inputs(
atoms_list,
r_max=config.model.r_max,
disable_pbar=config.progress_bar.disable_nl_pbar,
pos_unit=config.data.pos_unit,
)
labels = atoms_to_labels(
atoms_list,
additional_properties_info=config.data.additional_properties_info,
read_labels=read_labels,
pos_unit=config.data.pos_unit,
energy_unit=config.data.energy_unit,
)

if calc_stats:
ds_stats = compute_scale_shift_parameters(
inputs,
labels,
config.data.shift_method,
config.data.scale_method,
config.data.shift_options,
config.data.scale_options,
)

dataset = AtomisticDataset(
inputs,
config.n_epochs,
labels=labels,
buffer_size=config.data.shuffle_buffer_size,
)

if calc_stats:
return dataset, ds_stats
else:
return dataset
Loading

0 comments on commit 10470b1

Please sign in to comment.