Skip to content

Commit

Permalink
Merge pull request #197 from apax-hub/absl_logging_fix
Browse files Browse the repository at this point in the history
Absl logging fix
  • Loading branch information
M-R-Schaefer authored Nov 14, 2023
2 parents b90fdb2 + a2c92c9 commit 7f61d0d
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 41 deletions.
5 changes: 2 additions & 3 deletions apax/cli/apax_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,14 @@ def train(
train_config_path: Path = typer.Argument(
..., help="Training configuration YAML file."
),
log_level: str = typer.Option("error", help="Sets the training logging level."),
log_file: str = typer.Option("train.log", help="Specifies the name of the log file"),
log_level: str = typer.Option("info", help="Sets the training logging level."),
):
"""
Starts the training of a model with parameters provided by a configuration file.
"""
from apax.train.run import run

run(train_config_path, log_file, log_level)
run(train_config_path, log_level)


@app.command()
Expand Down
1 change: 0 additions & 1 deletion apax/config/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ def parse_config(config: Union[str, os.PathLike, dict], mode: str = "train") ->
config: Path to the config file or a dictionary
containing the config.
"""
log.info("Loading user config")
if isinstance(config, (str, os.PathLike)):
with open(config, "r") as stream:
config = yaml.safe_load(stream)
Expand Down
4 changes: 3 additions & 1 deletion apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,12 +112,14 @@ def validate_shift_scale_methods(self):

return self

@property
def model_version_path(self):
version_path = Path(self.directory) / self.experiment
return version_path

@property
def best_model_path(self):
return self.model_version_path() / "best"
return self.model_version_path / "best"


class ModelConfig(BaseModel, extra="forbid"):
Expand Down
5 changes: 2 additions & 3 deletions apax/data/initialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import dataclasses
import logging
import os
from typing import Optional

import numpy as np
Expand All @@ -19,7 +18,7 @@ class RawDataset:
additional_labels: Optional[dict] = None


def load_data_files(data_config, model_version_path):
def load_data_files(data_config):
log.info("Running Input Pipeline")
if data_config.data_path is not None:
log.info(f"Read data file {data_config.data_path}")
Expand All @@ -32,7 +31,7 @@ def load_data_files(data_config, model_version_path):
train_label_dict, val_label_dict = split_label(label_dict, train_idxs, val_idxs)

np.savez(
os.path.join(model_version_path, "train_val_idxs"),
data_config.model_version_path / "train_val_idxs",
train_idxs=train_idxs,
val_idxs=val_idxs,
)
Expand Down
2 changes: 1 addition & 1 deletion apax/md/nvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def md_setup(model_config: Config, md_config: MDConfig):
disable_cell_list=True,
)

_, params = restore_parameters(model_config.data.model_version_path())
_, params = restore_parameters(model_config.data.model_version_path)
params = canonicalize_energy_model_parameters(params)
energy_fn = create_energy_fn(
model.apply, params, system.atomic_numbers, system.box, model_config.n_models
Expand Down
3 changes: 2 additions & 1 deletion apax/train/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def stack_parameters(param_list: List[FrozenDict]) -> FrozenDict:


def load_params(model_version_path: Path, best=True) -> FrozenDict:
model_version_path = Path(model_version_path)
if best:
model_version_path = model_version_path / "best"
log.info(f"loading checkpoint from {model_version_path}")
Expand All @@ -142,7 +143,7 @@ def restore_single_parameters(model_dir: Path) -> Tuple[Config, FrozenDict]:
"""Load the config and parameters of a single model
"""
model_config = parse_config(Path(model_dir) / "config.yaml")
ckpt_dir = model_config.data.model_version_path()
ckpt_dir = model_config.data.model_version_path
return model_config, load_params(ckpt_dir)


Expand Down
31 changes: 17 additions & 14 deletions apax/train/run.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import os
from pathlib import Path
import sys
from typing import List

import jax
Expand Down Expand Up @@ -32,7 +31,15 @@ def setup_logging(log_file, log_level):
while len(logging.root.handlers) > 0:
logging.root.removeHandler(logging.root.handlers[-1])

logging.basicConfig(filename=log_file, level=log_levels[log_level])
# Remove uninformative checkpointing absl logs
logging.getLogger("absl").setLevel(logging.WARNING)

logging.basicConfig(
level=log_levels[log_level],
format="%(levelname)s | %(asctime)s | %(message)s",
datefmt="%H:%M:%S",
handlers=[logging.FileHandler(log_file), logging.StreamHandler(sys.stderr)],
)


def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection:
Expand All @@ -43,26 +50,22 @@ def initialize_loss_fn(loss_config_list: List[LossConfig]) -> LossCollection:
return LossCollection(loss_funcs)


def run(user_config, log_file="train.log", log_level="error"):
setup_logging(log_file, log_level)
log.info("Loading user config")
def run(user_config, log_level="error"):
config = parse_config(user_config)

seed_py_np_tf(config.seed)
rng_key = jax.random.PRNGKey(config.seed)

experiment = Path(config.data.experiment)
directory = Path(config.data.directory)
model_version_path = directory / experiment
log.info("Initializing directories")
model_version_path.mkdir(parents=True, exist_ok=True)
config.dump_config(model_version_path)
config.data.model_version_path.mkdir(parents=True, exist_ok=True)
setup_logging(config.data.model_version_path / "train.log", log_level)
config.dump_config(config.data.model_version_path)

callbacks = initialize_callbacks(config.callbacks, model_version_path)
callbacks = initialize_callbacks(config.callbacks, config.data.model_version_path)
loss_fn = initialize_loss_fn(config.loss)
Metrics = initialize_metrics(config.metrics)

train_raw_ds, val_raw_ds = load_data_files(config.data, model_version_path)
train_raw_ds, val_raw_ds = load_data_files(config.data)
train_ds, ds_stats = initialize_dataset(config, train_raw_ds)
val_ds = initialize_dataset(config, val_raw_ds, calc_stats=False)

Expand Down Expand Up @@ -112,7 +115,7 @@ def run(user_config, log_file="train.log", log_level="error"):
Metrics,
callbacks,
n_epochs,
ckpt_dir=os.path.join(config.data.directory, config.data.experiment),
ckpt_dir=config.data.model_version_path,
ckpt_interval=config.checkpoints.ckpt_interval,
val_ds=val_ds,
sam_rho=config.optimizer.sam_rho,
Expand Down
4 changes: 2 additions & 2 deletions apax/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def fit(
log.info("Beginning Training")
callbacks.on_train_begin()

latest_dir = ckpt_dir + "/latest"
best_dir = ckpt_dir + "/best"
latest_dir = ckpt_dir / "latest"
best_dir = ckpt_dir / "best"
ckpt_manager = CheckpointManager()

train_step, val_step = make_step_fns(
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,6 @@ def load_and_dump_config(config_path, dump_path):
model_config_dict["data"]["directory"] = dump_path.as_posix()

model_config = Config.model_validate(model_config_dict)
os.makedirs(model_config.data.model_version_path(), exist_ok=True)
model_config.dump_config(model_config.data.model_version_path())
os.makedirs(model_config.data.model_version_path, exist_ok=True)
model_config.dump_config(model_config.data.model_version_path)
return model_config
4 changes: 2 additions & 2 deletions tests/integration_tests/bal/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def test_kernel_selection(example_atoms, get_tmp_path, get_sample_input):
_, params = initialize_model(model_config, inputs)

ckpt = {"model": {"params": params}, "epoch": 0}
best_dir = model_config.data.best_model_path()
best_dir = model_config.data.best_model_path
checkpoints.save_checkpoint(
ckpt_dir=best_dir,
target=ckpt,
Expand All @@ -41,7 +41,7 @@ def test_kernel_selection(example_atoms, get_tmp_path, get_sample_input):
bs = 5

selected_indices = kernel_selection(
model_config.data.model_version_path(),
model_config.data.model_version_path,
train_atoms,
pool_atoms,
base_fm_options,
Expand Down
18 changes: 7 additions & 11 deletions tests/integration_tests/md/test_md.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ def test_run_md(get_tmp_path):
md_config_dict["initial_structure"] = get_tmp_path.as_posix() + "/atoms.extxyz"

model_config = Config.model_validate(model_config_dict)
os.makedirs(model_config.data.model_version_path())
model_config.dump_config(model_config.data.model_version_path())
os.makedirs(model_config.data.model_version_path)
model_config.dump_config(model_config.data.model_version_path)
md_config = MDConfig.model_validate(md_config_dict)

positions = jnp.array(
Expand Down Expand Up @@ -80,11 +80,8 @@ def test_run_md(get_tmp_path):
)

ckpt = {"model": {"params": params}, "epoch": 0}
best_dir = os.path.join(
model_config.data.directory, model_config.data.experiment, "best"
)
checkpoints.save_checkpoint(
ckpt_dir=best_dir,
ckpt_dir=model_config.data.best_model_path,
target=ckpt,
step=0,
overwrite=True,
Expand All @@ -106,8 +103,8 @@ def test_ase_calc(get_tmp_path):
model_config_dict["data"]["directory"] = get_tmp_path.as_posix()

model_config = Config.model_validate(model_config_dict)
os.makedirs(model_config.data.model_version_path(), exist_ok=True)
model_config.dump_config(model_config.data.model_version_path())
os.makedirs(model_config.data.model_version_path, exist_ok=True)
model_config.dump_config(model_config.data.model_version_path)

cell_size = 10.0
positions = np.array(
Expand Down Expand Up @@ -147,17 +144,16 @@ def test_ase_calc(get_tmp_path):
)
ckpt = {"model": {"params": params}, "epoch": 0}

best_dir = model_config.data.best_model_path()
checkpoints.save_checkpoint(
ckpt_dir=best_dir,
ckpt_dir=model_config.data.best_model_path,
target=ckpt,
step=0,
overwrite=True,
)

atoms = read(initial_structure_path.as_posix())
calc = ASECalculator(
[model_config.data.model_version_path(), model_config.data.model_version_path()]
[model_config.data.model_version_path, model_config.data.model_version_path]
)

atoms.calc = calc
Expand Down

0 comments on commit 7f61d0d

Please sign in to comment.