-
Notifications
You must be signed in to change notification settings - Fork 107
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[ray] launch multiple GPU with ray (#396)
- Loading branch information
1 parent
2c36a27
commit f58302a
Showing
14 changed files
with
815 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
import time | ||
import os | ||
import torch | ||
import torch.distributed | ||
from transformers import T5EncoderModel | ||
from xfuser import xFuserArgs | ||
from xfuser.ray.pipeline.pipeline_utils import RayDiffusionPipeline | ||
from xfuser.config import FlexibleArgumentParser | ||
from xfuser.model_executor.pipelines import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline | ||
|
||
def main(): | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
os.environ["MASTER_PORT"] = "12355" | ||
parser = FlexibleArgumentParser(description="xFuser Arguments") | ||
args = xFuserArgs.add_cli_args(parser).parse_args() | ||
engine_args = xFuserArgs.from_cli_args(args) | ||
engine_config, input_config = engine_args.create_config() | ||
engine_config.runtime_config.dtype = torch.bfloat16 | ||
model_name = engine_config.model_config.model.split("/")[-1] | ||
PipelineClass = xFuserFluxPipeline | ||
text_encoder_2 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_2", torch_dtype=torch.bfloat16) | ||
if args.use_fp8_t5_encoder: | ||
from optimum.quanto import freeze, qfloat8, quantize | ||
quantize(text_encoder_2, weights=qfloat8) | ||
freeze(text_encoder_2) | ||
|
||
pipe = RayDiffusionPipeline.from_pretrained( | ||
PipelineClass=PipelineClass, | ||
pretrained_model_name_or_path=engine_config.model_config.model, | ||
engine_config=engine_config, | ||
torch_dtype=torch.bfloat16, | ||
text_encoder_2=text_encoder_2, | ||
) | ||
pipe.prepare_run(input_config) | ||
|
||
start_time = time.time() | ||
output = pipe( | ||
height=input_config.height, | ||
width=input_config.width, | ||
prompt=input_config.prompt, | ||
num_inference_steps=input_config.num_inference_steps, | ||
output_type=input_config.output_type, | ||
max_sequence_length=256, | ||
guidance_scale=0.0, | ||
generator=torch.Generator(device="cuda").manual_seed(input_config.seed), | ||
) | ||
end_time = time.time() | ||
elapsed_time = end_time - start_time | ||
|
||
print(f"elapsed time:{elapsed_time}") | ||
if not os.path.exists("results"): | ||
os.mkdir("results") | ||
# output is a list of results from each worker, we take the last one | ||
for i, image in enumerate(output[-1].images): | ||
image.save( | ||
f"./results/{model_name}_result_{i}.png" | ||
) | ||
print( | ||
f"image {i} saved to ./results/{model_name}_result_{i}.png" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
set -x | ||
# If using a Ray cluster across multiple machines, you need to manually start a Ray cluster like this: | ||
# ray start --head --port=6379 for master node | ||
# ray start --address='192.168.1.1:6379' for worker node | ||
# otherwise, it is not necessary. (for single node) | ||
|
||
export PYTHONPATH=$PWD:$PYTHONPATH | ||
|
||
# Select the model type | ||
export MODEL_TYPE="Flux" | ||
# Configuration for different model types | ||
# script, model_id, inference_step | ||
declare -A MODEL_CONFIGS=( | ||
["Sd3"]="ray_sd3_example.py /cfs/dit/stable-diffusion-3-medium-diffusers 20" | ||
["Flux"]="ray_flux_example.py /cfs/dit/FLUX.1-dev 28" | ||
) | ||
|
||
if [[ -v MODEL_CONFIGS[$MODEL_TYPE] ]]; then | ||
IFS=' ' read -r SCRIPT MODEL_ID INFERENCE_STEP <<< "${MODEL_CONFIGS[$MODEL_TYPE]}" | ||
export SCRIPT MODEL_ID INFERENCE_STEP | ||
else | ||
echo "Invalid MODEL_TYPE: $MODEL_TYPE" | ||
exit 1 | ||
fi | ||
|
||
mkdir -p ./results | ||
|
||
# task args | ||
TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning" | ||
|
||
|
||
N_GPUS=2 | ||
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1" | ||
|
||
# CFG_ARGS="--use_cfg_parallel" | ||
|
||
# By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance. | ||
# PIPEFUSION_ARGS="--num_pipeline_patch 8 " | ||
|
||
# For high-resolution images, we use the latent output type to avoid runing the vae module. Used for measuring speed. | ||
# OUTPUT_ARGS="--output_type latent" | ||
|
||
# PARALLLEL_VAE="--use_parallel_vae" | ||
|
||
# Another compile option is `--use_onediff` which will use onediff's compiler. | ||
# COMPILE_FLAG="--use_torch_compile" | ||
|
||
|
||
# Use this flag to quantize the T5 text encoder, which could reduce the memory usage and have no effect on the result quality. | ||
# QUANTIZE_FLAG="--use_fp8_t5_encoder" | ||
|
||
export CUDA_VISIBLE_DEVICES=0,1 | ||
|
||
python ./examples/ray/$SCRIPT \ | ||
--model $MODEL_ID \ | ||
$PARALLEL_ARGS \ | ||
$TASK_ARGS \ | ||
$PIPEFUSION_ARGS \ | ||
$OUTPUT_ARGS \ | ||
--num_inference_steps $INFERENCE_STEP \ | ||
--warmup_steps 1 \ | ||
--prompt "brown dog laying on the ground with a metal bowl in front of him." \ | ||
--use_ray \ | ||
--ray_world_size $N_GPUS \ | ||
$CFG_ARGS \ | ||
$PARALLLEL_VAE \ | ||
$COMPILE_FLAG \ | ||
$QUANTIZE_FLAG \ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
import time | ||
import os | ||
import torch | ||
import torch.distributed | ||
from transformers import T5EncoderModel | ||
from xfuser import xFuserArgs | ||
from xfuser.ray.pipeline.pipeline_utils import RayDiffusionPipeline | ||
from xfuser.config import FlexibleArgumentParser | ||
from xfuser.model_executor.pipelines import xFuserPixArtAlphaPipeline, xFuserPixArtSigmaPipeline, xFuserStableDiffusion3Pipeline, xFuserHunyuanDiTPipeline, xFuserFluxPipeline | ||
import time | ||
import os | ||
import torch | ||
import torch.distributed | ||
from transformers import T5EncoderModel | ||
from xfuser import xFuserStableDiffusion3Pipeline, xFuserArgs | ||
from xfuser.config import FlexibleArgumentParser | ||
from xfuser.core.distributed import ( | ||
get_world_group, | ||
is_dp_last_group, | ||
get_data_parallel_rank, | ||
get_runtime_state, | ||
) | ||
from xfuser.core.distributed.parallel_state import get_data_parallel_world_size | ||
|
||
|
||
def main(): | ||
os.environ["MASTER_ADDR"] = "localhost" | ||
os.environ["MASTER_PORT"] = "12355" | ||
parser = FlexibleArgumentParser(description="xFuser Arguments") | ||
args = xFuserArgs.add_cli_args(parser).parse_args() | ||
engine_args = xFuserArgs.from_cli_args(args) | ||
engine_config, input_config = engine_args.create_config() | ||
model_name = engine_config.model_config.model.split("/")[-1] | ||
PipelineClass = xFuserStableDiffusion3Pipeline | ||
text_encoder_3 = T5EncoderModel.from_pretrained(engine_config.model_config.model, subfolder="text_encoder_3", torch_dtype=torch.float16) | ||
if args.use_fp8_t5_encoder: | ||
from optimum.quanto import freeze, qfloat8, quantize | ||
print(f"rank {local_rank} quantizing text encoder 2") | ||
quantize(text_encoder_3, weights=qfloat8) | ||
freeze(text_encoder_3) | ||
|
||
pipe = RayDiffusionPipeline.from_pretrained( | ||
PipelineClass=PipelineClass, | ||
pretrained_model_name_or_path=engine_config.model_config.model, | ||
engine_config=engine_config, | ||
torch_dtype=torch.float16, | ||
text_encoder_3=text_encoder_3, | ||
) | ||
pipe.prepare_run(input_config) | ||
|
||
torch.cuda.reset_peak_memory_stats() | ||
start_time = time.time() | ||
output = pipe( | ||
height=input_config.height, | ||
width=input_config.width, | ||
prompt=input_config.prompt, | ||
num_inference_steps=input_config.num_inference_steps, | ||
output_type=input_config.output_type, | ||
generator=torch.Generator(device="cuda").manual_seed(input_config.seed), | ||
) | ||
end_time = time.time() | ||
elapsed_time = end_time - start_time | ||
print(f"elapsed time:{elapsed_time}") | ||
if not os.path.exists("results"): | ||
os.mkdir("results") | ||
# output is a list of results from each worker, we take the last one | ||
for i, image in enumerate(output[-1].images): | ||
image.save( | ||
f"./results/{model_name}_result_{i}.png" | ||
) | ||
print( | ||
f"image {i} saved to ./results/{model_name}_result_{i}.png" | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.