From ea1c6943b63e390e1d63800619e31cb8554d41ac Mon Sep 17 00:00:00 2001 From: Ryan Date: Fri, 1 Nov 2024 16:59:27 -0600 Subject: [PATCH] fix: keras continue from cloud checkpoint (#10192) What happened was an unfortunate combination of fancy context managers and early exits, where I was unwittingly deleting checkpoints before keras could read them, but only in the case of checkpoints from different trial IDs. I had tested pause/continue, which didn't have the bug. The new code is structurally incapable of the same bug, and has a regression test anyway. --- harness/determined/keras/_callback.py | 60 ++++++++++--------- .../tests/experiment/keras/test_callback.py | 31 +++++++++- 2 files changed, 61 insertions(+), 30 deletions(-) diff --git a/harness/determined/keras/_callback.py b/harness/determined/keras/_callback.py index 843d15f766e..f1b38830bc4 100644 --- a/harness/determined/keras/_callback.py +++ b/harness/determined/keras/_callback.py @@ -1,6 +1,7 @@ import contextlib import logging import os +import pathlib import pickle import shutil import tempfile @@ -337,34 +338,8 @@ def _load(self, checkpoint: Optional[str]) -> Optional[contextlib.ExitStack]: # Load model. self.load_model(self.model, str(path / "model_checkpoint"), self._core.distributed) - # Load training state also. - state_path = path / "callback_state" - if not state_path.exists(): - return None - with state_path.open("rb") as f: - state = pickle.load(f) - if state["continue_id"] != self._continue_id: - return None - # Continue training where we left off. - self._steps_completed = state["steps_completed"] - self._training_length = state["training_length"] - self._validation_length = state["validation_length"] - initial_epoch: int = state["epoch"] + 1 - - # HACK: Trick the training loop into starting on a different epoch. Internally, this is - # how keras.callbacks.BackupAndRestore() sets the initial_epoch. - class WorkerTrainingState: - # For tf.keras. - def maybe_load_initial_epoch_from_ckpt(*_: Any, **__: Any) -> int: - return initial_epoch - - # For plain keras. - def maybe_load_initial_counters_from_ckpt(*_: Any, **__: Any) -> Tuple[int, int]: - # We only save on epoch boundaries. - initial_batch = 0 - return initial_epoch, initial_batch - - self.model._training_state = WorkerTrainingState() + # Load our own state. + self._load_training_state(path) # Success! Don't delete the checkpoint until after the first batch runs though, because # the checkpoint isn't actually read until then. @@ -373,6 +348,35 @@ def maybe_load_initial_counters_from_ckpt(*_: Any, **__: Any) -> Tuple[int, int] # mypy thinks it's possible to arrive here, but it isn't. raise RuntimeError("impossible codepath") + def _load_training_state(self, path: pathlib.Path) -> None: + state_path = path / "callback_state" + if not state_path.exists(): + return + with state_path.open("rb") as f: + state = pickle.load(f) + if state["continue_id"] != self._continue_id: + return + # Continue training where we left off. + self._steps_completed = state["steps_completed"] + self._training_length = state["training_length"] + self._validation_length = state["validation_length"] + initial_epoch: int = state["epoch"] + 1 + + # HACK: Trick the training loop into starting on a different epoch. Internally, this is + # how keras.callbacks.BackupAndRestore() sets the initial_epoch. + class WorkerTrainingState: + # For tf.keras. + def maybe_load_initial_epoch_from_ckpt(*_: Any, **__: Any) -> int: + return initial_epoch + + # For plain keras. + def maybe_load_initial_counters_from_ckpt(*_: Any, **__: Any) -> Tuple[int, int]: + # We only save on epoch boundaries. + initial_batch = 0 + return initial_epoch, initial_batch + + self.model._training_state = WorkerTrainingState() + def save_model( self, model: models.Model, path: str, distributed: core.DistributedContext ) -> None: diff --git a/harness/tests/experiment/keras/test_callback.py b/harness/tests/experiment/keras/test_callback.py index ff32e11a729..2823a8474f4 100644 --- a/harness/tests/experiment/keras/test_callback.py +++ b/harness/tests/experiment/keras/test_callback.py @@ -1,10 +1,11 @@ +import contextlib import json import os import pathlib import re import subprocess import sys -from typing import Any, Callable, Dict, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterator, Optional, Tuple, Union from unittest import mock import keras @@ -30,8 +31,21 @@ def mock_core_context( """ # Set up a functional DistributedContext. distributed = distributed or core.DummyDistributedContext() + # Set up a functional CheckpointContext. - storage_manager = storage.SharedFSStorageManager(path) + class StorageManagerForTesting(storage.SharedFSStorageManager): + @contextlib.contextmanager + def restore_path( + self, src: str, selector: Optional[storage.Selector] = None + ) -> Iterator[pathlib.Path]: + events.append(("restore_path:enter", None)) + try: + with super().restore_path(src, selector) as x: + yield x + finally: + events.append(("restore_path:exit", None)) + + storage_manager = StorageManagerForTesting(path) checkpoint = core.DummyCheckpointContext(distributed, storage_manager) # Mock everything else, logging report-like calls to events. @@ -74,6 +88,7 @@ class DeterminedCallbackForTesting(det.keras.DeterminedCallback): def __init__(self, events: utils.Events, *args: Any, **kwargs: Any) -> None: self.events = events + self.first_train_batch_end = False super().__init__(*args, **kwargs) def on_train_begin(self, logs: Any) -> None: @@ -82,6 +97,12 @@ def on_train_begin(self, logs: Any) -> None: fourdigits = "%.4f" % weight self.events.append((f"after_train_begin:{fourdigits}", weight)) + def on_train_batch_end(self, batch: int, logs: Any) -> None: + if not self.first_train_batch_end: + self.first_train_batch_end = True + self.events.append(("first_train_batch_end", None)) + super().on_train_batch_end(batch, logs) + def on_epoch_end(self, epoch: int, logs: Any) -> None: self.events.append((f"before_epoch_end:{epoch}", logs)) super().on_epoch_end(epoch, logs) @@ -250,12 +271,15 @@ def test_save_restore_and_warm_start(tmp_path: pathlib.Path, eager: bool) -> Non # - initial weight is nonzero (checkpoint was loaded) # - initial epoch is nonzero (training state was loaded) # - steps_completed was properly restored + # - checkpoint is not destoyed until first batch is completed events = do_fit(tmp_path, eager=eager, checkpoint=ckpt, continue_id=1) utils.assert_events_match( events, "set_status:restoring", "load_model", "after_train_begin:%.4f" % weight, + "first_train_batch_end", + "restore_path:exit", "!after_epoch_end:0", "before_epoch_end:1", "report_metrics:training:16", @@ -267,12 +291,15 @@ def test_save_restore_and_warm_start(tmp_path: pathlib.Path, eager: bool) -> Non # - initial weight is nonzero (no checkpoint was loaded) # - initial epoch is zero (no training state was loaded) # - steps_completed was properly reset + # - checkpoint is not destoyed until first batch is completed events = do_fit(tmp_path, eager=eager, checkpoint=ckpt, continue_id=2) utils.assert_events_match( events, "set_status:restoring", "load_model", "after_train_begin:%.4f" % weight, + "first_train_batch_end", + "restore_path:exit", "report_metrics:training:8", "after_epoch_end:0", "after_epoch_end:1",