Skip to content

Commit

Permalink
fix: upload non-conflicting files for sharded checkpointing [MD-298] (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
azhou-determined authored Jul 15, 2024
1 parent 4ece949 commit 3663c5b
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 13 deletions.
4 changes: 3 additions & 1 deletion harness/determined/common/storage/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ def pre_store_path(self, dst: str) -> pathlib.Path:
return pathlib.Path(storage_dir)

@abc.abstractmethod
def post_store_path(self, src: Union[str, os.PathLike], dst: str) -> None:
def post_store_path(
self, src: Union[str, os.PathLike], dst: str, paths: Optional[Set[str]] = None
) -> None:
"""
Subclasses typically push to persistent storage if necessary, then delete the src directory,
if necessary.
Expand Down
8 changes: 5 additions & 3 deletions harness/determined/common/storage/cloud.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import contextlib
import os
import pathlib
from typing import Iterator, Optional, Union
from typing import Iterator, Optional, Set, Union

from determined import util
from determined.common import storage
Expand All @@ -22,12 +22,14 @@ def restore_path(
finally:
util.rmtree_nfs_safe(dst, ignore_errors=True)

def post_store_path(self, src: Union[str, os.PathLike], dst: str) -> None:
def post_store_path(
self, src: Union[str, os.PathLike], dst: str, paths: Optional[Set[str]] = None
) -> None:
"""
post_store_path uploads the checkpoint to cloud storage and deletes the original files.
"""
try:
self.upload(src, dst)
self.upload(src, dst, paths)
finally:
util.rmtree_nfs_safe(src, ignore_errors=True)

Expand Down
6 changes: 4 additions & 2 deletions harness/determined/common/storage/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pathlib
import shutil
import urllib
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, Set, Union

from determined import errors, util
from determined.common import check, storage
Expand Down Expand Up @@ -171,7 +171,9 @@ def from_config(
)
return cls(base_path)

def post_store_path(self, src: Union[str, os.PathLike], dst: str) -> None:
def post_store_path(
self, src: Union[str, os.PathLike], dst: str, paths: Optional[Set[str]] = None
) -> None:
"""
Nothing to clean up after writing directly to shared_fs.
"""
Expand Down
8 changes: 3 additions & 5 deletions harness/determined/core/_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,6 @@ def _store_path_sharded(

path = self._storage_manager.pre_store_path(storage_id)
yield path, storage_id

ckpt_dir = os.fspath(path)

if self._storage_manager.store_path_is_direct_access():
Expand Down Expand Up @@ -567,20 +566,19 @@ def _store_path_sharded(
resources = self._storage_manager._list_directory(ckpt_dir)
else:
resources = {}

# Merge resources, detect conflicts.
all_resources = self._dist.allgather(resources)

merged_resources, conflicts = merge_resources(all_resources)
self._resolve_conflicts(resources, conflicts, ckpt_dir)
resources = self._resolve_conflicts(resources, conflicts, ckpt_dir)

all_metadata = self._merge_metadata(metadata)
if self._dist.rank == 0:
self._write_metadata_file(ckpt_dir, all_metadata)

if want_upload:
paths = set(resources.keys())
# Use post_store_path to upload and clean up ckpt_dir after uploading.
self._storage_manager.post_store_path(src=ckpt_dir, dst=storage_id)
self._storage_manager.post_store_path(src=ckpt_dir, dst=storage_id, paths=paths)

if self._dist.rank == 0:
self._report_checkpoint(storage_id, merged_resources, all_metadata)
Expand Down
151 changes: 149 additions & 2 deletions harness/tests/core/test_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from tests import parallel


def make_mock_storage_manager(basedir: pathlib.Path) -> Any:
def make_mock_storage_manager(
basedir: pathlib.Path,
dir_files: Optional[List[str]] = None,
) -> Any:
@contextlib.contextmanager
def store_path(dst: str) -> Iterator[pathlib.Path]:
path = basedir.joinpath("store-path")
Expand All @@ -25,10 +28,24 @@ def restore_path(
path.mkdir(exist_ok=True)
yield pathlib.Path(path)

if dir_files:
for file in dir_files:
(basedir / file).touch(exist_ok=True)
else:
dir_files = ["one", "two"]

def pre_store_path(dst: str) -> pathlib.Path:
path = basedir.joinpath("store-path")
path.mkdir(exist_ok=True)
return pathlib.Path(path)

mock_list_dir = {f: i for i, f in enumerate(dir_files)}

storage_manager = mock.MagicMock()
storage_manager.store_path = mock.MagicMock(side_effect=store_path)
storage_manager.pre_store_path = mock.MagicMock(side_effect=pre_store_path)
storage_manager.restore_path = mock.MagicMock(side_effect=restore_path)
storage_manager._list_directory = mock.MagicMock(return_value={"one": 1, "two": 2})
storage_manager._list_directory = mock.MagicMock(return_value=mock_list_dir)
storage_manager.delete = mock.MagicMock()

return storage_manager
Expand Down Expand Up @@ -232,3 +249,133 @@ def test_merge_metadata(
merged, conflicts = core._checkpoint.merge_metadata(metadata)
assert conflicts == expected_conflicts
assert merged == expected_merged


@pytest.mark.parametrize("sharded", [True, False])
def test_checkpoint_upload(sharded: bool, tmp_path: pathlib.Path) -> None:
ckpt_dir = tmp_path.joinpath("ckpt-dir")
ckpt_dir.mkdir(exist_ok=True)

# Create some mock files for each worker to upload, and also identical files across all
# workers to test the sharded conflict case.
ckpt_files = {
0: ["worker0-0", "worker0-1"],
1: ["worker1-0", "worker1-1"],
}

all_workers_files = ["metadata.json", "file2", "file3"]

with parallel.Execution(2) as pex:

@pex.run
def upload_ckpt() -> List[str]:
storage_manager = make_mock_storage_manager(
basedir=ckpt_dir, dir_files=ckpt_files[pex.rank] + all_workers_files
)
checkpoint_context = core.DummyCheckpointContext(pex.distributed, storage_manager)

# Upload across all workers, expect exceptions on non-chief workers if shard=False.
with parallel.raises_when(
not sharded and pex.distributed.rank != 0,
RuntimeError,
match="upload.*non-chief",
):
checkpoint_context.upload(
ckpt_dir,
metadata={"steps_completed": 1},
shard=sharded,
# Implement a selector to test file conflicts in sharded case.
selector=lambda _: True,
)

# When shard=True, every worker will call upload. When shard=False, only the chief
# worker will upload.
upload_paths = []
if sharded or pex.rank == 0:
storage_manager.upload.assert_called_once()
upload_paths = storage_manager.upload.call_args.kwargs["paths"]
storage_manager.upload.reset_mock()
storage_manager._list_directory.assert_called_once()
storage_manager._list_directory.reset_mock()
else:
storage_manager.upload.assert_not_called()
storage_manager._list_directory.assert_not_called()
return upload_paths

assert len(upload_ckpt) == 2
assert sorted(upload_ckpt[0]) == sorted(ckpt_files[0] + all_workers_files)

if sharded:
# In the sharded case, expect each worker to upload unique files. Files that conflict across
# workers should only be uploaded by the chief worker.
assert sorted(upload_ckpt[1]) == sorted(ckpt_files[1])
assert len(set(upload_ckpt[0]).intersection(ckpt_files[1])) == 0
else:
# Only the chief worker should upload files in the non-sharded case.
assert len(list(upload_ckpt[1])) == 0


@pytest.mark.parametrize("sharded", [True, False])
def test_store_path(sharded: bool, tmp_path: pathlib.Path) -> None:
ckpt_dir = tmp_path.joinpath("ckpt-dir")
ckpt_dir.mkdir(exist_ok=True)

# Create some mock files for each worker to upload, and also identical files across all
# workers to test the sharded conflict case.
ckpt_files = {
0: ["worker0-0", "worker0-1"],
1: ["worker1-0", "worker1-1"],
}

all_workers_files = ["metadata.json", "file2", "file3"]

with parallel.Execution(2) as pex:

@pex.run
def do_store_path() -> None:
storage_manager = make_mock_storage_manager(
basedir=ckpt_dir, dir_files=ckpt_files[0] + ckpt_files[1] + all_workers_files
)

storage_manager.store_path_is_direct_access = mock.MagicMock(return_value=False)
checkpoint_context = core.DummyCheckpointContext(pex.distributed, storage_manager)

# Upload across all workers, expect exceptions on non-chief workers if shard=False.
with parallel.raises_when(
not sharded and pex.distributed.rank != 0,
RuntimeError,
match=r"\.store_path.*non-chief",
):
with checkpoint_context.store_path(
metadata={"steps_completed": 1}, shard=sharded
) as (ckpt_path, storage_id):
for f in ckpt_files[pex.rank] + all_workers_files:
(ckpt_path / f).touch()

if not sharded:
if pex.rank == 0:
storage_manager.store_path.assert_called_once()
storage_manager.store_path.reset_mock()
storage_manager._list_directory.assert_called_once()
storage_manager._list_directory.reset_mock()
else:
storage_manager.store_path.assert_not_called()
storage_manager._list_directory.assert_not_called()
else:
# In the sharded case, the chief worker should upload all files written by each
# worker, merging duplicate, conflicting files.
storage_manager.pre_store_path.assert_called_once()
storage_manager.pre_store_path.reset_mock()
if pex.rank == 0:
storage_manager.post_store_path.assert_called_once()
_, call_kwargs = storage_manager.post_store_path.call_args_list[0]
uploaded_paths = call_kwargs["paths"]
assert sorted(uploaded_paths) == sorted(
ckpt_files[0] + ckpt_files[1] + all_workers_files
)
storage_manager.post_store_path.reset_mock()
storage_manager._list_directory.assert_called_once()
storage_manager._list_directory.reset_mock()
else:
storage_manager.post_store_path.assert_not_called()
storage_manager._list_directory.assert_not_called()

0 comments on commit 3663c5b

Please sign in to comment.