From 636b44281a27f5823973cafbd427664ea4f09acd Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Fri, 12 Jan 2024 15:47:59 +0100 Subject: [PATCH 01/14] fabric --- src/refiners/training_utils/config.py | 6 +- src/refiners/training_utils/fabric_trainer.py | 19 +++++ src/refiners/training_utils/trainer.py | 7 +- tests/conftest.py | 7 +- tests/training_utils/test_latent_diffusion.py | 73 +++++++++++++++++++ 5 files changed, 107 insertions(+), 5 deletions(-) create mode 100644 src/refiners/training_utils/fabric_trainer.py create mode 100644 tests/training_utils/test_latent_diffusion.py diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 6b61c1a70..c8c12c5c9 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -226,10 +226,12 @@ class BaseConfig(BaseModel): scheduler: SchedulerConfig dropout: DropoutConfig checkpointing: CheckpointingConfig - + @classmethod + def load_from_dict(cls: Type[T], config_dict: dict[str, Any]) -> T: + return cls(**config_dict) @classmethod def load_from_toml(cls: Type[T], toml_path: Path | str) -> T: with open(file=toml_path, mode="rb") as f: config_dict = tomli.load(f) - return cls(**config_dict) + return cls.load_from_dict(**config_dict) diff --git a/src/refiners/training_utils/fabric_trainer.py b/src/refiners/training_utils/fabric_trainer.py new file mode 100644 index 000000000..6242398be --- /dev/null +++ b/src/refiners/training_utils/fabric_trainer.py @@ -0,0 +1,19 @@ +from .trainer import Trainer + +class FabricTrainer(Trainer): + @cached_property + def optimizer(self) -> Optimizer: + optimizer = super().optimizer + for model_name in self.models: + model, optimizer = fabric.setup(self.models[model_name], optimizer) + self.models[model_name] = model + return optimizer + + def _backward(self, tensors: torch.Tensor | List[torch.Tensor]): + + # Check if the input is a single tensor + if isinstance(input_tensor, torch.Tensor): + input_tensor = [input_tensor] # Wrap the tensor in a list + + for tensor in tensors: + fabric.backward(tensor) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 40cad5876..eeb411155 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -480,12 +480,15 @@ def compute_loss(self, batch: Batch) -> Tensor: def compute_evaluation(self) -> None: pass - + + def _backward(self, tensors) -> None: + backward(tensors=tensors) + def backward(self) -> None: """Backward pass on the loss.""" self._call_callbacks(event_name="on_backward_begin") scaled_loss = self.loss / self.clock.num_step_per_iteration - backward(tensors=scaled_loss) + self._backward(scaled_loss) self._call_callbacks(event_name="on_backward_end") if self.clock.is_optimizer_step: self._call_callbacks(event_name="on_optimizer_step_begin") diff --git a/tests/conftest.py b/tests/conftest.py index d1403ffb3..b03deaa81 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,6 @@ PARENT_PATH = Path(__file__).parent - @fixture(scope="session") def test_device() -> torch.device: test_device = os.getenv("REFINERS_TEST_DEVICE") @@ -14,6 +13,12 @@ def test_device() -> torch.device: return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") return torch.device(test_device) +@fixture(scope="session") +def test_second_device() -> torch.device: + test_device = os.getenv("REFINERS_TEST_SECOND_DEVICE") + if not test_device and torch.cuda.device_count() > 1: + return torch.device("cuda:1") + return torch.device("cpu") @fixture(scope="session") def test_weights_path() -> Path: diff --git a/tests/training_utils/test_latent_diffusion.py b/tests/training_utils/test_latent_diffusion.py new file mode 100644 index 000000000..5cf2955bb --- /dev/null +++ b/tests/training_utils/test_latent_diffusion.py @@ -0,0 +1,73 @@ +from refiners.training_utils.latent_diffusion import FinetuneLatentDiffusionConfig, LatentDiffusionTrainer +from torch import device as Device +from warnings import warn +import pytest + +DEFAULT_LATENT_DICT = dict( + script = "foo.py", + wandb = dict( + mode = "offline", + entity = "acme", + project = "test-ldm-training" + ), + latent_diffusion = dict( + unconditional_sampling_probability = 0.2, + offset_noise = 0.1 + ), + optimizer = dict( + optimizer = "AdamW", + learning_rate = 1e-5, + betas = [0.9, 0.999], + eps = 1e-8, + weight_decay = 1e-2 + ), + scheduler = dict(), + dropout = dict(dropout_probability = 0.2), + checkpointing=dict(save_interval = "1:epoch"), + test_diffusion=dict(prompts = [ + "A cute cat", + ]), + models = dict( + lda = dict( + checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors", + train = False, + ), + text_encoder = dict( + checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", + train = True, + ), + unet= dict( + checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors", + train = True, + ), + ), + training = dict( + duration= "1:epoch", + gpu_index= 0 + ), + dataset = dict( + hf_repo= "1aurent/unsplash-lite-palette", + revision= "main", + caption_key = "ai_description" + ) +) + + +def test_ldm_trainer_text_encoder_on_two_devices(test_device: Device, test_second_device: Device): + + if test_device.type == "cpu": + warn("not running on CPU, skipping") + pytest.skip() + + if test_second_device.type == "cpu": + warn("Running with only one GPU, skipping") + pytest.skip() + + config = FinetuneLatentDiffusionConfig.load_from_dict( + dict(DEFAULT_LATENT_DICT) + ) + + trainer = LatentDiffusionTrainer(config=config) + trainer.train() + assert trainer.lda.device == test_device + assert trainer.text_encoder.device.type == test_second_device \ No newline at end of file From dba90652298db8dc39ae84682d4d335a55aa0842 Mon Sep 17 00:00:00 2001 From: Colle Date: Fri, 12 Jan 2024 18:32:22 +0100 Subject: [PATCH 02/14] fix test_debug_print Follow-up of #173 --- tests/fluxion/layers/test_chain.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/fluxion/layers/test_chain.py b/tests/fluxion/layers/test_chain.py index 5b5cd1f77..402564cc1 100644 --- a/tests/fluxion/layers/test_chain.py +++ b/tests/fluxion/layers/test_chain.py @@ -230,8 +230,8 @@ def test_setattr_dont_register() -> None: EXPECTED_TREE = ( - "(CHAIN)\n ├── Linear(in_features=1, out_features=1) (x2)\n └── (CHAIN)\n ├── Linear(in_features=1," - " out_features=1) #1\n └── Linear(in_features=2, out_features=1) #2" + "(CHAIN)\n ├── Linear(in_features=1, out_features=1, device=cpu, dtype=float32) (x2)\n └── (CHAIN)\n ├── Linear(in_features=1," + " out_features=1, device=cpu, dtype=float32) #1\n └── Linear(in_features=2, out_features=1, device=cpu, dtype=float32) #2" ) From 7f722029be15e346c0521cbe57581d2752e82388 Mon Sep 17 00:00:00 2001 From: limiteinductive Date: Sun, 14 Jan 2024 15:06:48 +0100 Subject: [PATCH 03/14] add basic unit test for training_utils --- tests/training_utils/mock_config.toml | 32 +++++ tests/training_utils/test_trainer.py | 185 ++++++++++++++++++++++++++ 2 files changed, 217 insertions(+) create mode 100644 tests/training_utils/mock_config.toml create mode 100644 tests/training_utils/test_trainer.py diff --git a/tests/training_utils/mock_config.toml b/tests/training_utils/mock_config.toml new file mode 100644 index 000000000..6064f495c --- /dev/null +++ b/tests/training_utils/mock_config.toml @@ -0,0 +1,32 @@ +[models.mock_model] +train = true + +[training] +duration = "100:epoch" +seed = 0 +gpu_index = 0 +batch_size = 4 +gradient_accumulation = "4:step" +clip_grad_norm = 1.0 +evaluation_interval = "5:epoch" +evaluation_seed = 1 + +[optimizer] +optimizer = "SGD" +learning_rate = 1 +momentum = 0.9 + +[scheduler] +scheduler_type = "ConstantLR" +update_interval = "1:step" +warmup = "20:step" + +[dropout] +dropout = 0.0 + +[checkpointing] +save_interval = "10:epoch" + +[wandb] +mode = "disabled" +project = "mock_project" \ No newline at end of file diff --git a/tests/training_utils/test_trainer.py b/tests/training_utils/test_trainer.py new file mode 100644 index 000000000..38e45d1c3 --- /dev/null +++ b/tests/training_utils/test_trainer.py @@ -0,0 +1,185 @@ +from dataclasses import dataclass +from functools import cached_property +from pathlib import Path +from warnings import warn + +import pytest +import torch +from torch import Tensor, nn +from torch.utils.data import Dataset + +from refiners.fluxion import layers as fl +from refiners.fluxion.utils import norm +from refiners.training_utils.config import BaseConfig, TimeUnit +from refiners.training_utils.trainer import ( + Trainer, + TrainingClock, + count_learnable_parameters, + human_readable_number, +) + + +@dataclass +class MockBatch: + inputs: torch.Tensor + targets: torch.Tensor + + +class MockDataset(Dataset[MockBatch]): + def __len__(self): + return 20 + + def __getitem__(self, _: int) -> MockBatch: + return MockBatch(inputs=torch.randn(1, 10), targets=torch.randn(1, 10)) + + def collate_fn(self, batch: list[MockBatch]) -> MockBatch: + return MockBatch( + inputs=torch.cat([b.inputs for b in batch]), + targets=torch.cat([b.targets for b in batch]), + ) + + +class MockConfig(BaseConfig): + pass + + +class MockModel(fl.Chain): + def __init__(self): + super().__init__( + fl.Linear(10, 10), + fl.Linear(10, 10), + fl.Linear(10, 10), + ) + + +class MockTrainer(Trainer[MockConfig, MockBatch]): + step_counter: int = 0 + + @cached_property + def mock_model(self) -> MockModel: + return MockModel() + + def load_dataset(self) -> Dataset[MockBatch]: + return MockDataset() + + def load_models(self) -> dict[str, fl.Module]: + return {"mock_model": self.mock_model} + + def compute_loss(self, batch: MockBatch) -> Tensor: + self.step_counter += 1 + inputs, targets = batch.inputs.to(self.device), batch.targets.to(self.device) + outputs = self.mock_model(inputs) + return norm(outputs - targets) + + +@pytest.fixture +def mock_config(test_device: torch.device) -> MockConfig: + if not test_device.type == "cuda": + warn("only running on CUDA, skipping") + pytest.skip("Skipping test because test_device is not CUDA") + config = MockConfig.load_from_toml(Path(__file__).parent / "mock_config.toml") + config.training.gpu_index = test_device.index + return config + + +@pytest.fixture +def mock_trainer(mock_config: MockConfig) -> MockTrainer: + return MockTrainer(config=mock_config) + + +@pytest.fixture +def mock_model() -> fl.Chain: + return MockModel() + + +def test_count_learnable_parameters_with_params() -> None: + params = [ + nn.Parameter(torch.randn(2, 2), requires_grad=True), + nn.Parameter(torch.randn(5), requires_grad=False), + nn.Parameter(torch.randn(3, 3), requires_grad=True), + ] + assert count_learnable_parameters(params) == 13 + + +def test_count_learnable_parameters_with_model(mock_model: fl.Chain) -> None: + assert count_learnable_parameters(mock_model.parameters()) == 330 + + +def test_human_readable_number() -> None: + assert human_readable_number(123) == "123.0" + assert human_readable_number(1234) == "1.2K" + assert human_readable_number(1234567) == "1.2M" + + +@pytest.fixture +def training_clock() -> TrainingClock: + return TrainingClock( + dataset_length=100, + batch_size=10, + training_duration={"number": 5, "unit": TimeUnit.EPOCH}, + gradient_accumulation={"number": 1, "unit": TimeUnit.EPOCH}, + evaluation_interval={"number": 1, "unit": TimeUnit.EPOCH}, + lr_scheduler_interval={"number": 1, "unit": TimeUnit.EPOCH}, + checkpointing_save_interval={"number": 1, "unit": TimeUnit.EPOCH}, + ) + + +def test_time_unit_to_steps_conversion(training_clock: TrainingClock) -> None: + assert training_clock.convert_time_unit_to_steps(1, TimeUnit.EPOCH) == 10 + assert training_clock.convert_time_unit_to_steps(2, TimeUnit.EPOCH) == 20 + assert training_clock.convert_time_unit_to_steps(1, TimeUnit.STEP) == 1 + + +def test_steps_to_time_unit_conversion(training_clock: TrainingClock) -> None: + assert training_clock.convert_steps_to_time_unit(10, TimeUnit.EPOCH) == 1 + assert training_clock.convert_steps_to_time_unit(20, TimeUnit.EPOCH) == 2 + assert training_clock.convert_steps_to_time_unit(1, TimeUnit.STEP) == 1 + + +def test_clock_properties(training_clock: TrainingClock) -> None: + assert training_clock.num_batches_per_epoch == 10 + assert training_clock.num_epochs == 5 + assert training_clock.num_iterations == 5 + assert training_clock.num_steps == 50 + + +def test_timer_functionality(training_clock: TrainingClock) -> None: + training_clock.start_timer() + assert training_clock.start_time is not None + training_clock.stop_timer() + assert training_clock.end_time is not None + assert training_clock.time_elapsed >= 0 + + +def test_state_based_properties(training_clock: TrainingClock) -> None: + training_clock.step = 5 # Halfway through the first epoch + assert not training_clock.is_evaluation_step # Assuming evaluation every epoch + assert not training_clock.is_checkpointing_step + training_clock.step = 10 # End of the first epoch + assert training_clock.is_evaluation_step + assert training_clock.is_checkpointing_step + + +def test_mock_trainer_initialization(mock_config: MockConfig, mock_trainer: MockTrainer) -> None: + assert mock_trainer.config == mock_config + assert isinstance(mock_trainer, MockTrainer) + assert mock_trainer.optimizer is not None + assert mock_trainer.lr_scheduler is not None + + +def test_training_cycle(mock_trainer: MockTrainer) -> None: + clock = mock_trainer.clock + config = mock_trainer.config + + assert clock.num_step_per_iteration == config.training.gradient_accumulation["number"] + assert clock.num_batches_per_epoch == mock_trainer.dataset_length // config.training.batch_size + + assert mock_trainer.step_counter == 0 + assert mock_trainer.clock.epoch == 0 + + mock_trainer.train() + + assert clock.epoch == config.training.duration["number"] + assert clock.step == config.training.duration["number"] * clock.num_batches_per_epoch + + assert mock_trainer.step_counter == mock_trainer.clock.step From d203974af2f4d0e0db680d27422f5bb141425b80 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Mon, 15 Jan 2024 14:41:38 +0100 Subject: [PATCH 04/14] experiments with model device sharding --- .../foundationals/latent_diffusion/model.py | 27 +- .../latent_diffusion/schedulers/dpm_solver.py | 41 ++- .../latent_diffusion/schedulers/scheduler.py | 4 +- .../stable_diffusion_1/model.py | 17 +- .../training_utils/accelerate_trainer.py | 48 +++ src/refiners/training_utils/color_palette.py | 306 ++++++++++++++++++ src/refiners/training_utils/config.py | 1 + src/refiners/training_utils/fabric_trainer.py | 75 ++++- .../training_utils/latent_diffusion.py | 74 +++-- .../training_utils/sharding_manager.py | 71 ++++ src/refiners/training_utils/trainer.py | 33 +- tests/training_utils/test_latent_diffusion.py | 14 +- 12 files changed, 616 insertions(+), 95 deletions(-) create mode 100644 src/refiners/training_utils/accelerate_trainer.py create mode 100644 src/refiners/training_utils/color_palette.py create mode 100644 src/refiners/training_utils/sharding_manager.py diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 22618d7f1..36e6c038c 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -20,22 +20,23 @@ def __init__( unet: fl.Module, lda: LatentDiffusionAutoencoder, clip_text_encoder: fl.Module, - scheduler: Scheduler, - device: Device | str = "cpu", - dtype: DType = torch.float32, + scheduler: Scheduler ) -> None: super().__init__() - self.device: Device = device if isinstance(device, Device) else Device(device=device) - self.dtype = dtype - self.unet = unet.to(device=self.device, dtype=self.dtype) - self.lda = lda.to(device=self.device, dtype=self.dtype) - self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype) - self.scheduler = scheduler.to(device=self.device, dtype=self.dtype) + self.unet = unet + self.lda = lda + self.clip_text_encoder = clip_text_encoder + self.scheduler = scheduler def set_num_inference_steps(self, num_inference_steps: int) -> None: initial_diffusion_rate = self.scheduler.initial_diffusion_rate final_diffusion_rate = self.scheduler.final_diffusion_rate + + # Question : + # Is there a better way to do this ? + # What is the purpose of this ? device, dtype = self.scheduler.device, self.scheduler.dtype + print(f"set_num_inference_steps device: {device}, dtype: {dtype}") self.scheduler = self.scheduler.__class__( num_inference_steps, initial_diffusion_rate=initial_diffusion_rate, @@ -51,7 +52,7 @@ def init_latents( ) -> Tensor: height, width = size if noise is None: - noise = torch.randn(1, 4, height // 8, width // 8, device=self.device) + noise = torch.randn(1, 4, height // 8, width // 8) assert list(noise.shape[2:]) == [ height // 8, width // 8, @@ -90,6 +91,7 @@ def forward( self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs) latents = torch.cat(tensors=(x, x)) # for classifier-free guidance + # scale latents for schedulers that need it latents = self.scheduler.scale_model_input(latents, step=step) unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2) @@ -102,7 +104,6 @@ def forward( noise += self.compute_self_attention_guidance( x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs ) - return self.scheduler(x, noise=noise, step=step) def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel: @@ -110,7 +111,5 @@ def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel: unet=self.unet.structural_copy(), lda=self.lda.structural_copy(), clip_text_encoder=self.clip_text_encoder.structural_copy(), - scheduler=self.scheduler, - device=self.device, - dtype=self.dtype, + scheduler=self.scheduler ) diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py index 52e706c1a..dc7d3f26b 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py @@ -49,14 +49,18 @@ def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> self.timesteps[step], self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0], ) + # Remark: + # We use noise.device as the target device + # Note: the scheduler here cannot be used with accelerate.prepare + # Cause it's not inherinting from torch.Module previous_ratio, current_ratio = ( - self.signal_to_noise_ratios[previous_timestep], - self.signal_to_noise_ratios[timestep], + self.signal_to_noise_ratios[previous_timestep].to(device=noise.device, dtype=noise.dtype), + self.signal_to_noise_ratios[timestep].to(device=noise.device, dtype=noise.dtype), ) - previous_scale_factor = self.cumulative_scale_factors[previous_timestep] + previous_scale_factor = self.cumulative_scale_factors[previous_timestep].to(device=noise.device, dtype=noise.dtype) previous_noise_std, current_noise_std = ( - self.noise_std[previous_timestep], - self.noise_std[timestep], + self.noise_std[previous_timestep].to(device=noise.device, dtype=noise.dtype), + self.noise_std[timestep].to(device=noise.device, dtype=noise.dtype), ) factor = exp(-(previous_ratio - current_ratio)) - 1.0 denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise @@ -70,14 +74,14 @@ def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tens ) current_data_estimation, next_data_estimation = self.estimated_data[-1], self.estimated_data[-2] previous_ratio, current_ratio, next_ratio = ( - self.signal_to_noise_ratios[previous_timestep], - self.signal_to_noise_ratios[current_timestep], - self.signal_to_noise_ratios[next_timestep], + self.signal_to_noise_ratios[previous_timestep].to(device=x.device, dtype=x.dtype), + self.signal_to_noise_ratios[current_timestep].to(device=x.device, dtype=x.dtype), + self.signal_to_noise_ratios[next_timestep].to(device=x.device, dtype=x.dtype), ) - previous_scale_factor = self.cumulative_scale_factors[previous_timestep] + previous_scale_factor = self.cumulative_scale_factors[previous_timestep].to(device=x.device, dtype=x.dtype) previous_std, current_std = ( - self.noise_std[previous_timestep], - self.noise_std[current_timestep], + self.noise_std[previous_timestep].to(device=x.device, dtype=x.dtype), + self.noise_std[current_timestep].to(device=x.device, dtype=x.dtype), ) estimation_delta = (current_data_estimation - next_data_estimation) / ( (current_ratio - next_ratio) / (previous_ratio - current_ratio) @@ -100,12 +104,21 @@ def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | N """ current_timestep = self.timesteps[step] scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep] - estimated_denoised_data = (x - noise_ratio * noise) / scale_factor + + # Remark: + # We use noise.device as the target device + # Note: the scheduler here cannot be used with accelerate.prepare + # Cause it's not inherinting from torch.Module + noise_ratio2 = noise_ratio.to(device=noise.device, dtype=noise.dtype) + x2 = x.to(device=noise.device, dtype=noise.dtype) + scale_factor2 = scale_factor.to(device=noise.device, dtype=noise.dtype) + + estimated_denoised_data = (x2 - noise_ratio2 * noise) / scale_factor2 self.estimated_data.append(estimated_denoised_data) denoised_x = ( - self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step) + self.dpm_solver_first_order_update(x=x2, noise=estimated_denoised_data, step=step) if (self.initial_steps == 0) - else self.multistep_dpm_solver_second_order_update(x=x, step=step) + else self.multistep_dpm_solver_second_order_update(x=x2, step=step) ) if self.initial_steps < 2: self.initial_steps += 1 diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py index f64a4cc91..f2e3341f7 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py @@ -107,8 +107,8 @@ def add_noise( step: int, ) -> Tensor: timestep = self.timesteps[step] - cumulative_scale_factors = self.cumulative_scale_factors[timestep] - noise_stds = self.noise_std[timestep] + cumulative_scale_factors = self.cumulative_scale_factors[timestep].to(device=x.device, dtype=x.dtype) + noise_stds = self.noise_std[timestep].to(device=x.device, dtype=x.dtype) noised_x = cumulative_scale_factors * x + noise_stds * noise return noised_x diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index 532c68ff1..ecaef7774 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -2,6 +2,7 @@ import torch from PIL import Image from torch import Tensor, device as Device, dtype as DType +from loguru import logger from refiners.fluxion.utils import image_to_tensor, interpolate from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL @@ -26,22 +27,17 @@ def __init__( unet: SD1UNet | None = None, lda: SD1Autoencoder | None = None, clip_text_encoder: CLIPTextEncoderL | None = None, - scheduler: Scheduler | None = None, - device: Device | str = "cpu", - dtype: DType = torch.float32, + scheduler: Scheduler | None = None ) -> None: unet = unet or SD1UNet(in_channels=4) lda = lda or SD1Autoencoder() clip_text_encoder = clip_text_encoder or CLIPTextEncoderL() scheduler = scheduler or DPMSolver(num_inference_steps=30) - super().__init__( unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, - scheduler=scheduler, - device=device, - dtype=dtype, + scheduler=scheduler ) def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor: @@ -53,8 +49,11 @@ def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Ten return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0) def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: - self.unet.set_timestep(timestep=timestep) - self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) + + # Question : + # Can we do this as part of the SetContext Logic ? + self.unet.set_timestep(timestep=timestep.to(device=self.unet.device, dtype=self.unet.dtype)) + self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding.to(device=self.unet.device, dtype=self.unet.dtype)) def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: if enable: diff --git a/src/refiners/training_utils/accelerate_trainer.py b/src/refiners/training_utils/accelerate_trainer.py new file mode 100644 index 000000000..43a245bb3 --- /dev/null +++ b/src/refiners/training_utils/accelerate_trainer.py @@ -0,0 +1,48 @@ +from accelerate import Accelerator +from .trainer import Trainer +from functools import cached_property +from torch.optim import Optimizer +from torch import Tensor, cuda +from typing import Sequence +from refiners.training_utils.config import BaseConfig +from typing import Generic, TypeVar, Any +from refiners.training_utils.callback import Callback +from torch import device as Device +from loguru import logger +from accelerate import Accelerator + +Batch = TypeVar("Batch") +ConfigType = TypeVar("ConfigType", bound=BaseConfig) + + +class AccelerateTrainer(Trainer, Generic[ConfigType, Batch]): + def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None: + self.accelerator = Accelerator() + print(self.accelerator.distributed_type) + super().__init__(config, callbacks) + + + def _backward(self, tensors: Tensor | Sequence[Tensor]): + + # Check if the input is a single tensor + if isinstance(tensors, Tensor): + tensors = [tensors] # Wrap the tensor in a list + + for tensor in tensors: + self.accelerator.backward(tensor) + + def __str__(self) -> str: + return f"Trainer : \n"+ "\n".join([f"* {self.models[model_name]}:{self.models[model_name].device}" for model_name in self.models]) + + @cached_property + def optimizer(self) -> Optimizer: + optimizer = super().optimizer + return self.accelerator.prepare(optimizer) + + def setup_model(self, model, **kwargs) -> None: + out_model = self.accelerator.prepare(model) + return out_model + + @property + def device(self) -> Device: + return self.accelerator.device \ No newline at end of file diff --git a/src/refiners/training_utils/color_palette.py b/src/refiners/training_utils/color_palette.py new file mode 100644 index 000000000..51506b796 --- /dev/null +++ b/src/refiners/training_utils/color_palette.py @@ -0,0 +1,306 @@ +import hashlib +import os +from dataclasses import dataclass +from functools import cached_property +from random import random +from typing import Any + +import requests +from loguru import logger +from PIL import Image +from pydantic import BaseModel +from torch import Tensor, cat, float32, randn, tensor, bfloat16 +from torch.utils.data import Dataset +from tqdm import tqdm + +import refiners.fluxion.layers as fl +from refiners.fluxion.adapters.color_palette import ColorPaletteEncoder, SD1ColorPaletteAdapter +from refiners.fluxion.utils import save_to_safetensors +from refiners.foundationals.latent_diffusion import ( + DPMSolver, + StableDiffusion_1, +) +from refiners.training_utils.callback import Callback +from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig +from refiners.training_utils.latent_diffusion import ( + FinetuneLatentDiffusionConfig, + LatentDiffusionConfig, + LatentDiffusionTrainer, + TestDiffusionConfig, + TextEmbeddingLatentsDataset, +) + + +class ColorPaletteConfig(BaseModel): + model_dim: int + trigger_phrase: str = "" + use_only_trigger_probability: float = 0.0 + max_colors: int + download_local: bool = True + + +class ColorPalettePromptConfig(BaseModel): + text: str + color_palette: list[list[float]] + + +class TestColorPaletteConfig(TestDiffusionConfig): + prompts: list[ColorPalettePromptConfig] = [] + + +class ColorPaletteDatasetConfig(HuggingfaceDatasetConfig): + local_folder: str = "data/color-palette" + + +@dataclass +class TextEmbeddingColorPaletteLatentsBatch: + text_embeddings: Tensor + latents: Tensor + color_palette_embeddings: Tensor + + +class ColorPaletteDataset(TextEmbeddingLatentsDataset): + def __init__( + self, + trainer: "ColorPaletteLatentDiffusionTrainer", + ) -> None: + super().__init__(trainer=trainer) + self.trigger_phrase = trainer.config.color_palette.trigger_phrase + self.use_only_trigger_probability = trainer.config.color_palette.use_only_trigger_probability + logger.info(f"Trigger phrase: {self.trigger_phrase}") + self.color_palette_encoder = trainer.color_palette_encoder + + self.local_folder = trainer.config.dataset.local_folder + + # Download images + # Question : there might be a more efficient way to do this + # I didn't find the way to do this easily with hugging face + # dataset library + if trainer.config.color_palette.download_local: + for item in tqdm(self.dataset, desc="Downloading images"): + self.download_image(item) + + def get_image_path_from_url(self, url: str) -> str: + hash_md5 = hashlib.md5() + hash_md5.update(url.encode()) + filename = hash_md5.hexdigest() + return self.local_folder + f"/{filename}" + + def download_image(self, item: dict[str, Any]) -> None: + url = item["url"] + image_path = self.get_image_path_from_url(url) + if not os.path.exists(image_path): + # download image from url + logger.info(f"Downloading image {image_path} from {url}") + response = requests.get(url) + + # Check if the request was successful + if response.status_code == 200: + # Save the image bytes to the image_path + with open(image_path, "wb") as file: + file.write(response.content) + else: + print(f"Failed to download image from {url}") + return None + + def get_caption(self, index: int) -> str: + return self.dataset[index]["ai_description"] + + def get_image(self, index: int) -> str: + url = self.dataset[index]["url"] + image_path = self.get_image_path_from_url(url) + + if not os.path.exists(image_path): + raise Exception(f"Image {image_path} does not exist") + return Image.open(image_path) + + def process_caption(self, caption: str) -> str: + caption = super().process_caption(caption=caption) + if self.trigger_phrase: + caption = ( + f"{self.trigger_phrase} {caption}" + if random() < self.use_only_trigger_probability + else self.trigger_phrase + ) + return caption + + def get_color_palette(self, index: int) -> Tensor: + # TO IMPLEMENT : use other palettes + return tensor([self.dataset[index]["palette_8"]]) + + def __getitem__(self, index: int) -> TextEmbeddingColorPaletteLatentsBatch: + caption = self.get_caption(index=index) + color_palette = self.get_color_palette(index=index) + image = self.get_image(index=index) + resized_image = self.resize_image( + image=image, + min_size=self.config.dataset.resize_image_min_size, + max_size=self.config.dataset.resize_image_max_size, + ) + processed_image = self.process_image(resized_image) + latents = self.lda.encode_image(image=processed_image) + processed_caption = self.process_caption(caption=caption) + + clip_text_embedding = self.text_encoder(processed_caption) + color_palette_embedding = self.color_palette_encoder(color_palette) + return TextEmbeddingColorPaletteLatentsBatch( + text_embeddings=clip_text_embedding, latents=latents, color_palette_embeddings=color_palette_embedding + ) + + def collate_fn(self, batch: list[TextEmbeddingColorPaletteLatentsBatch]) -> TextEmbeddingColorPaletteLatentsBatch: + text_embeddings = cat(tensors=[item.text_embeddings for item in batch]) + latents = cat(tensors=[item.latents for item in batch]) + color_palette_embeddings = cat(tensors=[item.color_palette_embeddings for item in batch]) + return TextEmbeddingColorPaletteLatentsBatch( + text_embeddings=text_embeddings, latents=latents, color_palette_embeddings=color_palette_embeddings + ) + + +class ColorPaletteLatentDiffusionConfig(FinetuneLatentDiffusionConfig): + dataset: ColorPaletteDatasetConfig + latent_diffusion: LatentDiffusionConfig + color_palette: ColorPaletteConfig + test_diffusion: TestColorPaletteConfig + + def model_post_init(self, __context: Any) -> None: + """Pydantic v2 does post init differently, so we need to override this method too.""" + logger.info("Freezing models to train only the color palette.") + self.models["text_encoder"].train = False + self.models["lda"].train = False + self.models["color_palette_encoder"].train = True + + # Question : Here I should not freeze the CrossAttentionBlock2d + # But what is the unfreeze only this block ? + self.models["unet"].train = False + + +class ColorPaletteLatentDiffusionTrainer(LatentDiffusionTrainer[ColorPaletteLatentDiffusionConfig]): + @cached_property + def color_palette_encoder(self) -> ColorPaletteEncoder: + assert ( + self.config.models["color_palette_encoder"] is not None + ), "The config must contain a color_palette_encoder entry." + + # TO FIX : connect this to unet cross attention embedding dim + EMBEDDING_DIM = 768 + + return ColorPaletteEncoder( + max_colors=self.config.color_palette.max_colors, + embedding_dim=EMBEDDING_DIM, + model_dim=self.config.color_palette.model_dim, + device=self.device, + ) + + def __init__( + self, + config: ColorPaletteLatentDiffusionConfig, + callbacks: "list[Callback[Any]] | None" = None, + ) -> None: + super().__init__(config=config, callbacks=callbacks) + self.callbacks.extend((LoadColorPalette(), SaveColorPalette())) + + def load_dataset(self) -> Dataset[TextEmbeddingColorPaletteLatentsBatch]: + return ColorPaletteDataset(trainer=self) + + def load_models(self) -> dict[str, fl.Module]: + return { + "unet": self.unet, + "text_encoder": self.text_encoder, + "lda": self.lda, + "color_palette_encoder": self.color_palette_encoder, + } + + def set_adapter(self, adapter) -> None: + self.adapter = adapter + + def compute_loss(self, batch: TextEmbeddingColorPaletteLatentsBatch) -> Tensor: + text_embeddings, latents, color_palette_embeddings = ( + batch.text_embeddings, + batch.latents, + batch.color_palette_embeddings, + ) + timestep = self.sample_timestep() + noise = self.sample_noise(size=latents.shape, dtype=latents.dtype) + noisy_latents = self.ddpm_scheduler.add_noise(x=latents, noise=noise, step=self.current_step) + + self.unet.set_timestep(timestep=timestep) + + clip_text_embedding = cat([text_embeddings, color_palette_embeddings], dim=1) + + # Used to run training on 2 parallel GPUs + # self.unet.to(device=1) + # clip_text_embedding.to(device=1) + # noisy_latents.to(device=1) + # noise.to(device=1) + + self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) + prediction = self.unet(noisy_latents) + loss = mse_loss(input=prediction, target=noise) + return loss + + def compute_evaluation(self) -> None: + sd = StableDiffusion_1( + unet=self.unet, + lda=self.lda, + clip_text_encoder=self.text_encoder, + scheduler=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps), + device=self.device, + dtype=self.dtype + ) + prompts = self.config.test_diffusion.prompts + num_images_per_prompt = self.config.test_diffusion.num_images_per_prompt + if self.config.test_diffusion.use_short_prompts: + prompts = [prompt.split(sep=",")[0] for prompt in prompts] + images: dict[str, WandbLoggable] = {} + for prompt in prompts: + canvas_image: Image.Image = Image.new(mode="RGB", size=(512, 512 * num_images_per_prompt)) + image_name = prompt.text + str(prompt.color_palette) + for i in range(num_images_per_prompt): + logger.info( + f"Generating image {i+1}/{num_images_per_prompt} for prompt: {prompt.text} and palette {prompt.color_palette}" + ) + x = randn(1, 4, 64, 64) + + # cfg means classifier-free guidance + clip_text_embedding = sd.compute_clip_text_embedding(text=prompt.text) + cfg_color_palette_embedding = self.adapter.compute_color_palette_embedding( + tensor([prompt.color_palette]) + ) + + self.adapter.set_color_palette_embedding(cfg_color_palette_embedding) + + for step in sd.steps: + x = sd( + x, + step=step, + clip_text_embedding=clip_text_embedding, + ) + canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i)) + images[image_name] = canvas_image + self.log(data=images) + + +class LoadColorPalette(Callback[ColorPaletteLatentDiffusionTrainer]): + def on_train_begin(self, trainer: ColorPaletteLatentDiffusionTrainer) -> None: + color_palette_config = trainer.config.color_palette + adapter = SD1ColorPaletteAdapter(target=trainer.unet, color_palette_encoder=trainer.color_palette_encoder) + trainer.set_adapter(adapter) + adapter.inject() + + +class SaveColorPalette(Callback[ColorPaletteLatentDiffusionTrainer]): + def on_checkpoint_save(self, trainer: ColorPaletteLatentDiffusionTrainer) -> None: + tensors: dict[str, Tensor] = {} + metadata: dict[str, str] = {} + + model = trainer.unet + adapter = model.parent + + tensors = {f"unet.{i:03d}": w for i, w in enumerate(adapter.weights)} + metadata = {f"unet_targets": ",".join(adapter.sub_targets)} + + save_to_safetensors( + path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors", + tensors=tensors, + metadata=metadata, + ) \ No newline at end of file diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index c8c12c5c9..0e80b7d85 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -174,6 +174,7 @@ class ModelConfig(BaseModel): checkpoint: Path | None = None train: bool = True learning_rate: float | None = None # TODO: Implement this + gpu_index: int | None = None class GyroDropoutConfig(BaseModel): diff --git a/src/refiners/training_utils/fabric_trainer.py b/src/refiners/training_utils/fabric_trainer.py index 6242398be..59555eafa 100644 --- a/src/refiners/training_utils/fabric_trainer.py +++ b/src/refiners/training_utils/fabric_trainer.py @@ -1,19 +1,70 @@ +from lightning.fabric import Fabric from .trainer import Trainer +from functools import cached_property +from torch.optim import Optimizer +from torch import Tensor, cuda +from typing import Sequence +from refiners.training_utils.config import TrainingConfig, BaseConfig +from typing import Generic, TypeVar, Any +from refiners.training_utils.callback import Callback +from torch import device as Device +from loguru import logger -class FabricTrainer(Trainer): - @cached_property - def optimizer(self) -> Optimizer: - optimizer = super().optimizer - for model_name in self.models: - model, optimizer = fabric.setup(self.models[model_name], optimizer) - self.models[model_name] = model - return optimizer +class FabricTrainingConfig(TrainingConfig): + devices: int = 1 + +class FabricBaseConfig(BaseConfig): + training: FabricTrainingConfig + + +Batch = TypeVar("Batch") +ConfigType = TypeVar("ConfigType", bound=FabricBaseConfig) + + +class FabricTrainer(Trainer, Generic[ConfigType, Batch]): + fabric_optimizer: Optimizer = None + + def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None: + self.fabric = Fabric(strategy="fsdp") + self.fabric.launch() + self.fabric.seed_everything(42) + + super().__init__(config, callbacks) + - def _backward(self, tensors: torch.Tensor | List[torch.Tensor]): + def _backward(self, tensors: Tensor | Sequence[Tensor]): # Check if the input is a single tensor - if isinstance(input_tensor, torch.Tensor): - input_tensor = [input_tensor] # Wrap the tensor in a list + if isinstance(tensors, Tensor): + tensors = [tensors] # Wrap the tensor in a list for tensor in tensors: - fabric.backward(tensor) + self.fabric.backward(tensor) + + def __str__(self) -> str: + return f"Trainer : \n"+ "\n".join([f"* {self.models[model_name]}:{self.models[model_name].device}" for model_name in self.models]) + + @cached_property + def optimizer(self) -> Optimizer: + optimizer = super().optimizer + return fabric.setup_optimizers(optimizer) + + def prepare_model(self, model_name: str) -> None: + self.fabric.print(model_name, cuda.memory_summary()) + + model = self.fabric.setup(self.models[model_name]) + # self.fabric_optimizer = optimizer + self.models[model_name] = model + + if (checkpoint := self.config.models[model_name].checkpoint) is not None: + model.load_from_safetensors(tensors_path=checkpoint) + else: + logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.") + model.requires_grad_(requires_grad=self.config.models[model_name].train) + model.zero_grad() + + + + @cached_property + def device(self) -> Device: + raise NotImplementedError("FabricTrainer does not support this property") \ No newline at end of file diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index 397c680ed..2917456ae 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -7,7 +7,7 @@ from loguru import logger from PIL import Image from pydantic import BaseModel -from torch import Generator, Tensor, cat, device as Device, dtype as DType, randn +from torch import Generator, Tensor, cat, dtype as DType, randn from torch.nn import Module from torch.nn.functional import mse_loss from torch.utils.data import Dataset @@ -23,10 +23,10 @@ from refiners.foundationals.latent_diffusion.schedulers import DDPM from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder from refiners.training_utils.callback import Callback -from refiners.training_utils.config import BaseConfig from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, HuggingfaceDatasetConfig, load_hf_dataset from refiners.training_utils.trainer import Trainer from refiners.training_utils.wandb import WandbLoggable +from refiners.training_utils.config import BaseConfig class LatentDiffusionConfig(BaseModel): @@ -69,7 +69,6 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]): def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None: self.trainer = trainer self.config = trainer.config - self.device = self.trainer.device self.lda = self.trainer.lda self.text_encoder = self.trainer.text_encoder self.dataset = self.load_huggingface_dataset() @@ -124,9 +123,9 @@ def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch: max_size=self.config.dataset.resize_image_max_size, ) processed_image = self.process_image(resized_image) - latents = self.lda.encode_image(image=processed_image).to(device=self.device) + latents = self.lda.encode_image(image=processed_image) processed_caption = self.process_caption(caption=caption) - clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device) + clip_text_embedding = self.text_encoder(processed_caption) return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents) def collate_fn(self, batch: list[TextEmbeddingLatentsBatch]) -> TextEmbeddingLatentsBatch: @@ -142,19 +141,23 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]): @cached_property def unet(self) -> SD1UNet: assert self.config.models["unet"] is not None, "The config must contain a unet entry." - return SD1UNet(in_channels=4, device=self.device).to(device=self.device) + return SD1UNet(in_channels=4, device=self.device) @cached_property def text_encoder(self) -> CLIPTextEncoderL: assert self.config.models["text_encoder"] is not None, "The config must contain a text_encoder entry." - return CLIPTextEncoderL(device=self.device).to(device=self.device) + return CLIPTextEncoderL(device=self.device) @cached_property def lda(self) -> SD1Autoencoder: assert self.config.models["lda"] is not None, "The config must contain a lda entry." - return SD1Autoencoder(device=self.device).to(device=self.device) + lda = SD1Autoencoder(device=self.device) + #TODO: clean this up + lda._tensor_methods = ["encode", "decode"] + return lda def load_models(self) -> dict[str, fl.Module]: + return {"unet": self.unet, "text_encoder": self.text_encoder, "lda": self.lda} def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]: @@ -162,11 +165,25 @@ def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]: @cached_property def ddpm_scheduler(self) -> DDPM: - return DDPM( + ddpm_scheduler = DDPM( num_inference_steps=1000, - device=self.device, - ).to(device=self.device) + device=self.device + ) + ddpm_scheduler._tensor_methods = ["add_noise"] + ddpm_scheduler = self.sharding_manager.add_execution_hooks(ddpm_scheduler, self.device) + return ddpm_scheduler + @cached_property + def sd(self) -> StableDiffusion_1: + + scheduler = DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps) + return StableDiffusion_1( + unet=self.unet, + lda=self.lda, + clip_text_encoder=self.text_encoder, + scheduler=scheduler + ) + def sample_timestep(self) -> Tensor: random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step) self.current_step = random_step @@ -174,7 +191,7 @@ def sample_timestep(self) -> Tensor: def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor: return sample_noise( - size=size, offset_noise=self.config.latent_diffusion.offset_noise, device=self.device, dtype=dtype + size=size, offset_noise=self.config.latent_diffusion.offset_noise, dtype=dtype ) def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor: @@ -182,20 +199,21 @@ def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor: timestep = self.sample_timestep() noise = self.sample_noise(size=latents.shape, dtype=latents.dtype) noisy_latents = self.ddpm_scheduler.add_noise(x=latents, noise=noise, step=self.current_step) - self.unet.set_timestep(timestep=timestep) - self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) + + # Question : + # Can we do this as part of the SetContext Logic ? + self.unet.set_timestep(timestep=timestep.to(device=self.unet.device, dtype=self.unet.dtype)) + self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding.to(device=self.unet.device, dtype=self.unet.dtype)) + prediction = self.unet(noisy_latents) - loss = mse_loss(input=prediction, target=noise) + + # Question : + # Can we move this mse_loss device alignement outside of the compute_loss ? + loss = mse_loss(input=prediction, target=noise.to(device=prediction.device)) return loss - def compute_evaluation(self) -> None: - sd = StableDiffusion_1( - unet=self.unet, - lda=self.lda, - clip_text_encoder=self.text_encoder, - scheduler=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps), - device=self.device, - ) + def compute_evaluation(self) -> None: + sd = self.sd prompts = self.config.test_diffusion.prompts num_images_per_prompt = self.config.test_diffusion.num_images_per_prompt if self.config.test_diffusion.use_short_prompts: @@ -205,8 +223,8 @@ def compute_evaluation(self) -> None: canvas_image: Image.Image = Image.new(mode="RGB", size=(512, 512 * num_images_per_prompt)) for i in range(num_images_per_prompt): logger.info(f"Generating image {i+1}/{num_images_per_prompt} for prompt: {prompt}") - x = randn(1, 4, 64, 64, device=self.device) - clip_text_embedding = sd.compute_clip_text_embedding(text=prompt).to(device=self.device) + x = randn(1, 4, 64, 64) + clip_text_embedding = sd.compute_clip_text_embedding(text=prompt) for step in sd.steps: x = sd( x, @@ -221,7 +239,6 @@ def compute_evaluation(self) -> None: def sample_noise( size: tuple[int, ...], offset_noise: float = 0.1, - device: Device | str = "cpu", dtype: DType | None = None, generator: Generator | None = None, ) -> Tensor: @@ -230,9 +247,8 @@ def sample_noise( If `offset_noise` is more than 0, the noise will be offset by a small amount. It allows the model to generate images with a wider range of contrast https://www.crosslabs.org/blog/diffusion-with-offset-noise. """ - device = Device(device) - noise = randn(*size, generator=generator, device=device, dtype=dtype) - return noise + offset_noise * randn(*size[:2], 1, 1, generator=generator, device=device, dtype=dtype) + noise = randn(*size, generator=generator, dtype=dtype) + return noise + offset_noise * randn(*size[:2], 1, 1, generator=generator, dtype=dtype) def resize_image(image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image: diff --git a/src/refiners/training_utils/sharding_manager.py b/src/refiners/training_utils/sharding_manager.py new file mode 100644 index 000000000..c82188ee4 --- /dev/null +++ b/src/refiners/training_utils/sharding_manager.py @@ -0,0 +1,71 @@ +from .config import ModelConfig, TrainingConfig +from torch.nn import Module +from torch import Tensor, device as Device +from torch.autograd import backward +from abc import ABC, abstractmethod +from functools import cached_property, partial, update_wrapper + +class ShardingManager(ABC): + @abstractmethod + def backward(self, tensor: Tensor) -> None: + ... + + @abstractmethod + def setup_model(self, model: Module, config: ModelConfig): + ... + + @property + @abstractmethod + def device(self) -> Device: + raise NotImplementedError("FabricTrainer does not support this property") + +class SimpleShardingManager(ShardingManager): + def __init__(self, config: TrainingConfig) -> None: + self.default_device = config.gpu_index if config.gpu_index is not None else "cpu" + + def backward(self, tensor: Tensor): + backward(tensor) + + def setup_model(self, model: Module, config: ModelConfig) -> Module: + + if config.gpu_index is not None: + device = f"cuda:{config.gpu_index}" + else: + device = self.default_device + model = model.to(device=device) + model = self.add_execution_hooks(model, device) + return model + + # inspired from https://github.com/huggingface/accelerate/blob/6f05bbd41a179cc9a86238c7c6f3f4eded70fbd8/src/accelerate/hooks.py#L159C1-L170C18 + def add_execution_hooks(self, module: Module, device: Device) -> None: + + if(hasattr(module, "_tensor_methods") is False): + method_list = ["forward"] + else: + method_list = module._tensor_methods + + for method_name in method_list: + module = self.add_execution_hook(module, device, method_name) + return module + + def add_execution_hook(self, module: Module, device: Device, method_name: str) -> None: + + old_method = getattr(module, method_name) + + def new_method(module, *args, **kwargs): + args = [arg.to(device) if hasattr(arg, "to") else arg for arg in args] + kwargs = {k: v.to(device) if hasattr(v, "to") else v for k, v in kwargs.items()} + output = old_method(*args, **kwargs) + return output + + new_method = update_wrapper( + partial(new_method, module), + old_method + ) + + setattr(module, method_name, new_method) + return module + + @property + def device(self) -> Device: + return self.default_device diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index eeb411155..900f0cdcf 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -7,9 +7,8 @@ import numpy as np from loguru import logger -from torch import Tensor, cuda, device as Device, get_rng_state, set_rng_state, stack -from torch.autograd import backward -from torch.nn import Parameter +from torch import Tensor, cuda, float32, dtype as Dtype, device as Device, get_rng_state, set_rng_state, stack +from torch.nn import Parameter, Module from torch.optim import Optimizer from torch.optim.lr_scheduler import ( CosineAnnealingLR, @@ -36,10 +35,12 @@ GradientValueClipping, MonitorLoss, ) -from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue +from refiners.training_utils.config import ModelConfig, BaseConfig, SchedulerType, TimeUnit, TimeValue from refiners.training_utils.dropout import DropoutCallback from refiners.training_utils.wandb import WandbLoggable, WandbLogger +from .sharding_manager import SimpleShardingManager, ShardingManager + __all__ = ["seed_everything", "scoped_seed", "Trainer"] @@ -298,9 +299,10 @@ def default_callbacks(self) -> list[Callback[Any]]: @cached_property def device(self) -> Device: - selected_device = Device(device=f"cuda:{self.config.training.gpu_index}") - logger.info(f"Using device: {selected_device}") - return selected_device + return self.sharding_manager.device + # selected_device = Device(device=f"cuda:{self.config.training.gpu_index}") + # logger.info(f"Using device: {selected_device}") + # return selected_device @property def parameters(self) -> list[Parameter]: @@ -394,6 +396,11 @@ def lr_scheduler(self) -> LRScheduler: ) return lr_scheduler + + @cached_property + def sharding_manager(self) -> ShardingManager: + # TODO : implement accelerate and fabric sharding manager + return SimpleShardingManager(self.config.training) @cached_property def models(self) -> dict[str, fl.Module]: @@ -421,9 +428,12 @@ def prepare_model(self, model_name: str) -> None: else: logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.") model.requires_grad_(requires_grad=self.config.models[model_name].train) - model.to(self.device) model.zero_grad() - + self.sharding_manager.setup_model( + model=model, + config=self.config.models[model_name] + ) + def prepare_models(self) -> None: assert self.models, "No models found." for model_name in self.models: @@ -480,15 +490,12 @@ def compute_loss(self, batch: Batch) -> Tensor: def compute_evaluation(self) -> None: pass - - def _backward(self, tensors) -> None: - backward(tensors=tensors) def backward(self) -> None: """Backward pass on the loss.""" self._call_callbacks(event_name="on_backward_begin") scaled_loss = self.loss / self.clock.num_step_per_iteration - self._backward(scaled_loss) + self.sharding_manager.backward(scaled_loss) self._call_callbacks(event_name="on_backward_end") if self.clock.is_optimizer_step: self._call_callbacks(event_name="on_optimizer_step_begin") diff --git a/tests/training_utils/test_latent_diffusion.py b/tests/training_utils/test_latent_diffusion.py index 5cf2955bb..1ec760e32 100644 --- a/tests/training_utils/test_latent_diffusion.py +++ b/tests/training_utils/test_latent_diffusion.py @@ -31,14 +31,17 @@ lda = dict( checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors", train = False, + gpu_index= 0 ), text_encoder = dict( checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", train = True, + gpu_index= 0 ), unet= dict( checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors", - train = True, + train = False, + gpu_index= 1 ), ), training = dict( @@ -52,6 +55,12 @@ ) ) +from lightning import Fabric +from refiners.foundationals.latent_diffusion import ( + SD1UNet +) +from lightning.fabric.strategies import FSDPStrategy +from accelerate import Accelerator, DistributedType def test_ldm_trainer_text_encoder_on_two_devices(test_device: Device, test_second_device: Device): @@ -66,8 +75,9 @@ def test_ldm_trainer_text_encoder_on_two_devices(test_device: Device, test_secon config = FinetuneLatentDiffusionConfig.load_from_dict( dict(DEFAULT_LATENT_DICT) ) - + trainer = LatentDiffusionTrainer(config=config) trainer.train() + assert trainer.lda.device == test_device assert trainer.text_encoder.device.type == test_second_device \ No newline at end of file From 017099c9ce87db22649730d99ccbab8243e1bed0 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Mon, 15 Jan 2024 17:22:04 +0100 Subject: [PATCH 05/14] hook scheduler --- .../foundationals/latent_diffusion/model.py | 11 ++- .../latent_diffusion/schedulers/dpm_solver.py | 49 ++++++------- .../training_utils/accelerate_trainer.py | 48 ------------- src/refiners/training_utils/fabric_trainer.py | 70 ------------------- .../training_utils/latent_diffusion.py | 10 ++- 5 files changed, 34 insertions(+), 154 deletions(-) delete mode 100644 src/refiners/training_utils/accelerate_trainer.py delete mode 100644 src/refiners/training_utils/fabric_trainer.py diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 36e6c038c..5b0b1d14b 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -32,16 +32,13 @@ def set_num_inference_steps(self, num_inference_steps: int) -> None: initial_diffusion_rate = self.scheduler.initial_diffusion_rate final_diffusion_rate = self.scheduler.final_diffusion_rate - # Question : - # Is there a better way to do this ? - # What is the purpose of this ? - device, dtype = self.scheduler.device, self.scheduler.dtype - print(f"set_num_inference_steps device: {device}, dtype: {dtype}") - self.scheduler = self.scheduler.__class__( + scheduler = self.scheduler.__class__( num_inference_steps, initial_diffusion_rate=initial_diffusion_rate, final_diffusion_rate=final_diffusion_rate, - ).to(device=device, dtype=dtype) + ) + + self.scheduler = scheduler def init_latents( self, diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py index dc7d3f26b..52af1f708 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py @@ -49,18 +49,14 @@ def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> self.timesteps[step], self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0], ) - # Remark: - # We use noise.device as the target device - # Note: the scheduler here cannot be used with accelerate.prepare - # Cause it's not inherinting from torch.Module previous_ratio, current_ratio = ( - self.signal_to_noise_ratios[previous_timestep].to(device=noise.device, dtype=noise.dtype), - self.signal_to_noise_ratios[timestep].to(device=noise.device, dtype=noise.dtype), + self.signal_to_noise_ratios[previous_timestep], + self.signal_to_noise_ratios[timestep], ) - previous_scale_factor = self.cumulative_scale_factors[previous_timestep].to(device=noise.device, dtype=noise.dtype) + previous_scale_factor = self.cumulative_scale_factors[previous_timestep] previous_noise_std, current_noise_std = ( - self.noise_std[previous_timestep].to(device=noise.device, dtype=noise.dtype), - self.noise_std[timestep].to(device=noise.device, dtype=noise.dtype), + self.noise_std[previous_timestep], + self.noise_std[timestep], ) factor = exp(-(previous_ratio - current_ratio)) - 1.0 denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise @@ -74,14 +70,14 @@ def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tens ) current_data_estimation, next_data_estimation = self.estimated_data[-1], self.estimated_data[-2] previous_ratio, current_ratio, next_ratio = ( - self.signal_to_noise_ratios[previous_timestep].to(device=x.device, dtype=x.dtype), - self.signal_to_noise_ratios[current_timestep].to(device=x.device, dtype=x.dtype), - self.signal_to_noise_ratios[next_timestep].to(device=x.device, dtype=x.dtype), + self.signal_to_noise_ratios[previous_timestep], + self.signal_to_noise_ratios[current_timestep], + self.signal_to_noise_ratios[next_timestep], ) - previous_scale_factor = self.cumulative_scale_factors[previous_timestep].to(device=x.device, dtype=x.dtype) + previous_scale_factor = self.cumulative_scale_factors[previous_timestep] previous_std, current_std = ( - self.noise_std[previous_timestep].to(device=x.device, dtype=x.dtype), - self.noise_std[current_timestep].to(device=x.device, dtype=x.dtype), + self.noise_std[previous_timestep], + self.noise_std[current_timestep], ) estimation_delta = (current_data_estimation - next_data_estimation) / ( (current_ratio - next_ratio) / (previous_ratio - current_ratio) @@ -93,8 +89,15 @@ def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tens - 0.5 * (factor * previous_scale_factor) * estimation_delta ) return denoised_x - + def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: + # We pass forward to give the ability to + # dynamically change the behavior of the solver + # using the sharding_manager + # TODO: change the Scheduler abstract class + return self.forward(x, noise, step, generator) + + def forward(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: """ Represents one step of the backward diffusion process that iteratively denoises the input data `x`. @@ -105,20 +108,12 @@ def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | N current_timestep = self.timesteps[step] scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep] - # Remark: - # We use noise.device as the target device - # Note: the scheduler here cannot be used with accelerate.prepare - # Cause it's not inherinting from torch.Module - noise_ratio2 = noise_ratio.to(device=noise.device, dtype=noise.dtype) - x2 = x.to(device=noise.device, dtype=noise.dtype) - scale_factor2 = scale_factor.to(device=noise.device, dtype=noise.dtype) - - estimated_denoised_data = (x2 - noise_ratio2 * noise) / scale_factor2 + estimated_denoised_data = (x - noise_ratio * noise) / scale_factor self.estimated_data.append(estimated_denoised_data) denoised_x = ( - self.dpm_solver_first_order_update(x=x2, noise=estimated_denoised_data, step=step) + self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step) if (self.initial_steps == 0) - else self.multistep_dpm_solver_second_order_update(x=x2, step=step) + else self.multistep_dpm_solver_second_order_update(x=x, step=step) ) if self.initial_steps < 2: self.initial_steps += 1 diff --git a/src/refiners/training_utils/accelerate_trainer.py b/src/refiners/training_utils/accelerate_trainer.py deleted file mode 100644 index 43a245bb3..000000000 --- a/src/refiners/training_utils/accelerate_trainer.py +++ /dev/null @@ -1,48 +0,0 @@ -from accelerate import Accelerator -from .trainer import Trainer -from functools import cached_property -from torch.optim import Optimizer -from torch import Tensor, cuda -from typing import Sequence -from refiners.training_utils.config import BaseConfig -from typing import Generic, TypeVar, Any -from refiners.training_utils.callback import Callback -from torch import device as Device -from loguru import logger -from accelerate import Accelerator - -Batch = TypeVar("Batch") -ConfigType = TypeVar("ConfigType", bound=BaseConfig) - - -class AccelerateTrainer(Trainer, Generic[ConfigType, Batch]): - def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None: - self.accelerator = Accelerator() - print(self.accelerator.distributed_type) - super().__init__(config, callbacks) - - - def _backward(self, tensors: Tensor | Sequence[Tensor]): - - # Check if the input is a single tensor - if isinstance(tensors, Tensor): - tensors = [tensors] # Wrap the tensor in a list - - for tensor in tensors: - self.accelerator.backward(tensor) - - def __str__(self) -> str: - return f"Trainer : \n"+ "\n".join([f"* {self.models[model_name]}:{self.models[model_name].device}" for model_name in self.models]) - - @cached_property - def optimizer(self) -> Optimizer: - optimizer = super().optimizer - return self.accelerator.prepare(optimizer) - - def setup_model(self, model, **kwargs) -> None: - out_model = self.accelerator.prepare(model) - return out_model - - @property - def device(self) -> Device: - return self.accelerator.device \ No newline at end of file diff --git a/src/refiners/training_utils/fabric_trainer.py b/src/refiners/training_utils/fabric_trainer.py deleted file mode 100644 index 59555eafa..000000000 --- a/src/refiners/training_utils/fabric_trainer.py +++ /dev/null @@ -1,70 +0,0 @@ -from lightning.fabric import Fabric -from .trainer import Trainer -from functools import cached_property -from torch.optim import Optimizer -from torch import Tensor, cuda -from typing import Sequence -from refiners.training_utils.config import TrainingConfig, BaseConfig -from typing import Generic, TypeVar, Any -from refiners.training_utils.callback import Callback -from torch import device as Device -from loguru import logger - -class FabricTrainingConfig(TrainingConfig): - devices: int = 1 - -class FabricBaseConfig(BaseConfig): - training: FabricTrainingConfig - - -Batch = TypeVar("Batch") -ConfigType = TypeVar("ConfigType", bound=FabricBaseConfig) - - -class FabricTrainer(Trainer, Generic[ConfigType, Batch]): - fabric_optimizer: Optimizer = None - - def __init__(self, config: ConfigType, callbacks: list[Callback[Any]] | None = None) -> None: - self.fabric = Fabric(strategy="fsdp") - self.fabric.launch() - self.fabric.seed_everything(42) - - super().__init__(config, callbacks) - - - def _backward(self, tensors: Tensor | Sequence[Tensor]): - - # Check if the input is a single tensor - if isinstance(tensors, Tensor): - tensors = [tensors] # Wrap the tensor in a list - - for tensor in tensors: - self.fabric.backward(tensor) - - def __str__(self) -> str: - return f"Trainer : \n"+ "\n".join([f"* {self.models[model_name]}:{self.models[model_name].device}" for model_name in self.models]) - - @cached_property - def optimizer(self) -> Optimizer: - optimizer = super().optimizer - return fabric.setup_optimizers(optimizer) - - def prepare_model(self, model_name: str) -> None: - self.fabric.print(model_name, cuda.memory_summary()) - - model = self.fabric.setup(self.models[model_name]) - # self.fabric_optimizer = optimizer - self.models[model_name] = model - - if (checkpoint := self.config.models[model_name].checkpoint) is not None: - model.load_from_safetensors(tensors_path=checkpoint) - else: - logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.") - model.requires_grad_(requires_grad=self.config.models[model_name].train) - model.zero_grad() - - - - @cached_property - def device(self) -> Device: - raise NotImplementedError("FabricTrainer does not support this property") \ No newline at end of file diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index 2917456ae..5afa72ff6 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -170,13 +170,19 @@ def ddpm_scheduler(self) -> DDPM: device=self.device ) ddpm_scheduler._tensor_methods = ["add_noise"] - ddpm_scheduler = self.sharding_manager.add_execution_hooks(ddpm_scheduler, self.device) + self.sharding_manager.add_execution_hooks(ddpm_scheduler, self.device) return ddpm_scheduler @cached_property def sd(self) -> StableDiffusion_1: - scheduler = DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps) + scheduler = DPMSolver( + device=self.sharding_manager.default_device, + num_inference_steps=self.config.test_diffusion.num_inference_steps + ) + + scheduler = self.sharding_manager.add_execution_hooks(scheduler, scheduler.device) + return StableDiffusion_1( unet=self.unet, lda=self.lda, From d999fd404847281921df6802b3b0ef598b4e2d87 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Mon, 15 Jan 2024 17:27:29 +0100 Subject: [PATCH 06/14] clean --- .../foundationals/latent_diffusion/model.py | 12 +- .../latent_diffusion/schedulers/dpm_solver.py | 4 +- .../stable_diffusion_1/model.py | 14 +- src/refiners/training_utils/color_palette.py | 306 ------------------ src/refiners/training_utils/config.py | 2 + .../training_utils/latent_diffusion.py | 48 ++- .../training_utils/sharding_manager.py | 36 +-- src/refiners/training_utils/trainer.py | 11 +- tests/conftest.py | 3 + tests/training_utils/test_latent_diffusion.py | 82 ++--- 10 files changed, 80 insertions(+), 438 deletions(-) delete mode 100644 src/refiners/training_utils/color_palette.py diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 5b0b1d14b..cbb3a1b5e 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -16,11 +16,7 @@ class LatentDiffusionModel(fl.Module, ABC): def __init__( - self, - unet: fl.Module, - lda: LatentDiffusionAutoencoder, - clip_text_encoder: fl.Module, - scheduler: Scheduler + self, unet: fl.Module, lda: LatentDiffusionAutoencoder, clip_text_encoder: fl.Module, scheduler: Scheduler ) -> None: super().__init__() self.unet = unet @@ -31,13 +27,13 @@ def __init__( def set_num_inference_steps(self, num_inference_steps: int) -> None: initial_diffusion_rate = self.scheduler.initial_diffusion_rate final_diffusion_rate = self.scheduler.final_diffusion_rate - + print(f"Setting num_inference_steps to {num_inference_steps}") scheduler = self.scheduler.__class__( num_inference_steps, initial_diffusion_rate=initial_diffusion_rate, final_diffusion_rate=final_diffusion_rate, ) - + self.scheduler = scheduler def init_latents( @@ -108,5 +104,5 @@ def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel: unet=self.unet.structural_copy(), lda=self.lda.structural_copy(), clip_text_encoder=self.clip_text_encoder.structural_copy(), - scheduler=self.scheduler + scheduler=self.scheduler, ) diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py index 52af1f708..67b033581 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/dpm_solver.py @@ -89,7 +89,7 @@ def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tens - 0.5 * (factor * previous_scale_factor) * estimation_delta ) return denoised_x - + def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor: # We pass forward to give the ability to # dynamically change the behavior of the solver @@ -107,7 +107,7 @@ def forward(self, x: Tensor, noise: Tensor, step: int, generator: Generator | No """ current_timestep = self.timesteps[step] scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep] - + estimated_denoised_data = (x - noise_ratio * noise) / scale_factor self.estimated_data.append(estimated_denoised_data) denoised_x = ( diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index ecaef7774..284b82a57 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -27,18 +27,13 @@ def __init__( unet: SD1UNet | None = None, lda: SD1Autoencoder | None = None, clip_text_encoder: CLIPTextEncoderL | None = None, - scheduler: Scheduler | None = None + scheduler: Scheduler | None = None, ) -> None: unet = unet or SD1UNet(in_channels=4) lda = lda or SD1Autoencoder() clip_text_encoder = clip_text_encoder or CLIPTextEncoderL() scheduler = scheduler or DPMSolver(num_inference_steps=30) - super().__init__( - unet=unet, - lda=lda, - clip_text_encoder=clip_text_encoder, - scheduler=scheduler - ) + super().__init__(unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler) def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor: conditional_embedding = self.clip_text_encoder(text) @@ -49,11 +44,12 @@ def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Ten return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0) def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: - # Question : # Can we do this as part of the SetContext Logic ? self.unet.set_timestep(timestep=timestep.to(device=self.unet.device, dtype=self.unet.dtype)) - self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding.to(device=self.unet.device, dtype=self.unet.dtype)) + self.unet.set_clip_text_embedding( + clip_text_embedding=clip_text_embedding.to(device=self.unet.device, dtype=self.unet.dtype) + ) def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: if enable: diff --git a/src/refiners/training_utils/color_palette.py b/src/refiners/training_utils/color_palette.py deleted file mode 100644 index 51506b796..000000000 --- a/src/refiners/training_utils/color_palette.py +++ /dev/null @@ -1,306 +0,0 @@ -import hashlib -import os -from dataclasses import dataclass -from functools import cached_property -from random import random -from typing import Any - -import requests -from loguru import logger -from PIL import Image -from pydantic import BaseModel -from torch import Tensor, cat, float32, randn, tensor, bfloat16 -from torch.utils.data import Dataset -from tqdm import tqdm - -import refiners.fluxion.layers as fl -from refiners.fluxion.adapters.color_palette import ColorPaletteEncoder, SD1ColorPaletteAdapter -from refiners.fluxion.utils import save_to_safetensors -from refiners.foundationals.latent_diffusion import ( - DPMSolver, - StableDiffusion_1, -) -from refiners.training_utils.callback import Callback -from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig -from refiners.training_utils.latent_diffusion import ( - FinetuneLatentDiffusionConfig, - LatentDiffusionConfig, - LatentDiffusionTrainer, - TestDiffusionConfig, - TextEmbeddingLatentsDataset, -) - - -class ColorPaletteConfig(BaseModel): - model_dim: int - trigger_phrase: str = "" - use_only_trigger_probability: float = 0.0 - max_colors: int - download_local: bool = True - - -class ColorPalettePromptConfig(BaseModel): - text: str - color_palette: list[list[float]] - - -class TestColorPaletteConfig(TestDiffusionConfig): - prompts: list[ColorPalettePromptConfig] = [] - - -class ColorPaletteDatasetConfig(HuggingfaceDatasetConfig): - local_folder: str = "data/color-palette" - - -@dataclass -class TextEmbeddingColorPaletteLatentsBatch: - text_embeddings: Tensor - latents: Tensor - color_palette_embeddings: Tensor - - -class ColorPaletteDataset(TextEmbeddingLatentsDataset): - def __init__( - self, - trainer: "ColorPaletteLatentDiffusionTrainer", - ) -> None: - super().__init__(trainer=trainer) - self.trigger_phrase = trainer.config.color_palette.trigger_phrase - self.use_only_trigger_probability = trainer.config.color_palette.use_only_trigger_probability - logger.info(f"Trigger phrase: {self.trigger_phrase}") - self.color_palette_encoder = trainer.color_palette_encoder - - self.local_folder = trainer.config.dataset.local_folder - - # Download images - # Question : there might be a more efficient way to do this - # I didn't find the way to do this easily with hugging face - # dataset library - if trainer.config.color_palette.download_local: - for item in tqdm(self.dataset, desc="Downloading images"): - self.download_image(item) - - def get_image_path_from_url(self, url: str) -> str: - hash_md5 = hashlib.md5() - hash_md5.update(url.encode()) - filename = hash_md5.hexdigest() - return self.local_folder + f"/{filename}" - - def download_image(self, item: dict[str, Any]) -> None: - url = item["url"] - image_path = self.get_image_path_from_url(url) - if not os.path.exists(image_path): - # download image from url - logger.info(f"Downloading image {image_path} from {url}") - response = requests.get(url) - - # Check if the request was successful - if response.status_code == 200: - # Save the image bytes to the image_path - with open(image_path, "wb") as file: - file.write(response.content) - else: - print(f"Failed to download image from {url}") - return None - - def get_caption(self, index: int) -> str: - return self.dataset[index]["ai_description"] - - def get_image(self, index: int) -> str: - url = self.dataset[index]["url"] - image_path = self.get_image_path_from_url(url) - - if not os.path.exists(image_path): - raise Exception(f"Image {image_path} does not exist") - return Image.open(image_path) - - def process_caption(self, caption: str) -> str: - caption = super().process_caption(caption=caption) - if self.trigger_phrase: - caption = ( - f"{self.trigger_phrase} {caption}" - if random() < self.use_only_trigger_probability - else self.trigger_phrase - ) - return caption - - def get_color_palette(self, index: int) -> Tensor: - # TO IMPLEMENT : use other palettes - return tensor([self.dataset[index]["palette_8"]]) - - def __getitem__(self, index: int) -> TextEmbeddingColorPaletteLatentsBatch: - caption = self.get_caption(index=index) - color_palette = self.get_color_palette(index=index) - image = self.get_image(index=index) - resized_image = self.resize_image( - image=image, - min_size=self.config.dataset.resize_image_min_size, - max_size=self.config.dataset.resize_image_max_size, - ) - processed_image = self.process_image(resized_image) - latents = self.lda.encode_image(image=processed_image) - processed_caption = self.process_caption(caption=caption) - - clip_text_embedding = self.text_encoder(processed_caption) - color_palette_embedding = self.color_palette_encoder(color_palette) - return TextEmbeddingColorPaletteLatentsBatch( - text_embeddings=clip_text_embedding, latents=latents, color_palette_embeddings=color_palette_embedding - ) - - def collate_fn(self, batch: list[TextEmbeddingColorPaletteLatentsBatch]) -> TextEmbeddingColorPaletteLatentsBatch: - text_embeddings = cat(tensors=[item.text_embeddings for item in batch]) - latents = cat(tensors=[item.latents for item in batch]) - color_palette_embeddings = cat(tensors=[item.color_palette_embeddings for item in batch]) - return TextEmbeddingColorPaletteLatentsBatch( - text_embeddings=text_embeddings, latents=latents, color_palette_embeddings=color_palette_embeddings - ) - - -class ColorPaletteLatentDiffusionConfig(FinetuneLatentDiffusionConfig): - dataset: ColorPaletteDatasetConfig - latent_diffusion: LatentDiffusionConfig - color_palette: ColorPaletteConfig - test_diffusion: TestColorPaletteConfig - - def model_post_init(self, __context: Any) -> None: - """Pydantic v2 does post init differently, so we need to override this method too.""" - logger.info("Freezing models to train only the color palette.") - self.models["text_encoder"].train = False - self.models["lda"].train = False - self.models["color_palette_encoder"].train = True - - # Question : Here I should not freeze the CrossAttentionBlock2d - # But what is the unfreeze only this block ? - self.models["unet"].train = False - - -class ColorPaletteLatentDiffusionTrainer(LatentDiffusionTrainer[ColorPaletteLatentDiffusionConfig]): - @cached_property - def color_palette_encoder(self) -> ColorPaletteEncoder: - assert ( - self.config.models["color_palette_encoder"] is not None - ), "The config must contain a color_palette_encoder entry." - - # TO FIX : connect this to unet cross attention embedding dim - EMBEDDING_DIM = 768 - - return ColorPaletteEncoder( - max_colors=self.config.color_palette.max_colors, - embedding_dim=EMBEDDING_DIM, - model_dim=self.config.color_palette.model_dim, - device=self.device, - ) - - def __init__( - self, - config: ColorPaletteLatentDiffusionConfig, - callbacks: "list[Callback[Any]] | None" = None, - ) -> None: - super().__init__(config=config, callbacks=callbacks) - self.callbacks.extend((LoadColorPalette(), SaveColorPalette())) - - def load_dataset(self) -> Dataset[TextEmbeddingColorPaletteLatentsBatch]: - return ColorPaletteDataset(trainer=self) - - def load_models(self) -> dict[str, fl.Module]: - return { - "unet": self.unet, - "text_encoder": self.text_encoder, - "lda": self.lda, - "color_palette_encoder": self.color_palette_encoder, - } - - def set_adapter(self, adapter) -> None: - self.adapter = adapter - - def compute_loss(self, batch: TextEmbeddingColorPaletteLatentsBatch) -> Tensor: - text_embeddings, latents, color_palette_embeddings = ( - batch.text_embeddings, - batch.latents, - batch.color_palette_embeddings, - ) - timestep = self.sample_timestep() - noise = self.sample_noise(size=latents.shape, dtype=latents.dtype) - noisy_latents = self.ddpm_scheduler.add_noise(x=latents, noise=noise, step=self.current_step) - - self.unet.set_timestep(timestep=timestep) - - clip_text_embedding = cat([text_embeddings, color_palette_embeddings], dim=1) - - # Used to run training on 2 parallel GPUs - # self.unet.to(device=1) - # clip_text_embedding.to(device=1) - # noisy_latents.to(device=1) - # noise.to(device=1) - - self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) - prediction = self.unet(noisy_latents) - loss = mse_loss(input=prediction, target=noise) - return loss - - def compute_evaluation(self) -> None: - sd = StableDiffusion_1( - unet=self.unet, - lda=self.lda, - clip_text_encoder=self.text_encoder, - scheduler=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps), - device=self.device, - dtype=self.dtype - ) - prompts = self.config.test_diffusion.prompts - num_images_per_prompt = self.config.test_diffusion.num_images_per_prompt - if self.config.test_diffusion.use_short_prompts: - prompts = [prompt.split(sep=",")[0] for prompt in prompts] - images: dict[str, WandbLoggable] = {} - for prompt in prompts: - canvas_image: Image.Image = Image.new(mode="RGB", size=(512, 512 * num_images_per_prompt)) - image_name = prompt.text + str(prompt.color_palette) - for i in range(num_images_per_prompt): - logger.info( - f"Generating image {i+1}/{num_images_per_prompt} for prompt: {prompt.text} and palette {prompt.color_palette}" - ) - x = randn(1, 4, 64, 64) - - # cfg means classifier-free guidance - clip_text_embedding = sd.compute_clip_text_embedding(text=prompt.text) - cfg_color_palette_embedding = self.adapter.compute_color_palette_embedding( - tensor([prompt.color_palette]) - ) - - self.adapter.set_color_palette_embedding(cfg_color_palette_embedding) - - for step in sd.steps: - x = sd( - x, - step=step, - clip_text_embedding=clip_text_embedding, - ) - canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i)) - images[image_name] = canvas_image - self.log(data=images) - - -class LoadColorPalette(Callback[ColorPaletteLatentDiffusionTrainer]): - def on_train_begin(self, trainer: ColorPaletteLatentDiffusionTrainer) -> None: - color_palette_config = trainer.config.color_palette - adapter = SD1ColorPaletteAdapter(target=trainer.unet, color_palette_encoder=trainer.color_palette_encoder) - trainer.set_adapter(adapter) - adapter.inject() - - -class SaveColorPalette(Callback[ColorPaletteLatentDiffusionTrainer]): - def on_checkpoint_save(self, trainer: ColorPaletteLatentDiffusionTrainer) -> None: - tensors: dict[str, Tensor] = {} - metadata: dict[str, str] = {} - - model = trainer.unet - adapter = model.parent - - tensors = {f"unet.{i:03d}": w for i, w in enumerate(adapter.weights)} - metadata = {f"unet_targets": ",".join(adapter.sub_targets)} - - save_to_safetensors( - path=trainer.ensure_checkpoints_save_folder / f"step{trainer.clock.step}.safetensors", - tensors=tensors, - metadata=metadata, - ) \ No newline at end of file diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 0e80b7d85..158ccdbee 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -227,9 +227,11 @@ class BaseConfig(BaseModel): scheduler: SchedulerConfig dropout: DropoutConfig checkpointing: CheckpointingConfig + @classmethod def load_from_dict(cls: Type[T], config_dict: dict[str, Any]) -> T: return cls(**config_dict) + @classmethod def load_from_toml(cls: Type[T], toml_path: Path | str) -> T: with open(file=toml_path, mode="rb") as f: diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index 5afa72ff6..ff5fbb7cf 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -3,7 +3,7 @@ from functools import cached_property from typing import Any, Callable, TypedDict, TypeVar -from datasets import DownloadManager # type: ignore +from datasets import DownloadManager # type: ignore from loguru import logger from PIL import Image from pydantic import BaseModel @@ -102,14 +102,14 @@ def process_caption(self, caption: str) -> str: return caption if random.random() > self.config.latent_diffusion.unconditional_sampling_probability else "" def get_caption(self, index: int, caption_key: str) -> str: - return self.dataset[index][caption_key] # type: ignore + return self.dataset[index][caption_key] # type: ignore def get_image(self, index: int) -> Image.Image: if "image" in self.dataset[index]: return self.dataset[index]["image"] elif "url" in self.dataset[index]: - url : str = self.dataset[index]["url"] - filename : str = self.download_manager.download(url) # type: ignore + url: str = self.dataset[index]["url"] + filename: str = self.download_manager.download(url) # type: ignore return Image.open(filename) else: raise RuntimeError(f"Dataset item at index [{index}] does not contain 'image' or 'url'") @@ -152,12 +152,11 @@ def text_encoder(self) -> CLIPTextEncoderL: def lda(self) -> SD1Autoencoder: assert self.config.models["lda"] is not None, "The config must contain a lda entry." lda = SD1Autoencoder(device=self.device) - #TODO: clean this up + # TODO: clean this up lda._tensor_methods = ["encode", "decode"] return lda def load_models(self) -> dict[str, fl.Module]: - return {"unet": self.unet, "text_encoder": self.text_encoder, "lda": self.lda} def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]: @@ -165,60 +164,51 @@ def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]: @cached_property def ddpm_scheduler(self) -> DDPM: - ddpm_scheduler = DDPM( - num_inference_steps=1000, - device=self.device - ) + ddpm_scheduler = DDPM(num_inference_steps=1000, device=self.device) ddpm_scheduler._tensor_methods = ["add_noise"] self.sharding_manager.add_execution_hooks(ddpm_scheduler, self.device) return ddpm_scheduler @cached_property def sd(self) -> StableDiffusion_1: - scheduler = DPMSolver( - device=self.sharding_manager.default_device, - num_inference_steps=self.config.test_diffusion.num_inference_steps + device=self.sharding_manager.default_device, + num_inference_steps=self.config.test_diffusion.num_inference_steps, ) - + scheduler = self.sharding_manager.add_execution_hooks(scheduler, scheduler.device) - - return StableDiffusion_1( - unet=self.unet, - lda=self.lda, - clip_text_encoder=self.text_encoder, - scheduler=scheduler - ) - + + return StableDiffusion_1(unet=self.unet, lda=self.lda, clip_text_encoder=self.text_encoder, scheduler=scheduler) + def sample_timestep(self) -> Tensor: random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step) self.current_step = random_step return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0) def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor: - return sample_noise( - size=size, offset_noise=self.config.latent_diffusion.offset_noise, dtype=dtype - ) + return sample_noise(size=size, offset_noise=self.config.latent_diffusion.offset_noise, dtype=dtype) def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor: clip_text_embedding, latents = batch.text_embeddings, batch.latents timestep = self.sample_timestep() noise = self.sample_noise(size=latents.shape, dtype=latents.dtype) noisy_latents = self.ddpm_scheduler.add_noise(x=latents, noise=noise, step=self.current_step) - + # Question : # Can we do this as part of the SetContext Logic ? self.unet.set_timestep(timestep=timestep.to(device=self.unet.device, dtype=self.unet.dtype)) - self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding.to(device=self.unet.device, dtype=self.unet.dtype)) + self.unet.set_clip_text_embedding( + clip_text_embedding=clip_text_embedding.to(device=self.unet.device, dtype=self.unet.dtype) + ) prediction = self.unet(noisy_latents) - + # Question : # Can we move this mse_loss device alignement outside of the compute_loss ? loss = mse_loss(input=prediction, target=noise.to(device=prediction.device)) return loss - def compute_evaluation(self) -> None: + def compute_evaluation(self) -> None: sd = self.sd prompts = self.config.test_diffusion.prompts num_images_per_prompt = self.config.test_diffusion.num_images_per_prompt diff --git a/src/refiners/training_utils/sharding_manager.py b/src/refiners/training_utils/sharding_manager.py index c82188ee4..81b314dc8 100644 --- a/src/refiners/training_utils/sharding_manager.py +++ b/src/refiners/training_utils/sharding_manager.py @@ -5,12 +5,13 @@ from abc import ABC, abstractmethod from functools import cached_property, partial, update_wrapper + class ShardingManager(ABC): - @abstractmethod + @abstractmethod def backward(self, tensor: Tensor) -> None: ... - - @abstractmethod + + @abstractmethod def setup_model(self, model: Module, config: ModelConfig): ... @@ -19,50 +20,45 @@ def setup_model(self, model: Module, config: ModelConfig): def device(self) -> Device: raise NotImplementedError("FabricTrainer does not support this property") + class SimpleShardingManager(ShardingManager): def __init__(self, config: TrainingConfig) -> None: self.default_device = config.gpu_index if config.gpu_index is not None else "cpu" - + def backward(self, tensor: Tensor): backward(tensor) - + def setup_model(self, model: Module, config: ModelConfig) -> Module: - if config.gpu_index is not None: device = f"cuda:{config.gpu_index}" - else: + else: device = self.default_device model = model.to(device=device) model = self.add_execution_hooks(model, device) return model - + # inspired from https://github.com/huggingface/accelerate/blob/6f05bbd41a179cc9a86238c7c6f3f4eded70fbd8/src/accelerate/hooks.py#L159C1-L170C18 def add_execution_hooks(self, module: Module, device: Device) -> None: - - if(hasattr(module, "_tensor_methods") is False): + if hasattr(module, "_tensor_methods") is False: method_list = ["forward"] else: method_list = module._tensor_methods - + for method_name in method_list: module = self.add_execution_hook(module, device, method_name) return module - + def add_execution_hook(self, module: Module, device: Device, method_name: str) -> None: - old_method = getattr(module, method_name) - + def new_method(module, *args, **kwargs): args = [arg.to(device) if hasattr(arg, "to") else arg for arg in args] kwargs = {k: v.to(device) if hasattr(v, "to") else v for k, v in kwargs.items()} output = old_method(*args, **kwargs) return output - - new_method = update_wrapper( - partial(new_method, module), - old_method - ) - + + new_method = update_wrapper(partial(new_method, module), old_method) + setattr(module, method_name, new_method) return module diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 900f0cdcf..71e977de0 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -396,7 +396,7 @@ def lr_scheduler(self) -> LRScheduler: ) return lr_scheduler - + @cached_property def sharding_manager(self) -> ShardingManager: # TODO : implement accelerate and fabric sharding manager @@ -429,11 +429,8 @@ def prepare_model(self, model_name: str) -> None: logger.info(f"No checkpoint found. Initializing model `{model_name}` from scratch.") model.requires_grad_(requires_grad=self.config.models[model_name].train) model.zero_grad() - self.sharding_manager.setup_model( - model=model, - config=self.config.models[model_name] - ) - + self.sharding_manager.setup_model(model=model, config=self.config.models[model_name]) + def prepare_models(self) -> None: assert self.models, "No models found." for model_name in self.models: @@ -490,7 +487,7 @@ def compute_loss(self, batch: Batch) -> Tensor: def compute_evaluation(self) -> None: pass - + def backward(self) -> None: """Backward pass on the loss.""" self._call_callbacks(event_name="on_backward_begin") diff --git a/tests/conftest.py b/tests/conftest.py index b03deaa81..1b129a5af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,6 +6,7 @@ PARENT_PATH = Path(__file__).parent + @fixture(scope="session") def test_device() -> torch.device: test_device = os.getenv("REFINERS_TEST_DEVICE") @@ -13,6 +14,7 @@ def test_device() -> torch.device: return torch.device("cuda:0" if torch.cuda.is_available() else "cpu") return torch.device(test_device) + @fixture(scope="session") def test_second_device() -> torch.device: test_device = os.getenv("REFINERS_TEST_SECOND_DEVICE") @@ -20,6 +22,7 @@ def test_second_device() -> torch.device: return torch.device("cuda:1") return torch.device("cpu") + @fixture(scope="session") def test_weights_path() -> Path: from_env = os.getenv("REFINERS_TEST_WEIGHTS_DIR") diff --git a/tests/training_utils/test_latent_diffusion.py b/tests/training_utils/test_latent_diffusion.py index 1ec760e32..cffb3a0c7 100644 --- a/tests/training_utils/test_latent_diffusion.py +++ b/tests/training_utils/test_latent_diffusion.py @@ -4,80 +4,48 @@ import pytest DEFAULT_LATENT_DICT = dict( - script = "foo.py", - wandb = dict( - mode = "offline", - entity = "acme", - project = "test-ldm-training" + script="foo.py", + wandb=dict(mode="offline", entity="acme", project="test-ldm-training"), + latent_diffusion=dict(unconditional_sampling_probability=0.2, offset_noise=0.1), + optimizer=dict(optimizer="AdamW", learning_rate=1e-5, betas=[0.9, 0.999], eps=1e-8, weight_decay=1e-2), + scheduler=dict(), + dropout=dict(dropout_probability=0.2), + checkpointing=dict(save_interval="1:epoch"), + test_diffusion=dict( + prompts=[ + "A cute cat", + ] ), - latent_diffusion = dict( - unconditional_sampling_probability = 0.2, - offset_noise = 0.1 - ), - optimizer = dict( - optimizer = "AdamW", - learning_rate = 1e-5, - betas = [0.9, 0.999], - eps = 1e-8, - weight_decay = 1e-2 - ), - scheduler = dict(), - dropout = dict(dropout_probability = 0.2), - checkpointing=dict(save_interval = "1:epoch"), - test_diffusion=dict(prompts = [ - "A cute cat", - ]), - models = dict( - lda = dict( - checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors", - train = False, - gpu_index= 0 - ), - text_encoder = dict( - checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", - train = True, - gpu_index= 0 + models=dict( + lda=dict(checkpoint="tests/weights/stable-diffusion-1-5/lda.safetensors", train=False, gpu_index=0), + text_encoder=dict( + checkpoint="tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", train=True, gpu_index=0 ), - unet= dict( - checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors", - train = False, - gpu_index= 1 - ), - ), - training = dict( - duration= "1:epoch", - gpu_index= 0 + unet=dict(checkpoint="tests/weights/stable-diffusion-1-5/unet.safetensors", train=False, gpu_index=1), ), - dataset = dict( - hf_repo= "1aurent/unsplash-lite-palette", - revision= "main", - caption_key = "ai_description" - ) + training=dict(duration="1:epoch", gpu_index=0), + dataset=dict(hf_repo="1aurent/unsplash-lite-palette", revision="main", caption_key="ai_description"), ) from lightning import Fabric -from refiners.foundationals.latent_diffusion import ( - SD1UNet -) +from refiners.foundationals.latent_diffusion import SD1UNet from lightning.fabric.strategies import FSDPStrategy from accelerate import Accelerator, DistributedType + def test_ldm_trainer_text_encoder_on_two_devices(test_device: Device, test_second_device: Device): - if test_device.type == "cpu": warn("not running on CPU, skipping") pytest.skip() - + if test_second_device.type == "cpu": warn("Running with only one GPU, skipping") pytest.skip() - - config = FinetuneLatentDiffusionConfig.load_from_dict( - dict(DEFAULT_LATENT_DICT) - ) + + config = FinetuneLatentDiffusionConfig.load_from_dict(dict(DEFAULT_LATENT_DICT)) trainer = LatentDiffusionTrainer(config=config) trainer.train() - + assert trainer.lda.device == test_device - assert trainer.text_encoder.device.type == test_second_device \ No newline at end of file + assert trainer.text_encoder.device.type == test_second_device From d4256bd7cb05bc565164ce16f38cbfb3037bc051 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Mon, 15 Jan 2024 18:51:20 +0100 Subject: [PATCH 07/14] clean up hook mechanism --- .../foundationals/latent_diffusion/model.py | 1 + .../stable_diffusion_1/model.py | 9 +-- .../training_utils/latent_diffusion.py | 24 ++++---- .../training_utils/sharding_manager.py | 57 +++++++++++++++---- 4 files changed, 59 insertions(+), 32 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index cbb3a1b5e..62ec97673 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -28,6 +28,7 @@ def set_num_inference_steps(self, num_inference_steps: int) -> None: initial_diffusion_rate = self.scheduler.initial_diffusion_rate final_diffusion_rate = self.scheduler.final_diffusion_rate print(f"Setting num_inference_steps to {num_inference_steps}") + raise NotImplementedError scheduler = self.scheduler.__class__( num_inference_steps, initial_diffusion_rate=initial_diffusion_rate, diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index 284b82a57..53cc16ebb 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -2,7 +2,6 @@ import torch from PIL import Image from torch import Tensor, device as Device, dtype as DType -from loguru import logger from refiners.fluxion.utils import image_to_tensor, interpolate from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL @@ -44,12 +43,8 @@ def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Ten return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0) def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None: - # Question : - # Can we do this as part of the SetContext Logic ? - self.unet.set_timestep(timestep=timestep.to(device=self.unet.device, dtype=self.unet.dtype)) - self.unet.set_clip_text_embedding( - clip_text_embedding=clip_text_embedding.to(device=self.unet.device, dtype=self.unet.dtype) - ) + self.unet.set_timestep(timestep=timestep) + self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None: if enable: diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index ff5fbb7cf..b1227fdd2 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -141,6 +141,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]): @cached_property def unet(self) -> SD1UNet: assert self.config.models["unet"] is not None, "The config must contain a unet entry." + print(f"self.device: {self.device}") return SD1UNet(in_channels=4, device=self.device) @cached_property @@ -151,9 +152,7 @@ def text_encoder(self) -> CLIPTextEncoderL: @cached_property def lda(self) -> SD1Autoencoder: assert self.config.models["lda"] is not None, "The config must contain a lda entry." - lda = SD1Autoencoder(device=self.device) - # TODO: clean this up - lda._tensor_methods = ["encode", "decode"] + lda = SD1Autoencoder() return lda def load_models(self) -> dict[str, fl.Module]: @@ -165,8 +164,7 @@ def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]: @cached_property def ddpm_scheduler(self) -> DDPM: ddpm_scheduler = DDPM(num_inference_steps=1000, device=self.device) - ddpm_scheduler._tensor_methods = ["add_noise"] - self.sharding_manager.add_execution_hooks(ddpm_scheduler, self.device) + self.sharding_manager.add_execution_hook(ddpm_scheduler, self.device, "add_noise") return ddpm_scheduler @cached_property @@ -176,7 +174,7 @@ def sd(self) -> StableDiffusion_1: num_inference_steps=self.config.test_diffusion.num_inference_steps, ) - scheduler = self.sharding_manager.add_execution_hooks(scheduler, scheduler.device) + self.sharding_manager.add_execution_hooks(scheduler, scheduler.device) return StableDiffusion_1(unet=self.unet, lda=self.lda, clip_text_encoder=self.text_encoder, scheduler=scheduler) @@ -187,6 +185,10 @@ def sample_timestep(self) -> Tensor: def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor: return sample_noise(size=size, offset_noise=self.config.latent_diffusion.offset_noise, dtype=dtype) + + @cached_property + def mse_loss(self): + return self.sharding_manager.bind_input_to_device(mse_loss, self.device) def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor: clip_text_embedding, latents = batch.text_embeddings, batch.latents @@ -194,18 +196,14 @@ def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor: noise = self.sample_noise(size=latents.shape, dtype=latents.dtype) noisy_latents = self.ddpm_scheduler.add_noise(x=latents, noise=noise, step=self.current_step) - # Question : - # Can we do this as part of the SetContext Logic ? - self.unet.set_timestep(timestep=timestep.to(device=self.unet.device, dtype=self.unet.dtype)) - self.unet.set_clip_text_embedding( - clip_text_embedding=clip_text_embedding.to(device=self.unet.device, dtype=self.unet.dtype) - ) + self.unet.set_timestep(timestep=timestep) + self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) prediction = self.unet(noisy_latents) # Question : # Can we move this mse_loss device alignement outside of the compute_loss ? - loss = mse_loss(input=prediction, target=noise.to(device=prediction.device)) + loss = self.mse_loss(input=prediction, target=noise) return loss def compute_evaluation(self) -> None: diff --git a/src/refiners/training_utils/sharding_manager.py b/src/refiners/training_utils/sharding_manager.py index 81b314dc8..49b64487b 100644 --- a/src/refiners/training_utils/sharding_manager.py +++ b/src/refiners/training_utils/sharding_manager.py @@ -4,7 +4,7 @@ from torch.autograd import backward from abc import ABC, abstractmethod from functools import cached_property, partial, update_wrapper - +from typing import Any, List, Callable class ShardingManager(ABC): @abstractmethod @@ -20,6 +20,7 @@ def setup_model(self, model: Module, config: ModelConfig): def device(self) -> Device: raise NotImplementedError("FabricTrainer does not support this property") +from refiners.fluxion.context import ContextProvider class SimpleShardingManager(ShardingManager): def __init__(self, config: TrainingConfig) -> None: @@ -39,29 +40,61 @@ def setup_model(self, model: Module, config: ModelConfig) -> Module: # inspired from https://github.com/huggingface/accelerate/blob/6f05bbd41a179cc9a86238c7c6f3f4eded70fbd8/src/accelerate/hooks.py#L159C1-L170C18 def add_execution_hooks(self, module: Module, device: Device) -> None: - if hasattr(module, "_tensor_methods") is False: - method_list = ["forward"] - else: - method_list = module._tensor_methods - + method_list = [] + if hasattr(module, "forward") is True: + method_list.append("forward") + + if hasattr(module, "set_context") is True: + method_list.append("set_context") + + if hasattr(module, "encode") is True: + method_list.append("encode") + + if hasattr(module, "decode") is True: + method_list.append("decode") + for method_name in method_list: - module = self.add_execution_hook(module, device, method_name) - return module + self.add_execution_hook(module, device, method_name) + + def recursive_to(self, obj: Any, device: Device) -> Any: + if hasattr(obj, "to"): + return obj.to(device) + elif isinstance(obj, dict): + return {k: self.recursive_to(v, device) for k, v in obj.items()} + elif isinstance(obj, list): + return [self.recursive_to(v, device) for v in obj] + elif isinstance(obj, tuple): + return tuple(self.recursive_to(v, device) for v in obj) + else: + return obj + def add_execution_hook(self, module: Module, device: Device, method_name: str) -> None: + old_method = getattr(module, method_name) + new_method = self.bind_input_to_device(old_method, device) + # new_method = update_wrapper(partial(new_method, module), old_method) + def new_method(module, *args, **kwargs): - args = [arg.to(device) if hasattr(arg, "to") else arg for arg in args] - kwargs = {k: v.to(device) if hasattr(v, "to") else v for k, v in kwargs.items()} + args = self.recursive_to(args, device) + kwargs = self.recursive_to(kwargs, device) output = old_method(*args, **kwargs) return output new_method = update_wrapper(partial(new_method, module), old_method) setattr(module, method_name, new_method) - return module - + + def bind_input_to_device(self, method: Callable, device: Device) -> Callable: + def new_method(*args, **kwargs): + args = self.recursive_to(args, device) + kwargs = self.recursive_to(kwargs, device) + return method(*args, **kwargs) + + return new_method + + @property def device(self) -> Device: return self.default_device From 6e5dfa439d3d166660379d06dd37a7b2e79f8307 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Mon, 15 Jan 2024 19:02:58 +0100 Subject: [PATCH 08/14] rename wrap --- src/refiners/foundationals/latent_diffusion/model.py | 6 +++--- .../latent_diffusion/schedulers/scheduler.py | 4 ++-- src/refiners/training_utils/latent_diffusion.py | 12 +++++------- src/refiners/training_utils/sharding_manager.py | 12 ++++++------ 4 files changed, 16 insertions(+), 18 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 62ec97673..644f2f6f6 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -27,13 +27,13 @@ def __init__( def set_num_inference_steps(self, num_inference_steps: int) -> None: initial_diffusion_rate = self.scheduler.initial_diffusion_rate final_diffusion_rate = self.scheduler.final_diffusion_rate - print(f"Setting num_inference_steps to {num_inference_steps}") - raise NotImplementedError + device, dtype = self.scheduler.device, self.scheduler.dtype + scheduler = self.scheduler.__class__( num_inference_steps, initial_diffusion_rate=initial_diffusion_rate, final_diffusion_rate=final_diffusion_rate, - ) + ).to(device=device, dtype=dtype) self.scheduler = scheduler diff --git a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py index f2e3341f7..f64a4cc91 100644 --- a/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py +++ b/src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py @@ -107,8 +107,8 @@ def add_noise( step: int, ) -> Tensor: timestep = self.timesteps[step] - cumulative_scale_factors = self.cumulative_scale_factors[timestep].to(device=x.device, dtype=x.dtype) - noise_stds = self.noise_std[timestep].to(device=x.device, dtype=x.dtype) + cumulative_scale_factors = self.cumulative_scale_factors[timestep] + noise_stds = self.noise_std[timestep] noised_x = cumulative_scale_factors * x + noise_stds * noise return noised_x diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index b1227fdd2..f0599bd20 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -164,17 +164,17 @@ def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]: @cached_property def ddpm_scheduler(self) -> DDPM: ddpm_scheduler = DDPM(num_inference_steps=1000, device=self.device) - self.sharding_manager.add_execution_hook(ddpm_scheduler, self.device, "add_noise") + self.sharding_manager.add_device_hook(ddpm_scheduler, ddpm_scheduler.device, "add_noise") return ddpm_scheduler @cached_property def sd(self) -> StableDiffusion_1: scheduler = DPMSolver( - device=self.sharding_manager.default_device, - num_inference_steps=self.config.test_diffusion.num_inference_steps, + device=self.device, + num_inference_steps=self.config.test_diffusion.num_inference_steps,, ) - self.sharding_manager.add_execution_hooks(scheduler, scheduler.device) + self.sharding_manager.add_device_hooks(scheduler, scheduler.device) return StableDiffusion_1(unet=self.unet, lda=self.lda, clip_text_encoder=self.text_encoder, scheduler=scheduler) @@ -188,7 +188,7 @@ def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Ten @cached_property def mse_loss(self): - return self.sharding_manager.bind_input_to_device(mse_loss, self.device) + return self.sharding_manager.wrap_device(mse_loss, self.device) def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor: clip_text_embedding, latents = batch.text_embeddings, batch.latents @@ -201,8 +201,6 @@ def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor: prediction = self.unet(noisy_latents) - # Question : - # Can we move this mse_loss device alignement outside of the compute_loss ? loss = self.mse_loss(input=prediction, target=noise) return loss diff --git a/src/refiners/training_utils/sharding_manager.py b/src/refiners/training_utils/sharding_manager.py index 49b64487b..3ea6f13e2 100644 --- a/src/refiners/training_utils/sharding_manager.py +++ b/src/refiners/training_utils/sharding_manager.py @@ -35,11 +35,11 @@ def setup_model(self, model: Module, config: ModelConfig) -> Module: else: device = self.default_device model = model.to(device=device) - model = self.add_execution_hooks(model, device) + model = self.add_device_hooks(model, device) return model # inspired from https://github.com/huggingface/accelerate/blob/6f05bbd41a179cc9a86238c7c6f3f4eded70fbd8/src/accelerate/hooks.py#L159C1-L170C18 - def add_execution_hooks(self, module: Module, device: Device) -> None: + def add_device_hooks(self, module: Module, device: Device) -> None: method_list = [] if hasattr(module, "forward") is True: method_list.append("forward") @@ -54,7 +54,7 @@ def add_execution_hooks(self, module: Module, device: Device) -> None: method_list.append("decode") for method_name in method_list: - self.add_execution_hook(module, device, method_name) + self.add_device_hook(module, device, method_name) def recursive_to(self, obj: Any, device: Device) -> Any: @@ -69,11 +69,11 @@ def recursive_to(self, obj: Any, device: Device) -> Any: else: return obj - def add_execution_hook(self, module: Module, device: Device, method_name: str) -> None: + def add_device_hook(self, module: Module, device: Device, method_name: str) -> None: old_method = getattr(module, method_name) - new_method = self.bind_input_to_device(old_method, device) + new_method = self.wrap_device(old_method, device) # new_method = update_wrapper(partial(new_method, module), old_method) def new_method(module, *args, **kwargs): @@ -86,7 +86,7 @@ def new_method(module, *args, **kwargs): setattr(module, method_name, new_method) - def bind_input_to_device(self, method: Callable, device: Device) -> Callable: + def wrap_device(self, method: Callable, device: Device) -> Callable: def new_method(*args, **kwargs): args = self.recursive_to(args, device) kwargs = self.recursive_to(kwargs, device) From 138bb08a1830baa7d041a6b8682b95b33c0ad3e4 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Mon, 15 Jan 2024 19:49:18 +0100 Subject: [PATCH 09/14] training is running --- configs/finetune-lora.toml | 10 ++-- src/refiners/training_utils/config.py | 7 +-- .../training_utils/latent_diffusion.py | 3 +- src/refiners/training_utils/trainer.py | 6 +-- tests/training_utils/test_latent_diffusion.py | 51 ------------------- 5 files changed, 10 insertions(+), 67 deletions(-) delete mode 100644 tests/training_utils/test_latent_diffusion.py diff --git a/configs/finetune-lora.toml b/configs/finetune-lora.toml index ba786ac28..3c9ce6cbf 100644 --- a/configs/finetune-lora.toml +++ b/configs/finetune-lora.toml @@ -4,9 +4,9 @@ entity = "acme" project = "test-lora-training" [models] -unet = {checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors"} -text_encoder = {checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"} -lda = {checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors"} +unet = {checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors", gpu_index = 1} +text_encoder = {checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", gpu_index = 0} +lda = {checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors", gpu_index = 0} [latent_diffusion] unconditional_sampling_probability = 0.05 @@ -24,8 +24,8 @@ lda_targets = [] duration = "1000:epoch" seed = 0 gpu_index = 0 -batch_size = 4 -gradient_accumulation = "4:step" +batch_size = 1 +gradient_accumulation = "1:step" clip_grad_norm = 1.0 # clip_grad_value = 1.0 evaluation_interval = "5:epoch" diff --git a/src/refiners/training_utils/config.py b/src/refiners/training_utils/config.py index 158ccdbee..54a1ab776 100644 --- a/src/refiners/training_utils/config.py +++ b/src/refiners/training_utils/config.py @@ -228,13 +228,8 @@ class BaseConfig(BaseModel): dropout: DropoutConfig checkpointing: CheckpointingConfig - @classmethod - def load_from_dict(cls: Type[T], config_dict: dict[str, Any]) -> T: - return cls(**config_dict) - @classmethod def load_from_toml(cls: Type[T], toml_path: Path | str) -> T: with open(file=toml_path, mode="rb") as f: config_dict = tomli.load(f) - - return cls.load_from_dict(**config_dict) + return cls(**config_dict) diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index f0599bd20..1e087d2ec 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -141,7 +141,6 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]): @cached_property def unet(self) -> SD1UNet: assert self.config.models["unet"] is not None, "The config must contain a unet entry." - print(f"self.device: {self.device}") return SD1UNet(in_channels=4, device=self.device) @cached_property @@ -171,7 +170,7 @@ def ddpm_scheduler(self) -> DDPM: def sd(self) -> StableDiffusion_1: scheduler = DPMSolver( device=self.device, - num_inference_steps=self.config.test_diffusion.num_inference_steps,, + num_inference_steps=self.config.test_diffusion.num_inference_steps, ) self.sharding_manager.add_device_hooks(scheduler, scheduler.device) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 71e977de0..5401c0334 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -252,13 +252,13 @@ def is_checkpointing_step(self) -> bool: return self.step % self.checkpointing_save_interval_steps == 0 -def compute_grad_norm(parameters: Iterable[Parameter]) -> float: +def compute_grad_norm(parameters: Iterable[Parameter], device: Device) -> float: """ Computes the gradient norm of the parameters of a given model similar to `clip_grad_norm_` returned value. """ gradients: list[Tensor] = [p.grad.detach() for p in parameters if p.grad is not None] assert gradients, "The model has no gradients to compute the norm." - total_norm = stack(tensors=[gradient.norm() for gradient in gradients]).norm().item() # type: ignore + total_norm = stack(tensors=[gradient.norm().to(device=device) for gradient in gradients]).norm().item() # type: ignore return total_norm # type: ignore @@ -332,7 +332,7 @@ def gradients(self) -> list[Tensor]: @property def total_gradient_norm(self) -> float: """Returns the total gradient norm for all learnable parameters in all models""" - return compute_grad_norm(parameters=self.parameters) + return compute_grad_norm(parameters=self.parameters, device= self.device) @cached_property def optimizer(self) -> Optimizer: diff --git a/tests/training_utils/test_latent_diffusion.py b/tests/training_utils/test_latent_diffusion.py deleted file mode 100644 index cffb3a0c7..000000000 --- a/tests/training_utils/test_latent_diffusion.py +++ /dev/null @@ -1,51 +0,0 @@ -from refiners.training_utils.latent_diffusion import FinetuneLatentDiffusionConfig, LatentDiffusionTrainer -from torch import device as Device -from warnings import warn -import pytest - -DEFAULT_LATENT_DICT = dict( - script="foo.py", - wandb=dict(mode="offline", entity="acme", project="test-ldm-training"), - latent_diffusion=dict(unconditional_sampling_probability=0.2, offset_noise=0.1), - optimizer=dict(optimizer="AdamW", learning_rate=1e-5, betas=[0.9, 0.999], eps=1e-8, weight_decay=1e-2), - scheduler=dict(), - dropout=dict(dropout_probability=0.2), - checkpointing=dict(save_interval="1:epoch"), - test_diffusion=dict( - prompts=[ - "A cute cat", - ] - ), - models=dict( - lda=dict(checkpoint="tests/weights/stable-diffusion-1-5/lda.safetensors", train=False, gpu_index=0), - text_encoder=dict( - checkpoint="tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", train=True, gpu_index=0 - ), - unet=dict(checkpoint="tests/weights/stable-diffusion-1-5/unet.safetensors", train=False, gpu_index=1), - ), - training=dict(duration="1:epoch", gpu_index=0), - dataset=dict(hf_repo="1aurent/unsplash-lite-palette", revision="main", caption_key="ai_description"), -) - -from lightning import Fabric -from refiners.foundationals.latent_diffusion import SD1UNet -from lightning.fabric.strategies import FSDPStrategy -from accelerate import Accelerator, DistributedType - - -def test_ldm_trainer_text_encoder_on_two_devices(test_device: Device, test_second_device: Device): - if test_device.type == "cpu": - warn("not running on CPU, skipping") - pytest.skip() - - if test_second_device.type == "cpu": - warn("Running with only one GPU, skipping") - pytest.skip() - - config = FinetuneLatentDiffusionConfig.load_from_dict(dict(DEFAULT_LATENT_DICT)) - - trainer = LatentDiffusionTrainer(config=config) - trainer.train() - - assert trainer.lda.device == test_device - assert trainer.text_encoder.device.type == test_second_device From 0c73c93535b19c213e26a82135584a1d77a91585 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Mon, 15 Jan 2024 23:02:33 +0100 Subject: [PATCH 10/14] lint --- .../foundationals/latent_diffusion/model.py | 2 +- .../training_utils/latent_diffusion.py | 4 +- .../training_utils/sharding_manager.py | 38 ++++++++++--------- src/refiners/training_utils/trainer.py | 10 ++--- 4 files changed, 28 insertions(+), 26 deletions(-) diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 644f2f6f6..5ed6c9583 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -28,7 +28,7 @@ def set_num_inference_steps(self, num_inference_steps: int) -> None: initial_diffusion_rate = self.scheduler.initial_diffusion_rate final_diffusion_rate = self.scheduler.final_diffusion_rate device, dtype = self.scheduler.device, self.scheduler.dtype - + scheduler = self.scheduler.__class__( num_inference_steps, initial_diffusion_rate=initial_diffusion_rate, diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index 1e087d2ec..cbcacf156 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -23,10 +23,10 @@ from refiners.foundationals.latent_diffusion.schedulers import DDPM from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder from refiners.training_utils.callback import Callback +from refiners.training_utils.config import BaseConfig from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, HuggingfaceDatasetConfig, load_hf_dataset from refiners.training_utils.trainer import Trainer from refiners.training_utils.wandb import WandbLoggable -from refiners.training_utils.config import BaseConfig class LatentDiffusionConfig(BaseModel): @@ -184,7 +184,7 @@ def sample_timestep(self) -> Tensor: def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor: return sample_noise(size=size, offset_noise=self.config.latent_diffusion.offset_noise, dtype=dtype) - + @cached_property def mse_loss(self): return self.sharding_manager.wrap_device(mse_loss, self.device) diff --git a/src/refiners/training_utils/sharding_manager.py b/src/refiners/training_utils/sharding_manager.py index 3ea6f13e2..979b34bf4 100644 --- a/src/refiners/training_utils/sharding_manager.py +++ b/src/refiners/training_utils/sharding_manager.py @@ -1,10 +1,13 @@ -from .config import ModelConfig, TrainingConfig -from torch.nn import Module -from torch import Tensor, device as Device -from torch.autograd import backward from abc import ABC, abstractmethod from functools import cached_property, partial, update_wrapper -from typing import Any, List, Callable +from typing import Any, Callable, List + +from torch import Tensor, device as Device +from torch.autograd import backward +from torch.nn import Module + +from .config import ModelConfig, TrainingConfig + class ShardingManager(ABC): @abstractmethod @@ -20,8 +23,10 @@ def setup_model(self, model: Module, config: ModelConfig): def device(self) -> Device: raise NotImplementedError("FabricTrainer does not support this property") + from refiners.fluxion.context import ContextProvider + class SimpleShardingManager(ShardingManager): def __init__(self, config: TrainingConfig) -> None: self.default_device = config.gpu_index if config.gpu_index is not None else "cpu" @@ -43,20 +48,19 @@ def add_device_hooks(self, module: Module, device: Device) -> None: method_list = [] if hasattr(module, "forward") is True: method_list.append("forward") - + if hasattr(module, "set_context") is True: method_list.append("set_context") - + if hasattr(module, "encode") is True: method_list.append("encode") - + if hasattr(module, "decode") is True: - method_list.append("decode") - + method_list.append("decode") + for method_name in method_list: self.add_device_hook(module, device, method_name) - def recursive_to(self, obj: Any, device: Device) -> Any: if hasattr(obj, "to"): return obj.to(device) @@ -68,14 +72,13 @@ def recursive_to(self, obj: Any, device: Device) -> Any: return tuple(self.recursive_to(v, device) for v in obj) else: return obj - + def add_device_hook(self, module: Module, device: Device, method_name: str) -> None: - old_method = getattr(module, method_name) new_method = self.wrap_device(old_method, device) # new_method = update_wrapper(partial(new_method, module), old_method) - + def new_method(module, *args, **kwargs): args = self.recursive_to(args, device) kwargs = self.recursive_to(kwargs, device) @@ -85,16 +88,15 @@ def new_method(module, *args, **kwargs): new_method = update_wrapper(partial(new_method, module), old_method) setattr(module, method_name, new_method) - + def wrap_device(self, method: Callable, device: Device) -> Callable: def new_method(*args, **kwargs): args = self.recursive_to(args, device) kwargs = self.recursive_to(kwargs, device) return method(*args, **kwargs) - + return new_method - - + @property def device(self) -> Device: return self.default_device diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index 5401c0334..fca7cfa84 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -7,8 +7,8 @@ import numpy as np from loguru import logger -from torch import Tensor, cuda, float32, dtype as Dtype, device as Device, get_rng_state, set_rng_state, stack -from torch.nn import Parameter, Module +from torch import Tensor, cuda, device as Device, dtype as Dtype, float32, get_rng_state, set_rng_state, stack +from torch.nn import Module, Parameter from torch.optim import Optimizer from torch.optim.lr_scheduler import ( CosineAnnealingLR, @@ -35,11 +35,11 @@ GradientValueClipping, MonitorLoss, ) -from refiners.training_utils.config import ModelConfig, BaseConfig, SchedulerType, TimeUnit, TimeValue +from refiners.training_utils.config import BaseConfig, ModelConfig, SchedulerType, TimeUnit, TimeValue from refiners.training_utils.dropout import DropoutCallback from refiners.training_utils.wandb import WandbLoggable, WandbLogger -from .sharding_manager import SimpleShardingManager, ShardingManager +from .sharding_manager import ShardingManager, SimpleShardingManager __all__ = ["seed_everything", "scoped_seed", "Trainer"] @@ -332,7 +332,7 @@ def gradients(self) -> list[Tensor]: @property def total_gradient_norm(self) -> float: """Returns the total gradient norm for all learnable parameters in all models""" - return compute_grad_norm(parameters=self.parameters, device= self.device) + return compute_grad_norm(parameters=self.parameters, device=self.device) @cached_property def optimizer(self) -> Optimizer: From a05aaee7739aeb65456a336db9e24763a918fdd4 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Mon, 15 Jan 2024 23:02:50 +0100 Subject: [PATCH 11/14] feat: P mode in image --- src/refiners/fluxion/utils.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 47789a2eb..0efe971ef 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -121,6 +121,9 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp Values are clamped to the range `[0, 1]`. """ + if image.mode == 'P': + image = image.convert('RGB') + image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype) match image.mode: From c063e96ade517e2ff0c10b28bea32b07031a6d12 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Mon, 15 Jan 2024 23:58:29 +0100 Subject: [PATCH 12/14] lint --- src/refiners/fluxion/utils.py | 6 +- .../foundationals/latent_diffusion/model.py | 2 +- .../stable_diffusion_1/model.py | 14 ++-- .../stable_diffusion_xl/model.py | 19 ++---- .../training_utils/latent_diffusion.py | 7 +- .../training_utils/sharding_manager.py | 64 +++++++++++-------- src/refiners/training_utils/trainer.py | 6 +- 7 files changed, 60 insertions(+), 58 deletions(-) diff --git a/src/refiners/fluxion/utils.py b/src/refiners/fluxion/utils.py index 0efe971ef..485552351 100644 --- a/src/refiners/fluxion/utils.py +++ b/src/refiners/fluxion/utils.py @@ -121,9 +121,9 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp Values are clamped to the range `[0, 1]`. """ - if image.mode == 'P': - image = image.convert('RGB') - + if image.mode == "P": + image = image.convert("RGB") + image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype) match image.mode: diff --git a/src/refiners/foundationals/latent_diffusion/model.py b/src/refiners/foundationals/latent_diffusion/model.py index 5ed6c9583..ba66fe042 100644 --- a/src/refiners/foundationals/latent_diffusion/model.py +++ b/src/refiners/foundationals/latent_diffusion/model.py @@ -3,7 +3,7 @@ import torch from PIL import Image -from torch import Tensor, device as Device, dtype as DType +from torch import Tensor import refiners.fluxion.layers as fl from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py index 53cc16ebb..f7cc3acf7 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_1/model.py @@ -27,11 +27,13 @@ def __init__( lda: SD1Autoencoder | None = None, clip_text_encoder: CLIPTextEncoderL | None = None, scheduler: Scheduler | None = None, + device: Device | str = "cpu", + dtype: DType = torch.float32, ) -> None: - unet = unet or SD1UNet(in_channels=4) - lda = lda or SD1Autoencoder() - clip_text_encoder = clip_text_encoder or CLIPTextEncoderL() - scheduler = scheduler or DPMSolver(num_inference_steps=30) + unet = unet or SD1UNet(in_channels=4, device=device, dtype=dtype) + lda = lda or SD1Autoencoder(device=device, dtype=dtype) + clip_text_encoder = clip_text_encoder or CLIPTextEncoderL(device=device, dtype=dtype) + scheduler = scheduler or DPMSolver(num_inference_steps=30, device=device, dtype=dtype) super().__init__(unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler) def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor: @@ -99,9 +101,7 @@ def __init__( ) -> None: self.mask_latents: Tensor | None = None self.target_image_latents: Tensor | None = None - super().__init__( - unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler, device=device, dtype=dtype - ) + super().__init__(unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler) def forward( self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **_: Tensor diff --git a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py index 0cb979b94..425262ff6 100644 --- a/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py +++ b/src/refiners/foundationals/latent_diffusion/stable_diffusion_xl/model.py @@ -27,19 +27,12 @@ def __init__( device: Device | str = "cpu", dtype: DType = torch.float32, ) -> None: - unet = unet or SDXLUNet(in_channels=4) - lda = lda or SDXLAutoencoder() - clip_text_encoder = clip_text_encoder or DoubleTextEncoder() - scheduler = scheduler or DDIM(num_inference_steps=30) - - super().__init__( - unet=unet, - lda=lda, - clip_text_encoder=clip_text_encoder, - scheduler=scheduler, - device=device, - dtype=dtype, - ) + unet = unet or SDXLUNet(in_channels=4, device=device, dtype=dtype) + lda = lda or SDXLAutoencoder(device=device, dtype=dtype) + clip_text_encoder = clip_text_encoder or DoubleTextEncoder(device=device, dtype=dtype) + scheduler = scheduler or DDIM(num_inference_steps=30, device=device, dtype=dtype) + + super().__init__(unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler) def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]: conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text) diff --git a/src/refiners/training_utils/latent_diffusion.py b/src/refiners/training_utils/latent_diffusion.py index cbcacf156..20754c86b 100644 --- a/src/refiners/training_utils/latent_diffusion.py +++ b/src/refiners/training_utils/latent_diffusion.py @@ -186,7 +186,7 @@ def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Ten return sample_noise(size=size, offset_noise=self.config.latent_diffusion.offset_noise, dtype=dtype) @cached_property - def mse_loss(self): + def mse_loss(self) -> Callable[[Tensor, Tensor], Tensor]: return self.sharding_manager.wrap_device(mse_loss, self.device) def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor: @@ -199,9 +199,8 @@ def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor: self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding) prediction = self.unet(noisy_latents) - - loss = self.mse_loss(input=prediction, target=noise) - return loss + loss = self.mse_loss(input=prediction, target=noise) # type: ignore + return loss # type: ignore def compute_evaluation(self) -> None: sd = self.sd diff --git a/src/refiners/training_utils/sharding_manager.py b/src/refiners/training_utils/sharding_manager.py index 979b34bf4..8d8efee51 100644 --- a/src/refiners/training_utils/sharding_manager.py +++ b/src/refiners/training_utils/sharding_manager.py @@ -1,13 +1,18 @@ from abc import ABC, abstractmethod -from functools import cached_property, partial, update_wrapper -from typing import Any, Callable, List +from functools import partial, update_wrapper +from typing import Any, Callable, Dict, List from torch import Tensor, device as Device from torch.autograd import backward from torch.nn import Module +from refiners.foundationals.latent_diffusion.schedulers import Scheduler + from .config import ModelConfig, TrainingConfig +Hookable = Module | Scheduler +WrappableMethod = Callable[..., Any] + class ShardingManager(ABC): @abstractmethod @@ -15,7 +20,19 @@ def backward(self, tensor: Tensor) -> None: ... @abstractmethod - def setup_model(self, model: Module, config: ModelConfig): + def setup_model(self, model: Hookable, config: ModelConfig) -> None: + ... + + @abstractmethod + def wrap_device(self, method: WrappableMethod, device: Device) -> WrappableMethod: + ... + + @abstractmethod + def add_device_hook(self, module: Hookable, device: Device, method_name: str) -> None: + ... + + @abstractmethod + def add_device_hooks(self, module: Hookable, device: Device) -> None: ... @property @@ -24,28 +41,25 @@ def device(self) -> Device: raise NotImplementedError("FabricTrainer does not support this property") -from refiners.fluxion.context import ContextProvider - - class SimpleShardingManager(ShardingManager): def __init__(self, config: TrainingConfig) -> None: - self.default_device = config.gpu_index if config.gpu_index is not None else "cpu" + device_str = config.gpu_index if config.gpu_index >= 0 else "cpu" + self.default_device = Device(device_str) def backward(self, tensor: Tensor): backward(tensor) - def setup_model(self, model: Module, config: ModelConfig) -> Module: + def setup_model(self, model: Hookable, config: ModelConfig) -> None: if config.gpu_index is not None: - device = f"cuda:{config.gpu_index}" + device = Device(f"cuda:{config.gpu_index}") else: device = self.default_device model = model.to(device=device) - model = self.add_device_hooks(model, device) - return model + self.add_device_hooks(model, device) # inspired from https://github.com/huggingface/accelerate/blob/6f05bbd41a179cc9a86238c7c6f3f4eded70fbd8/src/accelerate/hooks.py#L159C1-L170C18 - def add_device_hooks(self, module: Module, device: Device) -> None: - method_list = [] + def add_device_hooks(self, module: Hookable, device: Device) -> None: + method_list: List[str] = [] if hasattr(module, "forward") is True: method_list.append("forward") @@ -64,33 +78,29 @@ def add_device_hooks(self, module: Module, device: Device) -> None: def recursive_to(self, obj: Any, device: Device) -> Any: if hasattr(obj, "to"): return obj.to(device) - elif isinstance(obj, dict): - return {k: self.recursive_to(v, device) for k, v in obj.items()} - elif isinstance(obj, list): - return [self.recursive_to(v, device) for v in obj] - elif isinstance(obj, tuple): - return tuple(self.recursive_to(v, device) for v in obj) + elif isinstance(obj, dict): # type: ignore + return {k: self.recursive_to(v, device) for k, v in obj.items()} # type: ignore + elif isinstance(obj, list): # type: ignore + return [self.recursive_to(v, device) for v in obj] # type: ignore + elif isinstance(obj, tuple): # type: ignore + return tuple(self.recursive_to(v, device) for v in obj) # type: ignore else: return obj - def add_device_hook(self, module: Module, device: Device, method_name: str) -> None: + def add_device_hook(self, module: Hookable, device: Device, method_name: str) -> None: old_method = getattr(module, method_name) new_method = self.wrap_device(old_method, device) # new_method = update_wrapper(partial(new_method, module), old_method) - def new_method(module, *args, **kwargs): - args = self.recursive_to(args, device) - kwargs = self.recursive_to(kwargs, device) - output = old_method(*args, **kwargs) - return output + new_method = self.wrap_device(old_method, device) new_method = update_wrapper(partial(new_method, module), old_method) setattr(module, method_name, new_method) - def wrap_device(self, method: Callable, device: Device) -> Callable: - def new_method(*args, **kwargs): + def wrap_device(self, method: WrappableMethod, device: Device) -> WrappableMethod: + def new_method(*args: List[Any], **kwargs: Dict[Any, Any]) -> Any: args = self.recursive_to(args, device) kwargs = self.recursive_to(kwargs, device) return method(*args, **kwargs) diff --git a/src/refiners/training_utils/trainer.py b/src/refiners/training_utils/trainer.py index fca7cfa84..231a76ffa 100644 --- a/src/refiners/training_utils/trainer.py +++ b/src/refiners/training_utils/trainer.py @@ -7,8 +7,8 @@ import numpy as np from loguru import logger -from torch import Tensor, cuda, device as Device, dtype as Dtype, float32, get_rng_state, set_rng_state, stack -from torch.nn import Module, Parameter +from torch import Tensor, cuda, device as Device, get_rng_state, set_rng_state, stack +from torch.nn import Parameter from torch.optim import Optimizer from torch.optim.lr_scheduler import ( CosineAnnealingLR, @@ -35,7 +35,7 @@ GradientValueClipping, MonitorLoss, ) -from refiners.training_utils.config import BaseConfig, ModelConfig, SchedulerType, TimeUnit, TimeValue +from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue from refiners.training_utils.dropout import DropoutCallback from refiners.training_utils.wandb import WandbLoggable, WandbLogger From 14fa0963e28a933d7c98b9defe443d6ee022d8da Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Tue, 16 Jan 2024 00:25:19 +0100 Subject: [PATCH 13/14] rollback second_device in conftest --- tests/conftest.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1b129a5af..d1403ffb3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,14 +15,6 @@ def test_device() -> torch.device: return torch.device(test_device) -@fixture(scope="session") -def test_second_device() -> torch.device: - test_device = os.getenv("REFINERS_TEST_SECOND_DEVICE") - if not test_device and torch.cuda.device_count() > 1: - return torch.device("cuda:1") - return torch.device("cpu") - - @fixture(scope="session") def test_weights_path() -> Path: from_env = os.getenv("REFINERS_TEST_WEIGHTS_DIR") From 2d6476fc649dc168f8703d46a155f5666ae21873 Mon Sep 17 00:00:00 2001 From: Pierre Colle Date: Tue, 16 Jan 2024 01:36:34 +0100 Subject: [PATCH 14/14] fix: sharding_manager --- src/refiners/training_utils/sharding_manager.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/refiners/training_utils/sharding_manager.py b/src/refiners/training_utils/sharding_manager.py index 8d8efee51..b4e5933b9 100644 --- a/src/refiners/training_utils/sharding_manager.py +++ b/src/refiners/training_utils/sharding_manager.py @@ -1,5 +1,4 @@ from abc import ABC, abstractmethod -from functools import partial, update_wrapper from typing import Any, Callable, Dict, List from torch import Tensor, device as Device @@ -89,14 +88,7 @@ def recursive_to(self, obj: Any, device: Device) -> Any: def add_device_hook(self, module: Hookable, device: Device, method_name: str) -> None: old_method = getattr(module, method_name) - - new_method = self.wrap_device(old_method, device) - # new_method = update_wrapper(partial(new_method, module), old_method) - new_method = self.wrap_device(old_method, device) - - new_method = update_wrapper(partial(new_method, module), old_method) - setattr(module, method_name, new_method) def wrap_device(self, method: WrappableMethod, device: Device) -> WrappableMethod: