Skip to content

Commit

Permalink
clean up hook mechanism
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Jan 15, 2024
1 parent d999fd4 commit d4256bd
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 32 deletions.
1 change: 1 addition & 0 deletions src/refiners/foundationals/latent_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def set_num_inference_steps(self, num_inference_steps: int) -> None:
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
final_diffusion_rate = self.scheduler.final_diffusion_rate
print(f"Setting num_inference_steps to {num_inference_steps}")
raise NotImplementedError
scheduler = self.scheduler.__class__(
num_inference_steps,
initial_diffusion_rate=initial_diffusion_rate,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import torch
from PIL import Image
from torch import Tensor, device as Device, dtype as DType
from loguru import logger

from refiners.fluxion.utils import image_to_tensor, interpolate
from refiners.foundationals.clip.text_encoder import CLIPTextEncoderL
Expand Down Expand Up @@ -44,12 +43,8 @@ def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Ten
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)

def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
# Question :
# Can we do this as part of the SetContext Logic ?
self.unet.set_timestep(timestep=timestep.to(device=self.unet.device, dtype=self.unet.dtype))
self.unet.set_clip_text_embedding(
clip_text_embedding=clip_text_embedding.to(device=self.unet.device, dtype=self.unet.dtype)
)
self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)

def set_self_attention_guidance(self, enable: bool, scale: float = 1.0) -> None:
if enable:
Expand Down
24 changes: 11 additions & 13 deletions src/refiners/training_utils/latent_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
@cached_property
def unet(self) -> SD1UNet:
assert self.config.models["unet"] is not None, "The config must contain a unet entry."
print(f"self.device: {self.device}")
return SD1UNet(in_channels=4, device=self.device)

@cached_property
Expand All @@ -151,9 +152,7 @@ def text_encoder(self) -> CLIPTextEncoderL:
@cached_property
def lda(self) -> SD1Autoencoder:
assert self.config.models["lda"] is not None, "The config must contain a lda entry."
lda = SD1Autoencoder(device=self.device)
# TODO: clean this up
lda._tensor_methods = ["encode", "decode"]
lda = SD1Autoencoder()
return lda

def load_models(self) -> dict[str, fl.Module]:
Expand All @@ -165,8 +164,7 @@ def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:
@cached_property
def ddpm_scheduler(self) -> DDPM:
ddpm_scheduler = DDPM(num_inference_steps=1000, device=self.device)
ddpm_scheduler._tensor_methods = ["add_noise"]
self.sharding_manager.add_execution_hooks(ddpm_scheduler, self.device)
self.sharding_manager.add_execution_hook(ddpm_scheduler, self.device, "add_noise")
return ddpm_scheduler

@cached_property
Expand All @@ -176,7 +174,7 @@ def sd(self) -> StableDiffusion_1:
num_inference_steps=self.config.test_diffusion.num_inference_steps,
)

scheduler = self.sharding_manager.add_execution_hooks(scheduler, scheduler.device)
self.sharding_manager.add_execution_hooks(scheduler, scheduler.device)

return StableDiffusion_1(unet=self.unet, lda=self.lda, clip_text_encoder=self.text_encoder, scheduler=scheduler)

Expand All @@ -187,25 +185,25 @@ def sample_timestep(self) -> Tensor:

def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor:
return sample_noise(size=size, offset_noise=self.config.latent_diffusion.offset_noise, dtype=dtype)

@cached_property
def mse_loss(self):
return self.sharding_manager.bind_input_to_device(mse_loss, self.device)

def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor:
clip_text_embedding, latents = batch.text_embeddings, batch.latents
timestep = self.sample_timestep()
noise = self.sample_noise(size=latents.shape, dtype=latents.dtype)
noisy_latents = self.ddpm_scheduler.add_noise(x=latents, noise=noise, step=self.current_step)

# Question :
# Can we do this as part of the SetContext Logic ?
self.unet.set_timestep(timestep=timestep.to(device=self.unet.device, dtype=self.unet.dtype))
self.unet.set_clip_text_embedding(
clip_text_embedding=clip_text_embedding.to(device=self.unet.device, dtype=self.unet.dtype)
)
self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)

prediction = self.unet(noisy_latents)

# Question :
# Can we move this mse_loss device alignement outside of the compute_loss ?
loss = mse_loss(input=prediction, target=noise.to(device=prediction.device))
loss = self.mse_loss(input=prediction, target=noise)
return loss

def compute_evaluation(self) -> None:
Expand Down
57 changes: 45 additions & 12 deletions src/refiners/training_utils/sharding_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from torch.autograd import backward
from abc import ABC, abstractmethod
from functools import cached_property, partial, update_wrapper

from typing import Any, List, Callable

class ShardingManager(ABC):
@abstractmethod
Expand All @@ -20,6 +20,7 @@ def setup_model(self, model: Module, config: ModelConfig):
def device(self) -> Device:
raise NotImplementedError("FabricTrainer does not support this property")

from refiners.fluxion.context import ContextProvider

class SimpleShardingManager(ShardingManager):
def __init__(self, config: TrainingConfig) -> None:
Expand All @@ -39,29 +40,61 @@ def setup_model(self, model: Module, config: ModelConfig) -> Module:

# inspired from https://github.com/huggingface/accelerate/blob/6f05bbd41a179cc9a86238c7c6f3f4eded70fbd8/src/accelerate/hooks.py#L159C1-L170C18
def add_execution_hooks(self, module: Module, device: Device) -> None:
if hasattr(module, "_tensor_methods") is False:
method_list = ["forward"]
else:
method_list = module._tensor_methods

method_list = []
if hasattr(module, "forward") is True:
method_list.append("forward")

if hasattr(module, "set_context") is True:
method_list.append("set_context")

if hasattr(module, "encode") is True:
method_list.append("encode")

if hasattr(module, "decode") is True:
method_list.append("decode")

for method_name in method_list:
module = self.add_execution_hook(module, device, method_name)
return module
self.add_execution_hook(module, device, method_name)


def recursive_to(self, obj: Any, device: Device) -> Any:
if hasattr(obj, "to"):
return obj.to(device)
elif isinstance(obj, dict):
return {k: self.recursive_to(v, device) for k, v in obj.items()}
elif isinstance(obj, list):
return [self.recursive_to(v, device) for v in obj]
elif isinstance(obj, tuple):
return tuple(self.recursive_to(v, device) for v in obj)
else:
return obj

def add_execution_hook(self, module: Module, device: Device, method_name: str) -> None:

old_method = getattr(module, method_name)

new_method = self.bind_input_to_device(old_method, device)
# new_method = update_wrapper(partial(new_method, module), old_method)

def new_method(module, *args, **kwargs):
args = [arg.to(device) if hasattr(arg, "to") else arg for arg in args]
kwargs = {k: v.to(device) if hasattr(v, "to") else v for k, v in kwargs.items()}
args = self.recursive_to(args, device)
kwargs = self.recursive_to(kwargs, device)
output = old_method(*args, **kwargs)
return output

new_method = update_wrapper(partial(new_method, module), old_method)

setattr(module, method_name, new_method)
return module


def bind_input_to_device(self, method: Callable, device: Device) -> Callable:
def new_method(*args, **kwargs):
args = self.recursive_to(args, device)
kwargs = self.recursive_to(kwargs, device)
return method(*args, **kwargs)

return new_method


@property
def device(self) -> Device:
return self.default_device

0 comments on commit d4256bd

Please sign in to comment.