Skip to content

Commit

Permalink
[ray] launch multiple GPU with ray (#396)
Browse files Browse the repository at this point in the history
  • Loading branch information
lihuahua123 authored Dec 20, 2024
1 parent 2c36a27 commit f58302a
Show file tree
Hide file tree
Showing 14 changed files with 815 additions and 9 deletions.
64 changes: 64 additions & 0 deletions examples/ray/ray_flux_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import time
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserArgs
from xfuser.ray.pipeline.pipeline_utils import RayDiffusionPipeline
from xfuser.config import FlexibleArgumentParser
from xfuser.model_executor.pipelines import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline

def main():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
engine_config.runtime_config.dtype = torch.bfloat16
model_name = engine_config.model_config.model.split("/")[-1]
PipelineClass = xFuserFluxPipeline
text_encoder_2 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_2", torch_dtype=torch.bfloat16)
if args.use_fp8_t5_encoder:
from optimum.quanto import freeze, qfloat8, quantize
quantize(text_encoder_2, weights=qfloat8)
freeze(text_encoder_2)

pipe = RayDiffusionPipeline.from_pretrained(
PipelineClass=PipelineClass,
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.bfloat16,
text_encoder_2=text_encoder_2,
)
pipe.prepare_run(input_config)

start_time = time.time()
output = pipe(
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
output_type=input_config.output_type,
max_sequence_length=256,
guidance_scale=0.0,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
)
end_time = time.time()
elapsed_time = end_time - start_time

print(f"elapsed time:{elapsed_time}")
if not os.path.exists("results"):
os.mkdir("results")
# output is a list of results from each worker, we take the last one
for i, image in enumerate(output[-1].images):
image.save(
f"./results/{model_name}_result_{i}.png"
)
print(
f"image {i} saved to ./results/{model_name}_result_{i}.png"
)


if __name__ == "__main__":
main()
68 changes: 68 additions & 0 deletions examples/ray/ray_run.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
set -x
# If using a Ray cluster across multiple machines, you need to manually start a Ray cluster like this:
# ray start --head --port=6379 for master node
# ray start --address='192.168.1.1:6379' for worker node
# otherwise, it is not necessary. (for single node)

export PYTHONPATH=$PWD:$PYTHONPATH

# Select the model type
export MODEL_TYPE="Flux"
# Configuration for different model types
# script, model_id, inference_step
declare -A MODEL_CONFIGS=(
["Sd3"]="ray_sd3_example.py /cfs/dit/stable-diffusion-3-medium-diffusers 20"
["Flux"]="ray_flux_example.py /cfs/dit/FLUX.1-dev 28"
)

if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then
IFS=' ' read -r SCRIPT MODEL_ID INFERENCE_STEP <<< "${MODEL_CONFIGS[$MODEL_TYPE]}"
export SCRIPT MODEL_ID INFERENCE_STEP
else
echo "Invalid MODEL_TYPE: $MODEL_TYPE"
exit 1
fi

mkdir -p ./results

# task args
TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"


N_GPUS=2
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1"

# CFG_ARGS="--use_cfg_parallel"

# By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance.
# PIPEFUSION_ARGS="--num_pipeline_patch 8 "

# For high-resolution images, we use the latent output type to avoid runing the vae module. Used for measuring speed.
# OUTPUT_ARGS="--output_type latent"

# PARALLLEL_VAE="--use_parallel_vae"

# Another compile option is `--use_onediff` which will use onediff's compiler.
# COMPILE_FLAG="--use_torch_compile"


# Use this flag to quantize the T5 text encoder, which could reduce the memory usage and have no effect on the result quality.
# QUANTIZE_FLAG="--use_fp8_t5_encoder"

export CUDA_VISIBLE_DEVICES=0,1

python ./examples/ray/$SCRIPT \
--model $MODEL_ID \
$PARALLEL_ARGS \
$TASK_ARGS \
$PIPEFUSION_ARGS \
$OUTPUT_ARGS \
--num_inference_steps $INFERENCE_STEP \
--warmup_steps 1 \
--prompt "brown dog laying on the ground with a metal bowl in front of him." \
--use_ray \
--ray_world_size $N_GPUS \
$CFG_ARGS \
$PARALLLEL_VAE \
$COMPILE_FLAG \
$QUANTIZE_FLAG \
77 changes: 77 additions & 0 deletions examples/ray/ray_sd3_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import time
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserArgs
from xfuser.ray.pipeline.pipeline_utils import RayDiffusionPipeline
from xfuser.config import FlexibleArgumentParser
from xfuser.model_executor.pipelines import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline
import time
import os
import torch
import torch.distributed
from transformers import T5EncoderModel
from xfuser import xFuserStableDiffusion3Pipeline, xFuserArgs
from xfuser.config import FlexibleArgumentParser
from xfuser.core.distributed import (
get_world_group,
is_dp_last_group,
get_data_parallel_rank,
get_runtime_state,
)
from xfuser.core.distributed.parallel_state import get_data_parallel_world_size


def main():
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "12355"
parser = FlexibleArgumentParser(description="xFuser Arguments")
args = xFuserArgs.add_cli_args(parser).parse_args()
engine_args = xFuserArgs.from_cli_args(args)
engine_config, input_config = engine_args.create_config()
model_name = engine_config.model_config.model.split("/")[-1]
PipelineClass = xFuserStableDiffusion3Pipeline
text_encoder_3 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_3", torch_dtype=torch.float16)
if args.use_fp8_t5_encoder:
from optimum.quanto import freeze, qfloat8, quantize
print(f"rank {local_rank} quantizing text encoder 2")
quantize(text_encoder_3, weights=qfloat8)
freeze(text_encoder_3)

pipe = RayDiffusionPipeline.from_pretrained(
PipelineClass=PipelineClass,
pretrained_model_name_or_path=engine_config.model_config.model,
engine_config=engine_config,
torch_dtype=torch.float16,
text_encoder_3=text_encoder_3,
)
pipe.prepare_run(input_config)

torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
num_inference_steps=input_config.num_inference_steps,
output_type=input_config.output_type,
generator=torch.Generator(device="cuda").manual_seed(input_config.seed),
)
end_time = time.time()
elapsed_time = end_time - start_time
print(f"elapsed time:{elapsed_time}")
if not os.path.exists("results"):
os.mkdir("results")
# output is a list of results from each worker, we take the last one
for i, image in enumerate(output[-1].images):
image.save(
f"./results/{model_name}_result_{i}.png"
)
print(
f"image {i} saved to ./results/{model_name}_result_{i}.png"
)


if __name__ == "__main__":
main()
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ def get_cuda_version():
"imageio",
"imageio-ffmpeg",
"optimum-quanto",
"flash_attn>=2.6.3" # flash_attn>=2.7.0 with torch>=2.4.0 wraps ops with torch.ops
"flash_attn>=2.6.3",
"ray"
],
extras_require={
"diffusers": [
Expand Down
25 changes: 24 additions & 1 deletion xfuser/config/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ class xFuserArgs:
# tensor parallel
tensor_parallel_degree: int = 1
split_scheme: Optional[str] = "row"
# ray arguments
use_ray: bool = False
ray_world_size: int = 1
# pipefusion parallel
pipefusion_parallel_degree: int = 1
num_pipeline_patch: Optional[int] = None
Expand Down Expand Up @@ -151,6 +154,17 @@ def add_cli_args(parser: FlexibleArgumentParser):

# Parallel arguments
parallel_group = parser.add_argument_group("Parallel Processing Options")
runtime_group.add_argument(
"--use_ray",
action="store_true",
help="Enable ray to run inference in multi-card",
)
parallel_group.add_argument(
"--ray_world_size",
type=int,
default=1,
help="The number of ray workers (world_size for ray)",
)
parallel_group.add_argument(
"--use_cfg_parallel",
action="store_true",
Expand Down Expand Up @@ -322,11 +336,15 @@ def from_cli_args(cls, args: argparse.Namespace):
def create_config(
self,
) -> Tuple[EngineConfig, InputConfig]:
if not torch.distributed.is_initialized():
if not self.use_ray and not torch.distributed.is_initialized():
logger.warning(
"Distributed environment is not initialized. " "Initializing..."
)
init_distributed_environment()
if self.use_ray:
self.world_size = self.ray_world_size
else:
self.world_size = torch.distributed.get_world_size()

model_config = ModelConfig(
model=self.model,
Expand All @@ -348,20 +366,25 @@ def create_config(
dp_config=DataParallelConfig(
dp_degree=self.data_parallel_degree,
use_cfg_parallel=self.use_cfg_parallel,
world_size=self.world_size,
),
sp_config=SequenceParallelConfig(
ulysses_degree=self.ulysses_degree,
ring_degree=self.ring_degree,
world_size=self.world_size,
),
tp_config=TensorParallelConfig(
tp_degree=self.tensor_parallel_degree,
split_scheme=self.split_scheme,
world_size=self.world_size,
),
pp_config=PipeFusionParallelConfig(
pp_degree=self.pipefusion_parallel_degree,
num_pipeline_patch=self.num_pipeline_patch,
attn_layer_num_for_pp=self.attn_layer_num_for_pp,
world_size=self.world_size,
),
world_size=self.world_size,
)

fast_attn_config = FastAttnConfig(
Expand Down
20 changes: 13 additions & 7 deletions xfuser/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def __post_init__(self):
class DataParallelConfig:
dp_degree: int = 1
use_cfg_parallel: bool = False
world_size: int = 1

def __post_init__(self):
assert self.dp_degree >= 1, "dp_degree must greater than or equal to 1"
Expand All @@ -95,19 +96,20 @@ def __post_init__(self):
self.cfg_degree = 2
else:
self.cfg_degree = 1
assert self.dp_degree * self.cfg_degree <= dist.get_world_size(), (
assert self.dp_degree * self.cfg_degree <= self.world_size, (
"dp_degree * cfg_degree must be less than or equal to "
"world_size because of classifier free guidance"
)
assert (
dist.get_world_size() % (self.dp_degree * self.cfg_degree) == 0
self.world_size % (self.dp_degree * self.cfg_degree) == 0
), "world_size must be divisible by dp_degree * cfg_degree"


@dataclass
class SequenceParallelConfig:
ulysses_degree: Optional[int] = None
ring_degree: Optional[int] = None
world_size: int = 1

def __post_init__(self):
if self.ulysses_degree is None:
Expand Down Expand Up @@ -138,11 +140,12 @@ def __post_init__(self):
class TensorParallelConfig:
tp_degree: int = 1
split_scheme: Optional[str] = "row"
world_size: int = 1

def __post_init__(self):
assert self.tp_degree >= 1, "tp_degree must greater than 1"
assert (
self.tp_degree <= dist.get_world_size()
self.tp_degree <= self.world_size
), "tp_degree must be less than or equal to world_size"


Expand All @@ -151,13 +154,14 @@ class PipeFusionParallelConfig:
pp_degree: int = 1
num_pipeline_patch: Optional[int] = None
attn_layer_num_for_pp: Optional[List[int]] = (None,)
world_size: int = 1

def __post_init__(self):
assert (
self.pp_degree is not None and self.pp_degree >= 1
), "pipefusion_degree must be set and greater than 1 to use pipefusion"
assert (
self.pp_degree <= dist.get_world_size()
self.pp_degree <= self.world_size
), "pipefusion_degree must be less than or equal to world_size"
if self.num_pipeline_patch is None:
self.num_pipeline_patch = self.pp_degree
Expand Down Expand Up @@ -188,6 +192,8 @@ class ParallelConfig:
sp_config: SequenceParallelConfig
pp_config: PipeFusionParallelConfig
tp_config: TensorParallelConfig
world_size: int = 1 # FIXME: remove this
worker_cls: str = "xfuser.ray.worker.worker.Worker"

def __post_init__(self):
assert self.tp_config is not None, "tp_config must be set"
Expand All @@ -201,10 +207,10 @@ def __post_init__(self):
* self.tp_config.tp_degree
* self.pp_config.pp_degree
)
world_size = dist.get_world_size()
world_size = self.world_size
assert parallel_world_size == world_size, (
f"parallel_world_size {parallel_world_size} "
f"must be equal to world_size {dist.get_world_size()}"
f"must be equal to world_size {self.world_size}"
)
assert (
world_size % (self.dp_config.dp_degree * self.dp_config.cfg_degree) == 0
Expand Down Expand Up @@ -236,7 +242,7 @@ class EngineConfig:
fast_attn_config: FastAttnConfig

def __post_init__(self):
world_size = dist.get_world_size()
world_size = self.parallel_config.world_size
if self.fast_attn_config.use_fast_attn:
assert self.parallel_config.dp_degree == world_size, f"world_size must be equal to dp_degree when using DiTFastAttn"

Expand Down
Empty file added xfuser/ray/pipeline/__init__.py
Empty file.
Loading

0 comments on commit f58302a

Please sign in to comment.