Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
Signed-off-by: n1ck-guo <[email protected]>
  • Loading branch information
n1ck-guo committed Jan 3, 2025
1 parent f1c9dff commit 5504509
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 55 deletions.
4 changes: 3 additions & 1 deletion auto_round/mllm/mllm_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def get_mllm_dataloader(
"""
if isinstance(template, str):
from .template import get_template
template = get_template(template, model=model, tokenizer=tokenizer, processor=processor, image_processor=image_processor)
template = get_template(
template, model=model, tokenizer=tokenizer,
processor=processor, image_processor=image_processor)

if os.path.isfile(dataset) or dataset in MLLM_DATASET.keys():
dataset = MLLM_DATASET['liuhaotian/llava'](
Expand Down
13 changes: 12 additions & 1 deletion auto_round/mllm/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def get_input(


@regist_processor("qwen2_vl")
class Qwen2VLProcessor(BasicProcessor):
class Qwen2VLProcessor(HFProcessor):
@staticmethod
def squeeze_result(ret):
for key in ret:
Expand Down Expand Up @@ -290,3 +290,14 @@ class DataArgs:

def data_collator(self, batch):
return self.collator_func(batch)


@regist_processor("deepseek_vl_v2")
class DeepseekVL2Processor(BasicProcessor):
def get_input(
self,
text,
images,
return_tensors="pt",
squeeze=True, max_length=None, truncation=False, truncation_strategy="text", **kwargs):
breakpoint()
2 changes: 2 additions & 0 deletions auto_round/mllm/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ def _register_template(
)
return TEMPLATES[model_type]

_register_template("qwen2_vl", default_dataset="NeelNanda/pile-10k",processor=PROCESSORS["qwen2_vl"])
_register_template("mllama", default_dataset="liuhaotian/llava", processor=PROCESSORS["hf"])

def load_template(path: str):
"""Load template information from a json file."""
Expand Down
10 changes: 0 additions & 10 deletions auto_round/mllm/templates/mllama.json

This file was deleted.

13 changes: 0 additions & 13 deletions auto_round/mllm/templates/qwen2_vl.json

This file was deleted.

62 changes: 32 additions & 30 deletions auto_round/script/mllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,42 +283,44 @@ def tune(args):

# load_model
processor, image_processor = None, None
config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
if "llava" in model_name and config.architectures[0] != "LlavaForConditionalGeneration":
from llava.model.builder import load_pretrained_model # pylint: disable=E0401
tokenizer, model, image_processor, _ = load_pretrained_model(
model_name, model_base=None, model_name=model_name,
torch_dtype=torch_dtype)
model_type = "llava"
elif "deepseek" in model_name.lower():
from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM
processor = DeepseekVLV2Processor.from_pretrained(model_name)
if "deepseek" in model_name.lower():
from deepseek_vl2.models import DeepseekVLV2Processor, DeepseekVLV2ForCausalLM # pylint: disable=E0401
processor = DeepseekVLV2Processor.from_pretrained(model_name)
tokenizer = processor.tokenizer
model: DeepseekVLV2ForCausalLM = AutoModelForCausalLM.from_pretrained(
model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype,
device_map="auto" if use_auto_mapping else None)
model_type = "deepseek_vl_v2"
else:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
model_type = config.model_type
if "llava" in model_type:
from transformers import LlavaForConditionalGeneration
cls = LlavaForConditionalGeneration
elif "qwen2_vl" in model_type:
from transformers import Qwen2VLForConditionalGeneration
cls = Qwen2VLForConditionalGeneration
elif "mllama" in model_type:
from transformers import MllamaForConditionalGeneration
cls = MllamaForConditionalGeneration
elif "idefics3" in model_type:
from transformers import AutoModelForVision2Seq
cls = AutoModelForVision2Seq
config = AutoConfig.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
if "llava" in model_name and config.architectures[0] != "LlavaForConditionalGeneration":
from llava.model.builder import load_pretrained_model # pylint: disable=E0401
tokenizer, model, image_processor, _ = load_pretrained_model(
model_name, model_base=None, model_name=model_name,
torch_dtype=torch_dtype)
model_type = "llava"
else:
cls = AutoModelForCausalLM

model = cls.from_pretrained(
model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype,
device_map="auto" if use_auto_mapping else None)
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=not args.disable_trust_remote_code)
model_type = config.model_type
if "llava" in model_type:
from transformers import LlavaForConditionalGeneration
cls = LlavaForConditionalGeneration
elif "qwen2_vl" in model_type:
from transformers import Qwen2VLForConditionalGeneration
cls = Qwen2VLForConditionalGeneration
elif "mllama" in model_type:
from transformers import MllamaForConditionalGeneration
cls = MllamaForConditionalGeneration
elif "idefics3" in model_type:
from transformers import AutoModelForVision2Seq
cls = AutoModelForVision2Seq
else:
cls = AutoModelForCausalLM

model = cls.from_pretrained(
model_name, trust_remote_code=not args.disable_trust_remote_code, torch_dtype=torch_dtype,
device_map="auto" if use_auto_mapping else None)
if "cogvlm2" in model_name:
model.config.model_type = "cogvlm2"

Expand Down

0 comments on commit 5504509

Please sign in to comment.