Skip to content

Commit

Permalink
support for deepseek vl2 (#401)
Browse files Browse the repository at this point in the history
  • Loading branch information
n1ck-guo authored Jan 7, 2025
1 parent 6ac7a2b commit bf8d68d
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 33 deletions.
22 changes: 15 additions & 7 deletions auto_round/auto_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,16 +413,24 @@ def convert_model(self, model: nn.Module):
data_type = quantization_config.data_type if hasattr(quantization_config,
"data_type") else "int" # pragma: no cover
sym = quantization_config.sym
to_quant_block_names = quantization_config.to_quant_block_names if hasattr(quantization_config,
"to_quant_block_names") else None

quant_block_list = quantization_config.quant_block_list if hasattr(quantization_config,
"quant_block_list") else None
if to_quant_block_names is None: # TODO check compatibility
all_blocks = get_block_names(model)
else:
all_blocks = get_multimodal_block_names(model, quant_vision=True)

if quant_block_list is None:
quant_block_list = find_matching_blocks(model, all_blocks, to_quant_block_names)
to_quant_block_names = quantization_config.to_quant_block_names if hasattr(quantization_config,
"to_quant_block_names") else None
if to_quant_block_names is not None:
if isinstance(to_quant_block_names, (list, tuple)):
quant_block_list = to_quant_block_names
else:
quant_block_list = []
for block in to_quant_block_names.split(','):
quant_block_list.append([f'{block}.{i}' for i in range(len(get_module(model, block)))])
else:
all_blocks = get_block_names(model)
quant_block_list = find_matching_blocks(model, all_blocks, to_quant_block_names)

layer_names = get_layer_names_in_block(model, quant_block_list=quant_block_list)

extra_config = {}
Expand Down
1 change: 1 addition & 0 deletions auto_round/mllm/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def get_template(template_or_path: str, model=None, tokenizer=None, processor=No
else:
logger.warning(f"Unable to recognize {template_or_path}, using default template instead.")
template = TEMPLATES["default"]
template.model_type = template_or_path

template.processor.post_init(model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor)

Expand Down
57 changes: 33 additions & 24 deletions auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,32 +281,41 @@ def tune(args):

# load_model
processor, image_processor = None, None
config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
if "llava" in model_name and config.architectures[0] != "LlavaForConditionalGeneration":
from llava.model.builder import load_pretrained_model # pylint: disable=E0401
tokenizer, model, image_processor, _ = load_pretrained_model(
model_name, model_base=None, model_name=model_name,
torch_dtype=torch_dtype)
model_type = "llava"
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
model_type = config.model_type
if "llava" in model_type:
from transformers import LlavaForConditionalGeneration
cls = LlavaForConditionalGeneration
elif "qwen2_vl" in model_type:
from transformers import Qwen2VLForConditionalGeneration
cls = Qwen2VLForConditionalGeneration
elif "mllama" in model_type:
from transformers import MllamaForConditionalGeneration
cls = MllamaForConditionalGeneration
else:
cls = AutoModelForCausalLM

model = cls.from_pretrained(
if "deepseek-vl2" in model_name.lower():
from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM # pylint: disable=E0401
processor = DeepseekVLV2Processor.from_pretrained(model_name)
tokenizer = processor.tokenizer
model: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype,
device_map="auto" if use_auto_mapping else None)
model_type = "deepseek_vl_v2"
else:
config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
if "llava" in model_name and config.architectures[0] != "LlavaForConditionalGeneration":
from llava.model.builder import load_pretrained_model # pylint: disable=E0401
tokenizer, model, image_processor, _ = load_pretrained_model(
model_name, model_base=None, model_name=model_name,
torch_dtype=torch_dtype)
model_type = "llava"
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
model_type = config.model_type
if "llava" in model_type:
from transformers import LlavaForConditionalGeneration
cls = LlavaForConditionalGeneration
elif "qwen2_vl" in model_type:
from transformers import Qwen2VLForConditionalGeneration
cls = Qwen2VLForConditionalGeneration
elif "mllama" in model_type:
from transformers import MllamaForConditionalGeneration
cls = MllamaForConditionalGeneration
else:
cls = AutoModelForCausalLM

model = cls.from_pretrained(
model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype,
device_map="auto" if use_auto_mapping else None)
if "cogvlm2" in model_name:
model.config.model_type = "cogvlm2"

Expand Down
14 changes: 14 additions & 0 deletions auto_round/special_model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@
"idefics3"
]

def _get_deepseek_vl2_multimodal_block(model, quant_vision=False):
model.forward = model.language.forward
block_names = []
if quant_vision:
block_names.append([f"vision.blocks.{i}" for i in range(len(model.vision.blocks))])
block_names.append([f"projector.layers.{i}" for i in range(len(model.projector.layers))])
block_names.append([f"language.model.layers.{i}" for i in range(len(model.language.model.layers))])
return block_names

SPECIAL_MULTIMODAL_BLOCK = {
"deepseek_vl_v2": _get_deepseek_vl2_multimodal_block
}


def to_device(input, device=torch.device("cpu")):
"""Moves input data to the specified device.
Expand Down
4 changes: 3 additions & 1 deletion auto_round/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from functools import lru_cache
from packaging import version
import gc
from .special_model_handler import shareable_keywords
from .special_model_handler import shareable_keywords, SPECIAL_MULTIMODAL_BLOCK


@lru_cache(None)
Expand Down Expand Up @@ -402,6 +402,8 @@ def get_multimodal_block_names(model, quant_vision=False):
Returns:
block_names: A list whose elements are list of block's layer names
"""
if hasattr(model, "config") and model.config.model_type in SPECIAL_MULTIMODAL_BLOCK.keys():
return SPECIAL_MULTIMODAL_BLOCK.get(model.config.model_type)(model, quant_vision=quant_vision)
block_names = []
target_modules = []
vison_blocks_tuple = ("vision", "visual",)
Expand Down
66 changes: 65 additions & 1 deletion test_cuda/test_support_vlms.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
class TestSupportVLMS(unittest.TestCase):
@classmethod
def setUpClass(self):
self.save_dir = os.path.join(os.path.dirname(__file__), "./ut_saved")
self.save_dir = os.path.join(os.path.dirname(__file__), "ut_saved")
self.python_path = sys.executable
self.device = 0

Expand Down Expand Up @@ -333,6 +333,70 @@ def test_72b(self):
)
self.assertFalse(res > 0 or res == -1, msg="qwen2-72b tuning fail")
shutil.rmtree(self.save_dir, ignore_errors=True)

def test_deepseek_vl2(self):
model_path = "/models/deepseek-vl2-tiny"
res = os.system(
f"cd .. && {self.python_path} -m auto_round --mllm "
f"--model {model_path} --iter 3 --nsamples 10 --bs 4 --output_dir {self.save_dir} --device auto --group_size 32 "
f"--fp_layers language.model.layers.4,language.model.layers.6"
)
self.assertFalse(res > 0 or res == -1, msg="deepseek vl2 tuning fail")

quantized_model_path = os.path.join(self.save_dir, "deepseek-vl2-tiny-w4g32-auto_round")
from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
from transformers import AutoModelForCausalLM
vl_chat_processor: DeepseekVLV2Processor = DeepseekVLV2Processor.from_pretrained(quantized_model_path)
tokenizer = vl_chat_processor.tokenizer

vl_gpt: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(
quantized_model_path,
trust_remote_code=True,
device_map=f"cuda:{self.device}",
torch_dtype="auto",
)
vl_gpt = vl_gpt.eval()

image_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
content = "Describe this image."

## single image conversation example
conversation = [
{
"role": "<|User|>",
"content": content,
},
{"role": "<|Assistant|>", "content": ""},
]

# load images and prepare for inputs
pil_images = Image.open(requests.get(image_url, stream=True).raw)
prepare_inputs = vl_chat_processor(
conversations=conversation,
images=[pil_images],
force_batchify=True,
system_prompt=""
)
prepare_inputs = prepare_inputs.to(vl_gpt.device)

# run image encoder to get the image embeddings
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)

# run the model to get the response
outputs = vl_gpt.language.generate(
input_ids = prepare_inputs["input_ids"],
inputs_embeds=inputs_embeds,
attention_mask=prepare_inputs.attention_mask,
pad_token_id=tokenizer.eos_token_id,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
max_new_tokens=512,
do_sample=False,
use_cache=True
)

answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
print(f"{prepare_inputs['sft_format'][0]}", answer)

if __name__ == "__main__":
unittest.main()

0 comments on commit bf8d68d

Please sign in to comment.