Skip to content

Commit

Permalink
feat: Shuffle between epochs (#456)
Browse files Browse the repository at this point in the history
This PR introduces a `shuffle` option for training: If `True`, then we
shuffle the order of the partitions and the keys within the partitions
between each epoch.

Note that as described in #460, we might need to have this a bit more
finegrained for things like Criteo to optimize performance.
  • Loading branch information
MaxiBoether authored and robinholzi committed Jun 4, 2024
1 parent 5982bad commit db5d38f
Show file tree
Hide file tree
Showing 38 changed files with 320 additions and 201 deletions.
1 change: 1 addition & 0 deletions benchmark/mnist/mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ training:
use_previous_model: True
initial_model: random
batch_size: 64
shuffle: True
optimizers:
- name: "default"
algorithm: "SGD"
Expand Down
1 change: 1 addition & 0 deletions benchmark/wildtime_benchmarks/example_pipelines/arxiv.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ training:
use_previous_model: True
initial_model: random
batch_size: 128
shuffle: True
optimizers:
- name: "default"
algorithm: "SGD"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ training:
use_previous_model: True
initial_model: random
batch_size: 96
shuffle: True
optimizers:
- name: "default"
algorithm: "SGD"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ training:
use_previous_model: True
initial_model: random
batch_size: 64
shuffle: True
optimizers:
- name: "default"
algorithm: "SGD"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ training:
use_previous_model: True
initial_model: random
batch_size: 64
shuffle: True
optimizers:
- name: "default"
algorithm: "SGD"
Expand Down
1 change: 1 addition & 0 deletions benchmark/wildtime_benchmarks/example_pipelines/fmow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ training:
use_previous_model: True
initial_model: random
batch_size: 64
shuffle: True
optimizers:
- name: "default"
algorithm: "SGD"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ training:
use_previous_model: True
initial_model: random
batch_size: 64
shuffle: True
optimizers:
- name: "default"
algorithm: "SGD"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ training:
use_previous_model: True
initial_model: random
batch_size: 64
shuffle: True
optimizers:
- name: "default"
algorithm: "SGD"
Expand Down
89 changes: 73 additions & 16 deletions integrationtests/online_dataset/test_online_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import random
import shutil
import time
from typing import Iterable, Tuple
from typing import Iterable, Optional, Tuple

import grpc
import modyn.storage.internal.grpc.generated.storage_pb2 as storage_pb2
Expand Down Expand Up @@ -275,6 +275,8 @@ def test_dataset_impl(
pipeline_id: int,
trigger_id: int,
items: list[int],
shuffle: bool,
consistency_check: bool,
) -> None:
dataloader, _ = prepare_dataloaders(
pipeline_id,
Expand All @@ -289,6 +291,7 @@ def test_dataset_impl(
42,
prefetched_partitions,
parallel_prefetch_requests,
shuffle,
None,
None,
)
Expand Down Expand Up @@ -326,7 +329,7 @@ def test_dataset_impl(
+ f"expected_min = {expected_min_batches}, expected_max = {expected_max_batches}"
)

assert set(all_samples) == set(items)
assert set(all_samples) == set(items), f"all_samples = {all_samples} \n\n items = {items}"
assert set(all_labels) == set(range(len(items)))

trans = transforms.Compose([transforms.ToPILImage()])
Expand All @@ -339,6 +342,47 @@ def test_dataset_impl(
if image_bytes not in FIRST_ADDED_IMAGES:
raise ValueError(f"Could not find image {idx} in created images, all_samples = {all_samples}")

if not consistency_check:
return

print("Iterating again to check across epochs.")

second_samples = []
second_data = []
second_labels = []

for batch_number, batch in enumerate(dataloader):
sample_ids = batch[0]
if isinstance(sample_ids, torch.Tensor):
sample_ids = sample_ids.tolist()
elif isinstance(sample_ids, tuple):
sample_ids = list(sample_ids)

assert isinstance(sample_ids, list), "Cannot parse result from DataLoader"
assert isinstance(batch[1], torch.Tensor) and isinstance(batch[2], torch.Tensor)

second_samples.extend(sample_ids)
for sample in batch[1]:
second_data.append(sample) # iterate over batch dimension to extract samples
second_labels.extend(batch[2].tolist())

# Same content, but not same order
# (even without shuffle, the storage may return samples in a slightly different order)

assert set(second_samples) == set(
all_samples
), f"second_samples = {second_samples} \n\n all_samples = {all_samples}"
assert set(second_labels) == set(all_labels), f"second_labels = {second_labels} \n\n all_labels = {all_labels}"
for data1 in second_data:
assert any(torch.allclose(data1, data2) for data2 in all_data)

# when shuffling, we expect a different order

if shuffle:
assert second_samples != all_samples, f"second_samples = {second_samples} \n\n all_samples = {all_samples}"
assert not all(torch.allclose(data1, data2) for data1, data2 in zip(second_data, all_data))
assert second_labels != all_labels, f"second_labels = {second_labels} \n\n all_labels = {all_labels}"


def test_dataset() -> None:
NUM_IMAGES = 10
Expand All @@ -359,22 +403,35 @@ def test_dataset() -> None:
if prefetched_partitions == 5:
ppr_list = [1, 2, 5, 999]

# By default, we do neither test shuffle nor cross-epoch consistency
# Only in a selected case, we test it to avoid blowing up the test further.
shuffles = [False]
consistency_checks = [False]
if num_dataworkers in [0, 4] and prefetched_partitions in [0, 4]:
shuffles = [False, True]
consistency_checks = [True]

for parallel_prefetch_requests in ppr_list:
for batch_size in [1, 2, 10]:
print(
f"Testing num_workers = {num_dataworkers}, partitions = {prefetched_partitions},"
+ f"batch_size = {batch_size}, parallel_prefetch_requests={parallel_prefetch_requests}"
)
test_dataset_impl(
num_dataworkers,
batch_size,
prefetched_partitions,
parallel_prefetch_requests,
pipeline_id,
trigger_id,
keys,
)
gc.collect()
for consistency_check in consistency_checks:
for shuffle in shuffles:
print(
f"Testing num_workers = {num_dataworkers}, partitions = {prefetched_partitions},"
+ f"batch_size = {batch_size}, parallel_prefetch_requests={parallel_prefetch_requests}"
+ f" consistency_check = {consistency_check} shuffle = {shuffle}"
)
test_dataset_impl(
num_dataworkers,
batch_size,
prefetched_partitions,
parallel_prefetch_requests,
pipeline_id,
trigger_id,
keys,
shuffle,
consistency_check,
)
gc.collect()


def main() -> None:
Expand Down
1 change: 1 addition & 0 deletions modyn/config/examples/example-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ training:
use_previous_model: True
initial_model: random
batch_size: 64
shuffle: False
optimizers:
- name: "default"
algorithm: "SGD"
Expand Down
6 changes: 6 additions & 0 deletions modyn/config/schema/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,12 @@ class TrainingConfig(ModynBaseModel):
description="The number of data loader workers on the trainer node that fetch data from storage.", ge=1
)
batch_size: int = Field(description="The batch size to be used during training.", ge=1)
shuffle: bool = Field(
description=(
"If True, we shuffle the order of partitions and the data within each partition at each worker."
"Otherwise, the output order is deterministic."
)
)
use_previous_model: bool = Field(
description=(
"If True, on trigger, we continue training on the model outputted by the previous trigger. If False, "
Expand Down
1 change: 1 addition & 0 deletions modyn/protos/trainer_server.proto
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ message StartTrainingRequest {
optional int32 seed = 21;
optional PythonString tokenizer = 22;
int64 num_samples_to_pass = 23;
bool shuffle = 24;
}

message StartTrainingResponse {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,6 @@ def store_training_set(

swt.start("store_triggersamples", overwrite=True)
if insertion_threads == 1:

AbstractSelectionStrategy._store_triggersamples_impl(
partition,
target_trigger_id,
Expand Down
1 change: 1 addition & 0 deletions modyn/supervisor/internal/triggers/datadrifttrigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def _init_dataloader_info(self) -> None:
selector_address=f"{self.context.modyn_config.selector.address}",
num_prefetched_partitions=training_config.num_prefetched_partitions,
parallel_prefetch_requests=training_config.parallel_prefetch_requests,
shuffle=training_config.shuffle,
tokenizer=data_config.tokenizer,
)

Expand Down
1 change: 0 additions & 1 deletion modyn/supervisor/internal/triggers/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ class TriggerContext:


class Trigger(ABC):

# pylint: disable=unnecessary-pass
def init_trigger(self, context: TriggerContext) -> None:
"""The supervisor initializes the concrete Trigger with Trigger-type-specific configurations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def __init__(
selector_address: str,
num_prefetched_partitions: int,
parallel_prefetch_requests: int,
shuffle: bool,
tokenizer: Optional[str],
):
self.pipeline_id = pipeline_id
Expand All @@ -29,3 +30,4 @@ def __init__(
self.parallel_prefetch_requests = parallel_prefetch_requests
self.tokenizer = tokenizer
self.training_id = -1
self.shuffle = shuffle
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
training_id: int,
num_prefetched_partitions: int,
parallel_prefetch_requests: int,
shuffle: bool,
tokenizer: Optional[str] = None,
sample_prob: Optional[float] = None,
):
Expand All @@ -46,6 +47,7 @@ def __init__(
training_id,
num_prefetched_partitions,
parallel_prefetch_requests,
shuffle,
tokenizer,
None,
)
Expand Down
1 change: 1 addition & 0 deletions modyn/supervisor/internal/triggers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def prepare_trigger_dataloader_by_trigger(
dataloader_info.training_id,
dataloader_info.num_prefetched_partitions,
dataloader_info.parallel_prefetch_requests,
dataloader_info.shuffle,
dataloader_info.tokenizer,
sample_prob,
)
Expand Down
1 change: 1 addition & 0 deletions modyn/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def pipeline_training_config() -> TrainingConfig:
],
optimization_criterion=OptimizationCriterion(name="CrossEntropyLoss"),
checkpointing=CheckpointingConfig(activated=False),
shuffle=False,
)


Expand Down
1 change: 0 additions & 1 deletion modyn/tests/selector/internal/storage_backend/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class MockStorageBackend(AbstractStorageBackend):

# pylint: disable=super-init-not-called
def __init__(self, pipeline_id: int, modyn_config: dict, maximum_keys_in_memory: int):
self.insertion_threads = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,10 @@ def test_initialization(non_connecting_pipeline_executor: PipelineExecutor) -> N


def test_pipeline_stage_decorator(dummy_pipeline_args: PipelineExecutionParams) -> None:

class TestStageLogInfo(StageInfo):
name: str

class TestPipelineExecutor(PipelineExecutor):

@pipeline_stage(PipelineStage.INIT, log=True, track=True)
def _stage_func(self, s: ExecutionState, log: StageLog) -> int:
time.sleep(0.1)
Expand All @@ -170,7 +168,6 @@ def _stage_func(self, s: ExecutionState, log: StageLog) -> int:


def test_pipeline_stage_decorator_generator(dummy_pipeline_args: PipelineExecutionParams) -> None:

class TestStageLogInfo(StageInfo):
elements: list[int]
finalized: bool = False
Expand All @@ -182,7 +179,6 @@ def create_generator(x: int = 3) -> Generator[int, None, None]:
yield i

class TestPipelineExecutor(PipelineExecutor):

@pipeline_stage(PipelineStage.INIT, log=True, track=True)
def _stage_func(self, s: ExecutionState, log: StageLog) -> Generator[int, None, None]:
try:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def noop_dataloader_info_constructor_mock(
selector_address: str,
num_prefetched_partitions: int,
parallel_prefetch_requests: int,
shuffle: bool,
tokenizer: Optional[None],
) -> None:
pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def test_init():
num_prefetched_partitions=1,
parallel_prefetch_requests=1,
sample_prob=0.5,
shuffle=False,
)
assert online_trigger_dataset._pipeline_id == 1
assert online_trigger_dataset._trigger_id == 1
Expand All @@ -78,6 +79,7 @@ def test_dataset_iter():
num_prefetched_partitions=1,
parallel_prefetch_requests=1,
sample_prob=0.5,
shuffle=False,
)

all_trigger_data = list(online_trigger_dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_prepare_dataloaders(
test_weights, test_insecure_channel, test_grpc_connection_established, test_grpc_connection_established_selector
):
train_dataloader, _ = prepare_dataloaders(
1, 1, "MNIST", 4, 128, get_mock_bytes_parser(), [], "", "", 42, 5, 5, None, None
1, 1, "MNIST", 4, 128, get_mock_bytes_parser(), [], "", "", 42, 5, 5, False, None, None
)

assert train_dataloader.num_workers == 4
Expand Down
Loading

0 comments on commit db5d38f

Please sign in to comment.