Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean up xyz vs. h5 config file parsing #462

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions mace/cli/run_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,12 +132,15 @@ def run(args: argparse.Namespace) -> None:
args.compute_avg_num_neighbors = False
args.E0s = statistics["atomic_energies"]

train_file_type = data.utils.get_configs_file_type(args.train_file)

# Data preparation
if args.train_file.endswith(".xyz"):
if train_file_type == "non_hdf5_file":
if args.valid_file is not None:
assert args.valid_file.endswith(
".xyz"
), "valid_file if given must be same format as train_file"
assert (Path(args.valid_file).suffix ==
Path(args.train_file).suffix
), (f"valid_file {args.valid_file} if given must be same "
f"format as train_file {args.train_file}")
config_type_weights = get_config_type_weights(args.config_type_weights)
collections, atomic_energies_dict = get_dataset_from_xyz(
train_path=args.train_file,
Expand Down Expand Up @@ -165,7 +168,9 @@ def run(args: argparse.Namespace) -> None:
# Atomic number table
# yapf: disable
if args.atomic_numbers is None:
assert args.train_file.endswith(".xyz"), "Must specify atomic_numbers when using .h5 train_file input"
assert train_file_type not in ("hdf5_dir", "hdf5_file"), (
"Must specify atomic_numbers when using hdf5 train_file input"
)
z_table = tools.get_atomic_number_table_from_zs(
z
for configs in (collections.train, collections.valid)
Expand Down Expand Up @@ -197,7 +202,7 @@ def run(args: argparse.Namespace) -> None:
for z in z_table.zs
}
else:
if args.train_file.endswith(".xyz"):
if train_file_type == "non_hdf5_file":
atomic_energies_dict = get_atomic_energies(
args.E0s, collections.train, z_table
)
Expand Down Expand Up @@ -228,7 +233,7 @@ def run(args: argparse.Namespace) -> None:
)
logging.info(f"Atomic energies: {atomic_energies.tolist()}")

if args.train_file.endswith(".xyz"):
if train_file_type == "non_hdf5_file":
train_set = [
data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max)
for config in collections.train
Expand All @@ -237,16 +242,18 @@ def run(args: argparse.Namespace) -> None:
data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max)
for config in collections.valid
]
elif args.train_file.endswith(".h5"):
elif train_file_type == "hdf5_file":
train_set = data.HDF5Dataset(args.train_file, r_max=args.r_max, z_table=z_table)
valid_set = data.HDF5Dataset(args.valid_file, r_max=args.r_max, z_table=z_table)
else: # This case would be for when the file path is to a directory of multiple .h5 files
elif train_file_type == "hdf5_dir": # This case would be for when the file path is to a directory of multiple .h5 files
train_set = data.dataset_from_sharded_hdf5(
args.train_file, r_max=args.r_max, z_table=z_table
)
valid_set = data.dataset_from_sharded_hdf5(
args.valid_file, r_max=args.r_max, z_table=z_table
)
else:
raise ValueError(f"Unknown train_file_type {train_file_type}")

train_sampler, valid_sampler = None, None
if args.distributed:
Expand Down Expand Up @@ -719,7 +726,7 @@ def run(args: argparse.Namespace) -> None:
}

test_sets = {}
if args.train_file.endswith(".xyz"):
if train_file_type == "non_hdf5_file":
for name, subset in collections.tests:
test_sets[name] = [
data.AtomicData.from_config(config, z_table=z_table, cutoff=args.r_max)
Expand Down
18 changes: 18 additions & 0 deletions mace/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import logging
from dataclasses import dataclass
from typing import Dict, List, Optional, Sequence, Tuple
from pathlib import Path

import ase.data
import ase.io
Expand Down Expand Up @@ -189,6 +190,23 @@ def test_config_types(
return test_by_ct


def get_configs_file_type(
filename: str
) -> str:
filepath = Path(filename)
if filepath.is_dir():
# only support dirs when it contains (presumably sharded) hdf5 files
if len(list(filepath.glob("*.h5")) + list(filepath.glob("*.hdf5"))) == 0:
raise RuntimeError(f"Got directory {filename} with no .h5/.hdf5 files")
return "hdf5_dir"

if filepath.suffix in (".h5", ".hdf5"):
# special case
return "hdf5_file"

return "non_hdf5_file"


def load_from_xyz(
file_path: str,
config_type_weights: Dict,
Expand Down
36 changes: 18 additions & 18 deletions mace/tools/arg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,13 +240,13 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
# Dataset
parser.add_argument(
"--train_file",
help="Training set file, format is .xyz or .h5",
help="Training set file readable by ase.io.read, or .h5/.hdf5, or directory containing sharded hdf5",
type=str,
required=True,
)
parser.add_argument(
"--valid_file",
help="Validation set .xyz or .h5 file",
help="Validation set readable by ase.io.read, or .h5/.hdf5, or directory containing sharded hdf5",
default=None,
type=str,
required=False,
Expand All @@ -260,7 +260,7 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
"--test_file",
help="Test set .xyz pt .h5 file",
help="Test set readable by ase.io.read, or .h5/.hdf5, or directory containing sharded hdf5",
type=str,
)
parser.add_argument(
Expand Down Expand Up @@ -332,37 +332,37 @@ def build_default_arg_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
"--energy_key",
help="Key of reference energies in training xyz",
help="Key of reference energies in training file",
type=str,
default="energy",
)
parser.add_argument(
"--forces_key",
help="Key of reference forces in training xyz",
help="Key of reference forces in training file",
type=str,
default="forces",
)
parser.add_argument(
"--virials_key",
help="Key of reference virials in training xyz",
help="Key of reference virials in training file",
type=str,
default="virials",
)
parser.add_argument(
"--stress_key",
help="Key of reference stress in training xyz",
help="Key of reference stress in training file",
type=str,
default="stress",
)
parser.add_argument(
"--dipole_key",
help="Key of reference dipoles in training xyz",
help="Key of reference dipoles in training file",
type=str,
default="dipole",
)
parser.add_argument(
"--charges_key",
help="Key of atomic charges in training xyz",
help="Key of atomic charges in training file",
type=str,
default="charges",
)
Expand Down Expand Up @@ -611,14 +611,14 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser()
parser.add_argument(
"--train_file",
help="Training set h5 file",
help="Training set ase.io.read-compatible file",
type=str,
default=None,
required=True,
)
parser.add_argument(
"--valid_file",
help="Training set xyz file",
help="Training set ase.io.read-compatible file",
type=str,
default=None,
required=False,
Expand All @@ -638,7 +638,7 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
"--test_file",
help="Test set xyz file",
help="Test set ase.io.read-compatible file",
type=str,
default=None,
required=False,
Expand All @@ -660,37 +660,37 @@ def build_preprocess_arg_parser() -> argparse.ArgumentParser:
)
parser.add_argument(
"--energy_key",
help="Key of reference energies in training xyz",
help="Key of reference energies in training file",
type=str,
default="energy",
)
parser.add_argument(
"--forces_key",
help="Key of reference forces in training xyz",
help="Key of reference forces in training file",
type=str,
default="forces",
)
parser.add_argument(
"--virials_key",
help="Key of reference virials in training xyz",
help="Key of reference virials in training file",
type=str,
default="virials",
)
parser.add_argument(
"--stress_key",
help="Key of reference stress in training xyz",
help="Key of reference stress in training file",
type=str,
default="stress",
)
parser.add_argument(
"--dipole_key",
help="Key of reference dipoles in training xyz",
help="Key of reference dipoles in training file",
type=str,
default="dipole",
)
parser.add_argument(
"--charges_key",
help="Key of atomic charges in training xyz",
help="Key of atomic charges in training file",
type=str,
default="charges",
)
Expand Down
Loading