Skip to content

Commit

Permalink
Support EMS for VLLMModuleV2(not support cuda graph now).
Browse files Browse the repository at this point in the history
  • Loading branch information
adoda committed Jan 10, 2025
1 parent 452d52c commit edaa0c6
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 7 deletions.
97 changes: 94 additions & 3 deletions chatlearn/models/vllm_module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# ==============================================================================
"""VLLM module v2"""

import gc
import inspect
import os

Expand All @@ -30,7 +31,7 @@
from chatlearn.utils.vllm_import_helper import TextTokensPrompt
from chatlearn.utils.vllm_utils import initialize_vllm
from .torch_module import TorchModule

from .megatron.memory_manager import InferenceMemoryManager

class VLLMModuleV2(TorchModule, RayWorkerWrapper):
"""VLLMModuleV2"""
Expand All @@ -51,6 +52,7 @@ def __init__(self, *args, **kwargs):

self.tokenizer = None
self._model = None
self.llm = None
self.set_vllm_pp_layer_partition()

def add_extra_args(self, parser):
Expand Down Expand Up @@ -80,13 +82,19 @@ def init(self):
args_dict=self.model_args)

def setup(self):
"""Set up model and load checkpoint"""
"""Set up tokenizer."""
super().setup()
tokenizer = AutoTokenizer.from_pretrained(self.model_args['tokenizer'])
tokenizer.tokenizer = tokenizer
self.tokenizer = tokenizer

def model_setup(self):
"""Set up model and enable EMS(Efficient Memory Sharing)"""
super().model_setup()

def setup_vllm(self, workers):
if self.llm is not None: # for evaluator
return
# setup vllm engine in rank 0
os.environ['VLLM_HOST_IP'] = self.get_address()
set_vllm_actors(workers)
Expand Down Expand Up @@ -131,7 +139,18 @@ def setup_vllm(self, workers):
enforce_eager=self.model_args.get("enforce_eager", False),
disable_custom_all_reduce=True,
distributed_executor_backend="ray")
self.tokenizer = self.llm.llm_engine.tokenizer
self.llm.llm_engine.model_executor._run_workers("init_memory_manager")
self.offload_for_workers()
self.empty_cache_for_workers()

def init_memory_manager(self):
if self.module_args.offload_weights:
if InferenceMemoryManager is None:
raise Exception("Import InferenceMemoryManager failed, you may need to set right Megatron path first.")
self._memory_manager = InferenceMemoryManager(
self.model,
self.runtime_args.bucket_size_mb_in_memory_manager,
)

def set_vllm_pp_layer_partition(self):
pipeline_world_size = self.module_args.pipeline_model_parallel_size
Expand Down Expand Up @@ -238,6 +257,7 @@ def _convert_v1_inputs(self, prompts, prompt_token_ids):
return inputs

def generate_vllm(self, query, is_eval):
self.reinit_cache_engine()
prompt_key = self.model_args.get("vllm_prompt_key", "prompt")
input_ids_key = self.model_args.get("vllm_input_ids_key", "input_ids")

Expand Down Expand Up @@ -321,3 +341,74 @@ def pipeline_parallel_rank(self):
:meta private:
"""
return get_pipeline_model_parallel_rank()

def model_setup_for_workers(self):
self.llm.llm_engine.model_executor._run_workers("model_setup")

def offload_for_workers(self,
to_onload_weights=None,
to_build_grad_buffers=None,
to_onload_main_weights=None,
to_onload_optimizer_states=None):
"""
call offload for all workers
"""
self.llm.llm_engine.model_executor._run_workers("offload")

def onload_for_workers(self,
to_onload_weights=None,
to_build_grad_buffers=None,
to_onload_main_weights=None,
to_onload_optimizer_states=None):
"""
call onload for all workers
"""
self.llm.llm_engine.model_executor._run_workers("onload")

def empty_cache_for_workers(self):
"""
call empty cache for all workers
"""
self.llm.llm_engine.model_executor._run_workers("empty_cache")

def offload_weights(self):
"""
offload weights
"""
if self.module_args.offload_weights:
self._memory_manager.offload_weights()

def onload_weights(self):
"""
onload weights
"""
if self.module_args.offload_weights:
self._memory_manager.onload_weights()

def empty_cache(self):
if self.worker.gpu_cache is not None:
for ele in self.worker.gpu_cache: # pylint: disable=unused-variable
ele = None
self.worker.gpu_cache = None # pylint: disable=access-member-before-definition

if hasattr(self.worker, "cache_engine") and self.worker.cache_engine is not None:
for c_e in self.worker.cache_engine:
c_e.cpu_cache = None
c_e.gpu_cache = None
self.worker.cache_engine = None

self.clear_cache()

def clear_cache(self):
if not self.timers("gc").started_:
self.timers("gc").start()
gc.collect()
self.timers("gc").stop()

super().empty_cache()

def reinit_cache_engine(self):
# reinit cache engine
self.llm.llm_engine.model_executor._run_workers("clear_cache")
self.llm.llm_engine._initialize_kv_caches()
self.llm.llm_engine.model_executor._run_workers("clear_cache")
15 changes: 12 additions & 3 deletions chatlearn/runtime/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,10 @@ def get_kwarg(key):
is_eval = get_kwarg('is_eval')

if to_onload:
self.onload()
if isinstance(self, VLLMModuleV2):
self.onload_for_workers()
else:
self.onload()
generation_batch_size = self.module_args.generation_batch_size
final_results = None
if not trainable and generation_batch_size:
Expand Down Expand Up @@ -187,9 +190,15 @@ def get_kwarg(key):
if self.is_last_rank():
final_results = ret
if to_empty_cache:
self.empty_cache()
if isinstance(self, VLLMModuleV2):
self.empty_cache_for_workers()
else:
self.empty_cache()
if to_offload:
self.offload()
if isinstance(self, VLLMModuleV2):
self.offload_for_workers()
else:
self.offload()
if is_last_batch and not is_eval:
self.runtime_args.consumed_samples += self.runtime_args.sample_per_episode
return final_results
Expand Down
6 changes: 6 additions & 0 deletions chatlearn/runtime/dist_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,12 @@ def add_remote_func(self):
continue
if func_name in ["timer_summary"]:
dist_call = partial(self.call_vllm_engine_remote_funcs, func_name)
elif func_name in ["onload", "offload"]:
if func_name == "onload":
new_func_name = "onload_for_workers"
else:
new_func_name = "offload_for_workers"
dist_call = partial(self.call_vllm_engine_remote_funcs, new_func_name)
elif func_name in ["model_setup"]:
dist_call = partial(self.call_vllm_engine_and_workers_remote_funcs, func_name)
else: # needed to check for other call_funs.
Expand Down
6 changes: 6 additions & 0 deletions chatlearn/runtime/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,12 @@ def get_next_data():
kwargs["is_last_batch"] = is_last_batch
if is_eval is not None:
kwargs["is_eval"] = is_eval
if to_empty_cache is not None:
kwargs["to_empty_cache"] = to_empty_cache
if to_onload is not None:
kwargs["to_onload"] = to_onload
if to_offload is not None:
kwargs["to_offload"] = to_offload
mb, query = get_next_data()
assert isinstance(query, list)
ret = replica.call_actor_remote_func(replica.vllm_engine, func_name, *query, **kwargs)
Expand Down
2 changes: 1 addition & 1 deletion chatlearn/schedule/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,9 +336,9 @@ def _get_model_replica_from_pack(gpu_index, model_pack):
gpu_to_replicas.append(colocate_models)

for i, replicas in enumerate(gpu_to_replicas):
num_gpus = 1.0 / len(replicas)
group = i // self.resouce_manager.gpu_per_node
for replica in replicas:
num_gpus = 1.0 / len(replicas)
if isinstance(replica.model, VLLMModuleV2) and replica.vllm_engine is None:
num_gpus = num_gpus / 2
replica.create_engine_actor(num_gpus, placement_group, group)
Expand Down

0 comments on commit edaa0c6

Please sign in to comment.