-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support loading vision model #451
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Anh Uong <[email protected]>
Signed-off-by: Anh Uong <[email protected]>
Signed-off-by: Anh Uong <[email protected]>
Thanks for making a pull request! 😃 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tested this code loading llama 3.2-11b vision model as well as llava 1.6-mistral-7b vision model and they both were able to be lora tuned successfully with dataset https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft
Note that when loading llava model with FSDP, need to provide extra field fsdp_transformer_layer_cls_to_wrap: "LlamaDecoderLayer"
for llava 1.5 and fsdp_transformer_layer_cls_to_wrap: "MistralDecoderLayer"
for llava 1.6-mistral
Ran with configuration:
{
"model_name_or_path": "llava-hf/llava-v1.6-mistral-7b-hf",
"training_data_path": "HuggingFaceH4/llava-instruct-mix-vsft",
"output_dir": "/fmaas-integration-tests/tuning/output/anhuong/llava1.6-mistral-7b-vision_llava-dataset_lora",
"num_train_epochs": 3.0,
"per_device_train_batch_size": 4,
"gradient_accumulation_steps": 1,
"learning_rate": 1e-4,
"response_template": "\n### Response:", <--- FIX: this field is not used
"dataset_text_field": "output", <--- FIX: this field is not used
"bf16": true,
"torch_dtype": "bfloat16",
"use_flash_attn": false,
"remove_unused_columns": false,
"dataset_kwargs": {"skip_prepare_dataset": true},
"multimodal": true,
"peft_method": "lora",
"r": 8,
"lora_dropout": 0.05,
"lora_alpha": 16,
"target_modules": ["all-linear"],
"lora_post_process_for_vllm": true,
"gradient_checkpointing": true,
"text_field_name": "messages",
"image_field_name": "images"
}
"transformers", | ||
"torch>=2.2.0,<2.5", | ||
"sentencepiece>=0.1.99,<0.3", | ||
"tokenizers>=0.13.3,<1.0", | ||
"tqdm>=4.66.2,<5.0", | ||
"trl>=0.9.3,<0.12", | ||
"trl==0.13", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that both of these changes are waiting on a fix in fms-acceleration, otherwise the cross entropy fused ops will break with newer versions of transformers past 4.46. Also should update to give transformers upper bound version
@@ -61,6 +61,10 @@ class ModelArguments: | |||
tokenizer classes." | |||
}, | |||
) | |||
multimodal: bool = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a better way to tell if a user is loading a vision model and to then load the processor and load the model with AutoModelForVision2Seq
? Should update this to vision specific flag
text_field_name: str = field( | ||
default=None, | ||
metadata={ | ||
"help": "Required for running with vision models. \ | ||
The column name of the text data in the multi-modal dataset." | ||
}, | ||
) | ||
image_field_name: str = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will comment on these fields in the data collator section -- questions around how to load a multimodal dataset that contains images and text and how to preprocess the text portions
texts = [ | ||
self.processor.apply_chat_template( | ||
example[self.text_field_name], tokenize=False | ||
) | ||
for example in examples | ||
] | ||
images = [example[self.image_field_name] for example in examples] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dushyantbehl This is the key part I need your input on -- this works well for a dataset like https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft where there is a single column for images and for text. However, how do we want to preprocess datasets with multiple text fields that need to be concatenated together like https://huggingface.co/datasets/NSTiwari/DocumentIDEFICS_QA where the query and answer are separate columns, the data collator should look more like
for example in examples:
image = example["image"]
question = example["query"]['en']
answer = random.choice(example["answers"])
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Answer briefly."},
{"type": "image"},
{"type": "text", "text": question}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": answer}
]
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=False)
but how do we accommodate both forms and not knowing what the column field names are?
Or even further https://huggingface.co/datasets/philschmid/amazon-product-descriptions-vlm where there are multiple which contains many fields regarding amazon product information should look like
return {"messages": [
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt.format(product_name=sample["Product Name"], category=sample["Category"]),
},{
"type": "image",
"image": sample["image"],
}
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": sample["description"]}],
},
],
}
where only some fields are taken -- is this something the user would have to provide their own collator for?
train_dataset = load_dataset(data_args.training_data_path, split="train") | ||
else: | ||
train_dataset = data_processor.process_dataset_configs([train_dataset_config]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dushyantbehl please help in this area for verifying how this data preprocessing would look if the multimodal dataset went through the process_data_configs
Description of the change
Related issue number
How to verify the PR
Was the PR tested