Skip to content

Commit

Permalink
make hunyuan video work with more resolutions and update the performa…
Browse files Browse the repository at this point in the history
…nce table
  • Loading branch information
chengzeyi committed Dec 23, 2024
1 parent f91943e commit bc18d27
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 15 deletions.
3 changes: 2 additions & 1 deletion docs/performance/hunyuanvideo.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ xDiT is [HunyuanVideo](https://github.com/Tencent/HunyuanVideo?tab=readme-ov-fil
|----------|--------|---------|---------|---------|
| H100 | 1,904.08 | 925.04 | 514.08 | 337.58 |
| H20 | 6,639.17 | 3,400.55 | 1,762.86 | 940.97 |
| L20 | 6,043.88 | | | |
| L20 | 6,043.88 | 3,271.44 | 2,080.05 | |

</center>

Expand All @@ -22,5 +22,6 @@ xDiT is [HunyuanVideo](https://github.com/Tencent/HunyuanVideo?tab=readme-ov-fil
|----------|--------|---------|---------|---------|
| H100 | 1,735.01 | 934.09 | 645.45 | 367.02 |
| H20 | 6,621.46 | 3,400.55 | 2,310.48 | 1,214.67 |
| L20 | 6,039.08 | 3,260.62 | 2,070.96 | |

</center>
42 changes: 29 additions & 13 deletions examples/hunyuan_video_usp_example.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# from https://github.com/chengzeyi/ParaAttention/blob/main/examples/run_hunyuan_video.py
import functools
from typing import Any, Dict, Union
from typing import Any, Dict, Union, Optional
import logging
import time

import torch

from diffusers import DiffusionPipeline, HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.models.modeling_outputs import Transformer2DModelOutput
from diffusers.utils import scale_lora_layers, unscale_lora_layers, USE_PEFT_BACKEND
from diffusers.utils import export_to_video

from xfuser import xFuserArgs
Expand Down Expand Up @@ -45,8 +46,22 @@ def new_forward(
encoder_attention_mask: torch.Tensor,
pooled_projections: torch.Tensor,
guidance: torch.Tensor = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0

if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
logging.warning("Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective.")

batch_size, num_channels, num_frames, height, width = hidden_states.shape

assert batch_size % get_classifier_free_guidance_world_size(
Expand All @@ -68,13 +83,14 @@ def new_forward(
encoder_attention_mask)

hidden_states = hidden_states.reshape(batch_size, post_patch_num_frames, post_patch_height, post_patch_width, -1)
hidden_states = hidden_states.flatten(1, 3)

hidden_states = torch.chunk(hidden_states,
get_classifier_free_guidance_world_size(),
dim=0)[get_classifier_free_guidance_rank()]
hidden_states = torch.chunk(hidden_states,
get_sequence_parallel_world_size(),
dim=2)[get_sequence_parallel_rank()]
hidden_states = hidden_states.flatten(1, 3)
dim=-2)[get_sequence_parallel_rank()]

encoder_attention_mask = encoder_attention_mask[0].to(torch.bool)
encoder_hidden_states_indices = torch.arange(
Expand Down Expand Up @@ -103,11 +119,7 @@ def new_forward(
freqs_cos, freqs_sin = image_rotary_emb

def get_rotary_emb_chunk(freqs):
dim_thw = freqs.shape[-1]
freqs = freqs.reshape(num_frames, -1, dim_thw)
freqs = freqs.chunk(get_sequence_parallel_world_size(), dim=-2)[
get_sequence_parallel_rank()]
freqs = freqs.reshape(-1, dim_thw)
freqs = torch.chunk(freqs, get_sequence_parallel_world_size(), dim=0)[get_sequence_parallel_rank()]
return freqs

freqs_cos = get_rotary_emb_chunk(freqs_cos)
Expand Down Expand Up @@ -166,17 +178,21 @@ def custom_forward(*inputs):
hidden_states = self.norm_out(hidden_states, temb)
hidden_states = self.proj_out(hidden_states)

hidden_states = hidden_states.reshape(batch_size // get_classifier_free_guidance_world_size(),
hidden_states = get_sp_group().all_gather(hidden_states, dim=-2)
hidden_states = get_cfg_group().all_gather(hidden_states, dim=0)

hidden_states = hidden_states.reshape(batch_size,
post_patch_num_frames,
post_patch_height // get_sequence_parallel_world_size(),
post_patch_height,
post_patch_width, -1, p_t, p, p)

hidden_states = get_sp_group().all_gather(hidden_states, dim=2)
hidden_states = get_cfg_group().all_gather(hidden_states, dim=0)

hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
hidden_states = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)

if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)

if not return_dict:
return (hidden_states, )

Expand Down
9 changes: 8 additions & 1 deletion xfuser/core/distributed/runtime_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,12 @@ def __init__(self, pipeline: DiffusionPipeline, config: EngineConfig):
pipeline=pipeline, parallel_config=config.parallel_config
)
self.cogvideox = False
self.hunyuan_video = False
if pipeline.__class__.__name__.startswith(("CogVideoX", "HunyuanVideo")):
if pipeline.__class__.__name__.startswith("CogVideoX"):
self.cogvideox = True
else:
self.hunyuan_video = True
self._set_cogvideox_parameters(
vae_scale_factor_spatial=pipeline.vae_scale_factor_spatial,
vae_scale_factor_temporal=pipeline.vae_scale_factor_temporal,
Expand Down Expand Up @@ -194,7 +199,6 @@ def _set_cogvideox_parameters(
self.backbone_patch_size = backbone_patch_size
self.backbone_inner_dim = backbone_inner_dim
self.backbone_in_channel = backbone_in_channel
self.cogvideox = True

def set_patched_mode(self, patch_mode: bool):
self.patch_mode = patch_mode
Expand Down Expand Up @@ -259,6 +263,9 @@ def _video_input_size_change(
self.input_config.batch_size = batch_size or self.input_config.batch_size
if self.cogvideox:
self._calc_cogvideox_patches_metadata()
elif self.hunyuan_video:
# TODO: implement the hunyuan video patches metadata
pass
else:
self._calc_patches_metadata()
self._reset_recv_buffer()
Expand Down

0 comments on commit bc18d27

Please sign in to comment.