Skip to content

Commit

Permalink
Merge pull request PKU-YuanGroup#29 from PKU-YuanGroup/on_going
Browse files Browse the repository at this point in the history
add parallel inference
  • Loading branch information
SHYuanBest authored Dec 24, 2024
2 parents 838cb77 + c4ae88c commit ad672b9
Show file tree
Hide file tree
Showing 7 changed files with 518 additions and 5 deletions.
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,16 @@ This repository is the official implementation of ConsisID, a tuning-free DiT-ba
## 📣 News

* ⏳⏳⏳ Release the full code & datasets & weights.
* `[2024.12.22]` 🔥ConsisID will be merged into [diffusers](https://github.com/huggingface/diffusers) in the next version. So for now, please use `pip install git+https://github.com/SHYuanBest/ConsisID_diffusers.git` to install diffusers dev version. And we have reorganized the code and weight configs, so it's better to update your local files if you have cloned them previously.
* `[2024.12.24]` 🚀 We release the [parallel inference code](https://github.com/PKU-YuanGroup/ConsisID/tree/main/parallel_inference) for ConsisID powered by [xDiT](https://github.com/xdit-project/xDiT). Thanks [@feifeibear](https://github.com/feifeibear) for his help.
* `[2024.12.22]` 🤗 ConsisID will be merged into [diffusers](https://github.com/huggingface/diffusers) in the next version. So for now, please use `pip install git+https://github.com/SHYuanBest/ConsisID_diffusers.git` to install diffusers dev version. And we have reorganized the code and weight configs, so it's better to update your local files if you have cloned them previously.
* `[2024.12.09]` 🔥We release the [test set](https://huggingface.co/datasets/BestWishYsh/ConsisID-preview-Data/tree/main/eval) and [metric calculation code](https://github.com/PKU-YuanGroup/ConsisID/tree/main/eval) used in the paper, now your can measure the metrics on your own machine. Please refer to [this guide](https://github.com/PKU-YuanGroup/ConsisID/tree/main/eval) for more details.
* `[2024.12.08]` 🔥The code for <u>data preprocessing</u> is out, which is used to obtain the [training data](https://huggingface.co/datasets/BestWishYsh/ConsisID-preview-Data) required by ConsisID. Please refer to [this guide](https://github.com/PKU-YuanGroup/ConsisID/tree/main/data_preprocess) for more details.
* `[2024.12.04]` Thanks [@shizi](https://www.bilibili.com/video/BV1v3iUY4EeQ/?vd_source=ae3f2652765c02e41cdd698b311989e3) for providing [🤗Windows-ConsisID](https://huggingface.co/pkuhexianyi/ConsisID-Windows/tree/main) and [🟣Windows-ConsisID](https://www.wisemodel.cn/models/PkuHexianyi/ConsisID-Windows/file), which make it easy to run ConsisID on Windows.
* `[2024.12.01]` 🔥 We provide full text prompts corresponding to all the videos on project page. Click [here](https://github.com/PKU-YuanGroup/ConsisID/blob/main/asserts/prompt.xlsx) to get and try the demo.
* `[2024.11.30]` 🔥 We have fixed the [huggingface demo](https://huggingface.co/spaces/BestWishYsh/ConsisID-preview-Space), welcome to try it.
* `[2024.11.29]` 🔥 The current code and weights are our early versions, and the differences with the latest version in [arxiv](https://github.com/PKU-YuanGroup/ConsisID) can be viewed [here](https://github.com/PKU-YuanGroup/ConsisID/tree/main/util/on_going_module). And we will release the full code in the next few days.
* `[2024.11.30]` 🤗 We have fixed the [huggingface demo](https://huggingface.co/spaces/BestWishYsh/ConsisID-preview-Space), welcome to try it.
* `[2024.11.29]` 🏃‍♂️ The current code and weights are our early versions, and the differences with the latest version in [arxiv](https://github.com/PKU-YuanGroup/ConsisID) can be viewed [here](https://github.com/PKU-YuanGroup/ConsisID/tree/main/util/on_going_module). And we will release the full code in the next few days.
* `[2024.11.28]` Thanks [@camenduru](https://twitter.com/camenduru) for providing [Jupyter Notebook](https://colab.research.google.com/github/camenduru/ConsisID-jupyter/blob/main/ConsisID_jupyter.ipynb) and [@Kijai](https://github.com/kijai) for providing ComfyUI Extension [ComfyUI-ConsisIDWrapper](https://github.com/kijai/ComfyUI-CogVideoXWrapper). If you find related work, please let us know.
* `[2024.11.27]` 🔥 Due to policy restrictions, we only open-source part of the dataset. You can download it by clicking [here](https://huggingface.co/datasets/BestWishYsh/ConsisID-preview-Data). And we will release the data processing code in the next few days.
* `[2024.11.27]` 🏃‍♂️ Due to policy restrictions, we only open-source part of the dataset. You can download it by clicking [here](https://huggingface.co/datasets/BestWishYsh/ConsisID-preview-Data). And we will release the data processing code in the next few days.
* `[2024.11.26]` 🔥 We release the arXiv paper for ConsisID, and you can click [here](https://arxiv.org/abs/2411.17440) to see more details.
* `[2024.11.22]` 🔥 **All code & datasets** are coming soon! Stay tuned 👀!

Expand Down Expand Up @@ -279,7 +280,8 @@ We found some plugins created by community developers. Thanks for their efforts:
- ComfyUI Extension. [ComfyUI-ConsisIDWrapper](https://github.com/kijai/ComfyUI-CogVideoXWrapper) (by [@Kijai](https://github.com/kijai)).
- Jupyter Notebook. [Jupyter-ConsisID](https://colab.research.google.com/github/camenduru/ConsisID-jupyter/blob/main/ConsisID_jupyter.ipynb) (by [@camenduru](https://github.com/camenduru/consisid-tost)).
- Windows Docker. [🤗Windows-ConsisID](https://huggingface.co/pkuhexianyi/ConsisID-Windows/tree/main) and [🟣Windows-ConsisID](https://www.wisemodel.cn/models/PkuHexianyi/ConsisID-Windows/file) (by [@shizi](https://www.bilibili.com/video/BV1v3iUY4EeQ/?vd_source=ae3f2652765c02e41cdd698b311989e3)).
- xDiT. [xDit-ConsisID](https://github.com/xdit-project/xDiT) (by [pkuhxy](https://github.com/pkuhxy) and [feifeibear](https://github.com/feifeibear)).
- Diffusres. [Diffusers-ConsisID](https://github.com/huggingface/diffusers) (thanks [@arrow](https://github.com/a-r-r-o-w), [@yiyixuxu](https://github.com/yiyixuxu), [@hlky](https://github.com/hlky) and [@stevhliu](https://github.com/stevhliu) for their help).
- xDiT. [xDit-ConsisID](https://github.com/xdit-project/xDiT) (thanks [@feifeibear](https://github.com/feifeibear) for his help).
If you find related work, please let us know.
Expand Down
119 changes: 119 additions & 0 deletions parallel_inference/parallel_inference_xdit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import logging
import os
import time
import torch
import torch.distributed

from diffusers.pipelines.consisid.consisid_utils import prepare_face_models, process_face_embeddings_infer
from diffusers.utils import export_to_video
from huggingface_hub import snapshot_download

from xfuser import xFuserConsisIDPipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
get_runtime_state,
is_dp_last_group,
)


def main():
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)

engine_config, input_config = engine_args.create_config()
local_rank = get_world_group().local_rank

assert engine_args.pipefusion_parallel_degree == 1, "This script does not support PipeFusion."
assert engine_args.use_parallel_vae is False, "parallel VAE not implemented for ConsisID"

# 1. Prepare all the Checkpoints
if not os.path.exists(engine_config.model_config.model):
print("Base Model not found, downloading from Hugging Face...")
snapshot_download(repo_id="BestWishYsh/ConsisID-preview", local_dir=engine_config.model_config.model)
else:
print(f"Base Model already exists in {engine_config.model_config.model}, skipping download.")

# 2. Load Pipeline
device = torch.device(f"cuda:{local_rank}")
pipe = xFuserConsisIDPipeline.from_pretrained(
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.bfloat16,
)
if args.enable_sequential_cpu_offload:
pipe.enable_sequential_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} sequential CPU offload enabled")
elif args.enable_model_cpu_offload:
pipe.enable_model_cpu_offload(gpu_id=local_rank)
logging.info(f"rank {local_rank} model CPU offload enabled")
else:
pipe = pipe.to(device)

face_helper_1, face_helper_2, face_clip_model, face_main_model, eva_transform_mean, eva_transform_std = (
prepare_face_models(engine_config.model_config.model, device=device, dtype=torch.bfloat16)
)

if args.enable_tiling:
pipe.vae.enable_tiling()

if args.enable_slicing:
pipe.vae.enable_slicing()

# 3. Prepare Model Input
id_cond, id_vit_hidden, image, face_kps = process_face_embeddings_infer(
face_helper_1,
face_clip_model,
face_helper_2,
eva_transform_mean,
eva_transform_std,
face_main_model,
device,
torch.bfloat16,
input_config.img_file_path,
is_align_face=True,
)

# 4. Generate Identity-Preserving Video
torch.cuda.reset_peak_memory_stats()
start_time = time.time()

output = pipe(
image=image,
prompt=input_config.prompt[0],
id_vit_hidden=id_vit_hidden,
id_cond=id_cond,
kps_cond=face_kps,
height=input_config.height,
width=input_config.width,
num_frames=input_config.num_frames,
num_inference_steps=input_config.num_inference_steps,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
guidance_scale=6.0,
use_dynamic_cfg=False,
).frames[0]

end_time = time.time()
elapsed_time = end_time - start_time
peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}")

parallel_info = (
f"dp{engine_args.data_parallel_degree}_cfg{engine_config.parallel_config.cfg_degree}_"
f"ulysses{engine_args.ulysses_degree}_ring{engine_args.ring_degree}_"
f"tp{engine_args.tensor_parallel_degree}_"
f"pp{engine_args.pipefusion_parallel_degree}_patch{engine_args.num_pipeline_patch}"
)
if is_dp_last_group():
resolution = f"{input_config.width}x{input_config.height}"
output_filename = f"results/consisid_{parallel_info}_{resolution}.mp4"
export_to_video(output, output_filename, fps=8)
print(f"output saved to {output_filename}")

if get_world_group().rank == get_world_group().world_size - 1:
print(f"epoch time: {elapsed_time:.2f} sec, memory: {peak_memory/1e9} GB")
get_runtime_state().destory_distributed_env()


if __name__ == "__main__":
main()
Loading

0 comments on commit ad672b9

Please sign in to comment.