Skip to content

Commit

Permalink
refactor CrossAttentionAdapter to work with context.
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed Jan 8, 2024
1 parent a08e04c commit bd044b1
Show file tree
Hide file tree
Showing 6 changed files with 92 additions and 203 deletions.
11 changes: 1 addition & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,16 +128,7 @@ with no_grad():
text="best quality, high quality", negative_text="monochrome, lowres, bad anatomy, worst quality, low quality"
)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(image_prompt))

negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)

clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
ip_adapter.set_clip_image_embedding(clip_image_embedding)
time_ids = sdxl.default_time_ids

condition = image_to_tensor(condition_image.convert("RGB"), device=sdxl.device, dtype=sdxl.dtype)
Expand Down
18 changes: 4 additions & 14 deletions scripts/conversion/convert_diffusers_ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,24 +133,14 @@ def main() -> None:
ip_adapter_weights: dict[str, torch.Tensor] = weights["ip_adapter"]
assert len(ip_adapter.sub_adapters) == len(ip_adapter_weights.keys()) // 2

for i, cross_attn in enumerate(ip_adapter.sub_adapters):
for i, _ in enumerate(ip_adapter.sub_adapters):
cross_attn_index = cross_attn_mapping[i]
k_ip = f"{cross_attn_index}.to_k_ip.weight"
v_ip = f"{cross_attn_index}.to_v_ip.weight"

# Ignore Wq, Wk, Wv and Proj (hence strict=False): at runtime, they will be part of the UNet original weights

names = [k for k, _ in cross_attn.named_parameters()]
assert len(names) == 2

cross_attn_state_dict: dict[str, Any] = {
names[0]: ip_adapter_weights[k_ip],
names[1]: ip_adapter_weights[v_ip],
}
cross_attn.load_state_dict(state_dict=cross_attn_state_dict, strict=False)

for k, v in cross_attn_state_dict.items():
state_dict[f"ip_adapter.{i:03d}.{k}"] = v
# the name of the key is not checked at runtime, so we keep the original name
state_dict[f"ip_adapter.{i:03d}.to_k_ip.weight"] = ip_adapter_weights[k_ip]
state_dict[f"ip_adapter.{i:03d}.to_v_ip.weight"] = ip_adapter_weights[v_ip]

if args.half:
state_dict = {key: value.half() for key, value in state_dict.items()}
Expand Down
184 changes: 82 additions & 102 deletions src/refiners/foundationals/latent_diffusion/image_prompt.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
import math
from enum import IntEnum
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Generic, TypeVar
from typing import TYPE_CHECKING, Any, Generic, TypeVar

from jaxtyping import Float
from PIL import Image
from torch import Tensor, cat, device as Device, dtype as DType, softmax, zeros_like
from torch import Tensor, cat, device as Device, dtype as DType, nn, softmax, zeros_like

import refiners.fluxion.layers as fl
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.adapters.lora import Lora
from refiners.fluxion.context import Contexts
from refiners.fluxion.layers.attentions import ScaledDotProductAttention
from refiners.fluxion.layers.chain import Distribute
from refiners.fluxion.utils import image_to_tensor, normalize
from refiners.foundationals.clip.image_encoder import CLIPImageEncoderH

Expand Down Expand Up @@ -236,120 +234,89 @@ def init_context(self) -> Contexts:
return {"perceiver_resampler": {"x": None}}


class _CrossAttnIndex(IntEnum):
TXT_CROSS_ATTN = 0 # text cross-attention
IMG_CROSS_ATTN = 1 # image cross-attention


class InjectionPoint(fl.Chain):
pass


class ImageCrossAttention(fl.Chain):
def __init__(self, text_cross_attention: fl.Attention, scale: float = 1.0) -> None:
self.scale = scale
super().__init__(
fl.Distribute(
fl.UseContext(context="ip_adapter", key="query_projection"),
fl.Chain(
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
fl.Linear(
in_features=text_cross_attention.key_embedding_dim,
out_features=text_cross_attention.inner_dim,
bias=text_cross_attention.use_bias,
device=text_cross_attention.device,
dtype=text_cross_attention.dtype,
),
),
fl.Chain(
fl.UseContext(context="ip_adapter", key="clip_image_embedding"),
fl.Linear(
in_features=text_cross_attention.key_embedding_dim,
out_features=text_cross_attention.inner_dim,
bias=text_cross_attention.use_bias,
device=text_cross_attention.device,
dtype=text_cross_attention.dtype,
),
),
),
ScaledDotProductAttention(
num_heads=text_cross_attention.num_heads, is_causal=text_cross_attention.is_causal
),
fl.Multiply(self.scale),
)


class SetQueryProjection(fl.Passthrough):
def __init__(self) -> None:
super().__init__(fl.GetArg(index=0), fl.SetContext(context="ip_adapter", key="query_projection"))


class CrossAttentionAdapter(fl.Chain, Adapter[fl.Attention]):
def __init__(
self,
target: fl.Attention,
text_sequence_length: int = 77,
image_sequence_length: int = 4,
scale: float = 1.0,
) -> None:
self.text_sequence_length = text_sequence_length
self.image_sequence_length = image_sequence_length
self.scale = scale

with self.setup_adapter(target):
super().__init__(
fl.Distribute(
# Note: the same query is used for image cross-attention as for text cross-attention
InjectionPoint(), # Wq
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wk
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device,
dtype=target.dtype,
), # Wk'
),
),
fl.Parallel(
fl.Chain(
fl.Slicing(dim=1, end=text_sequence_length),
InjectionPoint(), # Wv
),
fl.Chain(
fl.Slicing(dim=1, start=text_sequence_length),
fl.Linear(
in_features=self.target.key_embedding_dim,
out_features=self.target.inner_dim,
bias=self.target.use_bias,
device=target.device,
dtype=target.dtype,
), # Wv'
),
),
),
fl.Sum(
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.TXT_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
),
fl.Chain(
fl.Lambda(func=partial(self.select_qkv, index=_CrossAttnIndex.IMG_CROSS_ATTN)),
ScaledDotProductAttention(num_heads=target.num_heads, is_causal=target.is_causal),
fl.Lambda(func=self.scale_outputs),
),
target[:-1], # original text cross attention
ImageCrossAttention(text_cross_attention=target, scale=scale),
),
InjectionPoint(), # proj
target[-1], # projection
)
self.ensure_find(fl.Attention).insert_after_type(Distribute, SetQueryProjection())

def select_qkv(
self, query: Tensor, keys: tuple[Tensor, Tensor], values: tuple[Tensor, Tensor], index: _CrossAttnIndex
) -> tuple[Tensor, Tensor, Tensor]:
return (query, keys[index.value], values[index.value])

def scale_outputs(self, x: Tensor) -> Tensor:
return x * self.scale

def _predicate(self, k: type[fl.Module]) -> Callable[[fl.Module, fl.Chain], bool]:
def f(m: fl.Module, _: fl.Chain) -> bool:
if isinstance(m, Lora): # do not adapt LoRAs
raise StopIteration
return isinstance(m, k)

return f

def _target_linears(self) -> list[fl.Linear]:
return [m for m, _ in self.target.walk(self._predicate(fl.Linear)) if isinstance(m, fl.Linear)]

def inject(self: "CrossAttentionAdapter", parent: fl.Chain | None = None) -> "CrossAttentionAdapter":
linears = self._target_linears()
assert len(linears) == 4 # Wq, Wk, Wv and Proj

injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
@property
def image_cross_attention(self) -> ImageCrossAttention:
return self.ensure_find(ImageCrossAttention)

for linear, ip in zip(linears, injection_points):
ip.append(linear)
assert len(ip) == 1
@property
def image_key_projection(self) -> fl.Linear:
return self.image_cross_attention.Distribute[1].Linear

return super().inject(parent)
@property
def image_value_projection(self) -> fl.Linear:
return self.image_cross_attention.Distribute[2].Linear

def eject(self) -> None:
injection_points = list(self.layers(InjectionPoint))
assert len(injection_points) == 4
@property
def scale(self) -> float:
return self.image_cross_attention.scale

for ip in injection_points:
ip.pop()
assert len(ip) == 0
@scale.setter
def scale(self, value: float) -> None:
self.image_cross_attention.scale = value

super().eject()
def load_weights(self, key_tensor: Tensor, value_tensor: Tensor) -> None:
self.image_key_projection.weight = nn.Parameter(key_tensor)
self.image_value_projection.weight = nn.Parameter(value_tensor)
self.image_cross_attention.to(self.device, self.dtype)


class IPAdapter(Generic[T], fl.Chain, Adapter[T]):
Expand Down Expand Up @@ -377,7 +344,7 @@ def __init__(
self._image_proj = [image_proj]

self.sub_adapters = [
CrossAttentionAdapter(target=cross_attn, scale=scale, image_sequence_length=self.image_proj.num_tokens)
CrossAttentionAdapter(target=cross_attn, scale=scale)
for cross_attn in filter(lambda attn: type(attn) != fl.SelfAttention, target.layers(fl.Attention))
]

Expand All @@ -388,14 +355,15 @@ def __init__(
self.image_proj.load_state_dict(image_proj_state_dict)

for i, cross_attn in enumerate(self.sub_adapters):
cross_attn_state_dict: dict[str, Tensor] = {}
cross_attention_weights: list[Tensor] = []
for k, v in weights.items():
prefix = f"ip_adapter.{i:03d}."
if not k.startswith(prefix):
continue
cross_attn_state_dict[k.removeprefix(prefix)] = v
cross_attention_weights.append(v)

cross_attn.load_state_dict(state_dict=cross_attn_state_dict)
assert len(cross_attention_weights) == 2
cross_attn.load_weights(*cross_attention_weights)

@property
def clip_image_encoder(self) -> CLIPImageEncoderH:
Expand All @@ -420,10 +388,22 @@ def eject(self) -> None:
adapter.eject()
super().eject()

@property
def scale(self) -> float:
return self.sub_adapters[0].scale

@scale.setter
def scale(self, value: float) -> None:
for cross_attn in self.sub_adapters:
cross_attn.scale = value

def set_scale(self, scale: float) -> None:
for cross_attn in self.sub_adapters:
cross_attn.scale = scale

def set_clip_image_embedding(self, image_embedding: Tensor) -> None:
self.set_context("ip_adapter", {"clip_image_embedding": image_embedding})

# These should be concatenated to the CLIP text embedding before setting the UNet context
def compute_clip_image_embedding(self, image_prompt: Tensor) -> Tensor:
image_encoder = self.clip_image_encoder if not self.fine_grained else self.grid_image_encoder
Expand Down
53 changes: 5 additions & 48 deletions tests/e2e/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,16 +1168,7 @@ def test_diffusion_ip_adapter(

clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))

negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)

clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
ip_adapter.set_clip_image_embedding(clip_image_embedding)

sd15.set_num_inference_steps(n_steps)

Expand Down Expand Up @@ -1220,16 +1211,8 @@ def test_diffusion_sdxl_ip_adapter(
text=prompt, negative_text=negative_prompt
)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)

negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)

clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps)

Expand Down Expand Up @@ -1285,16 +1268,7 @@ def test_diffusion_ip_adapter_controlnet(

clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(input_image))

negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)

clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
ip_adapter.set_clip_image_embedding(clip_image_embedding)

depth_cn_condition = image_to_tensor(
depth_condition_image.convert("RGB"),
Expand Down Expand Up @@ -1343,16 +1317,7 @@ def test_diffusion_ip_adapter_plus(

clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(statue_image))

negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)

clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
ip_adapter.set_clip_image_embedding(clip_image_embedding)

sd15.set_num_inference_steps(n_steps)

Expand Down Expand Up @@ -1396,16 +1361,8 @@ def test_diffusion_sdxl_ip_adapter_plus(
text=prompt, negative_text=negative_prompt
)
clip_image_embedding = ip_adapter.compute_clip_image_embedding(ip_adapter.preprocess_image(woman_image))
ip_adapter.set_clip_image_embedding(clip_image_embedding)

negative_text_embedding, conditional_text_embedding = clip_text_embedding.chunk(2)
negative_image_embedding, conditional_image_embedding = clip_image_embedding.chunk(2)

clip_text_embedding = torch.cat(
(
torch.cat([negative_text_embedding, negative_image_embedding], dim=1),
torch.cat([conditional_text_embedding, conditional_image_embedding], dim=1),
)
)
time_ids = sdxl.default_time_ids
sdxl.set_num_inference_steps(n_steps)

Expand Down
Binary file modified tests/e2e/test_diffusion_ref/expected_ip_adapter_controlnet.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading

0 comments on commit bd044b1

Please sign in to comment.