diff --git a/mace/cli/run_train.py b/mace/cli/run_train.py index c63f1c4d..959f3d6b 100644 --- a/mace/cli/run_train.py +++ b/mace/cli/run_train.py @@ -446,6 +446,21 @@ def main() -> None: ) start_epoch = 0 + if args.init_latest: + try: + opt_start_epoch = checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=True, + device=device, + model_only=True, + ) + except Exception: # pylint: disable=W0703 + opt_start_epoch = checkpoint_handler.load_latest( + state=tools.CheckpointState(model, optimizer, lr_scheduler), + swa=False, + device=device, + model_only=True, + ) if args.restart_latest: try: opt_start_epoch = checkpoint_handler.load_latest( diff --git a/mace/tools/arg_parser.py b/mace/tools/arg_parser.py index e16be03f..963e76db 100644 --- a/mace/tools/arg_parser.py +++ b/mace/tools/arg_parser.py @@ -432,7 +432,14 @@ def build_default_arg_parser() -> argparse.ArgumentParser: action="store_true", default=False, ) - parser.add_argument( + parser_restart = parser.add_mutually_exclusive_group() + parser_restart.add_argument( + "--init_latest", + help="initialize model from latest checkpoint", + action="store_true", + default=False, + ) + parser_restart.add_argument( "--restart_latest", help="restart optimizer from latest checkpoint", action="store_true", diff --git a/mace/tools/checkpoint.py b/mace/tools/checkpoint.py index 8a62f1f2..95fee177 100644 --- a/mace/tools/checkpoint.py +++ b/mace/tools/checkpoint.py @@ -35,11 +35,12 @@ def create_checkpoint(state: CheckpointState) -> Checkpoint: @staticmethod def load_checkpoint( - state: CheckpointState, checkpoint: Checkpoint, strict: bool + state: CheckpointState, checkpoint: Checkpoint, strict: bool, model_only: bool=False ) -> None: state.model.load_state_dict(checkpoint["model"], strict=strict) # type: ignore - state.optimizer.load_state_dict(checkpoint["optimizer"]) - state.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) + if not model_only: + state.optimizer.load_state_dict(checkpoint["optimizer"]) + state.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) @dataclasses.dataclass @@ -206,13 +207,14 @@ def load_latest( swa: Optional[bool] = False, device: Optional[torch.device] = None, strict=False, + model_only: bool = False ) -> Optional[int]: result = self.io.load_latest(swa=swa, device=device) if result is None: return None checkpoint, epochs = result - self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict) + self.builder.load_checkpoint(state=state, checkpoint=checkpoint, strict=strict, model_only=model_only) return epochs def load(