-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: support loading vision model #451
base: main
Are you sure you want to change the base?
Changes from all commits
30326cd
e166e86
de15409
802405a
95f62ca
79e9ecd
c1f8e5f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,10 @@ class ModelArguments: | |
tokenizer classes." | ||
}, | ||
) | ||
multimodal: bool = field( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a better way to tell if a user is loading a vision model and to then load the processor and load the model with |
||
default=False, | ||
metadata={"help": "Load multimodal model and processor"}, | ||
) | ||
|
||
|
||
@dataclass | ||
|
@@ -122,6 +126,20 @@ class DataArguments: | |
Passed in conjunction with response_template" | ||
}, | ||
) | ||
text_field_name: str = field( | ||
default=None, | ||
metadata={ | ||
"help": "Required for running with vision models. \ | ||
The column name of the text data in the multi-modal dataset." | ||
}, | ||
) | ||
image_field_name: str = field( | ||
Comment on lines
+129
to
+136
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Will comment on these fields in the data collator section -- questions around how to load a multimodal dataset that contains images and text and how to preprocess the text portions |
||
default=None, | ||
metadata={ | ||
"help": "Required for running with vision models. \ | ||
The column name of the vision data in the multi-modal dataset." | ||
}, | ||
) | ||
|
||
|
||
@dataclass | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,14 +13,17 @@ | |
# limitations under the License. | ||
# Standard | ||
from typing import Callable, Optional | ||
import logging | ||
|
||
# Third Party | ||
from transformers import AutoTokenizer, DataCollatorForSeq2Seq | ||
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, LlavaProcessor | ||
from trl import DataCollatorForCompletionOnlyLM | ||
|
||
# Local | ||
from tuning.config import configs | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def get_data_collator( | ||
packing: bool, | ||
|
@@ -29,6 +32,9 @@ def get_data_collator( | |
is_traindata_tokenized: bool, | ||
max_seq_length: int, | ||
instruction_template: Optional[str], | ||
text_field_name: Optional[str], | ||
image_field_name: Optional[str], | ||
processor=None, | ||
is_padding_free: bool = False, | ||
) -> Callable: | ||
"""Create and return the the appropriate collator type based on the configuration for packing, | ||
|
@@ -49,12 +55,27 @@ def get_data_collator( | |
str representing the human response in a chat template | ||
is_padding_free: bool | ||
if padding free plugin is used or not | ||
text_field_name: str | ||
Field name for the text used in multi-modal dataset. | ||
image_field_name: str | ||
Field name for the images used in multi-modal dataset. | ||
processor: | ||
Model processor to combine text and image data if using | ||
multi-modal vision model. | ||
|
||
Returns: | ||
Callable | ||
Callable collator to be leveraged by the trainer. | ||
""" | ||
|
||
if processor: | ||
if not (text_field_name or image_field_name): | ||
logger.error( | ||
"When training a vision model, you must pass in the \ | ||
text_field_name and image_field_name of the dataset being used." | ||
) | ||
return VisionDataCollator(processor, text_field_name, image_field_name) | ||
|
||
if response_template and instruction_template: | ||
return DataCollatorForCompletionOnlyLM( | ||
response_template=response_template, | ||
|
@@ -96,3 +117,48 @@ def get_data_collator( | |
raise ValueError( | ||
"Could not pick a data collator. Please refer to supported data formats" | ||
) | ||
|
||
class VisionDataCollator: | ||
def __init__(self, processor, text_field_name, image_field_name): | ||
self.processor = processor | ||
self.text_field_name = text_field_name | ||
self.image_field_name = image_field_name | ||
|
||
def __call__(self, examples): | ||
""" | ||
Processes both the text and images by applying the chat template | ||
and tokenizing the data. | ||
This collator takes a list of examples as input and | ||
returns a batch of processed data | ||
""" | ||
# Get the texts and images, and apply the chat template | ||
texts = [ | ||
self.processor.apply_chat_template( | ||
example[self.text_field_name], tokenize=False | ||
) | ||
for example in examples | ||
] | ||
images = [example[self.image_field_name] for example in examples] | ||
Comment on lines
+135
to
+141
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dushyantbehl This is the key part I need your input on -- this works well for a dataset like https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft where there is a single column for images and for text. However, how do we want to preprocess datasets with multiple text fields that need to be concatenated together like https://huggingface.co/datasets/NSTiwari/DocumentIDEFICS_QA where the query and answer are separate columns, the data collator should look more like for example in examples:
image = example["image"]
question = example["query"]['en']
answer = random.choice(example["answers"])
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "Answer briefly."},
{"type": "image"},
{"type": "text", "text": question}
]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": answer}
]
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=False) but how do we accommodate both forms and not knowing what the column field names are? Or even further https://huggingface.co/datasets/philschmid/amazon-product-descriptions-vlm where there are multiple which contains many fields regarding amazon product information should look like return {"messages": [
{
"role": "system",
"content": [{"type": "text", "text": system_message}],
},
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt.format(product_name=sample["Product Name"], category=sample["Category"]),
},{
"type": "image",
"image": sample["image"],
}
],
},
{
"role": "assistant",
"content": [{"type": "text", "text": sample["description"]}],
},
],
} where only some fields are taken -- is this something the user would have to provide their own collator for? |
||
|
||
# LLava1.5 does not support multiple images | ||
if isinstance(self.processor, LlavaProcessor): | ||
images = [image[0] for image in images] | ||
|
||
# Tokenize the texts and process the images | ||
batch = self.processor( | ||
text=texts, images=images, return_tensors="pt", padding=True | ||
) | ||
|
||
# The labels are the input_ids, and we mask the padding tokens in the loss computation | ||
# TOOD: should we be ensuring EOS tokens is set? | ||
labels = batch["input_ids"].clone() | ||
if self.processor.tokenizer.pad_token_id is not None: | ||
labels[labels == self.processor.tokenizer.pad_token_id] = -100 | ||
# Ignore the image token index in the loss computation (model specific) | ||
image_token_id = self.processor.tokenizer.convert_tokens_to_ids( | ||
self.processor.image_token | ||
) | ||
labels[labels == image_token_id] = -100 | ||
batch["labels"] = labels | ||
|
||
return batch |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ | |
import logging | ||
|
||
# Third Party | ||
from datasets import Dataset, IterableDataset | ||
from datasets import Dataset, IterableDataset, load_dataset | ||
|
||
# Third | ||
from transformers import AutoTokenizer | ||
|
@@ -217,6 +217,7 @@ def _process_raw_data_args( | |
max_seq_length: int, | ||
additional_data_handlers: Dict[str, Callable] = None, | ||
is_padding_free: bool = False, | ||
processor=None, | ||
): | ||
|
||
# Create a data processor with default processor config | ||
|
@@ -289,11 +290,18 @@ def _process_raw_data_args( | |
eval_dataset_config.data_handlers = handlers | ||
|
||
# And let processor handle the logic | ||
train_dataset = data_processor.process_dataset_configs([train_dataset_config]) | ||
train_dataset = None | ||
if processor: | ||
train_dataset = load_dataset(data_args.training_data_path, split="train") | ||
else: | ||
train_dataset = data_processor.process_dataset_configs([train_dataset_config]) | ||
Comment on lines
+295
to
+297
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @dushyantbehl please help in this area for verifying how this data preprocessing would look if the multimodal dataset went through the |
||
|
||
eval_dataset = None | ||
if is_eval_dataset_present: | ||
eval_dataset = data_processor.process_dataset_configs([eval_dataset_config]) | ||
if processor: | ||
eval_dataset = load_dataset(data_args.training_data_path) | ||
else: | ||
eval_dataset = data_processor.process_dataset_configs([eval_dataset_config]) | ||
|
||
return (train_dataset, eval_dataset, dataset_text_field) | ||
|
||
|
@@ -310,6 +318,7 @@ def process_dataargs( | |
train_args: TrainingArguments, | ||
additional_data_handlers: Dict[str, Callable] = None, | ||
is_padding_free: bool = False, | ||
processor=None, | ||
): | ||
""" | ||
Args: | ||
|
@@ -322,6 +331,9 @@ def process_dataargs( | |
which need to be registered with the data preprocessor | ||
is_padding_free: A bool representing if Padding free plugin is enabled. | ||
Defaults to False. | ||
processor: | ||
Model processor to combine text and image data if using | ||
multi-modal model. | ||
Returns: | ||
Tuple(Dataset, Dataset, str, DataCollator, int, Dict) | ||
tuple containing | ||
|
@@ -358,13 +370,22 @@ def process_dataargs( | |
max_seq_length, | ||
additional_data_handlers, | ||
is_padding_free, | ||
processor, | ||
) | ||
|
||
# Note: This check should not be removed. | ||
# Its important to recompute this post handling to | ||
# check if we already tokenized the dataset or not. | ||
is_tokenized_dataset = is_pretokenized_dataset(train_dataset or eval_dataset) | ||
|
||
if processor and not (data_args.text_field_name or data_args.image_field_name): | ||
logger.error( | ||
"When running a vision model you must provide the text_field_name and \ | ||
image_field_name for the columns in the dataset. Values should be from \ | ||
column names: %s", | ||
train_dataset.column_names, | ||
) | ||
|
||
data_collator = get_data_collator( | ||
train_args.packing, | ||
data_args.response_template, | ||
|
@@ -373,10 +394,13 @@ def process_dataargs( | |
max_seq_length, | ||
data_args.instruction_template, | ||
is_padding_free=is_padding_free, | ||
text_field_name=data_args.text_field_name, | ||
image_field_name=data_args.image_field_name, | ||
processor=processor, | ||
) | ||
|
||
dataset_kwargs = {} | ||
if is_tokenized_dataset: | ||
if is_tokenized_dataset or processor is not None: | ||
dataset_kwargs["skip_prepare_dataset"] = True | ||
|
||
return ( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that both of these changes are waiting on a fix in fms-acceleration, otherwise the cross entropy fused ops will break with newer versions of transformers past 4.46. Also should update to give transformers upper bound version