diff --git a/apax/config/common.py b/apax/config/common.py index c2acaf5a..393dfdb4 100644 --- a/apax/config/common.py +++ b/apax/config/common.py @@ -32,8 +32,7 @@ def parse_config(config: Union[str, os.PathLike, dict], mode: str = "train") -> def flatten(dictionary, parent_key="", separator="_"): - """https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys - """ + """https://stackoverflow.com/questions/6027558/flatten-nested-dictionaries-compressing-keys""" items = [] for key, value in dictionary.items(): new_key = parent_key + separator + key if parent_key else key diff --git a/apax/train/trainer.py b/apax/train/trainer.py index 6d6bc0f0..8c040a3f 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -107,12 +107,10 @@ def fit( epoch_loss["val_loss"] /= val_steps_per_epoch epoch_loss["val_loss"] = float(epoch_loss["val_loss"]) - epoch_metrics.update( - { - f"val_{key}": float(val) - for key, val in val_batch_metrics.compute().items() - } - ) + epoch_metrics.update({ + f"val_{key}": float(val) + for key, val in val_batch_metrics.compute().items() + }) epoch_metrics.update({**epoch_loss})