Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Jan 15, 2024
1 parent a05aaee commit c063e96
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 58 deletions.
6 changes: 3 additions & 3 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,9 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp
Values are clamped to the range `[0, 1]`.
"""
if image.mode == 'P':
image = image.convert('RGB')
if image.mode == "P":
image = image.convert("RGB")

image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)

match image.mode:
Expand Down
2 changes: 1 addition & 1 deletion src/refiners/foundationals/latent_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from PIL import Image
from torch import Tensor, device as Device, dtype as DType
from torch import Tensor

import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ def __init__(
lda: SD1Autoencoder | None = None,
clip_text_encoder: CLIPTextEncoderL | None = None,
scheduler: Scheduler | None = None,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
unet = unet or SD1UNet(in_channels=4)
lda = lda or SD1Autoencoder()
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
scheduler = scheduler or DPMSolver(num_inference_steps=30)
unet = unet or SD1UNet(in_channels=4, device=device, dtype=dtype)
lda = lda or SD1Autoencoder(device=device, dtype=dtype)
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL(device=device, dtype=dtype)
scheduler = scheduler or DPMSolver(num_inference_steps=30, device=device, dtype=dtype)
super().__init__(unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler)

def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor:
Expand Down Expand Up @@ -99,9 +101,7 @@ def __init__(
) -> None:
self.mask_latents: Tensor | None = None
self.target_image_latents: Tensor | None = None
super().__init__(
unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler, device=device, dtype=dtype
)
super().__init__(unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler)

def forward(
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **_: Tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,12 @@ def __init__(
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
unet = unet or SDXLUNet(in_channels=4)
lda = lda or SDXLAutoencoder()
clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
scheduler = scheduler or DDIM(num_inference_steps=30)

super().__init__(
unet=unet,
lda=lda,
clip_text_encoder=clip_text_encoder,
scheduler=scheduler,
device=device,
dtype=dtype,
)
unet = unet or SDXLUNet(in_channels=4, device=device, dtype=dtype)
lda = lda or SDXLAutoencoder(device=device, dtype=dtype)
clip_text_encoder = clip_text_encoder or DoubleTextEncoder(device=device, dtype=dtype)
scheduler = scheduler or DDIM(num_inference_steps=30, device=device, dtype=dtype)

super().__init__(unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler)

def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]:
conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
Expand Down
7 changes: 3 additions & 4 deletions src/refiners/training_utils/latent_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Ten
return sample_noise(size=size, offset_noise=self.config.latent_diffusion.offset_noise, dtype=dtype)

@cached_property
def mse_loss(self):
def mse_loss(self) -> Callable[[Tensor, Tensor], Tensor]:
return self.sharding_manager.wrap_device(mse_loss, self.device)

def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor:
Expand All @@ -199,9 +199,8 @@ def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor:
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)

prediction = self.unet(noisy_latents)

loss = self.mse_loss(input=prediction, target=noise)
return loss
loss = self.mse_loss(input=prediction, target=noise) # type: ignore
return loss # type: ignore

def compute_evaluation(self) -> None:
sd = self.sd
Expand Down
64 changes: 37 additions & 27 deletions src/refiners/training_utils/sharding_manager.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,38 @@
from abc import ABC, abstractmethod
from functools import cached_property, partial, update_wrapper
from typing import Any, Callable, List
from functools import partial, update_wrapper
from typing import Any, Callable, Dict, List

from torch import Tensor, device as Device
from torch.autograd import backward
from torch.nn import Module

from refiners.foundationals.latent_diffusion.schedulers import Scheduler

from .config import ModelConfig, TrainingConfig

Hookable = Module | Scheduler
WrappableMethod = Callable[..., Any]


class ShardingManager(ABC):
@abstractmethod
def backward(self, tensor: Tensor) -> None:
...

@abstractmethod
def setup_model(self, model: Module, config: ModelConfig):
def setup_model(self, model: Hookable, config: ModelConfig) -> None:
...

@abstractmethod
def wrap_device(self, method: WrappableMethod, device: Device) -> WrappableMethod:
...

@abstractmethod
def add_device_hook(self, module: Hookable, device: Device, method_name: str) -> None:
...

@abstractmethod
def add_device_hooks(self, module: Hookable, device: Device) -> None:
...

@property
Expand All @@ -24,28 +41,25 @@ def device(self) -> Device:
raise NotImplementedError("FabricTrainer does not support this property")


from refiners.fluxion.context import ContextProvider


class SimpleShardingManager(ShardingManager):
def __init__(self, config: TrainingConfig) -> None:
self.default_device = config.gpu_index if config.gpu_index is not None else "cpu"
device_str = config.gpu_index if config.gpu_index >= 0 else "cpu"
self.default_device = Device(device_str)

def backward(self, tensor: Tensor):
backward(tensor)

def setup_model(self, model: Module, config: ModelConfig) -> Module:
def setup_model(self, model: Hookable, config: ModelConfig) -> None:
if config.gpu_index is not None:
device = f"cuda:{config.gpu_index}"
device = Device(f"cuda:{config.gpu_index}")
else:
device = self.default_device
model = model.to(device=device)
model = self.add_device_hooks(model, device)
return model
self.add_device_hooks(model, device)

# inspired from https://github.com/huggingface/accelerate/blob/6f05bbd41a179cc9a86238c7c6f3f4eded70fbd8/src/accelerate/hooks.py#L159C1-L170C18
def add_device_hooks(self, module: Module, device: Device) -> None:
method_list = []
def add_device_hooks(self, module: Hookable, device: Device) -> None:
method_list: List[str] = []
if hasattr(module, "forward") is True:
method_list.append("forward")

Expand All @@ -64,33 +78,29 @@ def add_device_hooks(self, module: Module, device: Device) -> None:
def recursive_to(self, obj: Any, device: Device) -> Any:
if hasattr(obj, "to"):
return obj.to(device)
elif isinstance(obj, dict):
return {k: self.recursive_to(v, device) for k, v in obj.items()}
elif isinstance(obj, list):
return [self.recursive_to(v, device) for v in obj]
elif isinstance(obj, tuple):
return tuple(self.recursive_to(v, device) for v in obj)
elif isinstance(obj, dict): # type: ignore
return {k: self.recursive_to(v, device) for k, v in obj.items()} # type: ignore
elif isinstance(obj, list): # type: ignore
return [self.recursive_to(v, device) for v in obj] # type: ignore
elif isinstance(obj, tuple): # type: ignore
return tuple(self.recursive_to(v, device) for v in obj) # type: ignore
else:
return obj

def add_device_hook(self, module: Module, device: Device, method_name: str) -> None:
def add_device_hook(self, module: Hookable, device: Device, method_name: str) -> None:
old_method = getattr(module, method_name)

new_method = self.wrap_device(old_method, device)
# new_method = update_wrapper(partial(new_method, module), old_method)

def new_method(module, *args, **kwargs):
args = self.recursive_to(args, device)
kwargs = self.recursive_to(kwargs, device)
output = old_method(*args, **kwargs)
return output
new_method = self.wrap_device(old_method, device)

new_method = update_wrapper(partial(new_method, module), old_method)

setattr(module, method_name, new_method)

def wrap_device(self, method: Callable, device: Device) -> Callable:
def new_method(*args, **kwargs):
def wrap_device(self, method: WrappableMethod, device: Device) -> WrappableMethod:
def new_method(*args: List[Any], **kwargs: Dict[Any, Any]) -> Any:
args = self.recursive_to(args, device)
kwargs = self.recursive_to(kwargs, device)
return method(*args, **kwargs)
Expand Down
6 changes: 3 additions & 3 deletions src/refiners/training_utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import numpy as np
from loguru import logger
from torch import Tensor, cuda, device as Device, dtype as Dtype, float32, get_rng_state, set_rng_state, stack
from torch.nn import Module, Parameter
from torch import Tensor, cuda, device as Device, get_rng_state, set_rng_state, stack
from torch.nn import Parameter
from torch.optim import Optimizer
from torch.optim.lr_scheduler import (
CosineAnnealingLR,
Expand All @@ -35,7 +35,7 @@
GradientValueClipping,
MonitorLoss,
)
from refiners.training_utils.config import BaseConfig, ModelConfig, SchedulerType, TimeUnit, TimeValue
from refiners.training_utils.config import BaseConfig, SchedulerType, TimeUnit, TimeValue
from refiners.training_utils.dropout import DropoutCallback
from refiners.training_utils.wandb import WandbLoggable, WandbLogger

Expand Down

0 comments on commit c063e96

Please sign in to comment.