Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support EMS for VLLMModuleV2(not support cuda graph now). #205

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 91 additions & 3 deletions chatlearn/models/vllm_module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""VLLM module v2"""

import gc
import inspect
import os

Expand All @@ -30,7 +31,7 @@
from chatlearn.utils.vllm_import_helper import TextTokensPrompt
from chatlearn.utils.vllm_utils import initialize_vllm
from .torch_module import TorchModule

from .megatron.memory_manager import InferenceMemoryManager

class VLLMModuleV2(TorchModule, RayWorkerWrapper):
"""VLLMModuleV2"""
Expand All @@ -51,6 +52,7 @@ def __init__(self, *args, **kwargs):

self.tokenizer = None
self._model = None
self.llm = None
self.set_vllm_pp_layer_partition()

def add_extra_args(self, parser):
Expand Down Expand Up @@ -80,13 +82,15 @@ def init(self):
args_dict=self.model_args)

def setup(self):
"""Set up model and load checkpoint"""
"""Set up tokenizer."""
super().setup()
tokenizer = AutoTokenizer.from_pretrained(self.model_args['tokenizer'])
tokenizer.tokenizer = tokenizer
self.tokenizer = tokenizer

def setup_vllm(self, workers):
if self.llm is not None: # for evaluator
return
# setup vllm engine in rank 0
os.environ['VLLM_HOST_IP'] = self.get_address()
set_vllm_actors(workers)
Expand Down Expand Up @@ -131,7 +135,18 @@ def setup_vllm(self, workers):
enforce_eager=self.model_args.get("enforce_eager", False),
disable_custom_all_reduce=True,
distributed_executor_backend="ray")
self.tokenizer = self.llm.llm_engine.tokenizer
self.llm.llm_engine.model_executor._run_workers("init_memory_manager")
self.offload_for_workers()
self.empty_cache_for_workers()

def init_memory_manager(self):
if self.module_args.offload_weights:
if InferenceMemoryManager is None:
raise Exception("Import InferenceMemoryManager failed, you may need to set right Megatron path first.")
self._memory_manager = InferenceMemoryManager(
self.model,
self.runtime_args.bucket_size_mb_in_memory_manager,
)

def set_vllm_pp_layer_partition(self):
pipeline_world_size = self.module_args.pipeline_model_parallel_size
Expand Down Expand Up @@ -238,6 +253,7 @@ def _convert_v1_inputs(self, prompts, prompt_token_ids):
return inputs

def generate_vllm(self, query, is_eval):
self.reinit_cache_engine()
prompt_key = self.model_args.get("vllm_prompt_key", "prompt")
input_ids_key = self.model_args.get("vllm_input_ids_key", "input_ids")

Expand Down Expand Up @@ -321,3 +337,75 @@ def pipeline_parallel_rank(self):
:meta private:
"""
return get_pipeline_model_parallel_rank()

def model_setup_for_workers(self):
self.llm.llm_engine.model_executor._run_workers("model_setup")

# pylint: disable=unused-argument
def offload_for_workers(self,
to_onload_weights=None,
to_build_grad_buffers=None,
to_onload_main_weights=None,
to_onload_optimizer_states=None):
"""
call offload for all workers
"""
self.llm.llm_engine.model_executor._run_workers("offload")

def onload_for_workers(self,
to_onload_weights=None,
to_build_grad_buffers=None,
to_onload_main_weights=None,
to_onload_optimizer_states=None):
"""
call onload for all workers
"""
self.llm.llm_engine.model_executor._run_workers("onload")

def empty_cache_for_workers(self):
"""
call empty cache for all workers
"""
self.llm.llm_engine.model_executor._run_workers("empty_cache")

def offload_weights(self):
"""
offload weights
"""
if self.module_args.offload_weights:
self._memory_manager.offload_weights()

def onload_weights(self):
"""
onload weights
"""
if self.module_args.offload_weights:
self._memory_manager.onload_weights()

def empty_cache(self):
if self.worker.gpu_cache is not None:
for ele in self.worker.gpu_cache: # pylint: disable=unused-variable
ele = None
self.worker.gpu_cache = None # pylint: disable=access-member-before-definition

if hasattr(self.worker, "cache_engine") and self.worker.cache_engine is not None:
for c_e in self.worker.cache_engine:
c_e.cpu_cache = None
c_e.gpu_cache = None
self.worker.cache_engine = None

self.clear_cache()

def clear_cache(self):
if not self.timers("gc").started_:
self.timers("gc").start()
gc.collect()
self.timers("gc").stop()

super().empty_cache()

def reinit_cache_engine(self):
# reinit cache engine
self.llm.llm_engine.model_executor._run_workers("clear_cache")
self.llm.llm_engine._initialize_kv_caches()
self.llm.llm_engine.model_executor._run_workers("clear_cache")
15 changes: 12 additions & 3 deletions chatlearn/runtime/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ def get_kwarg(key):
is_eval = get_kwarg('is_eval')

if to_onload:
self.onload()
if isinstance(self, VLLMModuleV2):
self.onload_for_workers()
else:
self.onload()
generation_batch_size = self.module_args.generation_batch_size
final_results = None
if not trainable and generation_batch_size:
Expand Down Expand Up @@ -187,9 +190,15 @@ def get_kwarg(key):
if self.is_last_rank():
final_results = ret
if to_empty_cache:
self.empty_cache()
if isinstance(self, VLLMModuleV2):
self.empty_cache_for_workers()
else:
self.empty_cache()
if to_offload:
self.offload()
if isinstance(self, VLLMModuleV2):
self.offload_for_workers()
else:
self.offload()
if is_last_batch and not is_eval:
self.runtime_args.consumed_samples += self.runtime_args.sample_per_episode
return final_results
Expand Down
6 changes: 6 additions & 0 deletions chatlearn/runtime/dist_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ def add_remote_func(self):
continue
if func_name in ["timer_summary"]:
dist_call = partial(self.call_vllm_engine_remote_funcs, func_name)
elif func_name in ["onload", "offload"]:
if func_name == "onload":
new_func_name = "onload_for_workers"
else:
new_func_name = "offload_for_workers"
dist_call = partial(self.call_vllm_engine_remote_funcs, new_func_name)
elif func_name in ["model_setup"]:
dist_call = partial(self.call_vllm_engine_and_workers_remote_funcs, func_name)
else: # needed to check for other call_funs.
Expand Down
6 changes: 6 additions & 0 deletions chatlearn/runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ def get_next_data():
kwargs["is_last_batch"] = is_last_batch
if is_eval is not None:
kwargs["is_eval"] = is_eval
if to_empty_cache is not None:
kwargs["to_empty_cache"] = to_empty_cache
if to_onload is not None:
kwargs["to_onload"] = to_onload
if to_offload is not None:
kwargs["to_offload"] = to_offload
mb, query = get_next_data()
assert isinstance(query, list)
ret = replica.call_actor_remote_func(replica.vllm_engine, func_name, *query, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion chatlearn/schedule/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,9 @@ def _get_model_replica_from_pack(gpu_index, model_pack):
gpu_to_replicas.append(colocate_models)

for i, replicas in enumerate(gpu_to_replicas):
num_gpus = 1.0 / len(replicas)
group = i // self.resouce_manager.gpu_per_node
for replica in replicas:
num_gpus = 1.0 / len(replicas)
if isinstance(replica.model, VLLMModuleV2) and replica.vllm_engine is None:
num_gpus = num_gpus / 2
replica.create_engine_actor(num_gpus, placement_group, group)
Expand Down
Loading