Skip to content

Commit

Permalink
Update transformer_consisid.py
Browse files Browse the repository at this point in the history
  • Loading branch information
SHYuanBest authored Dec 22, 2024
1 parent 1ae5e49 commit 8817a46
Showing 1 changed file with 44 additions and 48 deletions.
92 changes: 44 additions & 48 deletions models/transformer_consisid.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,7 @@
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import PeftAdapterMixin
from diffusers.models.attention import Attention, FeedForward
from diffusers.models.attention_processor import (
AttentionProcessor,
CogVideoXAttnProcessor2_0,
FusedCogVideoXAttnProcessor2_0,
)
from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0
from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.models.modeling_utils import ModelMixin
Expand All @@ -40,6 +36,24 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


def reshape_tensor(x, heads):
"""
Reshapes the input tensor for multi-head attention.
Args:
x (torch.Tensor): The input tensor with shape (batch_size, length, width).
heads (int): The number of attention heads.
Returns:
torch.Tensor: The reshaped tensor, with shape (batch_size, heads, length, width).
"""
bs, length, width = x.shape
x = x.view(bs, length, heads, -1)
x = x.transpose(1, 2)
x = x.reshape(bs, heads, length, -1)
return x


def ConsisIDFeedForward(dim, mult=4):
"""
Creates a consistent ID feedforward block consisting of layer normalization, two linear layers, and a GELU
Expand All @@ -61,24 +75,6 @@ def ConsisIDFeedForward(dim, mult=4):
)


def reshape_tensor(x, heads):
"""
Reshapes the input tensor for multi-head attention.
Args:
x (torch.Tensor): The input tensor with shape (batch_size, length, width).
heads (int): The number of attention heads.
Returns:
torch.Tensor: The reshaped tensor, with shape (batch_size, heads, length, width).
"""
bs, length, width = x.shape
x = x.view(bs, length, heads, -1)
x = x.transpose(1, 2)
x = x.reshape(bs, heads, length, -1)
return x


class PerceiverAttention(nn.Module):
"""
Implements the Perceiver attention mechanism with multi-head attention.
Expand Down Expand Up @@ -158,6 +154,7 @@ def __init__(
num_queries=32,
output_dim=2048,
ff_mult=4,
num_scale=5,
):
"""
Initializes the LocalFacialExtractor class.
Expand All @@ -172,15 +169,17 @@ def __init__(
- num_queries (int): Number of query tokens for the latent representation.
- output_dim (int): Output dimension after projection.
- ff_mult (int): Multiplier for the feed-forward network hidden dimension.
- num_scale (int): The number of different scales visual feature.
"""
super().__init__()

# Storing identity token and query information
self.num_id_token = num_id_token
self.vit_dim = vit_dim
self.num_queries = num_queries
assert depth % 5 == 0
self.depth = depth // 5
assert depth % num_scale == 0
self.depth = depth // num_scale
self.num_scale = num_scale
scale = vit_dim**-0.5

# Learnable latent query embeddings
Expand All @@ -201,7 +200,7 @@ def __init__(
)

# Mappings for each of the 5 different ViT features
for i in range(5):
for i in range(num_scale):
setattr(
self,
f"mapping_{i}",
Expand Down Expand Up @@ -249,8 +248,8 @@ def forward(self, x, y):
# Concatenate identity tokens with the latent queries
latents = torch.cat((latents, x), dim=1)

# Process each of the 5 visual feature inputs
for i in range(5):
# Process each of the num_scale visual feature inputs
for i in range(self.num_scale):
vit_feature = getattr(self, f"mapping_{i}")(y[i])
ctx_feature = torch.cat((x, vit_feature), dim=1)

Expand Down Expand Up @@ -567,6 +566,9 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
The multiplication factor applied to the feed-forward network's hidden layer size in the Local Facial
Extractor (LFE). A higher value increases the model's capacity to learn more complex facial feature
transformations, but also increases the computation and memory requirements.
LFE_num_scale (`int`, optional, defaults to `5`):
The number of different scales visual feature. A higher value increases the model's capacity to learn more
complex facial feature transformations, but also increases the computation and memory requirements.
local_face_scale (`float`, defaults to `1.0`):
A scaling factor used to adjust the importance of local facial features in the model. This can influence
how strongly the model focuses on high frequency face-related content.
Expand Down Expand Up @@ -616,6 +618,7 @@ def __init__(
LFE_num_querie: int = 32,
LFE_output_dim: int = 2048,
LFE_ff_mult: int = 4,
LFE_num_scale: int = 5,
local_face_scale: float = 1.0,
):
super().__init__()
Expand Down Expand Up @@ -697,6 +700,7 @@ def __init__(
self.LFE_num_querie = LFE_num_querie
self.LFE_output_dim = LFE_output_dim
self.LFE_ff_mult = LFE_ff_mult
self.LFE_num_scale = LFE_num_scale
# cross configs
self.inner_dim = inner_dim
self.cross_attn_interval = cross_attn_interval
Expand Down Expand Up @@ -724,6 +728,7 @@ def _init_face_inputs(self):
num_queries=self.LFE_num_querie,
output_dim=self.LFE_output_dim,
ff_mult=self.LFE_ff_mult,
num_scale=self.LFE_num_scale,
)
self.local_facial_extractor.to(device, dtype=weight_dtype)
self.perceiver_cross_attention = nn.ModuleList(
Expand All @@ -738,19 +743,6 @@ def _init_face_inputs(self):
]
)

def save_face_modules(self, path: str):
save_dict = {
"local_facial_extractor": self.local_facial_extractor.state_dict(),
"perceiver_cross_attention": [ca.state_dict() for ca in self.perceiver_cross_attention],
}
torch.save(save_dict, path)

def load_face_modules(self, path: str):
checkpoint = torch.load(path, map_location=self.device)
self.local_facial_extractor.load_state_dict(checkpoint["local_facial_extractor"])
for ca, state_dict in zip(self.perceiver_cross_attention, checkpoint["perceiver_cross_attention"]):
ca.load_state_dict(state_dict)

@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
Expand Down Expand Up @@ -823,13 +815,6 @@ def forward(
id_vit_hidden: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
# fuse clip and insightface
if self.is_train_face:
assert id_cond is not None and id_vit_hidden is not None
valid_face_emb = self.local_facial_extractor(
id_cond, id_vit_hidden
) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048])

if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
Expand All @@ -845,6 +830,17 @@ def forward(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)

# fuse clip and insightface
if self.is_train_face:
assert id_cond is not None and id_vit_hidden is not None
id_cond = id_cond.to(device=hidden_states.device, dtype=hidden_states.dtype)
id_vit_hidden = [
tensor.to(device=hidden_states.device, dtype=hidden_states.dtype) for tensor in id_vit_hidden
]
valid_face_emb = self.local_facial_extractor(
id_cond, id_vit_hidden
) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048])

batch_size, num_frames, channels, height, width = hidden_states.shape

# 1. Time embedding
Expand Down

0 comments on commit 8817a46

Please sign in to comment.