Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support loading vision model #451

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@ classifiers=[
dependencies = [
"numpy>=1.26.4,<2.0",
"accelerate>=0.20.3,!=0.34,<1.1",
"transformers>=4.45,<4.46",
"transformers",
"torch>=2.2.0,<2.5",
"sentencepiece>=0.1.99,<0.3",
"tokenizers>=0.13.3,<1.0",
"tqdm>=4.66.2,<5.0",
"trl>=0.9.3,<0.12",
"trl==0.13",
Comment on lines +31 to +36
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that both of these changes are waiting on a fix in fms-acceleration, otherwise the cross entropy fused ops will break with newer versions of transformers past 4.46. Also should update to give transformers upper bound version

"peft>=0.8.0,<0.14",
"protobuf>=5.28.0,<6.0.0",
"datasets>=2.15.0,<3.0",
"simpleeval>=0.9.13,<1.0",
"pillow>=11.0.0,<12.0",
]

[project.optional-dependencies]
Expand Down
18 changes: 18 additions & 0 deletions tuning/config/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,10 @@ class ModelArguments:
tokenizer classes."
},
)
multimodal: bool = field(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a better way to tell if a user is loading a vision model and to then load the processor and load the model with AutoModelForVision2Seq? Should update this to vision specific flag

default=False,
metadata={"help": "Load multimodal model and processor"},
)


@dataclass
Expand Down Expand Up @@ -122,6 +126,20 @@ class DataArguments:
Passed in conjunction with response_template"
},
)
text_field_name: str = field(
default=None,
metadata={
"help": "Required for running with vision models. \
The column name of the text data in the multi-modal dataset."
},
)
image_field_name: str = field(
Comment on lines +129 to +136
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will comment on these fields in the data collator section -- questions around how to load a multimodal dataset that contains images and text and how to preprocess the text portions

default=None,
metadata={
"help": "Required for running with vision models. \
The column name of the vision data in the multi-modal dataset."
},
)


@dataclass
Expand Down
68 changes: 67 additions & 1 deletion tuning/data/data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,17 @@
# limitations under the License.
# Standard
from typing import Callable, Optional
import logging

# Third Party
from transformers import AutoTokenizer, DataCollatorForSeq2Seq
from transformers import AutoTokenizer, DataCollatorForSeq2Seq, LlavaProcessor
from trl import DataCollatorForCompletionOnlyLM

# Local
from tuning.config import configs

logger = logging.getLogger(__name__)


def get_data_collator(
packing: bool,
Expand All @@ -29,6 +32,9 @@ def get_data_collator(
is_traindata_tokenized: bool,
max_seq_length: int,
instruction_template: Optional[str],
text_field_name: Optional[str],
image_field_name: Optional[str],
processor=None,
is_padding_free: bool = False,
) -> Callable:
"""Create and return the the appropriate collator type based on the configuration for packing,
Expand All @@ -49,12 +55,27 @@ def get_data_collator(
str representing the human response in a chat template
is_padding_free: bool
if padding free plugin is used or not
text_field_name: str
Field name for the text used in multi-modal dataset.
image_field_name: str
Field name for the images used in multi-modal dataset.
processor:
Model processor to combine text and image data if using
multi-modal vision model.

Returns:
Callable
Callable collator to be leveraged by the trainer.
"""

if processor:
if not (text_field_name or image_field_name):
logger.error(
"When training a vision model, you must pass in the \
text_field_name and image_field_name of the dataset being used."
)
return VisionDataCollator(processor, text_field_name, image_field_name)

if response_template and instruction_template:
return DataCollatorForCompletionOnlyLM(
response_template=response_template,
Expand Down Expand Up @@ -96,3 +117,48 @@ def get_data_collator(
raise ValueError(
"Could not pick a data collator. Please refer to supported data formats"
)

class VisionDataCollator:
def __init__(self, processor, text_field_name, image_field_name):
self.processor = processor
self.text_field_name = text_field_name
self.image_field_name = image_field_name

def __call__(self, examples):
"""
Processes both the text and images by applying the chat template
and tokenizing the data.
This collator takes a list of examples as input and
returns a batch of processed data
"""
# Get the texts and images, and apply the chat template
texts = [
self.processor.apply_chat_template(
example[self.text_field_name], tokenize=False
)
for example in examples
]
images = [example[self.image_field_name] for example in examples]
Comment on lines +135 to +141
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dushyantbehl This is the key part I need your input on -- this works well for a dataset like https://huggingface.co/datasets/HuggingFaceH4/llava-instruct-mix-vsft where there is a single column for images and for text. However, how do we want to preprocess datasets with multiple text fields that need to be concatenated together like https://huggingface.co/datasets/NSTiwari/DocumentIDEFICS_QA where the query and answer are separate columns, the data collator should look more like

        for example in examples:
            image = example["image"]
            question = example["query"]['en']
            answer = random.choice(example["answers"])
            messages = [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": "Answer briefly."},
                        {"type": "image"},
                        {"type": "text", "text": question}
                    ]
                },
                {
                    "role": "assistant",
                    "content": [
                        {"type": "text", "text": answer}
                    ]
                }
            ]
            text = processor.apply_chat_template(messages, add_generation_prompt=False)

but how do we accommodate both forms and not knowing what the column field names are?

Or even further https://huggingface.co/datasets/philschmid/amazon-product-descriptions-vlm where there are multiple which contains many fields regarding amazon product information should look like

    return {"messages": [
                {
                    "role": "system",
                    "content": [{"type": "text", "text": system_message}],
                },
                {
                    "role": "user",
                    "content": [
                        {
                            "type": "text",
                            "text": prompt.format(product_name=sample["Product Name"], category=sample["Category"]),
                        },{
                            "type": "image",
                            "image": sample["image"],
                        }
                    ],
                },
                {
                    "role": "assistant",
                    "content": [{"type": "text", "text": sample["description"]}],
                },
            ],
        }

where only some fields are taken -- is this something the user would have to provide their own collator for?


# LLava1.5 does not support multiple images
if isinstance(self.processor, LlavaProcessor):
images = [image[0] for image in images]

# Tokenize the texts and process the images
batch = self.processor(
text=texts, images=images, return_tensors="pt", padding=True
)

# The labels are the input_ids, and we mask the padding tokens in the loss computation
# TOOD: should we be ensuring EOS tokens is set?
labels = batch["input_ids"].clone()
if self.processor.tokenizer.pad_token_id is not None:
labels[labels == self.processor.tokenizer.pad_token_id] = -100
# Ignore the image token index in the loss computation (model specific)
image_token_id = self.processor.tokenizer.convert_tokens_to_ids(
self.processor.image_token
)
labels[labels == image_token_id] = -100
batch["labels"] = labels

return batch
32 changes: 28 additions & 4 deletions tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import logging

# Third Party
from datasets import Dataset, IterableDataset
from datasets import Dataset, IterableDataset, load_dataset

# Third
from transformers import AutoTokenizer
Expand Down Expand Up @@ -217,6 +217,7 @@ def _process_raw_data_args(
max_seq_length: int,
additional_data_handlers: Dict[str, Callable] = None,
is_padding_free: bool = False,
processor=None,
):

# Create a data processor with default processor config
Expand Down Expand Up @@ -289,11 +290,18 @@ def _process_raw_data_args(
eval_dataset_config.data_handlers = handlers

# And let processor handle the logic
train_dataset = data_processor.process_dataset_configs([train_dataset_config])
train_dataset = None
if processor:
train_dataset = load_dataset(data_args.training_data_path, split="train")
else:
train_dataset = data_processor.process_dataset_configs([train_dataset_config])
Comment on lines +295 to +297
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dushyantbehl please help in this area for verifying how this data preprocessing would look if the multimodal dataset went through the process_data_configs


eval_dataset = None
if is_eval_dataset_present:
eval_dataset = data_processor.process_dataset_configs([eval_dataset_config])
if processor:
eval_dataset = load_dataset(data_args.training_data_path)
else:
eval_dataset = data_processor.process_dataset_configs([eval_dataset_config])

return (train_dataset, eval_dataset, dataset_text_field)

Expand All @@ -310,6 +318,7 @@ def process_dataargs(
train_args: TrainingArguments,
additional_data_handlers: Dict[str, Callable] = None,
is_padding_free: bool = False,
processor=None,
):
"""
Args:
Expand All @@ -322,6 +331,9 @@ def process_dataargs(
which need to be registered with the data preprocessor
is_padding_free: A bool representing if Padding free plugin is enabled.
Defaults to False.
processor:
Model processor to combine text and image data if using
multi-modal model.
Returns:
Tuple(Dataset, Dataset, str, DataCollator, int, Dict)
tuple containing
Expand Down Expand Up @@ -358,13 +370,22 @@ def process_dataargs(
max_seq_length,
additional_data_handlers,
is_padding_free,
processor,
)

# Note: This check should not be removed.
# Its important to recompute this post handling to
# check if we already tokenized the dataset or not.
is_tokenized_dataset = is_pretokenized_dataset(train_dataset or eval_dataset)

if processor and not (data_args.text_field_name or data_args.image_field_name):
logger.error(
"When running a vision model you must provide the text_field_name and \
image_field_name for the columns in the dataset. Values should be from \
column names: %s",
train_dataset.column_names,
)

data_collator = get_data_collator(
train_args.packing,
data_args.response_template,
Expand All @@ -373,10 +394,13 @@ def process_dataargs(
max_seq_length,
data_args.instruction_template,
is_padding_free=is_padding_free,
text_field_name=data_args.text_field_name,
image_field_name=data_args.image_field_name,
processor=processor,
)

dataset_kwargs = {}
if is_tokenized_dataset:
if is_tokenized_dataset or processor is not None:
dataset_kwargs["skip_prepare_dataset"] = True

return (
Expand Down
10 changes: 9 additions & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from torch.cuda import OutOfMemoryError
from transformers import (
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
GPT2Tokenizer,
GPTNeoXTokenizerFast,
Expand Down Expand Up @@ -214,6 +216,11 @@ def train(
).get_framework()

model_loader = AutoModelForCausalLM.from_pretrained
processor = None
if model_args.multimodal:
model_loader = AutoModelForVision2Seq.from_pretrained
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path)

if framework is not None and framework.requires_custom_loading:
model_loader = framework.model_loader # drop-in new loader
model_load_time = time.time()
Expand Down Expand Up @@ -327,6 +334,7 @@ def train(
train_args,
additional_data_handlers,
is_padding_free=is_padding_free,
processor=processor
)
additional_metrics["data_preprocessing_time"] = (
time.time() - data_preprocessing_time
Expand Down Expand Up @@ -362,7 +370,7 @@ def train(

trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
processing_class=tokenizer,
train_dataset=formatted_train_dataset,
eval_dataset=formatted_validation_dataset,
data_collator=data_collator,
Expand Down
Loading