-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add basic unit test for training_utils
- Loading branch information
1 parent
dba9065
commit 7f72202
Showing
2 changed files
with
217 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
[models.mock_model] | ||
train = true | ||
|
||
[training] | ||
duration = "100:epoch" | ||
seed = 0 | ||
gpu_index = 0 | ||
batch_size = 4 | ||
gradient_accumulation = "4:step" | ||
clip_grad_norm = 1.0 | ||
evaluation_interval = "5:epoch" | ||
evaluation_seed = 1 | ||
|
||
[optimizer] | ||
optimizer = "SGD" | ||
learning_rate = 1 | ||
momentum = 0.9 | ||
|
||
[scheduler] | ||
scheduler_type = "ConstantLR" | ||
update_interval = "1:step" | ||
warmup = "20:step" | ||
|
||
[dropout] | ||
dropout = 0.0 | ||
|
||
[checkpointing] | ||
save_interval = "10:epoch" | ||
|
||
[wandb] | ||
mode = "disabled" | ||
project = "mock_project" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,185 @@ | ||
from dataclasses import dataclass | ||
from functools import cached_property | ||
from pathlib import Path | ||
from warnings import warn | ||
|
||
import pytest | ||
import torch | ||
from torch import Tensor, nn | ||
from torch.utils.data import Dataset | ||
|
||
from refiners.fluxion import layers as fl | ||
from refiners.fluxion.utils import norm | ||
from refiners.training_utils.config import BaseConfig, TimeUnit | ||
from refiners.training_utils.trainer import ( | ||
Trainer, | ||
TrainingClock, | ||
count_learnable_parameters, | ||
human_readable_number, | ||
) | ||
|
||
|
||
@dataclass | ||
class MockBatch: | ||
inputs: torch.Tensor | ||
targets: torch.Tensor | ||
|
||
|
||
class MockDataset(Dataset[MockBatch]): | ||
def __len__(self): | ||
return 20 | ||
|
||
def __getitem__(self, _: int) -> MockBatch: | ||
return MockBatch(inputs=torch.randn(1, 10), targets=torch.randn(1, 10)) | ||
|
||
def collate_fn(self, batch: list[MockBatch]) -> MockBatch: | ||
return MockBatch( | ||
inputs=torch.cat([b.inputs for b in batch]), | ||
targets=torch.cat([b.targets for b in batch]), | ||
) | ||
|
||
|
||
class MockConfig(BaseConfig): | ||
pass | ||
|
||
|
||
class MockModel(fl.Chain): | ||
def __init__(self): | ||
super().__init__( | ||
fl.Linear(10, 10), | ||
fl.Linear(10, 10), | ||
fl.Linear(10, 10), | ||
) | ||
|
||
|
||
class MockTrainer(Trainer[MockConfig, MockBatch]): | ||
step_counter: int = 0 | ||
|
||
@cached_property | ||
def mock_model(self) -> MockModel: | ||
return MockModel() | ||
|
||
def load_dataset(self) -> Dataset[MockBatch]: | ||
return MockDataset() | ||
|
||
def load_models(self) -> dict[str, fl.Module]: | ||
return {"mock_model": self.mock_model} | ||
|
||
def compute_loss(self, batch: MockBatch) -> Tensor: | ||
self.step_counter += 1 | ||
inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device) | ||
outputs = self.mock_model(inputs) | ||
return norm(outputs - targets) | ||
|
||
|
||
@pytest.fixture | ||
def mock_config(test_device: torch.device) -> MockConfig: | ||
if not test_device.type == "cuda": | ||
warn("only running on CUDA, skipping") | ||
pytest.skip("Skipping test because test_device is not CUDA") | ||
config = MockConfig.load_from_toml(Path(__file__).parent / "mock_config.toml") | ||
config.training.gpu_index = test_device.index | ||
return config | ||
|
||
|
||
@pytest.fixture | ||
def mock_trainer(mock_config: MockConfig) -> MockTrainer: | ||
return MockTrainer(config=mock_config) | ||
|
||
|
||
@pytest.fixture | ||
def mock_model() -> fl.Chain: | ||
return MockModel() | ||
|
||
|
||
def test_count_learnable_parameters_with_params() -> None: | ||
params = [ | ||
nn.Parameter(torch.randn(2, 2), requires_grad=True), | ||
nn.Parameter(torch.randn(5), requires_grad=False), | ||
nn.Parameter(torch.randn(3, 3), requires_grad=True), | ||
] | ||
assert count_learnable_parameters(params) == 13 | ||
|
||
|
||
def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None: | ||
assert count_learnable_parameters(mock_model.parameters()) == 330 | ||
|
||
|
||
def test_human_readable_number() -> None: | ||
assert human_readable_number(123) == "123.0" | ||
assert human_readable_number(1234) == "1.2K" | ||
assert human_readable_number(1234567) == "1.2M" | ||
|
||
|
||
@pytest.fixture | ||
def training_clock() -> TrainingClock: | ||
return TrainingClock( | ||
dataset_length=100, | ||
batch_size=10, | ||
training_duration={"number": 5, "unit": TimeUnit.EPOCH}, | ||
gradient_accumulation={"number": 1, "unit": TimeUnit.EPOCH}, | ||
evaluation_interval={"number": 1, "unit": TimeUnit.EPOCH}, | ||
lr_scheduler_interval={"number": 1, "unit": TimeUnit.EPOCH}, | ||
checkpointing_save_interval={"number": 1, "unit": TimeUnit.EPOCH}, | ||
) | ||
|
||
|
||
def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None: | ||
assert training_clock.convert_time_unit_to_steps(1, TimeUnit.EPOCH) == 10 | ||
assert training_clock.convert_time_unit_to_steps(2, TimeUnit.EPOCH) == 20 | ||
assert training_clock.convert_time_unit_to_steps(1, TimeUnit.STEP) == 1 | ||
|
||
|
||
def test_steps_to_time_unit_conversion(training_clock: TrainingClock) -> None: | ||
assert training_clock.convert_steps_to_time_unit(10, TimeUnit.EPOCH) == 1 | ||
assert training_clock.convert_steps_to_time_unit(20, TimeUnit.EPOCH) == 2 | ||
assert training_clock.convert_steps_to_time_unit(1, TimeUnit.STEP) == 1 | ||
|
||
|
||
def test_clock_properties(training_clock: TrainingClock) -> None: | ||
assert training_clock.num_batches_per_epoch == 10 | ||
assert training_clock.num_epochs == 5 | ||
assert training_clock.num_iterations == 5 | ||
assert training_clock.num_steps == 50 | ||
|
||
|
||
def test_timer_functionality(training_clock: TrainingClock) -> None: | ||
training_clock.start_timer() | ||
assert training_clock.start_time is not None | ||
training_clock.stop_timer() | ||
assert training_clock.end_time is not None | ||
assert training_clock.time_elapsed >= 0 | ||
|
||
|
||
def test_state_based_properties(training_clock: TrainingClock) -> None: | ||
training_clock.step = 5 # Halfway through the first epoch | ||
assert not training_clock.is_evaluation_step # Assuming evaluation every epoch | ||
assert not training_clock.is_checkpointing_step | ||
training_clock.step = 10 # End of the first epoch | ||
assert training_clock.is_evaluation_step | ||
assert training_clock.is_checkpointing_step | ||
|
||
|
||
def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: MockTrainer) -> None: | ||
assert mock_trainer.config == mock_config | ||
assert isinstance(mock_trainer, MockTrainer) | ||
assert mock_trainer.optimizer is not None | ||
assert mock_trainer.lr_scheduler is not None | ||
|
||
|
||
def test_training_cycle(mock_trainer: MockTrainer) -> None: | ||
clock = mock_trainer.clock | ||
config = mock_trainer.config | ||
|
||
assert clock.num_step_per_iteration == config.training.gradient_accumulation["number"] | ||
assert clock.num_batches_per_epoch == mock_trainer.dataset_length // config.training.batch_size | ||
|
||
assert mock_trainer.step_counter == 0 | ||
assert mock_trainer.clock.epoch == 0 | ||
|
||
mock_trainer.train() | ||
|
||
assert clock.epoch == config.training.duration["number"] | ||
assert clock.step == config.training.duration["number"] * clock.num_batches_per_epoch | ||
|
||
assert mock_trainer.step_counter == mock_trainer.clock.step |