Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Mar 29, 2024
1 parent 07c86b5 commit 731592c
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 28 deletions.
2 changes: 1 addition & 1 deletion apax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import jax

os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
jax.config.update("jax_enable_x64", True)
from apax.utils.helpers import setup_ase

Expand Down
42 changes: 23 additions & 19 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import logging
import uuid
from collections import deque
from pathlib import Path
from random import shuffle
from typing import Dict, Iterator
import uuid
from pathlib import Path

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -46,7 +46,7 @@ def __init__(
n_jit_steps=1,
pre_shuffle=False,
ignore_labels=False,
cache_path = "."
cache_path=".",
) -> None:
if pre_shuffle:
shuffle(atoms)
Expand Down Expand Up @@ -82,7 +82,7 @@ def steps_per_epoch(self) -> int:
data are dropped in each epoch.
"""
return self.n_data // self.batch_size // self.n_jit_steps

def validate_batch_size(self, batch_size: int) -> int:
if batch_size > self.n_data:
msg = (
Expand All @@ -92,7 +92,7 @@ def validate_batch_size(self, batch_size: int) -> int:
log.warning(msg)
batch_size = self.n_data
return batch_size

def prepare_data(self, i):
inputs = {k: v[i] for k, v in self.inputs.items()}
idx, offsets = compute_nl(inputs["positions"], inputs["box"], self.cutoff)
Expand Down Expand Up @@ -177,7 +177,7 @@ def init_input(self) -> Dict[str, np.ndarray]:

inputs = jax.tree_map(lambda x: jnp.array(x), inputs)
return inputs, np.array(box)

def __iter__(self):
raise NotImplementedError

Expand All @@ -192,7 +192,6 @@ def cleanup(self):


class CachedInMemoryDataset(InMemoryDataset):

def __iter__(self):
while self.count < self.n_data or len(self.buffer) > 0:
yield self.buffer.popleft()
Expand All @@ -202,8 +201,6 @@ def __iter__(self):
space = self.n_data - self.count
self.enqueue(space)



def shuffle_and_batch(self):
"""Shuffles and batches the inputs/labels. This function prepares the
inputs and labels for the whole training and prefetches the data.
Expand All @@ -213,9 +210,13 @@ def shuffle_and_batch(self):
ds :
Iterator that returns inputs and labels of one batch in each step.
"""
ds = tf.data.Dataset.from_generator(
lambda: self, output_signature=self.make_signature()
).cache(self.file.as_posix()).repeat(self.n_epochs)
ds = (
tf.data.Dataset.from_generator(
lambda: self, output_signature=self.make_signature()
)
.cache(self.file.as_posix())
.repeat(self.n_epochs)
)

ds = ds.shuffle(
buffer_size=self.buffer_size, reshuffle_each_iteration=True
Expand All @@ -226,13 +227,17 @@ def shuffle_and_batch(self):
return ds

def batch(self) -> Iterator[jax.Array]:
ds = tf.data.Dataset.from_generator(
lambda: self, output_signature=self.make_signature()
).cache(self.file.as_posix()).repeat(self.n_epochs)
ds = (
tf.data.Dataset.from_generator(
lambda: self, output_signature=self.make_signature()
)
.cache(self.file.as_posix())
.repeat(self.n_epochs)
)
ds = ds.batch(batch_size=self.batch_size)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
return ds

def cleanup(self):
for p in self.file.parent.glob(f"{self.file.name}.data*"):
p.unlink()
Expand All @@ -242,7 +247,6 @@ def cleanup(self):


class OTFInMemoryDataset(InMemoryDataset):

def __iter__(self):
epoch = 0
while epoch < self.n_epochs or len(self.buffer) > 0:
Expand Down Expand Up @@ -285,9 +289,9 @@ def batch(self) -> Iterator[jax.Array]:
ds = ds.batch(batch_size=self.batch_size)
ds = prefetch_to_single_device(ds.as_numpy_iterator(), 2)
return ds


dataset_dict = {
"cached": CachedInMemoryDataset,
"otf": OTFInMemoryDataset,
}
}
10 changes: 7 additions & 3 deletions apax/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import jax

from apax.config import LossConfig, parse_config, Config
from apax.config import Config, LossConfig, parse_config
from apax.data.initialization import load_data_files
from apax.data.input_pipeline import dataset_dict
from apax.data.statistics import compute_scale_shift_parameters
Expand Down Expand Up @@ -50,6 +50,7 @@ def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection:
loss_funcs.append(Loss(**loss.model_dump()))
return LossCollection(loss_funcs)


def initialize_datasets(config: Config):
train_raw_ds, val_raw_ds = load_data_files(config.data)

Expand All @@ -66,7 +67,11 @@ def initialize_datasets(config: Config):
cache_path=config.data.model_version_path,
)
val_ds = Dataset(
val_raw_ds, config.model.r_max, config.data.valid_batch_size, config.n_epochs, cache_path=config.data.model_version_path,
val_raw_ds,
config.model.r_max,
config.data.valid_batch_size,
config.n_epochs,
cache_path=config.data.model_version_path,
)
ds_stats = compute_scale_shift_parameters(
train_ds.inputs,
Expand All @@ -79,7 +84,6 @@ def initialize_datasets(config: Config):
return train_ds, val_ds, ds_stats



def run(user_config, log_level="error"):
config = parse_config(user_config)

Expand Down
11 changes: 6 additions & 5 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import functools
import logging
from pathlib import Path
import time
from functools import partial
from typing import Callable, Optional
Expand Down Expand Up @@ -108,10 +107,12 @@ def fit(
epoch_loss["val_loss"] /= val_steps_per_epoch
epoch_loss["val_loss"] = float(epoch_loss["val_loss"])

epoch_metrics.update({
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
})
epoch_metrics.update(
{
f"val_{key}": float(val)
for key, val in val_batch_metrics.compute().items()
}
)

epoch_metrics.update({**epoch_loss})

Expand Down

0 comments on commit 731592c

Please sign in to comment.