Skip to content

Commit

Permalink
feat: clip-inspired color encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
piercus committed Jan 22, 2024
1 parent e403c93 commit b1eaa66
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 81 deletions.
138 changes: 88 additions & 50 deletions src/refiners/fluxion/adapters/color_palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,42 +13,107 @@
from refiners.foundationals.latent_diffusion.range_adapter import compute_sinusoidal_embedding
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.unet import SDXLUNet
from refiners.foundationals.clip.common import FeedForward, PositionalEncoder
from refiners.foundationals.clip.text_encoder import TransformerLayer

TSDNet = TypeVar("TSDNet", bound="SD1UNet | SDXLUNet")
TColorPaletteAdapter = TypeVar("TColorPaletteAdapter", bound="SD1ColorPaletteAdapter[Any]") # Self (see PEP 673)


class ColorPaletteEncoder(fl.Chain):

class ColorsTokenizer(fl.Module):
def __init__(
self,
max_colors: int = 8,
) -> None:
super().__init__()
self.max_colors = max_colors

def forward(self, colors):
colors = self.add_channel(colors)
colors = self.zero_right_padding(colors)
return colors

def add_channel(
self, x: Float[Tensor, "*batch colors 3"]
) -> Float[Tensor, "*batch colors_with_end 4"]:
return torch.cat((x, torch.ones(x.shape[0], x.shape[1], 1, dtype=x.dtype, device=x.device)), dim=2)

def zero_right_padding(
self, x: Float[Tensor, "*batch colors_with_end embedding_dim"]
) -> Float[Tensor, "*batch max_colors feedforward_dim"]:
# Zero padding for the right side
padding_width = (self.max_colors - x.shape[1] % self.max_colors) % self.max_colors
if x.shape[1] == 0:
padding_width = self.max_colors
result = pad(x, (0, 0, 0, padding_width))
return result


class ColorEncoder(fl.Chain):
def __init__(
self,
embedding_dim: int,
max_colors: int,
model_dim: int = 256,
sinuosidal_embedding_dim: int = 32,
device: Device | str | None = None,
dtype: DType = float32,
context_key: str = "color_palette_embedding",
dtype: DType | None = None,
) -> None:
super().__init__(
fl.Linear(in_features=4, out_features=embedding_dim, device=device, dtype=dtype),
)

class ColorPaletteEncoder(fl.Chain):
def __init__(
self,
embedding_dim: int = 768,
max_colors: int = 8,
num_layers: int = 3,
num_attention_heads: int = 6,
feedforward_dim: int = 512,
layer_norm_eps: float = 1e-5,
use_quick_gelu: bool = False,
device: Device | str | None = None,
dtype: DType | None = None,
) -> None:
self.embedding_dim = embedding_dim
self.model_dim = model_dim
self.max_colors = max_colors

self.num_layers = num_layers
self.num_attention_heads = num_attention_heads
self.feedforward_dim = feedforward_dim
self.layer_norm_eps = layer_norm_eps
self.use_quick_gelu = use_quick_gelu
super().__init__(
fl.Linear(in_features=3, out_features=model_dim, device=device, dtype=dtype),
fl.Residual(fl.Lambda(self.compute_sinuosoidal_embedding)),
fl.Linear(in_features=model_dim, out_features=model_dim, device=device, dtype=dtype),
fl.GeLU(),
fl.Linear(in_features=model_dim, out_features=embedding_dim, device=device, dtype=dtype),
fl.Lambda(self.end_of_sequence_token),
fl.Lambda(self.zero_right_padding),
ColorsTokenizer(
max_colors=max_colors
),
fl.Sum(
ColorEncoder(
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
PositionalEncoder(
max_sequence_length=max_colors,
embedding_dim=embedding_dim,
device=device,
dtype=dtype,
),
),
*(
TransformerLayer(
embedding_dim=embedding_dim,
num_attention_heads=num_attention_heads,
feedforward_dim=feedforward_dim,
layer_norm_eps=layer_norm_eps,
device=device,
dtype=dtype,
)
for _ in range(num_layers)
),
fl.LayerNorm(normalized_shape=embedding_dim, eps=layer_norm_eps, device=device, dtype=dtype),
)

def compute_sinuosoidal_embedding(
self, x: Int[Tensor, "*batch n_colors 3"]
) -> Float[Tensor, "*batch n_colors 3 model_dim"]:
range = arange(start=0, end=x.shape[1], dtype=self.dtype, device=x.device).unsqueeze(1)
embedding = compute_sinusoidal_embedding(range, embedding_dim=self.model_dim)
return embedding.squeeze(1).unsqueeze(0).repeat(x.shape[0], 1, 1).to(dtype=self.dtype)
if use_quick_gelu:
for gelu, parent in self.walk(predicate=lambda m, _: isinstance(m, fl.GeLU)):
parent.replace(old_module=gelu, new_module=fl.ApproximateGeLU())

def compute_color_palette_embedding(
self,
Expand All @@ -67,33 +132,6 @@ def compute_color_palette_embedding(
negative_embedding = self(negative_color_palette)
return torch.cat(tensors=(negative_embedding, conditional_embedding), dim=0)

def end_of_sequence_token(
self, x: Float[Tensor, "*batch colors embedding_dim"]
) -> Float[Tensor, "*batch colors_with_end embedding_dim"]:
# Build a tensor of size (batch_size, 1, embedding_dim) with the end of string token
# end _of string token is a dim_model vector with 1 in the last position
numpy_end_of_sequence_token = np.zeros((1, self.embedding_dim))
numpy_end_of_sequence_token[-1] = 1

end_of_sequence_tensor: Float[Tensor, "*batch 1 embedding_dim"] = (
tensor(numpy_end_of_sequence_token, device=x.device, dtype=x.dtype)
.reshape(1, 1, -1)
.repeat(x.shape[0], 1, 1)
)

with_eos = torch.cat((x, end_of_sequence_tensor), dim=1)
return with_eos[:, : self.max_colors, :]

def zero_right_padding(
self, x: Float[Tensor, "*batch colors_with_end embedding_dim"]
) -> Float[Tensor, "*batch max_colors model_dim"]:
# Zero padding for the right side
padding_width = (self.max_colors - x.shape[1] % self.max_colors) % self.max_colors

result = pad(x, (0, 0, 0, padding_width))
return result


class SD1ColorPaletteAdapter(fl.Chain, Adapter[TSDNet]):
# Prevent PyTorch module registration
_color_palette_encoder: list[ColorPaletteEncoder]
Expand All @@ -114,7 +152,7 @@ def __init__(
self.sub_adapters: list[CrossAttentionAdapter] = [
CrossAttentionAdapter(target=cross_attn, scale=scale)
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
]
]

@property
def weights(self) -> List[Tensor]:
Expand Down
7 changes: 4 additions & 3 deletions src/refiners/training_utils/color_palette.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import numpy as np

class ColorPaletteConfig(BaseModel):
model_dim: int
feedforward_dim: int
trigger_phrase: str = ""
use_only_trigger_probability: float = 0.0
max_colors: int
Expand Down Expand Up @@ -140,12 +140,13 @@ def color_palette_encoder(self) -> ColorPaletteEncoder:
# TO FIX : connect this to unet cross attention embedding dim
EMBEDDING_DIM = 768

return ColorPaletteEncoder(
encoder = ColorPaletteEncoder(
max_colors=self.config.color_palette.max_colors,
embedding_dim=EMBEDDING_DIM,
model_dim=self.config.color_palette.model_dim,
feedforward_dim=self.config.color_palette.feedforward_dim,
device=self.device,
)
return encoder

@cached_property
def color_palette_adapter(self) -> SD1ColorPaletteAdapter[Any]:
Expand Down
39 changes: 11 additions & 28 deletions tests/adapters/test_color_palette.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,21 @@
import torch

from refiners.fluxion.adapters.color_palette import ColorPaletteEncoder
from refiners.fluxion.adapters.color_palette import ColorPaletteEncoder, ColorsTokenizer
from refiners.foundationals.latent_diffusion.cross_attention import CrossAttentionBlock2d
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1UNet


def test_color_palette_encoder() -> None:
in_channels = 22
def test_colors_tokenizer() -> None:
max_colors = 10
unet = SD1UNet(in_channels)
cross_attn_2d = unet.ensure_find(CrossAttentionBlock2d)

color_palette_encoder = ColorPaletteEncoder(
model_dim=in_channels, max_colors=max_colors, embedding_dim=cross_attn_2d.context_embedding_dim
).to(device="cuda:0")

tokenizer = ColorsTokenizer(max_colors=max_colors)

batch_size = 5
color_size = 4

palettes = torch.zeros(batch_size, color_size, 3)

encoded = color_palette_encoder(palettes)

assert isinstance(encoded.shape, torch.Size)
assert encoded.shape == torch.Size([batch_size, max_colors, cross_attn_2d.context_embedding_dim])
colors = torch.zeros(batch_size, 0, 3)

# test with 0-colors palette
encoded_empty = color_palette_encoder(torch.zeros(batch_size, 0, 3))
assert isinstance(encoded_empty.shape, torch.Size)
assert encoded_empty.shape == torch.Size([batch_size, max_colors, cross_attn_2d.context_embedding_dim])

# test with 10-colors palette
encoded_full = color_palette_encoder(torch.zeros(batch_size, max_colors, 3))
assert isinstance(encoded_full.shape, torch.Size)
assert encoded_full.shape == torch.Size([batch_size, max_colors, cross_attn_2d.context_embedding_dim])
color_tokens = tokenizer(colors)
assert isinstance(color_tokens.shape, torch.Size)
assert color_tokens.shape == torch.Size([batch_size, max_colors, 4])


def test_color_palette_encoder() -> None:
Expand All @@ -44,7 +26,9 @@ def test_color_palette_encoder() -> None:
cross_attn_2d = unet.ensure_find(CrossAttentionBlock2d)

color_palette_encoder = ColorPaletteEncoder(
model_dim=in_channels, max_colors=max_colors, embedding_dim=cross_attn_2d.context_embedding_dim
feedforward_dim=in_channels,
max_colors=max_colors,
embedding_dim=cross_attn_2d.context_embedding_dim
).to(device=device)

batch_size = 5
Expand All @@ -71,4 +55,3 @@ def test_color_palette_encoder() -> None:
palette = torch.zeros(batch_size, max_colors, 3, dtype=torch.float16, device=device)
encoded_half = color_palette_encoder(palette)
assert encoded_half.dtype == torch.float16

0 comments on commit b1eaa66

Please sign in to comment.