Skip to content

Commit

Permalink
fix: evaluating db samples + zero init adapter + fix eos
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Jan 19, 2024
1 parent 0cce02b commit d3f6f4c
Show file tree
Hide file tree
Showing 6 changed files with 175 additions and 57 deletions.
18 changes: 14 additions & 4 deletions src/refiners/fluxion/adapters/color_palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def compute_color_palette_embedding(
) -> Float[Tensor, "cfg_batch n_colors 3"]:
tensor_x = tensor(x, device=self.device, dtype=self.dtype)
conditional_embedding = self(tensor_x)
if x == negative_color_palette:
if tensor_x == negative_color_palette:
return torch.cat(tensors=(conditional_embedding, conditional_embedding), dim=0)

if negative_color_palette is None:
Expand All @@ -84,14 +84,17 @@ def end_of_sequence_token(
.repeat(x.shape[0], 1, 1)
)

return torch.cat((x, end_of_sequence_tensor), dim=1)
with_eos = torch.cat((x, end_of_sequence_tensor), dim=1)
return with_eos[:, : self.max_colors, :]

def zero_right_padding(
self, x: Float[Tensor, "*batch colors_with_end embedding_dim"]
) -> Float[Tensor, "*batch max_colors model_dim"]:
# Zero padding for the right side
padding_width = (self.max_colors - x.shape[1] % self.max_colors) % self.max_colors
return pad(x, (0, 0, 0, padding_width))

result = pad(x, (0, 0, 0, padding_width))
return result


class SD1ColorPaletteAdapter(fl.Chain, Adapter[TSDNet]):
Expand All @@ -115,7 +118,14 @@ def __init__(
CrossAttentionAdapter(target=cross_attn, scale=scale)
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
]


@property
def weights(self) -> List[Tensor]:
weights = []
for adapter in self.sub_adapters:
weights += adapter.weights
return weights

def inject(self, parent: fl.Chain | None = None) -> "SD1ColorPaletteAdapter[Any]":
for adapter in self.sub_adapters:
adapter.inject()
Expand Down
7 changes: 5 additions & 2 deletions src/refiners/foundationals/latent_diffusion/image_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,7 @@ def scale(self) -> float:
def scale(self, value: float) -> None:
self._scale = value
self.ensure_find(fl.Multiply).scale = value



class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
def __init__(
self,
Expand Down Expand Up @@ -326,6 +325,10 @@ def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None:
self.image_key_projection.weight = nn.Parameter(key_tensor)
self.image_value_projection.weight = nn.Parameter(value_tensor)
self.image_cross_attention.to(self.device, self.dtype)

@property
def weights(self) -> list[Tensor]:
return [self.image_key_projection.weight, self.image_value_projection.weight]


class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
Expand Down
187 changes: 141 additions & 46 deletions src/refiners/training_utils/color_palette.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
from dataclasses import dataclass
from functools import cached_property
from random import randint
from typing import Any
from typing import Any, List, Tuple

from loguru import logger
from PIL import Image
from pydantic import BaseModel
from torch import Tensor, cat, randn, tensor
from torch.utils.data import Dataset

from sklearn.neighbors import NearestNeighbors
from torch.nn import init
import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.color_palette import ColorPaletteEncoder, SD1ColorPaletteAdapter
from refiners.fluxion.utils import save_to_safetensors
Expand All @@ -27,7 +28,7 @@
TextEmbeddingLatentsBatch,
)
from refiners.training_utils.wandb import WandbLoggable

import numpy as np

class ColorPaletteConfig(BaseModel):
model_dim: int
Expand All @@ -38,7 +39,7 @@ class ColorPaletteConfig(BaseModel):

class ColorPalettePromptConfig(BaseModel):
text: str
color_palette: list[list[float]]
color_palette: List[List[int]]


class ColorPaletteDatasetConfig(HuggingfaceDatasetConfig):
Expand All @@ -47,7 +48,7 @@ class ColorPaletteDatasetConfig(HuggingfaceDatasetConfig):

class TestColorPaletteConfig(TestDiffusionBaseConfig):
prompts: list[ColorPalettePromptConfig]

num_palette_sample: int = 0

@dataclass
class TextEmbeddingColorPaletteLatentsBatch(TextEmbeddingLatentsBatch):
Expand All @@ -57,14 +58,14 @@ class TextEmbeddingColorPaletteLatentsBatch(TextEmbeddingLatentsBatch):


class CaptionPaletteImage(CaptionImage):
palette_1: list[list[float]]
palette_2: list[list[float]]
palette_3: list[list[float]]
palette_4: list[list[float]]
palette_5: list[list[float]]
palette_6: list[list[float]]
palette_7: list[list[float]]
palette_8: list[list[float]]
palette_1: List[List[int]]
palette_2: List[List[int]]
palette_3: List[List[int]]
palette_4: List[List[int]]
palette_5: List[List[int]]
palette_6: List[List[int]]
palette_7: List[List[int]]
palette_8: List[List[int]]


class ColorPaletteDataset(TextEmbeddingLatentsBaseDataset[TextEmbeddingColorPaletteLatentsBatch]):
Expand All @@ -78,14 +79,14 @@ def __init__(
logger.info(f"Trigger phrase: {self.trigger_phrase}")
self.color_palette_encoder = trainer.color_palette_encoder

def get_color_palette(self, index: int) -> Tensor:
def get_color_palette(self, index: int) -> List[List[int]]:
# Randomly pick a palette between 1 and 8
palette_index = randint(1, 8)
return tensor([self.dataset[index][f"palette_{palette_index}"]])
return self.dataset[index][f"palette_{palette_index}"]

def __getitem__(self, index: int) -> TextEmbeddingColorPaletteLatentsBatch:
caption = self.get_caption(index=index, caption_key=self.config.dataset.caption_key)
color_palette = self.get_color_palette(index=index)
color_palette = tensor([self.get_color_palette(index=index)])
image = self.get_image(index=index)
resized_image = self.resize_image(
image=image,
Expand Down Expand Up @@ -136,7 +137,7 @@ def color_palette_encoder(self) -> ColorPaletteEncoder:
)

@cached_property
def color_palette_adapter(self) -> ColorPaletteEncoder:
def color_palette_adapter(self) -> SD1ColorPaletteAdapter[Any]:
adapter = SD1ColorPaletteAdapter(target=self.unet, color_palette_encoder=self.color_palette_encoder)

return adapter
Expand Down Expand Up @@ -188,43 +189,137 @@ def sd(self) -> StableDiffusion_1:
self.sharding_manager.add_device_hooks(scheduler, scheduler.device)

return StableDiffusion_1(unet=self.unet, lda=self.lda, clip_text_encoder=self.text_encoder, scheduler=scheduler)

def compute_evaluation(self) -> None:
def compute_prompt_evaluation(self, prompt: ColorPalettePromptConfig, num_images_per_prompt: int, img_size: int = 512) -> Image.Image:
sd = self.sd
prompts = self.config.test_color_palette.prompts
num_images_per_prompt = self.config.test_color_palette.num_images_per_prompt
palette_img_size = img_size//self.config.color_palette.max_colors
canvas_image: Image.Image = Image.new(mode="RGB", size=(img_size * num_images_per_prompt, img_size+palette_img_size))
for i in range(num_images_per_prompt):
logger.info(
f"Generating image {i+1}/{num_images_per_prompt} for prompt: {prompt.text} and palette {prompt.color_palette}"
)
x = randn(1, 4, 64, 64)

cfg_clip_text_embedding = sd.compute_clip_text_embedding(text=prompt.text).to(device=self.device)
cfg_color_palette_embedding = self.color_palette_encoder.compute_color_palette_embedding(
[prompt.color_palette]
)

self.color_palette_adapter.set_color_palette_embedding(cfg_color_palette_embedding)

for step in sd.steps:
x = sd(
x,
step=step,
clip_text_embedding=cfg_clip_text_embedding,
)
canvas_image.paste(sd.lda.decode_latents(x=x), box=(img_size * i, 0))
for index, palette in enumerate(prompt.color_palette):
color_box = Image.fromarray(np.full((palette_img_size, palette_img_size, 3), palette, dtype=np.uint8)) # type: ignore
canvas_image.paste(color_box,box=(img_size * i+palette_img_size*index, img_size))

return canvas_image

def compute_edge_case_evaluation(self, prompts: List[ColorPalettePromptConfig], num_images_per_prompt: int) -> None:
images: dict[str, WandbLoggable] = {}
for prompt in prompts:
canvas_image: Image.Image = Image.new(mode="RGB", size=(512, 512 * num_images_per_prompt))
image_name = prompt.text + str(prompt.color_palette)
for i in range(num_images_per_prompt):
logger.info(
f"Generating image {i+1}/{num_images_per_prompt} for prompt: {prompt.text} and palette {prompt.color_palette}"
)
x = randn(1, 4, 64, 64)

cfg_clip_text_embedding = sd.compute_clip_text_embedding(text=prompt.text).to(device=self.device)
cfg_color_palette_embedding = self.color_palette_encoder.compute_color_palette_embedding(
[prompt.color_palette]
)
image_name = f"edge_case/{prompt.text.replace(' ', '_')}"
images[image_name] = self.compute_prompt_evaluation(prompt, num_images_per_prompt)

self.color_palette_adapter.set_color_palette_embedding(cfg_color_palette_embedding)

for step in sd.steps:
x = sd(
x,
step=step,
clip_text_embedding=cfg_clip_text_embedding,
)
canvas_image.paste(sd.lda.decode_latents(x=x), box=(0, 512 * i))

images[image_name] = canvas_image
self.log(data=images)


@cached_property
def eval_indices(self) -> list[int]:
l = self.dataset_length
size = self.config.test_color_palette.num_palette_sample
indices = list(np.random.choice(l, size=size, replace=False))
return list(map(int, indices))

def image_palette_metrics(self, image: Image.Image, palette: List[List[int]], img_size : Tuple[int, int]=(256,256), sampling_size :int = 1000):
resized_img = image.resize(img_size)
all_points : List[List[int]] = np.array(resized_img.getdata(), dtype=np.float64)
choices = np.random.choice(len(all_points), sampling_size)
points = all_points[choices]

num = len(palette)

centroids = np.stack(palette)

nn = NearestNeighbors(n_neighbors=num)
nn.fit(centroids)

indices = nn.kneighbors(points, return_distance=False)
indices = indices[:, 0]

counts = np.bincount(indices)
counts = np.pad(counts, (0, num - len(counts)), 'constant')
ordered_centroids = np.argsort(counts)[::-1]
y_true_ranking = list(range(num, 0, -1))
if num > 1:
ndcg = ndcg_score([y_true_ranking], [counts])
else:
ndcg = 1.0

def calculate_std_dev(clusters, points, centroids):
distances_list = []

for i in range(len(centroids)):
cluster_points = points[np.where(clusters == i)]
distances = [distance(p, centroids[i]) for p in cluster_points]
distances_list.extend(distances)

return np.std(distances_list)

std_dev = calculate_std_dev(indices, points, centroids)

self.log({
"palette-img/ndcg": ndcg,
"palette-img/std_dev": std_dev
})

def compute_db_samples_evaluation(self, num_images_per_prompt: int, img_size: int = 512) -> None:
sd = self.sd
images: dict[str, WandbLoggable] = {}

palette_img_size = img_size//self.config.color_palette.max_colors

for eval_index, db_index in enumerate(self.eval_indices):

palette = self.dataset.get_color_palette(db_index)
caption = self.dataset.get_caption(db_index, self.config.dataset.caption_key)

prompt = ColorPalettePromptConfig(text=caption, color_palette=palette)
generated_image = self.compute_prompt_evaluation(prompt, 1, img_size=img_size)

image = self.dataset.get_image(db_index)
resized_image = image.resize((img_size, img_size))
join_canvas_image: Image.Image = Image.new(mode="RGB", size=(img_size, img_size*2+palette_img_size))
join_canvas_image.paste(generated_image, box=(0, 0))
join_canvas_image.paste(resized_image, box=(0, img_size+palette_img_size))
image_name = f"db_samples/{db_index}_{caption}"

images[image_name] = join_canvas_image

self.log(data=images)

def compute_evaluation(self) -> None:
prompts = self.config.test_color_palette.prompts
num_images_per_prompt = self.config.test_color_palette.num_images_per_prompt
if len(prompts) > 0:
self.compute_edge_case_evaluation(prompts, num_images_per_prompt)

num_palette_sample = self.config.test_color_palette.num_palette_sample
if num_palette_sample > 0:
self.compute_db_samples_evaluation(num_images_per_prompt)

class LoadColorPalette(Callback[ColorPaletteLatentDiffusionTrainer]):
def on_train_begin(self, trainer: ColorPaletteLatentDiffusionTrainer) -> None:
trainer.color_palette_adapter.inject()
adapter = trainer.color_palette_adapter
weights = adapter.weights
for weight in weights:
init.zeros_(weight)

adapter.inject()


class SaveColorPalette(Callback[ColorPaletteLatentDiffusionTrainer]):
Expand Down
3 changes: 2 additions & 1 deletion src/refiners/training_utils/huggingface_datasets.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Generic, Protocol, TypeVar, cast

from datasets import DownloadManager, VerificationMode, load_dataset as _load_dataset # type: ignore
from datasets import VerificationMode, load_dataset as _load_dataset # type: ignore
from pydantic import BaseModel # type: ignore

__all__ = ["load_hf_dataset", "HuggingfaceDataset"]
Expand Down Expand Up @@ -35,3 +35,4 @@ class HuggingfaceDatasetConfig(BaseModel):
resize_image_min_size: int = 512
resize_image_max_size: int = 576
caption_key: str = "caption"
n_samples: int | None = None
5 changes: 4 additions & 1 deletion src/refiners/training_utils/latent_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,12 @@ def build_image_processor(self) -> Callable[[Image.Image], Image.Image]:
def load_huggingface_dataset(self) -> HuggingfaceDataset[Any]:
dataset_config = self.config.dataset
logger.info(f"Loading dataset from {dataset_config.hf_repo} revision {dataset_config.revision}")
return load_hf_dataset(
dataset = load_hf_dataset(
path=dataset_config.hf_repo, revision=dataset_config.revision, split=dataset_config.split
)
if dataset_config.n_samples is not None:
dataset = dataset[:dataset_config.n_samples]
return dataset

def resize_image(self, image: Image.Image, min_size: int = 512, max_size: int = 576) -> Image.Image:
return resize_image(image=image, min_size=min_size, max_size=max_size)
Expand Down
12 changes: 9 additions & 3 deletions tests/adapters/test_color_palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def test_color_palette_encoder() -> None:
assert encoded.shape == torch.Size([batch_size, max_colors, cross_attn_2d.context_embedding_dim])

# test with 0-colors palette
encodeded_empty = color_palette_encoder(torch.zeros(batch_size, 0, 3))
assert isinstance(encodeded_empty.shape, torch.Size)
assert encodeded_empty.shape == torch.Size([batch_size, max_colors, cross_attn_2d.context_embedding_dim])
encoded_empty = color_palette_encoder(torch.zeros(batch_size, 0, 3))
assert isinstance(encoded_empty.shape, torch.Size)
assert encoded_empty.shape == torch.Size([batch_size, max_colors, cross_attn_2d.context_embedding_dim])

# test with 10-colors palette
encoded_full = color_palette_encoder(torch.zeros(batch_size, max_colors, 3))
assert isinstance(encoded_full.shape, torch.Size)
assert encoded_full.shape == torch.Size([batch_size, max_colors, cross_attn_2d.context_embedding_dim])

0 comments on commit d3f6f4c

Please sign in to comment.