Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple sharding manager #2

Open
wants to merge 16 commits into
base: handle-url-hf-datasets
Choose a base branch
from
10 changes: 5 additions & 5 deletions configs/finetune-lora.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ entity = "acme"
project = "test-lora-training"

[models]
unet = {checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors"}
text_encoder = {checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors"}
lda = {checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors"}
unet = {checkpoint = "tests/weights/stable-diffusion-1-5/unet.safetensors", gpu_index = 1}
text_encoder = {checkpoint = "tests/weights/stable-diffusion-1-5/CLIPTextEncoderL.safetensors", gpu_index = 0}
lda = {checkpoint = "tests/weights/stable-diffusion-1-5/lda.safetensors", gpu_index = 0}

[latent_diffusion]
unconditional_sampling_probability = 0.05
Expand All @@ -24,8 +24,8 @@ lda_targets = []
duration = "1000:epoch"
seed = 0
gpu_index = 0
batch_size = 4
gradient_accumulation = "4:step"
batch_size = 1
gradient_accumulation = "1:step"
clip_grad_norm = 1.0
# clip_grad_value = 1.0
evaluation_interval = "5:epoch"
Expand Down
3 changes: 3 additions & 0 deletions src/refiners/fluxion/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def image_to_tensor(image: Image.Image, device: Device | str | None = None, dtyp

Values are clamped to the range `[0, 1]`.
"""
if image.mode == "P":
image = image.convert("RGB")

image_tensor = torch.tensor(array(image).astype(float32) / 255.0, device=device, dtype=dtype)

match image.mode:
Expand Down
31 changes: 12 additions & 19 deletions src/refiners/foundationals/latent_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from PIL import Image
from torch import Tensor, device as Device, dtype as DType
from torch import Tensor

import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
Expand All @@ -16,32 +16,27 @@

class LatentDiffusionModel(fl.Module, ABC):
def __init__(
self,
unet: fl.Module,
lda: LatentDiffusionAutoencoder,
clip_text_encoder: fl.Module,
scheduler: Scheduler,
device: Device | str = "cpu",
dtype: DType = torch.float32,
self, unet: fl.Module, lda: LatentDiffusionAutoencoder, clip_text_encoder: fl.Module, scheduler: Scheduler
) -> None:
super().__init__()
self.device: Device = device if isinstance(device, Device) else Device(device=device)
self.dtype = dtype
self.unet = unet.to(device=self.device, dtype=self.dtype)
self.lda = lda.to(device=self.device, dtype=self.dtype)
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
self.unet = unet
self.lda = lda
self.clip_text_encoder = clip_text_encoder
self.scheduler = scheduler

def set_num_inference_steps(self, num_inference_steps: int) -> None:
initial_diffusion_rate = self.scheduler.initial_diffusion_rate
final_diffusion_rate = self.scheduler.final_diffusion_rate
device, dtype = self.scheduler.device, self.scheduler.dtype
self.scheduler = self.scheduler.__class__(

scheduler = self.scheduler.__class__(
num_inference_steps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
).to(device=device, dtype=dtype)

self.scheduler = scheduler

def init_latents(
self,
size: tuple[int, int],
Expand All @@ -51,7 +46,7 @@ def init_latents(
) -> Tensor:
height, width = size
if noise is None:
noise = torch.randn(1, 4, height // 8, width // 8, device=self.device)
noise = torch.randn(1, 4, height // 8, width // 8)
assert list(noise.shape[2:]) == [
height // 8,
width // 8,
Expand Down Expand Up @@ -90,6 +85,7 @@ def forward(
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)

latents = torch.cat(tensors=(x, x)) # for classifier-free guidance

# scale latents for schedulers that need it
latents = self.scheduler.scale_model_input(latents, step=step)
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)
Expand All @@ -102,7 +98,6 @@ def forward(
noise += self.compute_self_attention_guidance(
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
)

return self.scheduler(x, noise=noise, step=step)

def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
Expand All @@ -111,6 +106,4 @@ def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
lda=self.lda.structural_copy(),
clip_text_encoder=self.clip_text_encoder.structural_copy(),
scheduler=self.scheduler,
device=self.device,
dtype=self.dtype,
)
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,13 @@ def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tens
return denoised_x

def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
# We pass forward to give the ability to
# dynamically change the behavior of the solver
# using the sharding_manager
# TODO: change the Scheduler abstract class
return self.forward(x, noise, step, generator)

def forward(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.

Expand All @@ -100,6 +107,7 @@ def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | N
"""
current_timestep = self.timesteps[step]
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]

estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
self.estimated_data.append(estimated_denoised_data)
denoised_x = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,11 @@ def __init__(
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
unet = unet or SD1UNet(in_channels=4)
lda = lda or SD1Autoencoder()
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL()
scheduler = scheduler or DPMSolver(num_inference_steps=30)

super().__init__(
unet=unet,
lda=lda,
clip_text_encoder=clip_text_encoder,
scheduler=scheduler,
device=device,
dtype=dtype,
)
unet = unet or SD1UNet(in_channels=4, device=device, dtype=dtype)
lda = lda or SD1Autoencoder(device=device, dtype=dtype)
clip_text_encoder = clip_text_encoder or CLIPTextEncoderL(device=device, dtype=dtype)
scheduler = scheduler or DPMSolver(num_inference_steps=30, device=device, dtype=dtype)
super().__init__(unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler)

def compute_clip_text_embedding(self, text: str, negative_text: str = "") -> Tensor:
conditional_embedding = self.clip_text_encoder(text)
Expand Down Expand Up @@ -109,9 +101,7 @@ def __init__(
) -> None:
self.mask_latents: Tensor | None = None
self.target_image_latents: Tensor | None = None
super().__init__(
unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler, device=device, dtype=dtype
)
super().__init__(unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler)

def forward(
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **_: Tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,12 @@ def __init__(
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
unet = unet or SDXLUNet(in_channels=4)
lda = lda or SDXLAutoencoder()
clip_text_encoder = clip_text_encoder or DoubleTextEncoder()
scheduler = scheduler or DDIM(num_inference_steps=30)

super().__init__(
unet=unet,
lda=lda,
clip_text_encoder=clip_text_encoder,
scheduler=scheduler,
device=device,
dtype=dtype,
)
unet = unet or SDXLUNet(in_channels=4, device=device, dtype=dtype)
lda = lda or SDXLAutoencoder(device=device, dtype=dtype)
clip_text_encoder = clip_text_encoder or DoubleTextEncoder(device=device, dtype=dtype)
scheduler = scheduler or DDIM(num_inference_steps=30, device=device, dtype=dtype)

super().__init__(unet=unet, lda=lda, clip_text_encoder=clip_text_encoder, scheduler=scheduler)

def compute_clip_text_embedding(self, text: str, negative_text: str | None = None) -> tuple[Tensor, Tensor]:
conditional_embedding, conditional_pooled_embedding = self.clip_text_encoder(text)
Expand Down
2 changes: 1 addition & 1 deletion src/refiners/training_utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ class ModelConfig(BaseModel):
checkpoint: Path | None = None
train: bool = True
learning_rate: float | None = None # TODO: Implement this
gpu_index: int | None = None


class GyroDropoutConfig(BaseModel):
Expand Down Expand Up @@ -231,5 +232,4 @@ class BaseConfig(BaseModel):
def load_from_toml(cls: Type[T], toml_path: Path | str) -> T:
with open(file=toml_path, mode="rb") as f:
config_dict = tomli.load(f)

return cls(**config_dict)
70 changes: 38 additions & 32 deletions src/refiners/training_utils/latent_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
from functools import cached_property
from typing import Any, Callable, TypedDict, TypeVar

from datasets import DownloadManager # type: ignore
from datasets import DownloadManager # type: ignore
from loguru import logger
from PIL import Image
from pydantic import BaseModel
from torch import Generator, Tensor, cat, device as Device, dtype as DType, randn
from torch import Generator, Tensor, cat, dtype as DType, randn
from torch.nn import Module
from torch.nn.functional import mse_loss
from torch.utils.data import Dataset
Expand Down Expand Up @@ -69,7 +69,6 @@ class TextEmbeddingLatentsDataset(Dataset[TextEmbeddingLatentsBatch]):
def __init__(self, trainer: "LatentDiffusionTrainer[Any]") -> None:
self.trainer = trainer
self.config = trainer.config
self.device = self.trainer.device
self.lda = self.trainer.lda
self.text_encoder = self.trainer.text_encoder
self.dataset = self.load_huggingface_dataset()
Expand Down Expand Up @@ -103,14 +102,14 @@ def process_caption(self, caption: str) -> str:
return caption if random.random() > self.config.latent_diffusion.unconditional_sampling_probability else ""

def get_caption(self, index: int, caption_key: str) -> str:
return self.dataset[index][caption_key] # type: ignore
return self.dataset[index][caption_key] # type: ignore

def get_image(self, index: int) -> Image.Image:
if "image" in self.dataset[index]:
return self.dataset[index]["image"]
elif "url" in self.dataset[index]:
url : str = self.dataset[index]["url"]
filename : str = self.download_manager.download(url) # type: ignore
url: str = self.dataset[index]["url"]
filename: str = self.download_manager.download(url) # type: ignore
return Image.open(filename)
else:
raise RuntimeError(f"Dataset item at index [{index}] does not contain 'image' or 'url'")
Expand All @@ -124,9 +123,9 @@ def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch:
max_size=self.config.dataset.resize_image_max_size,
)
processed_image = self.process_image(resized_image)
latents = self.lda.encode_image(image=processed_image).to(device=self.device)
latents = self.lda.encode_image(image=processed_image)
processed_caption = self.process_caption(caption=caption)
clip_text_embedding = self.text_encoder(processed_caption).to(device=self.device)
clip_text_embedding = self.text_encoder(processed_caption)
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)

def collate_fn(self, batch: list[TextEmbeddingLatentsBatch]) -> TextEmbeddingLatentsBatch:
Expand All @@ -142,17 +141,18 @@ class LatentDiffusionTrainer(Trainer[ConfigType, TextEmbeddingLatentsBatch]):
@cached_property
def unet(self) -> SD1UNet:
assert self.config.models["unet"] is not None, "The config must contain a unet entry."
return SD1UNet(in_channels=4, device=self.device).to(device=self.device)
return SD1UNet(in_channels=4, device=self.device)

@cached_property
def text_encoder(self) -> CLIPTextEncoderL:
assert self.config.models["text_encoder"] is not None, "The config must contain a text_encoder entry."
return CLIPTextEncoderL(device=self.device).to(device=self.device)
return CLIPTextEncoderL(device=self.device)

@cached_property
def lda(self) -> SD1Autoencoder:
assert self.config.models["lda"] is not None, "The config must contain a lda entry."
return SD1Autoencoder(device=self.device).to(device=self.device)
lda = SD1Autoencoder()
return lda

def load_models(self) -> dict[str, fl.Module]:
return {"unet": self.unet, "text_encoder": self.text_encoder, "lda": self.lda}
Expand All @@ -162,40 +162,48 @@ def load_dataset(self) -> Dataset[TextEmbeddingLatentsBatch]:

@cached_property
def ddpm_scheduler(self) -> DDPM:
return DDPM(
num_inference_steps=1000,
ddpm_scheduler = DDPM(num_inference_steps=1000, device=self.device)
self.sharding_manager.add_device_hook(ddpm_scheduler, ddpm_scheduler.device, "add_noise")
return ddpm_scheduler

@cached_property
def sd(self) -> StableDiffusion_1:
scheduler = DPMSolver(
device=self.device,
).to(device=self.device)
num_inference_steps=self.config.test_diffusion.num_inference_steps,
)

self.sharding_manager.add_device_hooks(scheduler, scheduler.device)

return StableDiffusion_1(unet=self.unet, lda=self.lda, clip_text_encoder=self.text_encoder, scheduler=scheduler)

def sample_timestep(self) -> Tensor:
random_step = random.randint(a=self.config.latent_diffusion.min_step, b=self.config.latent_diffusion.max_step)
self.current_step = random_step
return self.ddpm_scheduler.timesteps[random_step].unsqueeze(dim=0)

def sample_noise(self, size: tuple[int, ...], dtype: DType | None = None) -> Tensor:
return sample_noise(
size=size, offset_noise=self.config.latent_diffusion.offset_noise, device=self.device, dtype=dtype
)
return sample_noise(size=size, offset_noise=self.config.latent_diffusion.offset_noise, dtype=dtype)

@cached_property
def mse_loss(self) -> Callable[[Tensor, Tensor], Tensor]:
return self.sharding_manager.wrap_device(mse_loss, self.device)

def compute_loss(self, batch: TextEmbeddingLatentsBatch) -> Tensor:
clip_text_embedding, latents = batch.text_embeddings, batch.latents
timestep = self.sample_timestep()
noise = self.sample_noise(size=latents.shape, dtype=latents.dtype)
noisy_latents = self.ddpm_scheduler.add_noise(x=latents, noise=noise, step=self.current_step)

self.unet.set_timestep(timestep=timestep)
self.unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)

prediction = self.unet(noisy_latents)
loss = mse_loss(input=prediction, target=noise)
return loss
loss = self.mse_loss(input=prediction, target=noise) # type: ignore
return loss # type: ignore

def compute_evaluation(self) -> None:
sd = StableDiffusion_1(
unet=self.unet,
lda=self.lda,
clip_text_encoder=self.text_encoder,
scheduler=DPMSolver(num_inference_steps=self.config.test_diffusion.num_inference_steps),
device=self.device,
)
sd = self.sd
prompts = self.config.test_diffusion.prompts
num_images_per_prompt = self.config.test_diffusion.num_images_per_prompt
if self.config.test_diffusion.use_short_prompts:
Expand All @@ -205,8 +213,8 @@ def compute_evaluation(self) -> None:
canvas_image: Image.Image = Image.new(mode="RGB", size=(512, 512 * num_images_per_prompt))
for i in range(num_images_per_prompt):
logger.info(f"Generating image {i+1}/{num_images_per_prompt} for prompt: {prompt}")
x = randn(1, 4, 64, 64, device=self.device)
clip_text_embedding = sd.compute_clip_text_embedding(text=prompt).to(device=self.device)
x = randn(1, 4, 64, 64)
clip_text_embedding = sd.compute_clip_text_embedding(text=prompt)
for step in sd.steps:
x = sd(
x,
Expand All @@ -221,7 +229,6 @@ def compute_evaluation(self) -> None:
def sample_noise(
size: tuple[int, ...],
offset_noise: float = 0.1,
device: Device | str = "cpu",
dtype: DType | None = None,
generator: Generator | None = None,
) -> Tensor:
Expand All @@ -230,9 +237,8 @@ def sample_noise(
If `offset_noise` is more than 0, the noise will be offset by a small amount. It allows the model to generate
images with a wider range of contrast https://www.crosslabs.org/blog/diffusion-with-offset-noise.
"""
device = Device(device)
noise = randn(*size, generator=generator, device=device, dtype=dtype)
return noise + offset_noise * randn(*size[:2], 1, 1, generator=generator, device=device, dtype=dtype)
noise = randn(*size, generator=generator, dtype=dtype)
return noise + offset_noise * randn(*size[:2], 1, 1, generator=generator, dtype=dtype)


def resize_image(image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image:
Expand Down
Loading