Skip to content

Commit

Permalink
support awq format for vlms (#398)
Browse files Browse the repository at this point in the history
* support awq format for vlms

Signed-off-by: Zhang, Weiwei1 <[email protected]>

* add qvision awq generation ut

Signed-off-by: Zhang, Weiwei1 <[email protected]>

---------

Signed-off-by: Zhang, Weiwei1 <[email protected]>
  • Loading branch information
WeiweiZhang1 authored Dec 30, 2024
1 parent 01b779c commit c82aa88
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 11 deletions.
27 changes: 21 additions & 6 deletions auto_round/export/export_to_awq/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
import torch
import torch.nn as nn

from auto_round.utils import logger, get_module, set_module, check_to_quantized
from auto_round.utils import (logger, get_module,
set_module,
check_to_quantized,
get_multimodal_block_names,
extract_block_names_to_str)
import copy
import json
from .utils import WQLinear_GEMM, clear_memory
Expand Down Expand Up @@ -69,15 +73,22 @@ def save_quantized_as_autoawq(output_dir, inplace=True, **kwargs):
"""Export the model to autogptq format to easily leverage cuda kernel."""
model = kwargs["model"]
layer_config = kwargs["layer_config"]

to_quant_block_names = kwargs.get("to_quant_block_names", None)
tokenizer = kwargs.get("tokenizer", None)
processor = kwargs.get("processor", None)

modules_to_not_convert = []

logger.info("Saving quantized model to auto_awq format")
if tokenizer is not None:
tokenizer.save_pretrained(output_dir)
if processor is not None:
processor.save_pretrained(output_dir)
# mllm models
all_blocks = get_multimodal_block_names(model, quant_vision=True)
all_block_names = extract_block_names_to_str(all_blocks)
all_block_names = all_block_names.split(',')
to_quant_block_names = to_quant_block_names.split(',')
modules_to_not_convert = list(set(all_block_names) - set(to_quant_block_names))

if inplace:
compressed_model = model.to("cpu")
Expand All @@ -101,11 +112,13 @@ def wrapper(name):

if output_dir is None:
return compressed_model
modules_to_not_convert = []

layer_config = kwargs["layer_config"]
for key in layer_config.keys():
if not check_to_quantized(layer_config[key]):
if not check_to_quantized(layer_config[key]) and \
not any(name in key for name in modules_to_not_convert):
modules_to_not_convert.append(key)

quantization_config["quant_method"] = "awq"
quantization_config["zero_point"] = not quantization_config["sym"]
quantization_config["version"] = "gemm"
Expand All @@ -115,7 +128,8 @@ def wrapper(name):

if hasattr(compressed_model, "config"):
compressed_model.config.quantization_config = quantization_config
save(compressed_model, output_dir, safe_serialization=True)
safe_serialization = kwargs.get('safe_serialization', True)
save(compressed_model, output_dir, safe_serialization=safe_serialization)

return compressed_model

Expand Down Expand Up @@ -146,3 +160,4 @@ def save(model: nn.Module, save_dir: str, max_shard_size: str = "5GB", safe_seri
if hasattr(model, "config") and hasattr(model.config, "quantization_config"):
with open(os.path.join(save_dir, config_file), "w", encoding="utf-8") as f:
json.dump(model.config.quantization_config, f, indent=2)

9 changes: 8 additions & 1 deletion auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def setup_lmeval_parser():
def tune(args):
if args.format is None:
args.format = "auto_round"
supported_formats = ["auto_round", "auto_round:auto_gptq", "auto_round:auto_awq"]
supported_formats = ["auto_round", "auto_round:auto_gptq", "auto_round:auto_awq", "auto_awq"]
if not args.quant_nontext_module:
supported_formats.extend(["auto_gptq", "auto_gptq:marlin"])

Expand Down Expand Up @@ -392,6 +392,12 @@ def tune(args):

if "--truncation" not in sys.argv:
args.truncation = None

if "auto_awq" in args.format:
from auto_round.utils import check_awq_gemm_compatibility
awq_supported, info = check_awq_gemm_compatibility(model,args.bits,args.group_size, not args.asym, layer_config)
if not awq_supported:
logger.warning(f"The AutoAWQ format may not be supported due to {info}")

autoround = round(model, tokenizer, processor=processor, image_processor=image_processor, dataset=args.dataset,
extra_data_dir=args.extra_data_dir, bits=args.bits, group_size=args.group_size,
Expand Down Expand Up @@ -512,3 +518,4 @@ def lmms_eval(args):
apply_chat_template=False,
)
return results

71 changes: 67 additions & 4 deletions test_cuda/test_support_vlms.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def test_phi3(self):
## test tune
res = os.system(
f"cd .. && {self.python_path} -m auto_round --mllm "
f"--model {model_path} --iter 2 --output_dir {self.save_dir} --device {self.device}")
f"--model {model_path} --iter 2 --output_dir {self.save_dir} --device {self.device}")
self.assertFalse(res > 0 or res == -1, msg="Phi-3.5 tuning fail")

## test infer
Expand Down Expand Up @@ -114,11 +114,74 @@ def test_phi3(self):
image_inputs = Image.open(requests.get(image_url, stream=True).raw)
inputs = processor(prompt, image_inputs, return_tensors="pt").to(model.device)

generation_args = {
generation_args = {
"max_new_tokens": 1000,
"temperature": 0.0,
"do_sample": False,
}

generate_ids = model.generate(**inputs,
eos_token_id=processor.tokenizer.eos_token_id,
**generation_args
)

# remove input tokens
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
response = processor.batch_decode(generate_ids,
skip_special_tokens=True,
clean_up_tokenization_spaces=False)[0]
print(response)
shutil.rmtree(quantized_model_path, ignore_errors=True)

def test_phi3_vision_awq(self):
model_path = "/models/Phi-3.5-vision-instruct/"
## test tune
res = os.system(
f"cd .. && {self.python_path} -m auto_round --mllm "
f"--model {model_path} --iter 2 --quant_nontext_module "
f"--nsample 64 --seqlen 32 "
f"--format auto_awq --output_dir {self.save_dir} --device {self.device}")
self.assertFalse(res > 0 or res == -1, msg="Phi-3.5 tuning fail")

## test infer
from transformers import AutoModelForCausalLM, AutoProcessor
from auto_round.export.export_to_awq import WQLinear_GEMM
quantized_model_path = os.path.join(self.save_dir, "Phi-3.5-vision-instruct-w4g128-auto_awq")
res = os.system(f"cp /models/Phi-3.5-vision-instruct/*.py {quantized_model_path}")
model = AutoModelForCausalLM.from_pretrained(
quantized_model_path,
device_map=f"cuda:{self.device}",
trust_remote_code=True,
torch_dtype="auto"
)
assert "WQLinear_GEMM" in str(
type(model.model.vision_embed_tokens.img_processor.vision_model.encoder.layers[0].mlp.fc1)), \
"model quantization failed."
processor = AutoProcessor.from_pretrained(quantized_model_path,
trust_remote_code=True,
num_crops=4
)

image_url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg"
content = "Describe this image."
messages = [
{"role": "user",
"content": "<|image_1|>\n"+content},
]

prompt = processor.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
image_inputs = Image.open(requests.get(image_url, stream=True).raw)
inputs = processor(prompt, image_inputs, return_tensors="pt").to(model.device)

generation_args = {
"max_new_tokens": 1000,
"temperature": 0.0,
"do_sample": False,
}
}

generate_ids = model.generate(**inputs,
eos_token_id=processor.tokenizer.eos_token_id,
Expand Down Expand Up @@ -272,4 +335,4 @@ def test_72b(self):
shutil.rmtree(self.save_dir, ignore_errors=True)

if __name__ == "__main__":
unittest.main()
unittest.main()

0 comments on commit c82aa88

Please sign in to comment.