Skip to content

Commit

Permalink
added dataset splitting cli utility
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Apr 3, 2024
1 parent 731592c commit b9ebb65
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
32 changes: 32 additions & 0 deletions apax/cli/apax_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,38 @@ def eval(
eval_model(train_config_path, n_data)


@app.command()
def split(
path: Path = typer.Argument(..., help="Dataset to split."),
n_train: int = typer.Argument(
...,
help="Number of training samples.",
),
n_val: int = typer.Argument(
...,
help="Number of validation samples.",
),
seed: int = typer.Argument(
0,
help="Random number generator seed.",
),
indices: bool = typer.Option(
False,
help="Whether or not to write train/val indices to disk.",
),
):
"""
Small utility for splitting datasets on the command line.
"""
import numpy as np

from apax.utils.data import split_dataset

np.random.seed(seed)

split_dataset(path, n_train, n_val, indices)


@app.command()
def docs():
"""
Expand Down
22 changes: 21 additions & 1 deletion apax/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import jax.numpy as jnp
import numpy as np
from ase.io import read
from ase.io import read, write

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -112,3 +112,23 @@ def split_atoms(atoms_list, train_idxs, val_idxs=None):
val_atoms_list = []

return train_atoms_list, val_atoms_list


def split_dataset(path: Path, n_train: int, n_val: int, save_indices: bool = False):
atoms_list = load_data(path)

train_idxs, val_idxs = split_idxs(atoms_list, n_train, n_val)
train_atoms_list, val_atoms_list = split_atoms(atoms_list, train_idxs, val_idxs)
train_path = path.parent / f"{path.stem}_train.extxyz"
val_path = path.parent / f"{path.stem}_val.extxyz"
write(train_path.as_posix(), train_atoms_list)
write(val_path.as_posix(), val_atoms_list)

if save_indices:
idx_path = (path.parent / f"{path.stem}_train_val_idxs").as_posix()

np.savez(
idx_path,
train_idxs=train_idxs,
val_idxs=val_idxs,
)

0 comments on commit b9ebb65

Please sign in to comment.