diff --git a/models/transformer_consisid.py b/models/transformer_consisid.py index e4abc52..77540e7 100644 --- a/models/transformer_consisid.py +++ b/models/transformer_consisid.py @@ -24,11 +24,7 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import PeftAdapterMixin from diffusers.models.attention import Attention, FeedForward -from diffusers.models.attention_processor import ( - AttentionProcessor, - CogVideoXAttnProcessor2_0, - FusedCogVideoXAttnProcessor2_0, -) +from diffusers.models.attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0 from diffusers.models.embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps from diffusers.models.modeling_outputs import Transformer2DModelOutput from diffusers.models.modeling_utils import ModelMixin @@ -40,6 +36,24 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +def reshape_tensor(x, heads): + """ + Reshapes the input tensor for multi-head attention. + + Args: + x (torch.Tensor): The input tensor with shape (batch_size, length, width). + heads (int): The number of attention heads. + + Returns: + torch.Tensor: The reshaped tensor, with shape (batch_size, heads, length, width). + """ + bs, length, width = x.shape + x = x.view(bs, length, heads, -1) + x = x.transpose(1, 2) + x = x.reshape(bs, heads, length, -1) + return x + + def ConsisIDFeedForward(dim, mult=4): """ Creates a consistent ID feedforward block consisting of layer normalization, two linear layers, and a GELU @@ -61,24 +75,6 @@ def ConsisIDFeedForward(dim, mult=4): ) -def reshape_tensor(x, heads): - """ - Reshapes the input tensor for multi-head attention. - - Args: - x (torch.Tensor): The input tensor with shape (batch_size, length, width). - heads (int): The number of attention heads. - - Returns: - torch.Tensor: The reshaped tensor, with shape (batch_size, heads, length, width). - """ - bs, length, width = x.shape - x = x.view(bs, length, heads, -1) - x = x.transpose(1, 2) - x = x.reshape(bs, heads, length, -1) - return x - - class PerceiverAttention(nn.Module): """ Implements the Perceiver attention mechanism with multi-head attention. @@ -158,6 +154,7 @@ def __init__( num_queries=32, output_dim=2048, ff_mult=4, + num_scale=5, ): """ Initializes the LocalFacialExtractor class. @@ -172,6 +169,7 @@ def __init__( - num_queries (int): Number of query tokens for the latent representation. - output_dim (int): Output dimension after projection. - ff_mult (int): Multiplier for the feed-forward network hidden dimension. + - num_scale (int): The number of different scales visual feature. """ super().__init__() @@ -179,8 +177,9 @@ def __init__( self.num_id_token = num_id_token self.vit_dim = vit_dim self.num_queries = num_queries - assert depth % 5 == 0 - self.depth = depth // 5 + assert depth % num_scale == 0 + self.depth = depth // num_scale + self.num_scale = num_scale scale = vit_dim**-0.5 # Learnable latent query embeddings @@ -201,7 +200,7 @@ def __init__( ) # Mappings for each of the 5 different ViT features - for i in range(5): + for i in range(num_scale): setattr( self, f"mapping_{i}", @@ -249,8 +248,8 @@ def forward(self, x, y): # Concatenate identity tokens with the latent queries latents = torch.cat((latents, x), dim=1) - # Process each of the 5 visual feature inputs - for i in range(5): + # Process each of the num_scale visual feature inputs + for i in range(self.num_scale): vit_feature = getattr(self, f"mapping_{i}")(y[i]) ctx_feature = torch.cat((x, vit_feature), dim=1) @@ -567,6 +566,9 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin): The multiplication factor applied to the feed-forward network's hidden layer size in the Local Facial Extractor (LFE). A higher value increases the model's capacity to learn more complex facial feature transformations, but also increases the computation and memory requirements. + LFE_num_scale (`int`, optional, defaults to `5`): + The number of different scales visual feature. A higher value increases the model's capacity to learn more + complex facial feature transformations, but also increases the computation and memory requirements. local_face_scale (`float`, defaults to `1.0`): A scaling factor used to adjust the importance of local facial features in the model. This can influence how strongly the model focuses on high frequency face-related content. @@ -616,6 +618,7 @@ def __init__( LFE_num_querie: int = 32, LFE_output_dim: int = 2048, LFE_ff_mult: int = 4, + LFE_num_scale: int = 5, local_face_scale: float = 1.0, ): super().__init__() @@ -697,6 +700,7 @@ def __init__( self.LFE_num_querie = LFE_num_querie self.LFE_output_dim = LFE_output_dim self.LFE_ff_mult = LFE_ff_mult + self.LFE_num_scale = LFE_num_scale # cross configs self.inner_dim = inner_dim self.cross_attn_interval = cross_attn_interval @@ -724,6 +728,7 @@ def _init_face_inputs(self): num_queries=self.LFE_num_querie, output_dim=self.LFE_output_dim, ff_mult=self.LFE_ff_mult, + num_scale=self.LFE_num_scale, ) self.local_facial_extractor.to(device, dtype=weight_dtype) self.perceiver_cross_attention = nn.ModuleList( @@ -738,19 +743,6 @@ def _init_face_inputs(self): ] ) - def save_face_modules(self, path: str): - save_dict = { - "local_facial_extractor": self.local_facial_extractor.state_dict(), - "perceiver_cross_attention": [ca.state_dict() for ca in self.perceiver_cross_attention], - } - torch.save(save_dict, path) - - def load_face_modules(self, path: str): - checkpoint = torch.load(path, map_location=self.device) - self.local_facial_extractor.load_state_dict(checkpoint["local_facial_extractor"]) - for ca, state_dict in zip(self.perceiver_cross_attention, checkpoint["perceiver_cross_attention"]): - ca.load_state_dict(state_dict) - @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self) -> Dict[str, AttentionProcessor]: @@ -823,13 +815,6 @@ def forward( id_vit_hidden: Optional[torch.Tensor] = None, return_dict: bool = True, ): - # fuse clip and insightface - if self.is_train_face: - assert id_cond is not None and id_vit_hidden is not None - valid_face_emb = self.local_facial_extractor( - id_cond, id_vit_hidden - ) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048]) - if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -845,6 +830,17 @@ def forward( "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." ) + # fuse clip and insightface + if self.is_train_face: + assert id_cond is not None and id_vit_hidden is not None + id_cond = id_cond.to(device=hidden_states.device, dtype=hidden_states.dtype) + id_vit_hidden = [ + tensor.to(device=hidden_states.device, dtype=hidden_states.dtype) for tensor in id_vit_hidden + ] + valid_face_emb = self.local_facial_extractor( + id_cond, id_vit_hidden + ) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048]) + batch_size, num_frames, channels, height, width = hidden_states.shape # 1. Time embedding