Skip to content

Commit

Permalink
rollback batch preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Jan 26, 2024
1 parent a09f91d commit d8ff953
Show file tree
Hide file tree
Showing 4 changed files with 174 additions and 186 deletions.
87 changes: 44 additions & 43 deletions src/refiners/training_utils/datasets/color_palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,30 +4,23 @@

import numpy as np
from pydantic import BaseModel
from torch import Tensor, tensor, empty
from PIL import Image
from refiners.training_utils.datasets.latent_diffusion import TextEmbeddingLatentsBaseDataset
from torch import Tensor, cat, tensor, empty

from refiners.fluxion.adapters.color_palette import ColorPaletteEncoder
from refiners.foundationals.clip.text_encoder import CLIPTextEncoder
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
from refiners.training_utils.datasets.latent_diffusion import TextEmbeddingLatentsBaseDataset, TextEmbeddingLatentsBatch
from refiners.training_utils.huggingface_datasets import HuggingfaceDatasetConfig
from loguru import logger

Color = Tuple[int, int, int]

ColorPalette = List[Color]

@dataclass
class ColorPaletteDatasetItem:
color_palette: ColorPalette
text: str
image: Image.Image
conditional_flag: bool

@dataclass
class DatasetItem:
palettes: dict[str, ColorPalette]
image: Image.Image

@dataclass
class TextEmbeddingColorPaletteLatentsBatch(TextEmbeddingLatentsBatch):
color_palette_embeddings: Tensor

TextEmbeddingColorPaletteLatentsBatch = List[ColorPaletteDatasetItem]

class SamplingByPalette(BaseModel):
palette_1: float = 1.0
Expand All @@ -44,51 +37,59 @@ class ColorPaletteDataset(TextEmbeddingLatentsBaseDataset[TextEmbeddingColorPale
def __init__(
self,
config: HuggingfaceDatasetConfig,
lda: SD1Autoencoder,
text_encoder: CLIPTextEncoder,
color_palette_encoder: ColorPaletteEncoder,
sampling_by_palette: SamplingByPalette = SamplingByPalette(),
unconditional_sampling_probability: float = 0.2,
) -> None:
self.sampling_by_palette = sampling_by_palette
self.color_palette_encoder = color_palette_encoder
super().__init__(
config=config,
lda=lda,
text_encoder=text_encoder,
unconditional_sampling_probability=unconditional_sampling_probability,
)

def __getitem__(self, index: int) -> TextEmbeddingColorPaletteLatentsBatch:

item : DatasetItem = self.hf_dataset[index]
resized_image = self.resize_image(
image=item["image"],
min_size=self.config.resize_image_min_size,
max_size=self.config.resize_image_max_size,
(latents, _) = self.get_processed_latents(index)
(clip_text_embedding, color_palette_embedding) = self.process_text_embedding_and_palette(index)

return TextEmbeddingColorPaletteLatentsBatch(
text_embeddings=clip_text_embedding, latents=latents, color_palette_embeddings=color_palette_embedding
)

image = self.process_image(resized_image)

caption_key = self.config.caption_key
caption = item[caption_key]
(caption_processed, conditional_flag) = self.process_caption(caption)

return [
ColorPaletteDatasetItem(
color_palette=self.process_color_palette(item),
text=caption_processed,
image=image,
conditional_flag=conditional_flag
)
]

def process_color_palette(self, item: DatasetItem) -> ColorPalette:
def process_text_embedding_and_palette(self, index: int) -> tuple[Tensor, Tensor]:
caption = self.get_caption(index=index)

(processed_caption, conditionnal_flag) = self.process_caption(caption=caption)

if not conditionnal_flag:
return (self.text_encoder(caption), self.color_palette_encoder([[]]))

clip_text_embedding = self.text_encoder(processed_caption)
color_palette_embedding = self.get_processed_palette(index)
return (clip_text_embedding, color_palette_embedding)

def get_processed_palette(self, index: int) -> Tensor:

return self.color_palette_encoder([self.get_color_palette(index)])

def get_color_palette(self, index: int) -> ColorPalette:
choices = range(1, 9)
weights = np.array([getattr(self.sampling_by_palette, f"palette_{i}") for i in choices])
sum = weights.sum()
probabilities = weights / sum
palette_index = int(random.choices(choices, probabilities, k=1)[0])
item = self.hf_dataset[index]
palette: ColorPalette = item[f"palettes"][str(palette_index)]

return palette
def get_color_palette(self, index: int) -> ColorPalette:
item = self.hf_dataset[index]
return self.process_color_palette(item)

def collate_fn(self, batch: list[TextEmbeddingColorPaletteLatentsBatch]) -> TextEmbeddingColorPaletteLatentsBatch:
return [item for sublist in batch for item in sublist]
text_embeddings = cat(tensors=[item.text_embeddings for item in batch])
latents = cat(tensors=[item.latents for item in batch])
color_palette_embeddings = cat(tensors=[item.color_palette_embeddings for item in batch])
return TextEmbeddingColorPaletteLatentsBatch(
text_embeddings=text_embeddings, latents=latents, color_palette_embeddings=color_palette_embeddings
)
228 changes: 114 additions & 114 deletions src/refiners/training_utils/datasets/latent_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,137 +1,137 @@
import random
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Callable, TypeVar, List
from typing import Any, Callable, TypeVar

from datasets import DownloadManager # type: ignore
from loguru import logger
from PIL import Image
from torch import Tensor, cat
from torch.nn import Module as TorchModule
from torch.utils.data import Dataset
from torchvision.transforms import Compose, RandomCrop, RandomHorizontalFlip # type: ignore

from refiners.foundationals.clip.text_encoder import CLIPTextEncoder
from refiners.foundationals.latent_diffusion.stable_diffusion_1.model import SD1Autoencoder
from refiners.training_utils.datasets.utils import resize_image
from refiners.training_utils.huggingface_datasets import HuggingfaceDataset, HuggingfaceDatasetConfig, load_hf_dataset


@dataclass
class TextImageDatasetItem:
text: str
image: Image.Image
class TextEmbeddingLatentsBatch:
text_embeddings: Tensor
latents: Tensor

TextEmbeddingLatentsBatch = List[TextImageDatasetItem]

BatchType = TypeVar("BatchType", bound=List[Any])
BatchType = TypeVar("BatchType", bound=Any)


class TextEmbeddingLatentsBaseDataset(Dataset[BatchType]):
def __init__(
self,
config: HuggingfaceDatasetConfig,
unconditional_sampling_probability: float = 0.2,
) -> None:
self.config = config
self.hf_dataset = self.load_huggingface_dataset()
self.process_image = self.build_image_processor()
self.download_manager = DownloadManager()
self.unconditional_sampling_probability = unconditional_sampling_probability

logger.info(f"Loaded {len(self.hf_dataset)} samples from dataset")

def build_image_processor(self) -> Callable[[Image.Image], Image.Image]:
# TODO: make this configurable and add other transforms
transforms: list[TorchModule] = []
if self.config.random_crop:
transforms.append(RandomCrop(size=512))
if self.config.horizontal_flip:
transforms.append(RandomHorizontalFlip(p=0.5))
if not transforms:
return lambda image: image
return Compose(transforms)

def load_huggingface_dataset(self) -> HuggingfaceDataset[Any]:
dataset_config = self.config
logger.info(f"Loading dataset from {dataset_config.hf_repo} revision {dataset_config.revision}")
dataset = load_hf_dataset(
path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split
)
return dataset

def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image:
return resize_image(image=image, min_size=min_size, max_size=max_size)

def process_caption(self, caption: str) -> tuple[str, bool]:
conditional_flag = random.random() > self.unconditional_sampling_probability
if conditional_flag:
return (caption, conditional_flag)
else:
return ("", conditional_flag)

def get_caption(self, index: int) -> str:
caption_key = self.config.caption_key or "caption"

caption = self.hf_dataset[index][caption_key]
if not isinstance(caption, str):
raise RuntimeError(
f"Dataset item at index [{index}] and caption_key [{caption_key}] does not contain a string caption"
)
return caption

def get_image(self, index: int) -> Image.Image:
item = self.hf_dataset[index]
if "image" in item:
return item["image"]
elif "url" in item:
url: str = item["url"]
filename: str = self.download_manager.download(url) # type: ignore
return Image.open(filename)
else:
raise RuntimeError(f"Dataset item at index [{index}] does not contain 'image' or 'url'")

@abstractmethod
def get_hf_item(self, index: int) -> Any:
return self.hf_dataset[index]

@abstractmethod
def __getitem__(self, index: int) -> BatchType:
...

@abstractmethod
def collate_fn(self, batch: list[BatchType]) -> BatchType:
...

# def get_processed_text_embedding(self, index: int) -> Tensor:
# caption = self.get_caption(index=index)
# (processed_caption, _) = self.process_caption(caption=caption)
# return self.text_encoder(processed_caption)

def get_processed_image(self, index: int) -> Image.Image:
image = self.get_image(index=index)
logger.info(f"resize_image image {index}")

resized_image = self.resize_image(
image=image,
min_size=self.config.resize_image_min_size,
max_size=self.config.resize_image_max_size,
)
logger.info(f"resized_image image {index}")

return self.process_image(resized_image)

def __len__(self) -> int:
return len(self.hf_dataset)
def __init__(
self,
config: HuggingfaceDatasetConfig,
lda: SD1Autoencoder,
text_encoder: CLIPTextEncoder,
unconditional_sampling_probability: float = 0.2,
) -> None:
self.config = config
self.lda = lda
self.text_encoder = text_encoder
self.hf_dataset = self.load_huggingface_dataset()
self.process_image = self.build_image_processor()
self.download_manager = DownloadManager()
self.unconditional_sampling_probability = unconditional_sampling_probability

logger.info(f"Loaded {len(self.hf_dataset)} samples from dataset")

def build_image_processor(self) -> Callable[[Image.Image], Image.Image]:
# TODO: make this configurable and add other transforms
transforms: list[TorchModule] = []
if self.config.random_crop:
transforms.append(RandomCrop(size=512))
if self.config.horizontal_flip:
transforms.append(RandomHorizontalFlip(p=0.5))
if not transforms:
return lambda image: image
return Compose(transforms)

def load_huggingface_dataset(self) -> HuggingfaceDataset[Any]:
dataset_config = self.config
logger.info(f"Loading dataset from {dataset_config.hf_repo} revision {dataset_config.revision}")
dataset = load_hf_dataset(
path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split
)
return dataset

def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image:
return resize_image(image=image, min_size=min_size, max_size=max_size)

def process_caption(self, caption: str) -> tuple[str, bool]:
conditional_flag = random.random() > self.unconditional_sampling_probability
if conditional_flag:
return (caption, conditional_flag)
else:
return ("", conditional_flag)

def get_caption(self, index: int) -> str:
caption_key = self.config.caption_key or "caption"

caption = self.hf_dataset[index][caption_key]
if not isinstance(caption, str):
raise RuntimeError(
f"Dataset item at index [{index}] and caption_key [{caption_key}] does not contain a string caption"
)
return caption

def get_image(self, index: int) -> Image.Image:
if "image" in self.hf_dataset[index]:
return self.hf_dataset[index]["image"]
elif "url" in self.hf_dataset[index]:
url: str = self.hf_dataset[index]["url"]
filename: str = self.download_manager.download(url) # type: ignore
return Image.open(filename)
else:
raise RuntimeError(f"Dataset item at index [{index}] does not contain 'image' or 'url'")

@abstractmethod
def get_hf_item(self, index: int) -> Any:
return self.hf_dataset[index]

@abstractmethod
def __getitem__(self, index: int) -> BatchType:
...

@abstractmethod
def collate_fn(self, batch: list[BatchType]) -> BatchType:
...

def get_processed_text_embedding(self, index: int) -> Tensor:
caption = self.get_caption(index=index)
(processed_caption, _) = self.process_caption(caption=caption)
return self.text_encoder(processed_caption)

def get_processed_latents(self, index: int) -> tuple[Tensor, Image.Image]:
image = self.get_image(index=index)
resized_image = self.resize_image(
image=image,
min_size=self.config.resize_image_min_size,
max_size=self.config.resize_image_max_size,
)
processed_image = self.process_image(resized_image)
encoded_image = self.lda.encode_image(image=processed_image)
return (encoded_image, processed_image)

def __len__(self) -> int:
return len(self.hf_dataset)


class TextEmbeddingLatentsDataset(TextEmbeddingLatentsBaseDataset[TextEmbeddingLatentsBatch]):
def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch:
image = self.get_processed_image(index)
(caption, _) = self.process_caption(self.get_caption(index))

return [
TextImageDatasetItem(
text=caption,
image=image
)
]

def collate_fn(self, batch: list[TextEmbeddingLatentsBatch]) -> TextEmbeddingLatentsBatch:
return [item for sublist in batch for item in sublist]
def __getitem__(self, index: int) -> TextEmbeddingLatentsBatch:
clip_text_embedding = self.get_processed_text_embedding(index)
(latents, _) = self.get_processed_latents(index)
return TextEmbeddingLatentsBatch(text_embeddings=clip_text_embedding, latents=latents)

def collate_fn(self, batch: list[TextEmbeddingLatentsBatch]) -> TextEmbeddingLatentsBatch:
text_embeddings = cat(tensors=[item.text_embeddings for item in batch])
latents = cat(tensors=[item.latents for item in batch])
return TextEmbeddingLatentsBatch(text_embeddings=text_embeddings, latents=latents)
Loading

0 comments on commit d8ff953

Please sign in to comment.