From e836727fb2946d2a0baf0c1eac3553cf8b8d649b Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Thu, 26 Dec 2024 10:29:08 +0800 Subject: [PATCH] add teacache --- README.md | 30 +- tools/cache_inference/README.md | 54 ++++ tools/cache_inference/run.sh | 8 + .../teacache_inference_consisid.py | 277 ++++++++++++++++++ .../parallel_inference/README.md | 2 +- .../parallel_inference/README_zh.md | 2 +- .../parallel_inference_xdit.py | 0 .../parallel_inference_xdit_usp.py | 0 .../parallel_inference}/run.sh | 0 .../parallel_inference}/run_usp.sh | 0 10 files changed, 364 insertions(+), 9 deletions(-) create mode 100644 tools/cache_inference/README.md create mode 100644 tools/cache_inference/run.sh create mode 100644 tools/cache_inference/teacache_inference_consisid.py rename parallel_inference/performance.md => tools/parallel_inference/README.md (97%) rename parallel_inference/performance_zh.md => tools/parallel_inference/README_zh.md (97%) rename {parallel_inference => tools/parallel_inference}/parallel_inference_xdit.py (100%) rename {parallel_inference => tools/parallel_inference}/parallel_inference_xdit_usp.py (100%) rename {parallel_inference => tools/parallel_inference}/run.sh (100%) rename {parallel_inference => tools/parallel_inference}/run_usp.sh (100%) diff --git a/README.md b/README.md index 95029d7..aede16d 100644 --- a/README.md +++ b/README.md @@ -27,11 +27,6 @@ This repository is the official implementation of ConsisID, a tuning-free DiT-based controllable IPT2V model to keep human-identity consistent in the generated video. The approach draws inspiration from previous studies on frequency analysis of vision/diffusion transformers. - - - - -
💡 We also have other video generation projects that may interest you ✨.

@@ -54,7 +49,8 @@ This repository is the official implementation of ConsisID, a tuning-free DiT-ba ## 📣 News * ⏳⏳⏳ Release the full code & datasets & weights. -* `[2024.12.24]` 🚀 We release the [parallel inference code](https://github.com/PKU-YuanGroup/ConsisID/tree/main/parallel_inference) for ConsisID powered by [xDiT](https://github.com/xdit-project/xDiT). Thanks [@feifeibear](https://github.com/feifeibear) for his help. +* `[2024.12.26]` 🚀 We release the [cache inference code](https://github.com/PKU-YuanGroup/ConsisID/tree/main/tools/cache_inference) for ConsisID powered by [TeaCache](https://github.com/LiewFeng/TeaCache). Thanks [@LiewFeng](https://github.com/LiewFeng) for his help. +* `[2024.12.24]` 🚀 We release the [parallel inference code](https://github.com/PKU-YuanGroup/ConsisID/tree/main/tools/parallel_inference) for ConsisID powered by [xDiT](https://github.com/xdit-project/xDiT). Thanks [@feifeibear](https://github.com/feifeibear) for his help. * `[2024.12.22]` 🤗 ConsisID will be merged into [diffusers](https://github.com/huggingface/diffusers) in the next version. So for now, please use `pip install git+https://github.com/SHYuanBest/ConsisID_diffusers.git` to install diffusers dev version. And we have reorganized the code and weight configs, so it's better to update your local files if you have cloned them previously. * `[2024.12.09]` 🔥We release the [test set](https://huggingface.co/datasets/BestWishYsh/ConsisID-preview-Data/tree/main/eval) and [metric calculation code](https://github.com/PKU-YuanGroup/ConsisID/tree/main/eval) used in the paper, now your can measure the metrics on your own machine. Please refer to [this guide](https://github.com/PKU-YuanGroup/ConsisID/tree/main/eval) for more details. * `[2024.12.08]` 🔥The code for data preprocessing is out, which is used to obtain the [training data](https://huggingface.co/datasets/BestWishYsh/ConsisID-preview-Data) required by ConsisID. Please refer to [this guide](https://github.com/PKU-YuanGroup/ConsisID/tree/main/data_preprocess) for more details. @@ -173,6 +169,25 @@ pipe.vae.enable_tiling() ``` warning: it will cost more time in inference and may also reduce the quality. +## 🚀 Parallel Inference on Multiple GPUs by xDiT + +[xDiT](https://github.com/xdit-project/xDiT) is a Scalable Inference Engine for Diffusion Transformers (DiTs) on multi-GPU Clusters. It has successfully provided low-latency parallel inference solutions for a variety of DiTs models. For example, to generate a video with 6 GPUs, you can use the following command: + +``` +cd tools/parallel_inference +bash run.sh +# run_usp.sh +``` + +## 🚀 Cache Inference by TeaCache + +[TeaCache](https://github.com/LiewFeng/TeaCache) is a training-free caching approach that estimates and leverages the fluctuating differences among model outputs across timesteps, thereby accelerate the inference. For example, you can use the following command: + +``` +cd tools/cache_inference +bash run.sh +``` + ## ⚙️ Requirements and Installation We recommend the requirements as follows. @@ -282,6 +297,7 @@ We found some plugins created by community developers. Thanks for their efforts: - Windows Docker. [🤗Windows-ConsisID](https://huggingface.co/pkuhexianyi/ConsisID-Windows/tree/main) and [🟣Windows-ConsisID](https://www.wisemodel.cn/models/PkuHexianyi/ConsisID-Windows/file) (by [@shizi](https://www.bilibili.com/video/BV1v3iUY4EeQ/?vd_source=ae3f2652765c02e41cdd698b311989e3)). - Diffusres. [Diffusers-ConsisID](https://github.com/huggingface/diffusers) (thanks [@arrow](https://github.com/a-r-r-o-w), [@yiyixuxu](https://github.com/yiyixuxu), [@hlky](https://github.com/hlky) and [@stevhliu](https://github.com/stevhliu) for their help). - xDiT. [xDiT-ConsisID](https://github.com/xdit-project/xDiT) (thanks [@feifeibear](https://github.com/feifeibear) for his help). + - TeaCache. [TeaCache-ConsisID](https://github.com/LiewFeng/TeaCache) (thanks [@LiewFeng](https://github.com/LiewFeng) for his help). If you find related work, please let us know. @@ -327,4 +343,4 @@ If you find our paper and code useful in your research, please consider giving a - + \ No newline at end of file diff --git a/tools/cache_inference/README.md b/tools/cache_inference/README.md new file mode 100644 index 0000000..6614bce --- /dev/null +++ b/tools/cache_inference/README.md @@ -0,0 +1,54 @@ + +# TeaCache4ConsisID + +[TeaCache](https://github.com/LiewFeng/TeaCache) can speedup [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) 2x without much visual quality degradation, in a training-free manner. + +## 📈 Inference Latency Comparisons on a Single H100 GPU + +| ConsisID | TeaCache (0.1) | TeaCache (0.15) | TeaCache (0.20) | +| :------: | :------------: | :-------------: | :-------------: | +| ~110 s | ~70 s | ~53 s | ~41 s | + + +## Usage + +Follow [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) to clone the repo and finish the installation, then you can modify the `rel_l1_thresh` to obtain your desired trade-off between latency and visul quality, and change the `ckpts_path`, `prompt`, `image` to customize your identity-preserving video. + +For single-gpu inference, you can use the following command: + +```bash +python teacache_inference_consisid.py \ + --rel_l1_thresh 0.1 \ + --ckpts_path BestWishYsh/ConsisID-preview \ + --image "https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true" \ + --prompt "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy\'s path, adding depth to the scene. The lighting highlights the boy\'s subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel." \ + --seed 42 \ + --num_infer_steps 50 \ + --output_path ./teacache_results +``` + +To generate a video with 8 GPUs, you can use the following [here](https://github.com/PKU-YuanGroup/ConsisID/tree/main/tools). + +## Resources + +Learn more about ConsisID with the following resources. +- A [video](https://www.youtube.com/watch?v=PhlgC-bI5SQ) demonstrating ConsisID's main features. +- The research paper, [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://hf.co/papers/2411.17440) for more details. + +## Citation + +If you find TeaCache is useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry. + +``` +@article{liu2024timestep, + title={Timestep Embedding Tells: It's Time to Cache for Video Diffusion Model}, + author={Liu, Feng and Zhang, Shiwei and Wang, Xiaofeng and Wei, Yujie and Qiu, Haonan and Zhao, Yuzhong and Zhang, Yingya and Ye, Qixiang and Wan, Fang}, + journal={arXiv preprint arXiv:2411.19108}, + year={2024} +} +``` + + +## Acknowledgements + +We would like to thank the contributors to the [ConsisID](https://github.com/PKU-YuanGroup/ConsisID). diff --git a/tools/cache_inference/run.sh b/tools/cache_inference/run.sh new file mode 100644 index 0000000..14c7b9b --- /dev/null +++ b/tools/cache_inference/run.sh @@ -0,0 +1,8 @@ +python teacache_inference_consisid.py \ + --rel_l1_thresh 0.1 \ + --ckpts_path BestWishYsh/ConsisID-preview \ + --image "https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true" \ + --prompt "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy\'s path, adding depth to the scene. The lighting highlights the boy\'s subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel." \ + --seed 42 \ + --num_infer_steps 50 \ + --output_path ./teacache_results \ No newline at end of file diff --git a/tools/cache_inference/teacache_inference_consisid.py b/tools/cache_inference/teacache_inference_consisid.py new file mode 100644 index 0000000..25fd5de --- /dev/null +++ b/tools/cache_inference/teacache_inference_consisid.py @@ -0,0 +1,277 @@ +import os +import argparse +import numpy as np +from typing import Any, Dict, Optional, Tuple, Union + +import torch + +from diffusers import ConsisIDPipeline +from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer +from diffusers.models.modeling_outputs import Transformer2DModelOutput +from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers +from diffusers.utils import export_to_video +from huggingface_hub import snapshot_download + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def teacache_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: torch.Tensor, + timestep: Union[int, float, torch.LongTensor], + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + id_cond: Optional[torch.Tensor] = None, + id_vit_hidden: Optional[torch.Tensor] = None, + return_dict: bool = True, +): + if attention_kwargs is not None: + attention_kwargs = attention_kwargs.copy() + lora_scale = attention_kwargs.pop("scale", 1.0) + else: + lora_scale = 1.0 + + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + else: + if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective." + ) + + # fuse clip and insightface + if self.is_train_face: + assert id_cond is not None and id_vit_hidden is not None + id_cond = id_cond.to(device=hidden_states.device, dtype=hidden_states.dtype) + id_vit_hidden = [ + tensor.to(device=hidden_states.device, dtype=hidden_states.dtype) for tensor in id_vit_hidden + ] + valid_face_emb = self.local_facial_extractor( + id_cond, id_vit_hidden + ) # torch.Size([1, 1280]), list[5](torch.Size([1, 577, 1024])) -> torch.Size([1, 32, 2048]) + + batch_size, num_frames, channels, height, width = hidden_states.shape + + # 1. Time embedding + timesteps = timestep + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + # 2. Patch embedding + # torch.Size([1, 226, 4096]) torch.Size([1, 13, 32, 60, 90]) + hidden_states = self.patch_embed(encoder_hidden_states, hidden_states) # torch.Size([1, 17776, 3072]) + hidden_states = self.embedding_dropout(hidden_states) # torch.Size([1, 17776, 3072]) + + text_seq_length = encoder_hidden_states.shape[1] + encoder_hidden_states = hidden_states[:, :text_seq_length] # torch.Size([1, 226, 3072]) + hidden_states = hidden_states[:, text_seq_length:] # torch.Size([1, 17550, 3072]) + + if self.enable_teacache: + if self.cnt == 0 or self.cnt == self.num_steps-1: + should_calc = True + self.accumulated_rel_l1_distance = 0 + else: + coefficients = [-1.53880483e+03, 8.43202495e+02, -1.34363087e+02, 7.97131516e+00, -5.23162339e-02] + rescale_func = np.poly1d(coefficients) + self.accumulated_rel_l1_distance += rescale_func(((emb-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) + if self.accumulated_rel_l1_distance < self.rel_l1_thresh: + should_calc = False + else: + should_calc = True + self.accumulated_rel_l1_distance = 0 + self.previous_modulated_input = emb + self.cnt = 0 if self.cnt == self.num_steps-1 else self.cnt + 1 + + if self.enable_teacache: + if not should_calc: + hidden_states += self.previous_residual + encoder_hidden_states += self.previous_residual_encoder + else: + ori_hidden_states = hidden_states.clone() + ori_encoder_hidden_states = encoder_hidden_states.clone() + # 3. Transformer blocks + ca_idx = 0 + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + if self.is_train_face: + if i % self.cross_attn_interval == 0 and valid_face_emb is not None: + hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx]( + valid_face_emb, hidden_states + ) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072]) + ca_idx += 1 + + self.previous_residual = hidden_states - ori_hidden_states + self.previous_residual_encoder = encoder_hidden_states - ori_encoder_hidden_states + else: + # 3. Transformer blocks + ca_idx = 0 + for i, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + emb, + image_rotary_emb, + **ckpt_kwargs, + ) + else: + hidden_states, encoder_hidden_states = block( + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb=emb, + image_rotary_emb=image_rotary_emb, + ) + + if self.is_train_face: + if i % self.cross_attn_interval == 0 and valid_face_emb is not None: + hidden_states = hidden_states + self.local_face_scale * self.perceiver_cross_attention[ca_idx]( + valid_face_emb, hidden_states + ) # torch.Size([2, 32, 2048]) torch.Size([2, 17550, 3072]) + ca_idx += 1 + + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + hidden_states = self.norm_final(hidden_states) + hidden_states = hidden_states[:, text_seq_length:] + + # 4. Final block + hidden_states = self.norm_out(hidden_states, temb=emb) + hidden_states = self.proj_out(hidden_states) + + # 5. Unpatchify + # Note: we use `-1` instead of `channels`: + # - It is okay to `channels` use for ConsisID (number of input channels is equal to output channels) + p = self.config.patch_size + output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p) + output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (output,) + return Transformer2DModelOutput(sample=output) + + +def main(args): + seed = args.seed + num_infer_steps = args.num_infer_steps + output_path = args.output_path + ckpts_path = args.ckpts_path + # higher speedup will cause to worse quality -- 0.1 for 1.6x speedup -- 0.15 for 2.1x speedup -- 0.2 for 2.5x speedup + rel_l1_thresh = args.rel_l1_thresh + # ConsisID works well with long and well-described prompts. Make sure the face in the image is clearly visible (e.g., preferably half-body or full-body). + prompt = args.prompt + image = args.image + + if not os.path.exists(ckpts_path): + print("Base Model not found, downloading from Hugging Face...") + snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir=ckpts_path) + else: + print(f"Base Model already exists in {ckpts_path}, skipping download.") + + if not os.path.exists(output_path): + os.makedirs(output_path, exist_ok=True) + + face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = ( + prepare_face_models(ckpts_path, device="cuda", dtype=torch.bfloat16) + ) + pipe = ConsisIDPipeline.from_pretrained(ckpts_path, torch_dtype=torch.bfloat16) + pipe.to("cuda") + + id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer( + face_helper_1, + face_clip_model, + face_helper_2, + eva_transform_mean, + eva_transform_std, + face_main_model, + "cuda", + torch.bfloat16, + image, + is_align_face=True, + ) + + # TeaCache Config + pipe.transformer.__class__.enable_teacache = True + pipe.transformer.__class__.cnt = 0 + pipe.transformer.__class__.num_steps = num_infer_steps - 1 + pipe.transformer.__class__.rel_l1_thresh = rel_l1_thresh # 0.1 for 1.6x speedup -- 0.15 for 2.1x speedup -- 0.2 for 2.5x speedup + pipe.transformer.__class__.accumulated_rel_l1_distance = 0 + pipe.transformer.__class__.previous_modulated_input = None + pipe.transformer.__class__.previous_residual = None + pipe.transformer.__class__.previous_residual_encoder = None + pipe.transformer.__class__.forward = teacache_forward + + video = pipe( + image=image, + prompt=prompt, + num_inference_steps=num_infer_steps, + guidance_scale=6.0, + use_dynamic_cfg=False, + id_vit_hidden=id_vit_hidden, + id_cond=id_cond, + kps_cond=face_kps, + generator=torch.Generator("cuda").manual_seed(seed), + ) + file_count = len([f for f in os.listdir(output_path) if os.path.isfile(os.path.join(output_path, f))]) + video_path = f"{output_path}/{seed}_{rel_l1_thresh}_{file_count:04d}.mp4" + export_to_video(video.frames[0], video_path, fps=8) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Run ConsisID with given parameters") + + parser.add_argument('--seed', type=int, default=42, help='Random seed') + parser.add_argument('--num_infer_steps', type=int, default=50, help='Number of inference steps') + parser.add_argument("--output_path", type=str, default="./teacache_results", help="The path where the generated video will be saved") + parser.add_argument('--ckpts_path', type=str, default="BestWishYsh/ConsisID-preview", help='Path to checkpoint') + # higher speedup will cause to worse quality -- 0.1 for 1.6x speedup -- 0.15 for 2.1x speedup -- 0.2 for 2.5x speedup + parser.add_argument('--rel_l1_thresh', type=float, default=0.1, help='Higher speedup will cause to worse quality -- 0.1 for 1.6x speedup -- 0.15 for 2.1x speedup -- 0.2 for 2.5x speedup') + # ConsisID works well with long and well-described prompts. Make sure the face in the image is clearly visible (e.g., preferably half-body or full-body). + parser.add_argument('--prompt', type=str, default="The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy\'s path, adding depth to the scene. The lighting highlights the boy\'s subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel.", help='Description of the video for the model to generate') + parser.add_argument('--image', type=str, default="https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true", help='URL or path to input image') + args = parser.parse_args() + + main(args) \ No newline at end of file diff --git a/parallel_inference/performance.md b/tools/parallel_inference/README.md similarity index 97% rename from parallel_inference/performance.md rename to tools/parallel_inference/README.md index bd44dd0..5bd2d4a 100644 --- a/parallel_inference/performance.md +++ b/tools/parallel_inference/README.md @@ -1,4 +1,4 @@ -## ConsisID Performance Report +## xDiT-ConsisID Performance Report [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) is an identity-preserving text-to-video generation model that keeps the face consistent in the generated video by frequency decomposition. [xDiT](https://github.com/xdit-project/xDiT) currently integrates USP techniques, including Ulysses Attention, Ring Attention, and CFG parallelization, to enhance inference speed, while work on PipeFusion is ongoing. We conducted an in-depth analysis comparing single-GPU ConsisID inference, based on the diffusers library, with our proposed parallelized version for generating 49 frames (6 seconds) of 720x480 resolution video. By flexibly combining different parallelization methods, we achieved varying performance outcomes. In this study, we systematically evaluate xDiT's acceleration performance across 1 to 6 Nvidia H100 GPUs. diff --git a/parallel_inference/performance_zh.md b/tools/parallel_inference/README_zh.md similarity index 97% rename from parallel_inference/performance_zh.md rename to tools/parallel_inference/README_zh.md index 7258dd8..ec1d984 100644 --- a/parallel_inference/performance_zh.md +++ b/tools/parallel_inference/README_zh.md @@ -1,4 +1,4 @@ -## ConsisID Performance Report +## xDiT-ConsisID Performance Report [ConsisID](https://github.com/PKU-YuanGroup/ConsisID) 是一种身份保持的文本到视频生成模型,其通过频率分解在生成的视频中保持面部一致性。[xDiT](https://github.com/xdit-project/xDiT)目前整合了USP技术(包括Ulysses Attention和Ring Attention)和Classifier-Free Guidance(CFG)来提高推理速度,同时我们还将整合PipeFusion策略。我们对基于diffusers库的单GPU ConsisID推理与我们提出的并行化版本在生成49帧(6秒)720x480分辨率视频时的性能差异进行了深入分析。由于我们可以任意组合不同的并行方式以获得不同的性能。在本文中,我们对xDiT在1-6张H100(Nvidia)GPU上的加速性能进行了系统测试。 diff --git a/parallel_inference/parallel_inference_xdit.py b/tools/parallel_inference/parallel_inference_xdit.py similarity index 100% rename from parallel_inference/parallel_inference_xdit.py rename to tools/parallel_inference/parallel_inference_xdit.py diff --git a/parallel_inference/parallel_inference_xdit_usp.py b/tools/parallel_inference/parallel_inference_xdit_usp.py similarity index 100% rename from parallel_inference/parallel_inference_xdit_usp.py rename to tools/parallel_inference/parallel_inference_xdit_usp.py diff --git a/parallel_inference/run.sh b/tools/parallel_inference/run.sh similarity index 100% rename from parallel_inference/run.sh rename to tools/parallel_inference/run.sh diff --git a/parallel_inference/run_usp.sh b/tools/parallel_inference/run_usp.sh similarity index 100% rename from parallel_inference/run_usp.sh rename to tools/parallel_inference/run_usp.sh