From 23cecd31f06ad8ea966e0b498c8fdad2c757ad92 Mon Sep 17 00:00:00 2001 From: yoontkim Date: Tue, 25 Jun 2024 21:35:53 +0900 Subject: [PATCH] [refactor]: add type checking to sample image and video functions --- .../models/diffusion/latte/modeling_latte.py | 22 +++++----- opensora/sample/pipeline_videogen.py | 41 +++++++++---------- opensora/sample/sample_t2v.py | 2 +- 3 files changed, 32 insertions(+), 33 deletions(-) diff --git a/opensora/models/diffusion/latte/modeling_latte.py b/opensora/models/diffusion/latte/modeling_latte.py index 7a5706b20..d7e952e9b 100644 --- a/opensora/models/diffusion/latte/modeling_latte.py +++ b/opensora/models/diffusion/latte/modeling_latte.py @@ -5,7 +5,7 @@ from dataclasses import dataclass from einops import rearrange, repeat -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Optional, Tuple, Union from diffusers.models import Transformer2DModel from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, ImagePositionalEmbeddings @@ -86,7 +86,7 @@ def __init__( rope_scaling_type: str = 'linear', compress_kv_factor: int = 1, interpolation_scale_1d: float = None, - ): + ) -> None: super().__init__() self.use_linear_projection = use_linear_projection self.num_attention_heads = num_attention_heads @@ -249,15 +249,15 @@ def __init__( self.gradient_checkpointing = False - def _set_gradient_checkpointing(self, module, value=False): + def _set_gradient_checkpointing(self, value: bool=False) -> None: self.gradient_checkpointing = value - def make_position(self, b, t, use_image_num, h, w, device): + def make_position(self, b: int, t: int, use_image_num: int, h: int, w: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: pos_hw = self.position_getter_2d(b*(t+use_image_num), h, w, device) # fake_b = b*(t+use_image_num) pos_t = self.position_getter_1d(b*h*w, t, device) # fake_b = b*h*w return pos_hw, pos_t - def make_attn_mask(self, attention_mask, frame, dtype): + def make_attn_mask(self, attention_mask: torch.Tensor, frame: int, dtype: torch.dtype) -> torch.Tensor: attention_mask = rearrange(attention_mask, 'b t h w -> (b t) 1 (h w)') # assume that mask is expressed as: # (1 = keep, 0 = discard) @@ -267,7 +267,7 @@ def make_attn_mask(self, attention_mask, frame, dtype): attention_mask = attention_mask.to(self.dtype) return attention_mask - def vae_to_diff_mask(self, attention_mask, use_image_num): + def vae_to_diff_mask(self, attention_mask: torch.Tensor, use_image_num: int) -> torch.Tensor: dtype = attention_mask.dtype # b, t+use_image_num, h, w, assume t as channel # this version do not use 3d patch embedding @@ -288,7 +288,7 @@ def forward( use_image_num: int = 0, enable_temporal_attentions: bool = True, return_dict: bool = True, - ): + ) -> Union[Tuple[torch.Tensor,], Transformer3DModelOutput]: """ The [`Transformer2DModel`] forward method. @@ -571,14 +571,14 @@ def forward( shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) ) output = rearrange(output, '(b f) c h w -> b c f h w', b=input_batch_size).contiguous() - + if not return_dict: return (output,) return Transformer3DModelOutput(sample=output) @classmethod - def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, **kwargs): + def from_pretrained_2d(cls, pretrained_model_path: str, subfolder=None, **kwargs): if subfolder is not None: pretrained_model_path = os.path.join(pretrained_model_path, subfolder) @@ -592,10 +592,10 @@ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, **kwargs): return model # depth = num_layers * 2 -def LatteT2V_XL_122(**kwargs): +def LatteT2V_XL_122(**kwargs) -> LatteT2V: return LatteT2V(num_layers=28, attention_head_dim=72, num_attention_heads=16, patch_size_t=1, patch_size=2, norm_type="ada_norm_single", caption_channels=4096, cross_attention_dim=1152, **kwargs) -def LatteT2V_D64_XL_122(**kwargs): +def LatteT2V_D64_XL_122(**kwargs) -> LatteT2V: return LatteT2V(num_layers=28, attention_head_dim=64, num_attention_heads=18, patch_size_t=1, patch_size=2, norm_type="ada_norm_single", caption_channels=4096, cross_attention_dim=1152, **kwargs) diff --git a/opensora/sample/pipeline_videogen.py b/opensora/sample/pipeline_videogen.py index 1263473fa..fcbbaa675 100644 --- a/opensora/sample/pipeline_videogen.py +++ b/opensora/sample/pipeline_videogen.py @@ -16,7 +16,7 @@ import inspect import re import urllib.parse as ul -from typing import Callable, List, Optional, Tuple, Union +from typing import Callable, List, Optional, Tuple, Union, Any, Dict import torch import einops @@ -104,7 +104,7 @@ def __init__( vae: AutoencoderKL, transformer: Transformer2DModel, scheduler: DPMSolverMultistepScheduler, - ): + ) -> None: super().__init__() self.register_modules( @@ -114,7 +114,7 @@ def __init__( # self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/utils.py - def mask_text_embeddings(self, emb, mask): + def mask_text_embeddings(self, emb: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, int]: if emb.shape[0] == 1: keep_index = mask.sum().item() return emb[:, :, :keep_index, :], keep_index # 1, 120, 4096 -> 1 7 4096 @@ -134,7 +134,7 @@ def encode_prompt( negative_prompt_embeds: Optional[torch.FloatTensor] = None, clean_caption: bool = False, mask_feature: bool = True, - ): + ) -> Tuple[torch.Tensor, torch.Tensor]: r""" Encodes the prompt into text encoder hidden states. @@ -280,19 +280,17 @@ def encode_prompt( # masked_negative_prompt_embeds_ = F.pad(masked_negative_prompt_embeds, padding, "constant", 0) # print(masked_prompt_embeds == masked_prompt_embeds_[:, :masked_negative_prompt_embeds.shape[1], ...]) - return masked_prompt_embeds, masked_negative_prompt_embeds # return masked_prompt_embeds_, masked_negative_prompt_embeds_ return prompt_embeds, negative_prompt_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs - def prepare_extra_step_kwargs(self, generator, eta): + def prepare_extra_step_kwargs(self, generator: Optional[Union[torch.Generator, List[torch.Generator]]], eta: float) -> Dict[str, Any]: # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) extra_step_kwargs = {} if accepts_eta: @@ -302,18 +300,19 @@ def prepare_extra_step_kwargs(self, generator, eta): accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) if accepts_generator: extra_step_kwargs["generator"] = generator + return extra_step_kwargs def check_inputs( self, - prompt, - height, - width, - negative_prompt, - callback_steps, - prompt_embeds=None, - negative_prompt_embeds=None, - ): + prompt: Union[str, List[str]], + height: Optional[int], + width: Optional[int], + negative_prompt: Optional[torch.FloatTensor], + callback_steps: Optional[Callable[[int, int, torch.FloatTensor], None]], + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ) -> None: if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -358,7 +357,7 @@ def check_inputs( ) # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing - def _text_preprocessing(self, text, clean_caption=False): + def _text_preprocessing(self, text: Union[str, List[str]], clean_caption: bool=False) -> List[str]: if clean_caption and not is_bs4_available(): logger.warn(BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`")) logger.warn("Setting `clean_caption` to False...") @@ -379,11 +378,11 @@ def process(text: str): else: text = text.lower().strip() return text - + return [process(t) for t in text] # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption - def _clean_caption(self, caption): + def _clean_caption(self, caption: str) -> str: caption = str(caption) caption = ul.unquote_plus(caption) caption = caption.strip().lower() @@ -501,8 +500,8 @@ def _clean_caption(self, caption): return caption.strip() # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, - latents=None): + def prepare_latents(self, batch_size: int, num_channels_latents: int, num_frames: int, height: Optional[int], width: Optional[int], dtype: torch.float16, device: torch.device, generator: Optional[Union[torch.Generator, List[torch.Generator]]], + latents: Optional[torch.FloatTensor]=None) -> torch.Tensor: shape = ( batch_size, num_channels_latents, @@ -755,7 +754,7 @@ def __call__( return VideoPipelineOutput(video=video) - def decode_latents(self, latents): + def decode_latents(self, latents: torch.Tensor) -> torch.Tensor: video = self.vae.decode(latents) # b t c h w # b t c h w -> b t h w c video = ((video / 2.0 + 0.5).clamp(0, 1) * 255).to(dtype=torch.uint8).cpu().permute(0, 1, 3, 4, 2).contiguous() diff --git a/opensora/sample/sample_t2v.py b/opensora/sample/sample_t2v.py index 425739b85..8013c81cc 100644 --- a/opensora/sample/sample_t2v.py +++ b/opensora/sample/sample_t2v.py @@ -28,7 +28,7 @@ import imageio -def main(args): +def main(args: argparse.Namespace) -> None: # torch.manual_seed(args.seed) torch.set_grad_enabled(False) device = "cuda" if torch.cuda.is_available() else "cpu"