Skip to content

Commit

Permalink
Load ckpt of megatron format for vllm module v2. (#181)
Browse files Browse the repository at this point in the history
  • Loading branch information
charles9304 authored Dec 17, 2024
1 parent 7bb2113 commit 1081492
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 9 deletions.
1 change: 1 addition & 0 deletions chatlearn/models/vllm/hooks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion
if CURRENT_VLLM_VERSION == VLLMVersion.v_0_6_3:
from chatlearn.models.vllm.hooks import input_preprocess
from chatlearn.models.vllm.hooks import loader
else:
if importlib.util.find_spec("vllm"):
import vllm
Expand Down
108 changes: 108 additions & 0 deletions chatlearn/models/vllm/hooks/loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2024 Alibaba Group Holding Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Hooks of vllm-0.6.3 loader to load ckpt of megatron format."""


import torch

# pylint: disable=unused-import,wildcard-import,unused-argument
from vllm.model_executor.model_loader import loader
from vllm.model_executor.model_loader.loader import device_loading_context, _initialize_model
from vllm.model_executor.model_loader.weight_utils import initialize_dummy_weights
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models import qwen2

from chatlearn.utils.vllm_import_helper import LlamaForCausalLM
from chatlearn.utils.vllm_import_helper import QWenLMHeadModel
from chatlearn.utils.vllm_import_helper import Qwen2ForCausalLM
from chatlearn.utils.vllm_import_helper import get_model_architecture
from chatlearn.utils.utils import get_use_legacy_models

from chatlearn.utils.vllm_utils import (
convert_llama_state_dict_from_megatron_to_vllm,
convert_llama_state_dict_from_mcore_to_vllm,
convert_qwen_state_dict_from_megatron_to_vllm,
load_checkpoint
)

def load_weights(self, model_args):
torch.distributed.barrier()
self.model_args = model_args
load_checkpoint(self, None, None, model_args=model_args)
torch.distributed.barrier()

def load_state_dict(self, state_dict, strict=True, assign=False):
qwen_version = None
if isinstance(self, LlamaForCausalLM):
use_legacy_models = get_use_legacy_models(self.model_args)
if use_legacy_models:
convert_state_dict_internal = convert_llama_state_dict_from_megatron_to_vllm
else:
convert_state_dict_internal = convert_llama_state_dict_from_mcore_to_vllm
elif isinstance(self, QWenLMHeadModel):
qwen_version = 1.0
convert_state_dict_internal = convert_qwen_state_dict_from_megatron_to_vllm
elif isinstance(self, Qwen2ForCausalLM) or (Qwen2MoeForCausalLM is not None and isinstance(self, Qwen2MoeForCausalLM)):
qwen_version = 2.0
convert_state_dict_internal = convert_qwen_state_dict_from_megatron_to_vllm
else:
raise RuntimeError(f"Unsupported model for vllm backend. \
support [LlamaForCausalLM, QWenLMHeadModel, Qwen2ForCausalLM, Qwen2MoeForCausalLM] only, while {self}")

state_dict = convert_state_dict_internal(self.model_args, self.config, qwen_version=qwen_version)
super(type(self), self).load_state_dict(state_dict, strict=strict)


def init(self, load_config):
# remove 'Model loader extra config' assert.
self.load_config = load_config

loader.DummyModelLoader.__init__ = init


# add ckpt loading of megatron format
def load_model(self, *, model_config,
device_config,
lora_config,
parallel_config,
scheduler_config,
cache_config):
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, cache_config,
scheduler_config)
if self.load_config.model_loader_extra_config.get("need_load_ckpt", True):
qwen2.Qwen2ForCausalLM.load_state_dict = load_state_dict
qwen2.Qwen2ForCausalLM.load_weights = load_weights
model.load_weights(self.load_config.model_loader_extra_config)
else:
# For accurate performance evaluation, we assign
# random values to the weights.
initialize_dummy_weights(model)

for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
# When quant methods need to process weights after loading
# (for repacking, quantizing, etc), they expect parameters
# to be on the global target device. This scope is for the
# case where cpu offloading is used, where we will move the
# parameters onto device for processing and back off after.
with device_loading_context(
module, torch.device(device_config.device)):
quant_method.process_weights_after_loading(module)
return model.eval()
loader.DummyModelLoader.load_model = load_model
9 changes: 8 additions & 1 deletion chatlearn/models/vllm_module_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from transformers import AutoTokenizer
from vllm import SamplingParams
from vllm.config import EngineConfig
from vllm.config import LoadFormat
from vllm.executor.ray_utils import RayWorkerWrapper
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -67,6 +68,12 @@ def _init_args(self, args):
# logger config
args.disable_log_requests = True

# load format: 'dummy' for megatron ckpt or mock weight; others for hf ckpt.
args.load_format = self.model_args.get("vllm_load_format", LoadFormat.DUMMY)
if args.load_format == LoadFormat.DUMMY:
args.model_loader_extra_config = self.model_args
self.model_args["need_load_ckpt"] = self.src_parameter_model is None

# engine config
args.enforce_eager = self.model_args.get("enforce_eager", False)

Expand All @@ -81,7 +88,7 @@ def setup_vllm(self, workers):
if self.model_args.get("fp16", False):
dtype = "float16"
vllm_sys_argv = ["",
f"--model={self.model_args['load']}",
f"--model={self.model_args['tokenizer']}",
f"--tensor_parallel_size={self.module_args.tensor_model_parallel_size}",
f"--pipeline_parallel_size={self.module_args.pipeline_model_parallel_size}",
f"--dtype={dtype}",
Expand Down
25 changes: 17 additions & 8 deletions chatlearn/utils/vllm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import torch
import torch.distributed

from chatlearn.models.vllm import is_vllm_v2
from chatlearn.utils.constant import CURRENT_VLLM_VERSION, VLLMVersion
from chatlearn.utils.utils import get_use_legacy_models

Expand Down Expand Up @@ -1136,7 +1137,7 @@ def convert_qwen_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version=
final_norm = "ln_f"
func_map = megatron_qwen_to_transformers
elif qwen_version == QwenVersion.v_2:
prefix_name = "model.model."
prefix_name = "model." if is_vllm_v2() else "model.model."
embed_name = "embed_tokens"
layer_prefix = "layers"
final_norm = "norm"
Expand Down Expand Up @@ -1393,15 +1394,16 @@ def convert_qwen_state_dict_from_megatron_to_vllm(args, hf_config, qwen_version=

# For LM head, transformers' wants the matrix to weight embeddings.
print("Converting LM head")
lm_head_name = "lm_head.weight" if is_vllm_v2() else "model.lm_head.weight"
if megatron_args.untie_embeddings_and_output_weights:
if hasattr(megatron_args, "moe_num_experts") and megatron_args.moe_num_experts:
params = get_element_from_dict_by_path(final_state_dicts[tp_rank], 'model.language_model.output_layer.weight')
else:
params = get_element_from_dict_by_path(tp_state_dicts[tp_rank], 'model.language_model.output_layer.weight')
if (isinstance(params, dict) and len(params.keys())) or (params is not None and not isinstance(params, dict)):
output_state_dict["model.lm_head.weight"] = params.to(hf_config.torch_dtype)
output_state_dict[lm_head_name] = params.to(hf_config.torch_dtype)
elif pp_rank == 0 or (pp_rank == pp_size - 1) or (hasattr(megatron_args, "moe_num_experts") and megatron_args.moe_num_experts):
output_state_dict["model.lm_head.weight"] = word_embeddings
output_state_dict[lm_head_name] = word_embeddings

# It should be done!
print("Conversion from Megatron-LM to Transformers is done!")
Expand Down Expand Up @@ -1491,9 +1493,12 @@ def _load_base_checkpoint(load_dir, rank0=False):
return state_dict, checkpoint_name, release


def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True, model_args=None):
""""Transform parallel strategy for checkpoint if needed."""
args = model.model_args
if model_args is not None:
args = model_args
else:
args = model.model_args
if args.get("adaptive_parallel_strategy_on_checkpoint"):
load_dir = args[load_arg]
target_tp = args.get("tensor_model_parallel_size")
Expand Down Expand Up @@ -1522,16 +1527,20 @@ def load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', stri
torch.distributed.barrier()
args[load_arg] = save_dir
print_rank_0(f"Using transformed checkpoint {save_dir}")
return vllm_load_checkpoint(model, optimizer, opt_param_scheduler, load_arg=load_arg, strict=strict)
return vllm_load_checkpoint(model, optimizer, opt_param_scheduler, load_arg=load_arg, strict=strict, model_args=model_args)


def vllm_load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True):
def vllm_load_checkpoint(model, optimizer, opt_param_scheduler, load_arg='load', strict=True, model_args=None):
"""Load a model checkpoint and return the iteration.
strict (bool): whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
parameters and buffers in model.
"""
args = model.model_args
if model_args is not None:
args = model_args
else:
args = model.model_args

load_dir = args[load_arg]

model = [unwrap_model(model)]
Expand Down

0 comments on commit 1081492

Please sign in to comment.