From c4ae88cc031ab4e03b1b1198a952f8b96e88b8aa Mon Sep 17 00:00:00 2001 From: SHYuanBest Date: Tue, 24 Dec 2024 14:55:52 +0800 Subject: [PATCH] add parallel inference --- README.md | 12 +- parallel_inference/parallel_inference_xdit.py | 119 ++++++++ .../parallel_inference_xdit_usp.py | 256 ++++++++++++++++++ parallel_inference/performance.md | 29 ++ parallel_inference/performance_zh.md | 23 ++ parallel_inference/run.sh | 42 +++ parallel_inference/run_usp.sh | 42 +++ 7 files changed, 518 insertions(+), 5 deletions(-) create mode 100644 parallel_inference/parallel_inference_xdit.py create mode 100644 parallel_inference/parallel_inference_xdit_usp.py create mode 100644 parallel_inference/performance.md create mode 100644 parallel_inference/performance_zh.md create mode 100644 parallel_inference/run.sh create mode 100644 parallel_inference/run_usp.sh diff --git a/README.md b/README.md index 5cab7aa..f5a7293 100644 --- a/README.md +++ b/README.md @@ -54,15 +54,16 @@ This repository is the official implementation of ConsisID, a tuning-free DiT-ba ## 📣 News * ⏳⏳⏳ Release the full code & datasets & weights. -* `[2024.12.22]` 🔥ConsisID will be merged into [diffusers](https://github.com/huggingface/diffusers) in the next version. So for now, please use `pip install git+https://github.com/SHYuanBest/ConsisID_diffusers.git` to install diffusers dev version. And we have reorganized the code and weight configs, so it's better to update your local files if you have cloned them previously. +* `[2024.12.24]` 🚀 We release the [parallel inference code](https://github.com/PKU-YuanGroup/ConsisID/tree/main/parallel_inference) for ConsisID powered by [xDiT](https://github.com/xdit-project/xDiT). Thanks [@feifeibear](https://github.com/feifeibear) for his help. +* `[2024.12.22]` 🤗 ConsisID will be merged into [diffusers](https://github.com/huggingface/diffusers) in the next version. So for now, please use `pip install git+https://github.com/SHYuanBest/ConsisID_diffusers.git` to install diffusers dev version. And we have reorganized the code and weight configs, so it's better to update your local files if you have cloned them previously. * `[2024.12.09]` 🔥We release the [test set](https://huggingface.co/datasets/BestWishYsh/ConsisID-preview-Data/tree/main/eval) and [metric calculation code](https://github.com/PKU-YuanGroup/ConsisID/tree/main/eval) used in the paper, now your can measure the metrics on your own machine. Please refer to [this guide](https://github.com/PKU-YuanGroup/ConsisID/tree/main/eval) for more details. * `[2024.12.08]` 🔥The code for data preprocessing is out, which is used to obtain the [training data](https://huggingface.co/datasets/BestWishYsh/ConsisID-preview-Data) required by ConsisID. Please refer to [this guide](https://github.com/PKU-YuanGroup/ConsisID/tree/main/data_preprocess) for more details. * `[2024.12.04]` Thanks [@shizi](https://www.bilibili.com/video/BV1v3iUY4EeQ/?vd_source=ae3f2652765c02e41cdd698b311989e3) for providing [🤗Windows-ConsisID](https://huggingface.co/pkuhexianyi/ConsisID-Windows/tree/main) and [🟣Windows-ConsisID](https://www.wisemodel.cn/models/PkuHexianyi/ConsisID-Windows/file), which make it easy to run ConsisID on Windows. * `[2024.12.01]` 🔥 We provide full text prompts corresponding to all the videos on project page. Click [here](https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/prompt.xlsx) to get and try the demo. -* `[2024.11.30]` 🔥 We have fixed the [huggingface demo](https://huggingface.co/spaces/BestWishYsh/ConsisID-preview-Space), welcome to try it. -* `[2024.11.29]` 🔥 The current code and weights are our early versions, and the differences with the latest version in [arxiv](https://github.com/PKU-YuanGroup/ConsisID) can be viewed [here](https://github.com/PKU-YuanGroup/ConsisID/tree/main/util/on_going_module). And we will release the full code in the next few days. +* `[2024.11.30]` 🤗 We have fixed the [huggingface demo](https://huggingface.co/spaces/BestWishYsh/ConsisID-preview-Space), welcome to try it. +* `[2024.11.29]` 🏃‍♂️ The current code and weights are our early versions, and the differences with the latest version in [arxiv](https://github.com/PKU-YuanGroup/ConsisID) can be viewed [here](https://github.com/PKU-YuanGroup/ConsisID/tree/main/util/on_going_module). And we will release the full code in the next few days. * `[2024.11.28]` Thanks [@camenduru](https://twitter.com/camenduru) for providing [Jupyter Notebook](https://colab.research.google.com/github/camenduru/ConsisID-jupyter/blob/main/ConsisID_jupyter.ipynb) and [@Kijai](https://github.com/kijai) for providing ComfyUI Extension [ComfyUI-ConsisIDWrapper](https://github.com/kijai/ComfyUI-CogVideoXWrapper). If you find related work, please let us know. -* `[2024.11.27]` 🔥 Due to policy restrictions, we only open-source part of the dataset. You can download it by clicking [here](https://huggingface.co/datasets/BestWishYsh/ConsisID-preview-Data). And we will release the data processing code in the next few days. +* `[2024.11.27]` 🏃‍♂️ Due to policy restrictions, we only open-source part of the dataset. You can download it by clicking [here](https://huggingface.co/datasets/BestWishYsh/ConsisID-preview-Data). And we will release the data processing code in the next few days. * `[2024.11.26]` 🔥 We release the arXiv paper for ConsisID, and you can click [here](https://arxiv.org/abs/2411.17440) to see more details. * `[2024.11.22]` 🔥 **All code & datasets** are coming soon! Stay tuned 👀! @@ -279,7 +280,8 @@ We found some plugins created by community developers. Thanks for their efforts: - ComfyUI Extension. [ComfyUI-ConsisIDWrapper](https://github.com/kijai/ComfyUI-CogVideoXWrapper) (by [@Kijai](https://github.com/kijai)). - Jupyter Notebook. [Jupyter-ConsisID](https://colab.research.google.com/github/camenduru/ConsisID-jupyter/blob/main/ConsisID_jupyter.ipynb) (by [@camenduru](https://github.com/camenduru/consisid-tost)). - Windows Docker. [🤗Windows-ConsisID](https://huggingface.co/pkuhexianyi/ConsisID-Windows/tree/main) and [🟣Windows-ConsisID](https://www.wisemodel.cn/models/PkuHexianyi/ConsisID-Windows/file) (by [@shizi](https://www.bilibili.com/video/BV1v3iUY4EeQ/?vd_source=ae3f2652765c02e41cdd698b311989e3)). - - xDiT. [xDit-ConsisID](https://github.com/xdit-project/xDiT) (by [pkuhxy](https://github.com/pkuhxy) and [feifeibear](https://github.com/feifeibear)). + - Diffusres. [Diffusers-ConsisID](https://github.com/huggingface/diffusers) (thanks [@arrow](https://github.com/a-r-r-o-w), [@yiyixuxu](https://github.com/yiyixuxu), [@hlky](https://github.com/hlky) and [@stevhliu](https://github.com/stevhliu) for their help). + - xDiT. [xDit-ConsisID](https://github.com/xdit-project/xDiT) (thanks [@feifeibear](https://github.com/feifeibear) for his help). If you find related work, please let us know. diff --git a/parallel_inference/parallel_inference_xdit.py b/parallel_inference/parallel_inference_xdit.py new file mode 100644 index 0000000..355bc1a --- /dev/null +++ b/parallel_inference/parallel_inference_xdit.py @@ -0,0 +1,119 @@ +import logging +import os +import time +import torch +import torch.distributed + +from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer +from diffusers.utils import export_to_video +from huggingface_hub import snapshot_download + +from xfuser import xFuserConsisIDPipeline, xFuserArgs +from xfuser.config import FlexibleArgumentParser +from xfuser.core.distributed import ( + get_world_group, + get_runtime_state, + is_dp_last_group, +) + + +def main(): + parser = FlexibleArgumentParser(description="xFuser Arguments") + args = xFuserArgs.add_cli_args(parser).parse_args() + engine_args = xFuserArgs.from_cli_args(args) + + engine_config, input_config = engine_args.create_config() + local_rank = get_world_group().local_rank + + assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion." + assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for ConsisID" + + # 1. Prepare all the Checkpoints + if not os.path.exists(engine_config.model_config.model): + print("Base Model not found, downloading from Hugging Face...") + snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir=engine_config.model_config.model) + else: + print(f"Base Model already exists in {engine_config.model_config.model}, skipping download.") + + # 2. Load Pipeline + device = torch.device(f"cuda:{local_rank}") + pipe = xFuserConsisIDPipeline.from_pretrained( + pretrained_model_name_or_path=engine_config.model_config.model, + engine_config=engine_config, + torch_dtype=torch.bfloat16, + ) + if args.enable_sequential_cpu_offload: + pipe.enable_sequential_cpu_offload(gpu_id=local_rank) + logging.info(f"rank {local_rank} sequential CPU offload enabled") + elif args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload(gpu_id=local_rank) + logging.info(f"rank {local_rank} model CPU offload enabled") + else: + pipe = pipe.to(device) + + face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = ( + prepare_face_models(engine_config.model_config.model, device=device, dtype=torch.bfloat16) + ) + + if args.enable_tiling: + pipe.vae.enable_tiling() + + if args.enable_slicing: + pipe.vae.enable_slicing() + + # 3. Prepare Model Input + id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer( + face_helper_1, + face_clip_model, + face_helper_2, + eva_transform_mean, + eva_transform_std, + face_main_model, + device, + torch.bfloat16, + input_config.img_file_path, + is_align_face=True, + ) + + # 4. Generate Identity-Preserving Video + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + + output = pipe( + image=image, + prompt=input_config.prompt[0], + id_vit_hidden=id_vit_hidden, + id_cond=id_cond, + kps_cond=face_kps, + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + num_inference_steps=input_config.num_inference_steps, + generator=torch.Generator(device="cuda").manual_seed(input_config.seed), + guidance_scale=6.0, + use_dynamic_cfg=False, + ).frames[0] + + end_time = time.time() + elapsed_time = end_time - start_time + peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + parallel_info = ( + f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_" + f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" + f"tp{engine_args.tensor_parallel_degree}_" + f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" + ) + if is_dp_last_group(): + resolution = f"{input_config.width}x{input_config.height}" + output_filename = f"results/consisid_{parallel_info}_{resolution}.mp4" + export_to_video(output, output_filename, fps=8) + print(f"output saved to {output_filename}") + + if get_world_group().rank == get_world_group().world_size - 1: + print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB") + get_runtime_state().destory_distributed_env() + + +if __name__ == "__main__": + main() diff --git a/parallel_inference/parallel_inference_xdit_usp.py b/parallel_inference/parallel_inference_xdit_usp.py new file mode 100644 index 0000000..778dd72 --- /dev/null +++ b/parallel_inference/parallel_inference_xdit_usp.py @@ -0,0 +1,256 @@ +import functools +from typing import Optional, Tuple, Any, Dict + +import logging +import os +import time +import torch + +from diffusers import DiffusionPipeline, ConsisIDPipeline +from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer +from diffusers.utils import export_to_video +from huggingface_hub import snapshot_download + +from xfuser import xFuserArgs +from xfuser.config import FlexibleArgumentParser +from xfuser.core.distributed import ( + get_world_group, + get_runtime_state, + get_classifier_free_guidance_world_size, + get_classifier_free_guidance_rank, + get_cfg_group, + get_sequence_parallel_world_size, + get_sequence_parallel_rank, + get_sp_group, + is_dp_last_group, + initialize_runtime_state, + get_pipeline_parallel_world_size, +) +from xfuser.model_executor.layers.attention_processor import xFuserConsisIDAttnProcessor2_0 + +def parallelize_transformer(pipe: DiffusionPipeline): + transformer = pipe.transformer + original_forward = transformer.forward + + @functools.wraps(transformer.__class__.forward) + def new_forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: torch.LongTensor = None, + timestep_cond: Optional[torch.Tensor] = None, + image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + attention_kwargs: Optional[Dict[str, Any]] = None, + id_cond: Optional[torch.Tensor] = None, + id_vit_hidden: Optional[torch.Tensor] = None, + **kwargs, + ): + if encoder_hidden_states.shape[-2] % get_sequence_parallel_world_size() != 0: + get_runtime_state().split_text_embed_in_sp = False + else: + get_runtime_state().split_text_embed_in_sp = True + + temporal_size = hidden_states.shape[1] + if isinstance(timestep, torch.Tensor) and timestep.ndim != 0 and timestep.shape[0] == hidden_states.shape[0]: + timestep = torch.chunk(timestep, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] + hidden_states = torch.chunk(hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] + hidden_states = torch.chunk(hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + encoder_hidden_states = torch.chunk(encoder_hidden_states, get_classifier_free_guidance_world_size(),dim=0)[get_classifier_free_guidance_rank()] + if get_runtime_state().split_text_embed_in_sp: + encoder_hidden_states = torch.chunk(encoder_hidden_states, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + if image_rotary_emb is not None: + freqs_cos, freqs_sin = image_rotary_emb + + def get_rotary_emb_chunk(freqs): + dim_thw = freqs.shape[-1] + freqs = freqs.reshape(temporal_size, -1, dim_thw) + freqs = torch.chunk(freqs, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + freqs = freqs.reshape(-1, dim_thw) + return freqs + + freqs_cos = get_rotary_emb_chunk(freqs_cos) + freqs_sin = get_rotary_emb_chunk(freqs_sin) + image_rotary_emb = (freqs_cos, freqs_sin) + + for block in transformer.transformer_blocks: + block.attn1.processor = xFuserConsisIDAttnProcessor2_0() + + output = original_forward( + hidden_states, + encoder_hidden_states, + timestep=timestep, + timestep_cond=timestep_cond, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + id_cond=id_cond, + id_vit_hidden=id_vit_hidden, + **kwargs, + ) + + return_dict = not isinstance(output, tuple) + sample = output[0] + sample = get_sp_group().all_gather(sample, dim=-2) + sample = get_cfg_group().all_gather(sample, dim=0) + if return_dict: + return output.__class__(sample, *output[1:]) + return (sample, *output[1:]) + + new_forward = new_forward.__get__(transformer) + transformer.forward = new_forward + + original_patch_embed_forward = transformer.patch_embed.forward + + @functools.wraps(transformer.patch_embed.__class__.forward) + def new_patch_embed( + self, text_embeds: torch.Tensor, image_embeds: torch.Tensor + ): + text_embeds = get_sp_group().all_gather(text_embeds.contiguous(), dim=-2) + image_embeds = get_sp_group().all_gather(image_embeds.contiguous(), dim=-2) + batch, num_frames, channels, height, width = image_embeds.shape + text_len = text_embeds.shape[-2] + + output = original_patch_embed_forward(text_embeds, image_embeds) + + text_embeds = output[:,:text_len,:] + image_embeds = output[:,text_len:,:].reshape(batch, num_frames, -1, output.shape[-1]) + + text_embeds = torch.chunk(text_embeds, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + image_embeds = torch.chunk(image_embeds, get_sequence_parallel_world_size(),dim=-2)[get_sequence_parallel_rank()] + image_embeds = image_embeds.reshape(batch, -1, image_embeds.shape[-1]) + return torch.cat([text_embeds, image_embeds], dim=1) + + new_patch_embed = new_patch_embed.__get__(transformer.patch_embed) + transformer.patch_embed.forward = new_patch_embed + +def main(): + parser = FlexibleArgumentParser(description="xFuser Arguments") + args = xFuserArgs.add_cli_args(parser).parse_args() + engine_args = xFuserArgs.from_cli_args(args) + + engine_config, input_config = engine_args.create_config() + local_rank = get_world_group().local_rank + + assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion." + assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for ConsisID" + + # 1. Prepare all the Checkpoints + if not os.path.exists(engine_config.model_config.model): + print("Base Model not found, downloading from Hugging Face...") + snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir=engine_config.model_config.model) + else: + print(f"Base Model already exists in {engine_config.model_config.model}, skipping download.") + + # 2. Load Pipeline + device = torch.device(f"cuda:{local_rank}") + pipe = ConsisIDPipeline.from_pretrained( + pretrained_model_name_or_path=engine_config.model_config.model, + torch_dtype=torch.bfloat16, + ) + if args.enable_sequential_cpu_offload: + pipe.enable_sequential_cpu_offload(gpu_id=local_rank) + logging.info(f"rank {local_rank} sequential CPU offload enabled") + elif args.enable_model_cpu_offload: + pipe.enable_model_cpu_offload(gpu_id=local_rank) + logging.info(f"rank {local_rank} model CPU offload enabled") + else: + pipe = pipe.to(device) + + face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = ( + prepare_face_models(engine_config.model_config.model, device=device, dtype=torch.bfloat16) + ) + + if args.enable_tiling: + pipe.vae.enable_tiling() + + if args.enable_slicing: + pipe.vae.enable_slicing() + + parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + initialize_runtime_state(pipe, engine_config) + get_runtime_state().set_video_input_parameters( + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + batch_size=1, + num_inference_steps=input_config.num_inference_steps, + split_text_embed_in_sp=get_pipeline_parallel_world_size() == 1, + ) + parallelize_transformer(pipe) + + # 3. Prepare Model Input + id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer( + face_helper_1, + face_clip_model, + face_helper_2, + eva_transform_mean, + eva_transform_std, + face_main_model, + device, + torch.bfloat16, + input_config.img_file_path, + is_align_face=True, + ) + + # 4. Generate Identity-Preserving Video + if engine_config.runtime_config.use_torch_compile: + torch._inductor.config.reorder_for_compute_comm_overlap = True + pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune-no-cudagraphs") + + # one step to warmup the torch compiler + output = pipe( + image=image, + prompt=input_config.prompt[0], + id_vit_hidden=id_vit_hidden, + id_cond=id_cond, + kps_cond=face_kps, + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + num_inference_steps=1, + generator=torch.Generator(device="cuda").manual_seed(input_config.seed), + guidance_scale=6.0, + use_dynamic_cfg=False, + ).frames[0] + + torch.cuda.reset_peak_memory_stats() + start_time = time.time() + + output = pipe( + image=image, + prompt=input_config.prompt[0], + id_vit_hidden=id_vit_hidden, + id_cond=id_cond, + kps_cond=face_kps, + height=input_config.height, + width=input_config.width, + num_frames=input_config.num_frames, + num_inference_steps=input_config.num_inference_steps, + generator=torch.Generator(device="cuda").manual_seed(input_config.seed), + guidance_scale=6.0, + use_dynamic_cfg=False, + ).frames[0] + + end_time = time.time() + elapsed_time = end_time - start_time + peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") + + parallel_info = ( + f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_" + f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_" + f"tp{engine_args.tensor_parallel_degree}_" + f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}" + ) + if is_dp_last_group(): + resolution = f"{input_config.width}x{input_config.height}" + output_filename = f"results/consisid_{parallel_info}_{resolution}.mp4" + export_to_video(output, output_filename, fps=8) + print(f"output saved to {output_filename}") + + if get_world_group().rank == get_world_group().world_size - 1: + print(f"epoch time: {elapsed_time:.2f} sec, parameter memory: {parameter_peak_memory/1e9:.2f} GB, memory: {peak_memory/1e9} GB") + get_runtime_state().destory_distributed_env() + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/parallel_inference/performance.md b/parallel_inference/performance.md new file mode 100644 index 0000000..41cdd01 --- /dev/null +++ b/parallel_inference/performance.md @@ -0,0 +1,29 @@ +## ConsisID Performance Report + +[ConsisID](https://github.com/PKU-YuanGroup/ConsisID) is an identity-preserving text-to-video generation model that keeps the face consistent in the generated video by frequency decomposition.xDiT currently integrates USP techniques, including Ulysses Attention, Ring Attention, and CFG parallelization, to enhance inference speed, while work on PipeFusion is ongoing. We conducted an in-depth analysis comparing single-GPU ConsisID inference, based on the diffusers library, with our proposed parallelized version for generating 49 frames (6 seconds) of 720x480 resolution video. By flexibly combining different parallelization methods, we achieved varying performance outcomes. In this study, we systematically evaluate xDiT's acceleration performance across 1 to 6 Nvidia H100 GPUs. + +As shown in the table, the ConsisID model achieves a significant reduction in inference latency with Ulysses Attention, Ring Attention, or Classifier-Free Guidance (CFG) parallelization. Notably, CFG parallelization outperforms the other two techniques due to its lower communication overhead. By combining sequence parallelization and CFG parallelization, inference efficiency was further improved. With increased parallelism, inference latency continued to decrease. Under the optimal configuration, xDiT achieved a 3.21× speedup over single-GPU inference, reducing iteration time to just 0.72 seconds. For the default 50 iterations of ConsisID, this enables end-to-end generation of 49 frames in 35 seconds, with a GPU memory usage of 40 GB. + +### 720x480 Resolution (49 frames, 50 steps) + + +| N-GPUs | Ulysses Degree | Ring Degree | Cfg Parallel | Times | +| :----: | :------------: | :---------: | :----------: | :-----: | +| 6 | 2 | 3 | 1 | 44.89s | +| 6 | 3 | 2 | 1 | 44.24s | +| 6 | 1 | 3 | 2 | 35.78s | +| 6 | 3 | 1 | 2 | 38.35s | +| 4 | 2 | 1 | 2 | 41.37s | +| 4 | 1 | 2 | 2 | 40.68s | +| 3 | 3 | 1 | 1 | 53.57s | +| 3 | 1 | 3 | 1 | 55.51s | +| 2 | 1 | 2 | 1 | 70.19s | +| 2 | 2 | 1 | 1 | 76.56s | +| 2 | 1 | 1 | 2 | 59.72s | +| 1 | 1 | 1 | 1 | 114.87s | + +## Resources + +Learn more about ConsisID with the following resources. +- A [video](https://www.youtube.com/watch?v=PhlgC-bI5SQ) demonstrating ConsisID's main features. +- The research paper, [Identity-Preserving Text-to-Video Generation by Frequency Decomposition](https://hf.co/papers/2411.17440) for more details. diff --git a/parallel_inference/performance_zh.md b/parallel_inference/performance_zh.md new file mode 100644 index 0000000..a82c78e --- /dev/null +++ b/parallel_inference/performance_zh.md @@ -0,0 +1,23 @@ +## ConsisID Performance Report + +[ConsisID](https://github.com/PKU-YuanGroup/ConsisID) 是一种身份保持的文本到视频生成模型,其通过频率分解在生成的视频中保持面部一致性。xDiT 目前整合了 USP 技术(包括 Ulysses 注意力和 Ring 注意力)和 CFG 并行来提高推理速度,同时 PipeFusion 的工作正在进行中。我们对基于 diffusers 库的单 GPU ConsisID 推理与我们提出的并行化版本在生成 49帧(6秒)720x480 分辨率视频时的性能差异进行了深入分析。由于我们可以任意组合不同的并行方式以获得不同的性能。在本文中,我们对xDiT在1-6张H100(Nvidia)GPU上的加速性能进行了系统测试。 + +如表所示,对于模型ConsisID,无论是采用 Ulysses Attention、Ring Attention 还是 Classifier-Free Guidance(CFG)并行,均观察到推理延迟的显著降低。值得注意的是,由于其较低的通信开销,CFG 并行方法在性能上优于其他两种技术。通过结合序列并行和 CFG 并行,我们成功提升了推理效率。随着并行度的增加,推理延迟持续下降。在最优配置下,xDiT 相对于单GPU推理实现了 3.21 倍的加速,使得每次迭代仅需 0.72 秒。鉴于 ConsisID 默认的 50 次迭代,总计 35 秒即可完成 49帧 视频的端到端生成,并且运行过程中占用GPU显存40G。 + +### 720x480 Resolution (49 frames, 50 steps) + + +| N-GPUs | ulysses_degree | ring_degree | cfg-parallel | times | +|:------:|:--------------:|:-----------:|:------------:|:---------:| +| 6 | 2 | 3 | 1 | 44.89s | +| 6 | 3 | 2 | 1 | 44.24s | +| 6 | 1 | 3 | 2 | 35.78s | +| 6 | 3 | 1 | 2 | 38.35s | +| 4 | 2 | 1 | 2 | 41.37s | +| 4 | 1 | 2 | 2 | 40.68s | +| 3 | 3 | 1 | 1 | 53.57s | +| 3 | 1 | 3 | 1 | 55.51s | +| 2 | 1 | 2 | 1 | 70.19s | +| 2 | 2 | 1 | 1 | 76.56s | +| 2 | 1 | 1 | 2 | 59.72s | +| 1 | 1 | 1 | 1 | 114.87s | diff --git a/parallel_inference/run.sh b/parallel_inference/run.sh new file mode 100644 index 0000000..01ccf86 --- /dev/null +++ b/parallel_inference/run.sh @@ -0,0 +1,42 @@ +#!/bin/bash +pip install xfuser +set -x + +export PYTHONPATH=$PWD:$PYTHONPATH + +# ConsisID configuration +SCRIPT="parallel_inference_xdit.py" +MODEL_ID="BestWishYsh/ConsisID-preview" +INFERENCE_STEP=50 + +mkdir -p ./results + +# ConsisID specific task args +TASK_ARGS="--height 480 --width 720 --num_frames 49" + +# ConsisID parallel configuration +N_GPUS=6 +PARALLEL_ARGS="--ulysses_degree 2 --ring_degree 3" +# CFG_ARGS="--use_cfg_parallel" + +# Uncomment and modify these as needed +# PIPEFUSION_ARGS="--num_pipeline_patch 8" +# OUTPUT_ARGS="--output_type latent" +# PARALLLEL_VAE="--use_parallel_vae" +# ENABLE_TILING="--enable_tiling" +# COMPILE_FLAG="--use_torch_compile" + +torchrun --master_port=1234 --nproc_per_node=$N_GPUS ./examples/$SCRIPT \ +--model $MODEL_ID \ +$PARALLEL_ARGS \ +$TASK_ARGS \ +$PIPEFUSION_ARGS \ +$OUTPUT_ARGS \ +--num_inference_steps $INFERENCE_STEP \ +--warmup_steps 0 \ +--prompt "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel." \ +--img_file_path "https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true" \ +$CFG_ARGS \ +$PARALLLEL_VAE \ +$ENABLE_TILING \ +$COMPILE_FLAG \ No newline at end of file diff --git a/parallel_inference/run_usp.sh b/parallel_inference/run_usp.sh new file mode 100644 index 0000000..02ded85 --- /dev/null +++ b/parallel_inference/run_usp.sh @@ -0,0 +1,42 @@ +#!/bin/bash +pip install xfuser +set -x + +export PYTHONPATH=$PWD:$PYTHONPATH + +# ConsisID configuration +SCRIPT="parallel_inference_xdit_usp.py" +MODEL_ID="BestWishYsh/ConsisID-preview" +INFERENCE_STEP=50 + +mkdir -p ./results + +# ConsisID specific task args +TASK_ARGS="--height 480 --width 720 --num_frames 49" + +# ConsisID parallel configuration +N_GPUS=6 +PARALLEL_ARGS="--ulysses_degree 2 --ring_degree 3" +# CFG_ARGS="--use_cfg_parallel" + +# Uncomment and modify these as needed +# PIPEFUSION_ARGS="--num_pipeline_patch 8" +# OUTPUT_ARGS="--output_type latent" +# PARALLLEL_VAE="--use_parallel_vae" +# ENABLE_TILING="--enable_tiling" +# COMPILE_FLAG="--use_torch_compile" + +torchrun --master_port=1234 --nproc_per_node=$N_GPUS ./examples/$SCRIPT \ +--model $MODEL_ID \ +$PARALLEL_ARGS \ +$TASK_ARGS \ +$PIPEFUSION_ARGS \ +$OUTPUT_ARGS \ +--num_inference_steps $INFERENCE_STEP \ +--warmup_steps 0 \ +--prompt "The video captures a boy walking along a city street, filmed in black and white on a classic 35mm camera. His expression is thoughtful, his brow slightly furrowed as if he's lost in contemplation. The film grain adds a textured, timeless quality to the image, evoking a sense of nostalgia. Around him, the cityscape is filled with vintage buildings, cobblestone sidewalks, and softly blurred figures passing by, their outlines faint and indistinct. Streetlights cast a gentle glow, while shadows play across the boy's path, adding depth to the scene. The lighting highlights the boy's subtle smile, hinting at a fleeting moment of curiosity. The overall cinematic atmosphere, complete with classic film still aesthetics and dramatic contrasts, gives the scene an evocative and introspective feel." \ +--img_file_path "https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/example_images/2.png?raw=true" \ +$CFG_ARGS \ +$PARALLLEL_VAE \ +$ENABLE_TILING \ +$COMPILE_FLAG \ No newline at end of file