diff --git a/chatlearn/models/vllm/hooks/__init__.py b/chatlearn/models/vllm/hooks/__init__.py index 88872939..2a4036bc 100644 --- a/chatlearn/models/vllm/hooks/__init__.py +++ b/chatlearn/models/vllm/hooks/__init__.py @@ -25,6 +25,7 @@ from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3: from chatlearn.models.vllm.hooks import input_preprocess + from chatlearn.models.vllm.hooks import loader else: if importlib.util.find_spec("vllm"): import vllm diff --git a/chatlearn/models/vllm/hooks/loader.py b/chatlearn/models/vllm/hooks/loader.py new file mode 100644 index 00000000..4a4d06c5 --- /dev/null +++ b/chatlearn/models/vllm/hooks/loader.py @@ -0,0 +1,108 @@ +# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Hooks of vllm-0.6.3 loader to load ckpt of megatron format.""" + + +import torch + +# pylint: disable=unused-import,wildcard-import,unused-argument +from vllm.model_executor.model_loader import loader +from vllm.model_executor.model_loader.loader import device_loading_context, _initialize_model +from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights +from vllm.model_executor.model_loader.utils import set_default_torch_dtype +from vllm.model_executor.models import qwen2 + +from chatlearn.utils.vllm_import_helper import LlamaForCausalLM +from chatlearn.utils.vllm_import_helper import QWenLMHeadModel +from chatlearn.utils.vllm_import_helper import Qwen2ForCausalLM +from chatlearn.utils.vllm_import_helper import get_model_architecture +from chatlearn.utils.utils import get_use_legacy_models + +from chatlearn.utils.vllm_utils import ( + convert_llama_state_dict_from_megatron_to_vllm, + convert_llama_state_dict_from_mcore_to_vllm, + convert_qwen_state_dict_from_megatron_to_vllm, + load_checkpoint +) + +def load_weights(self, model_args): + torch.distributed.barrier() + self.model_args = model_args + load_checkpoint(self, None, None, model_args=model_args) + torch.distributed.barrier() + +def load_state_dict(self, state_dict, strict=True, assign=False): + qwen_version = None + if isinstance(self, LlamaForCausalLM): + use_legacy_models = get_use_legacy_models(self.model_args) + if use_legacy_models: + convert_state_dict_internal = convert_llama_state_dict_from_megatron_to_vllm + else: + convert_state_dict_internal = convert_llama_state_dict_from_mcore_to_vllm + elif isinstance(self, QWenLMHeadModel): + qwen_version = 1.0 + convert_state_dict_internal = convert_qwen_state_dict_from_megatron_to_vllm + elif isinstance(self, Qwen2ForCausalLM) or (Qwen2MoeForCausalLM is not None and isinstance(self, Qwen2MoeForCausalLM)): + qwen_version = 2.0 + convert_state_dict_internal = convert_qwen_state_dict_from_megatron_to_vllm + else: + raise RuntimeError(f"Unsupported model for vllm backend. \ + support [LlamaForCausalLM, QWenLMHeadModel, Qwen2ForCausalLM, Qwen2MoeForCausalLM] only, while {self}") + + state_dict = convert_state_dict_internal(self.model_args, self.config, qwen_version=qwen_version) + super(type(self), self).load_state_dict(state_dict, strict=strict) + + +def init(self, load_config): + # remove 'Model loader extra config' assert. + self.load_config = load_config + +loader.DummyModelLoader.__init__ = init + + +# add ckpt loading of megatron format +def load_model(self, *, model_config, + device_config, + lora_config, + parallel_config, + scheduler_config, + cache_config): + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model(model_config, self.load_config, + lora_config, cache_config, + scheduler_config) + if self.load_config.model_loader_extra_config.get("need_load_ckpt", True): + qwen2.Qwen2ForCausalLM.load_state_dict = load_state_dict + qwen2.Qwen2ForCausalLM.load_weights = load_weights + model.load_weights(self.load_config.model_loader_extra_config) + else: + # For accurate performance evaluation, we assign + # random values to the weights. + initialize_dummy_weights(model) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context( + module, torch.device(device_config.device)): + quant_method.process_weights_after_loading(module) + return model.eval() +loader.DummyModelLoader.load_model = load_model diff --git a/chatlearn/models/vllm_module_v2.py b/chatlearn/models/vllm_module_v2.py index a8ea52d4..2daf2cc4 100644 --- a/chatlearn/models/vllm_module_v2.py +++ b/chatlearn/models/vllm_module_v2.py @@ -24,6 +24,7 @@ from transformers import AutoTokenizer from vllm import SamplingParams from vllm.config import EngineConfig +from vllm.config import LoadFormat from vllm.executor.ray_utils import RayWorkerWrapper from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.engine.arg_utils import AsyncEngineArgs @@ -67,6 +68,12 @@ def _init_args(self, args): # logger config args.disable_log_requests = True + # load format: 'dummy' for megatron ckpt or mock weight; others for hf ckpt. + args.load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY) + if args.load_format == LoadFormat.DUMMY: + args.model_loader_extra_config = self.model_args + self.model_args["need_load_ckpt"] = self.src_parameter_model is None + # engine config args.enforce_eager = self.model_args.get("enforce_eager", False) @@ -81,7 +88,7 @@ def setup_vllm(self, workers): if self.model_args.get("fp16", False): dtype = "float16" vllm_sys_argv = ["", - f"--model={self.model_args['load']}", + f"--model={self.model_args['tokenizer']}", f"--tensor_parallel_size={self.module_args.tensor_model_parallel_size}", f"--pipeline_parallel_size={self.module_args.pipeline_model_parallel_size}", f"--dtype={dtype}", diff --git a/chatlearn/utils/vllm_utils.py b/chatlearn/utils/vllm_utils.py index 00169785..c04880b0 100644 --- a/chatlearn/utils/vllm_utils.py +++ b/chatlearn/utils/vllm_utils.py @@ -30,6 +30,7 @@ import torch import torch.distributed +from chatlearn.models.vllm import is_vllm_v2 from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion from chatlearn.utils.utils import get_use_legacy_models @@ -1136,7 +1137,7 @@ def convert_qwen_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version= final_norm = "ln_f" func_map = megatron_qwen_to_transformers elif qwen_version == QwenVersion.v_2: - prefix_name = "model.model." + prefix_name = "model." if is_vllm_v2() else "model.model." embed_name = "embed_tokens" layer_prefix = "layers" final_norm = "norm" @@ -1393,15 +1394,16 @@ def convert_qwen_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version= # For LM head, transformers' wants the matrix to weight embeddings. print("Converting LM head") + lm_head_name = "lm_head.weight" if is_vllm_v2() else "model.lm_head.weight" if megatron_args.untie_embeddings_and_output_weights: if hasattr(megatron_args, "moe_num_experts") and megatron_args.moe_num_experts: params = get_element_from_dict_by_path(final_state_dicts[tp_rank], 'model.language_model.output_layer.weight') else: params = get_element_from_dict_by_path(tp_state_dicts[tp_rank], 'model.language_model.output_layer.weight') if (isinstance(params, dict) and len(params.keys())) or (params is not None and not isinstance(params, dict)): - output_state_dict["model.lm_head.weight"] = params.to(hf_config.torch_dtype) + output_state_dict[lm_head_name] = params.to(hf_config.torch_dtype) elif pp_rank == 0 or (pp_rank == pp_size - 1) or (hasattr(megatron_args, "moe_num_experts") and megatron_args.moe_num_experts): - output_state_dict["model.lm_head.weight"] = word_embeddings + output_state_dict[lm_head_name] = word_embeddings # It should be done! print("Conversion from Megatron-LM to Transformers is done!") @@ -1491,9 +1493,12 @@ def _load_base_checkpoint(load_dir, rank0=False): return state_dict, checkpoint_name, release -def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True): +def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True, model_args=None): """"Transform parallel strategy for checkpoint if needed.""" - args = model.model_args + if model_args is not None: + args = model_args + else: + args = model.model_args if args.get("adaptive_parallel_strategy_on_checkpoint"): load_dir = args[load_arg] target_tp = args.get("tensor_model_parallel_size") @@ -1522,16 +1527,20 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri torch.distributed.barrier() args[load_arg] = save_dir print_rank_0(f"Using transformed checkpoint {save_dir}") - return vllm_load_checkpoint(model, optimizer, opt_param_scheduler, load_arg=load_arg, strict=strict) + return vllm_load_checkpoint(model, optimizer, opt_param_scheduler, load_arg=load_arg, strict=strict, model_args=model_args) -def vllm_load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True): +def vllm_load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True, model_args=None): """Load a model checkpoint and return the iteration. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` of the checkpoint match the names of parameters and buffers in model. """ - args = model.model_args + if model_args is not None: + args = model_args + else: + args = model.model_args + load_dir = args[load_arg] model = [unwrap_model(model)]