diff --git a/docs/classes/models/mt5.rst b/docs/classes/models/mt5.rst new file mode 100644 index 0000000000..d05542056d --- /dev/null +++ b/docs/classes/models/mt5.rst @@ -0,0 +1,24 @@ +MT5 +===== + +The mT5 model was presented in `mT5: A massively multilingual pre-trained text-to-text transformer +`__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, +Aditya Siddhant, Aditya Barua, Colin Raffel. + +The abstract from the paper is the following, + + +- The recent "Text-to-Text Transfer Transformer" (T5) leveraged a unified text-to-text format and scale to attain + state-of-the-art results on a wide variety of English-language NLP tasks. In this paper, we introduce mT5, a + multilingual variant of T5 that was pre-trained on a new Common Crawl-based dataset covering 101 languages. We detail + the design and modified training of mT5 and demonstrate its state-of-the-art performance on many multilingual + benchmarks. We also describe a simple technique to prevent "accidental translation" in the zero-shot setting, where a + generative model chooses to (partially) translate its prediction into the wrong language. All of the code and model + checkpoints used in this work are publicly available. + +MT5AdapterModel +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +.. autoclass:: adapters.MT5AdapterModel + :members: + :inherited-members: MT5PreTrainedModel \ No newline at end of file diff --git a/docs/model_overview.md b/docs/model_overview.md index a5ba7c4e8c..58ae523b43 100644 --- a/docs/model_overview.md +++ b/docs/model_overview.md @@ -27,6 +27,7 @@ The table below further shows which model architectures support which adaptation | [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | +| [MT5](classes/models/mt5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | | | [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | diff --git a/src/adapters/__init__.py b/src/adapters/__init__.py index 221b7db190..36232ada84 100644 --- a/src/adapters/__init__.py +++ b/src/adapters/__init__.py @@ -105,6 +105,7 @@ "models.gptj": ["GPTJAdapterModel"], "models.llama": ["LlamaAdapterModel"], "models.mbart": ["MBartAdapterModel"], + "models.mt5": ["MT5AdapterModel"], "models.roberta": ["RobertaAdapterModel"], "models.t5": ["T5AdapterModel"], "models.vit": ["ViTAdapterModel"], @@ -207,6 +208,7 @@ from .models.gptj import GPTJAdapterModel from .models.llama import LlamaAdapterModel from .models.mbart import MBartAdapterModel + from .models.mt5 import MT5AdapterModel from .models.roberta import RobertaAdapterModel from .models.t5 import T5AdapterModel from .models.vit import ViTAdapterModel diff --git a/src/adapters/composition.py b/src/adapters/composition.py index 937fae2685..6d37e44b1f 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -127,6 +127,7 @@ def __init__( "deberta", "bart", "mbart", + "mt5", "gpt2", "gptj", "t5", diff --git a/src/adapters/head_utils.py b/src/adapters/head_utils.py index 7673857adc..2144fbe5ee 100644 --- a/src/adapters/head_utils.py +++ b/src/adapters/head_utils.py @@ -369,6 +369,35 @@ }, "layers": ["lm_head"], }, + # MT5 + "MT5ForConditionalGeneration": { + "config": { + "head_type": "seq2seq_lm", + }, + "layers": ["lm_head"], + }, + "MT5ForQuestionAnswering": { + "config": { + "head_type": "question_answering", + "activation_function": None, + "layers": 1, + }, + "layers": [None, "qa_outputs"], + }, + "MT5ForSequenceClassification": { + "config": { + "head_type": "classification", + "layers": 2, + "activation_function": "tanh", + }, + "layers": [ + None, + "classification_head.dense", + None, + None, + "classification_head.out_proj", + ], + }, # DistilBERT "DistilBertForSequenceClassification": { "config": { diff --git a/src/adapters/models/__init__.py b/src/adapters/models/__init__.py index dd48552d23..46eba733b7 100644 --- a/src/adapters/models/__init__.py +++ b/src/adapters/models/__init__.py @@ -55,6 +55,11 @@ "MBartDecoder": BartDecoderAdaptersMixin, "MBartDecoderWrapper": BartDecoderWrapperAdaptersMixin, "MBartModel": BartModelAdaptersMixin, + "MT5Block": T5BlockAdaptersMixin, + "MT5Model": T5ModelAdaptersMixin, + "MT5ForConditionalGeneration": T5ForCondiditionalGenerationWithHeadsMixin, + "MT5ForQuestionAnswering": T5ForQuestionAnsweringWithHeadsMixin, + "MT5EncoderModel": T5ModelAdaptersMixin, "GPT2Model": GPT2ModelAdapterMixin, "GPTJMLP": GPTJMLPAdaptersMixin, "GPTJModel": GPTJModelAdapterMixin, diff --git a/src/adapters/models/auto/adapter_model.py b/src/adapters/models/auto/adapter_model.py index 5ff84de483..2d59c6da44 100644 --- a/src/adapters/models/auto/adapter_model.py +++ b/src/adapters/models/auto/adapter_model.py @@ -23,6 +23,7 @@ ("gptj", "GPTJAdapterModel"), ("llama", "LlamaAdapterModel"), ("mbart", "MBartAdapterModel"), + ("mt5", "MT5AdapterModel"), ("roberta", "RobertaAdapterModel"), ("t5", "T5AdapterModel"), ("vit", "ViTAdapterModel"), diff --git a/src/adapters/models/mt5/__init__.py b/src/adapters/models/mt5/__init__.py new file mode 100644 index 0000000000..1a3469d953 --- /dev/null +++ b/src/adapters/models/mt5/__init__.py @@ -0,0 +1,39 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The Adapter-Hub Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import _LazyModule + + +_import_structure = { + "adapter_model": ["MT5AdapterModel"], +} + + +if TYPE_CHECKING: + from .adapter_model import MT5AdapterModel + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + ) diff --git a/src/adapters/models/mt5/adapter_model.py b/src/adapters/models/mt5/adapter_model.py new file mode 100644 index 0000000000..58bb236469 --- /dev/null +++ b/src/adapters/models/mt5/adapter_model.py @@ -0,0 +1,266 @@ +import logging + +import torch + +from transformers.models.mt5.modeling_mt5 import ( + MT5_INPUTS_DOCSTRING, + MT5_START_DOCSTRING, + MT5Model, + MT5PreTrainedModel, +) +from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward + +from ...composition import adjust_tensors_for_parallel +from ...heads import ( + ClassificationHead, + ModelWithFlexibleHeadsAdaptersMixin, + MultiLabelClassificationHead, + QuestionAnsweringHead, + Seq2SeqLMHead, +) +from ...model_mixin import EmbeddingAdaptersWrapperMixin +from ...wrappers import init + + +logger = logging.getLogger(__name__) + + +@add_start_docstrings( + "MT5 Model with the option to add multiple flexible prediction heads on top.", MT5_START_DOCSTRING +) +class MT5AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, MT5PreTrainedModel): + _tied_weights_keys = [ + "encoder.embed_tokens.weight", + "decoder.embed_tokens.weight", + ] + + _keys_to_ignore_on_load_unexpected = [ + r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight", + ] + + def __init__(self, config): + super().__init__(config) + + self.transformer = MT5Model(config) + init(self.transformer) + + self._init_head_modules() + + self.init_weights() + + # Model parallel + self.model_parallel = False + self.device_map = None + + def get_encoder(self): + return self.transformer.encoder + + def get_decoder(self): + return self.transformer.decoder + + @add_start_docstrings_to_model_forward(MT5_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + past_key_values=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + head=None, + output_adapter_gating_scores=False, + output_adapter_fusion_attentions=False, + **kwargs + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if decoder_input_ids is None and decoder_inputs_embeds is None: + # Check if we're using a LM head + if labels is not None and any([isinstance(head, Seq2SeqLMHead) for head in self._get_used_heads(head)]): + # get decoder inputs from shifting lm labels to the right + decoder_input_ids = self._shift_right(labels) + else: + # decoder_input_ids from input_ids if no decoder_input_ids are provided + decoder_input_ids = self._shift_right(input_ids) + + model_output, context = self.transformer( + input_ids=input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + output_adapter_gating_scores=output_adapter_gating_scores, + output_adapter_fusion_attentions=output_adapter_fusion_attentions, + adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), + output_context=True, + ) + # required e.g. for prompt tuning in all models + kwargs["context"] = context + sequence_output = model_output[0] + # ToDo move head to device for parallel forward pass + + if self.config.tie_word_embeddings: + # Rescale output before projecting on vocab + # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 + new_hidden_state = sequence_output * (self.config.d_model**-0.5) + if isinstance(model_output, tuple): + model_output = (new_hidden_state,) + model_output[1:] + else: + model_output["last_hidden_state"] = new_hidden_state + + # sequence classification based on last token in sequence + if input_ids is not None and sequence_output.shape[1] == input_ids.shape[1]: + eos_mask = input_ids.eq(self.config.eos_token_id) + (eos_mask,) = adjust_tensors_for_parallel(sequence_output, eos_mask) + if len(torch.unique(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + cls_representation = sequence_output[eos_mask, :].view( + sequence_output.size(0), -1, sequence_output.size(-1) + )[:, -1, :] + else: + cls_representation = sequence_output + + kwargs["labels"] = labels + head_outputs = self.forward_head( + model_output, + head_name=head, + cls_output=cls_representation, + return_dict=return_dict, + **kwargs, + ) + return head_outputs + + # Copied from T5ForConditionalGeneration + def prepare_inputs_for_generation( + self, + input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs + ): + # cut decoder_input_ids if past is used + if past is not None: + input_ids = input_ids[:, -1:] + + return { + "decoder_input_ids": input_ids, + "past_key_values": past, + "encoder_outputs": encoder_outputs, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, + "adapter_input_parallelized": kwargs.pop("adapter_input_parallelized", False), + } + + # Copied from T5ForConditionalGeneration + def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): + return self._shift_right(labels) + + # Copied from T5ForConditionalGeneration + def _reorder_cache(self, past, beam_idx): + # if decoder past is not included in output + # speedy decoding is disabled and no need to reorder + if past is None: + logger.warning("You might want to consider setting `use_cache=True` to speed up decoding") + return past + + reordered_decoder_past = () + for layer_past_states in past: + # get the correct batch idx from layer past batch dim + # batch dim of `past` is at 2nd position + reordered_layer_past_states = () + for layer_past_state in layer_past_states: + # need to set correct `past` for each of the four key / value states + reordered_layer_past_states = reordered_layer_past_states + ( + layer_past_state.index_select(0, beam_idx.to(layer_past_state.device)), + ) + + assert reordered_layer_past_states[0].shape == layer_past_states[0].shape + assert len(reordered_layer_past_states) == len(layer_past_states) + + reordered_decoder_past = reordered_decoder_past + (reordered_layer_past_states,) + return reordered_decoder_past + + head_types = { + "seq2seq_lm": Seq2SeqLMHead, + "question_answering": QuestionAnsweringHead, + "classification": ClassificationHead, + "multilabel_classification": MultiLabelClassificationHead, + } + + def add_seq2seq_lm_head(self, head_name, overwrite_ok=False): + """ + Adds a seq2seq language modeling head on top of the model. + + Args: + head_name (str): The name of the head. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = Seq2SeqLMHead(self, head_name) + self.add_prediction_head(head, overwrite_ok=overwrite_ok) + + def add_qa_head( + self, + head_name, + num_labels=2, + layers=1, + activation_function=None, + overwrite_ok=False, + id2label=None, + ): + head = QuestionAnsweringHead(self, head_name, num_labels, layers, activation_function, id2label) + self.add_prediction_head(head, overwrite_ok) + + def add_classification_head( + self, + head_name, + num_labels=2, + layers=2, + activation_function="tanh", + overwrite_ok=False, + multilabel=False, + id2label=None, + ): + """ + Adds a sequence classification head on top of the model. + + Args: + head_name (str): The name of the head. + num_labels (int, optional): Number of classification labels. Defaults to 2. + layers (int, optional): Number of layers. Defaults to 2. + activation_function (str, optional): Activation function. Defaults to 'tanh'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + multilabel (bool, optional): Enable multilabel classification setup. Defaults to False. + """ + + if multilabel: + head = MultiLabelClassificationHead(self, head_name, num_labels, layers, activation_function, id2label) + else: + head = ClassificationHead(self, head_name, num_labels, layers, activation_function, id2label) + self.add_prediction_head(head, overwrite_ok) diff --git a/src/adapters/models/mt5/modeling_mt5.py b/src/adapters/models/mt5/modeling_mt5.py new file mode 100644 index 0000000000..12ad630a74 --- /dev/null +++ b/src/adapters/models/mt5/modeling_mt5.py @@ -0,0 +1,484 @@ +# coding=utf-8 +# Copyright 2018 Mesh TensorFlow authors, MT5 Authors and HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch MT5 model.""" + +import torch +from torch import nn +from torch.utils.checkpoint import checkpoint + +from transformers.modeling_outputs import BaseModelOutputWithPastAndCrossAttentions +from transformers.models.mt5.modeling_mt5 import ( + MT5Attention, + MT5LayerCrossAttention, + MT5LayerFF, + MT5LayerSelfAttention, + MT5Stack, +) +from transformers.utils import logging + +from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel +from ..t5.mixin_t5 import ( + T5AttentionAdaptersMixin, + T5CrossAttentionLayerAdaptersMixin, + T5FFLayerAdaptersMixin, + T5SelfAttentionLayerAdaptersMixin, + T5StackAdaptersMixin, +) + + +logger = logging.get_logger(__name__) + + +class MT5LayerFFWithAdapters(T5FFLayerAdaptersMixin, MT5LayerFF): + def forward(self, hidden_states): + forwarded_states = self.layer_norm(hidden_states) + forwarded_states = self.DenseReluDense(forwarded_states) + hidden_states = self.bottleneck_layer_forward( + hidden_states=self.dropout(forwarded_states), residual_input=hidden_states, layer_norm=None + ) + return hidden_states + + +class MT5AttentionWithAdapters(T5AttentionAdaptersMixin, MT5Attention): + def forward( + self, + hidden_states, + mask=None, + key_value_states=None, + position_bias=None, + past_key_value=None, + layer_head_mask=None, + query_length=None, + use_cache=False, + output_attentions=False, + ): + """ + Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). + """ + # Input is (batch_size, seq_length, dim) + # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) + # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + batch_size, seq_length = hidden_states.shape[:2] + + real_seq_length = seq_length + + if past_key_value is not None: + assert ( + len(past_key_value) == 2 + ), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" + real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length + + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + + def shape(states): + """projection""" + return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + + def unshape(states): + """reshape""" + return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + + def project(hidden_states, proj_layer, key_value_states, past_key_value): + """projects hidden states correctly to key/query states""" + if key_value_states is None: + # self-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(hidden_states)) + elif past_key_value is None: + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + + if past_key_value is not None: + if key_value_states is None: + # self-attn + # (batch_size, n_heads, key_length, dim_per_head) + hidden_states = torch.cat([past_key_value, hidden_states], dim=2) + elif past_key_value.shape[2] != key_value_states.shape[1]: + # checking that the `sequence_length` of the `past_key_value` is the same as + # the provided `key_value_states` to support prefix tuning + # cross-attn + # (batch_size, n_heads, seq_length, dim_per_head) + hidden_states = shape(proj_layer(key_value_states)) + else: + # cross-attn + hidden_states = past_key_value + return hidden_states + + # get query states + query_states = shape(self.q(hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head) + + # get key/value states + key_states = project( + hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None + ) + value_states = project( + hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None + ) + + query_states, key_states, value_states = match_attn_matrices_for_parallel( + query_states, key_states, value_states + ) + (mask,) = adjust_tensors_for_parallel(query_states, mask) + + present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None + + key_states, value_states, mask = self.prefix_tuning(key_states, value_states, hidden_states, mask) + (query_states,) = adjust_tensors_for_parallel(key_states, query_states) + batch_size, key_length = key_states.shape[0], key_states.shape[2] + + # compute scores + scores = torch.matmul( + query_states, key_states.transpose(3, 2) + ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + + if position_bias is None: + if not self.has_relative_attention_bias: + position_bias = torch.zeros( + (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + ) + if self.gradient_checkpointing and self.training: + position_bias.requires_grad = True + else: + position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) + + # if key and values are already calculated + # we want only the last query position bias + if past_key_value is not None: + position_bias = position_bias[:, :, -hidden_states.size(1) :, :] + + if mask is not None: + position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + + if self.pruned_heads: + mask = torch.ones(position_bias.shape[1]) + mask[list(self.pruned_heads)] = 0 + position_bias_masked = position_bias[:, mask.bool()] + else: + position_bias_masked = position_bias + + scores += position_bias_masked + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as( + scores + ) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=self.training + ) # (batch_size, n_heads, seq_length, key_length) + + # Mask heads if we want to + if layer_head_mask is not None: + attn_weights = attn_weights * layer_head_mask + + attn_output = unshape(torch.matmul(attn_weights, value_states)) # (batch_size, seq_length, dim) + attn_output = self.o(attn_output) + + outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + + if output_attentions: + outputs = outputs + (attn_weights,) + return outputs + + +class MT5LayerSelfAttentionWithAdapters(T5SelfAttentionLayerAdaptersMixin, MT5LayerSelfAttention): + def forward( + self, + hidden_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.SelfAttention( + normed_hidden_states, + mask=attention_mask, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + hidden_states = self.bottleneck_layer_forward( + hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None + ) + outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them + return outputs + + +class MT5LayerCrossAttentionWithAdapters(T5CrossAttentionLayerAdaptersMixin, MT5LayerCrossAttention): + def forward( + self, + hidden_states, + key_value_states, + attention_mask=None, + position_bias=None, + layer_head_mask=None, + past_key_value=None, + use_cache=False, + query_length=None, + output_attentions=False, + ): + normed_hidden_states = self.layer_norm(hidden_states) + attention_output = self.EncDecAttention( + normed_hidden_states, + mask=attention_mask, + key_value_states=key_value_states, + position_bias=position_bias, + layer_head_mask=layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + query_length=query_length, + output_attentions=output_attentions, + ) + layer_output = self.bottleneck_layer_forward( + hidden_states=self.dropout(attention_output[0]), residual_input=hidden_states, layer_norm=None + ) + outputs = (layer_output,) + attention_output[1:] # add attentions if we output them + return outputs + + +class MT5StackWithAdapters(T5StackAdaptersMixin, MT5Stack): + def forward( + self, + input_ids=None, + attention_mask=None, + encoder_hidden_states=None, + encoder_attention_mask=None, + inputs_embeds=None, + head_mask=None, + cross_attn_head_mask=None, + past_key_values=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + # Model parallel + if self.model_parallel: + torch.cuda.set_device(self.first_device) + self.embed_tokens = self.embed_tokens.to(self.first_device) + use_cache = use_cache if use_cache is not None else self.config.use_cache + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + if self.is_decoder and encoder_hidden_states is not None: + input_ids, encoder_attention_mask = adjust_tensors_for_parallel( + encoder_hidden_states, input_ids, encoder_attention_mask + ) + + if input_ids is not None and inputs_embeds is not None: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError( + f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time" + ) + elif input_ids is not None: + input_shape = input_ids.size() + input_ids = input_ids.view(-1, input_shape[-1]) + elif inputs_embeds is not None: + input_shape = inputs_embeds.size()[:-1] + else: + err_msg_prefix = "decoder_" if self.is_decoder else "" + raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds") + + if inputs_embeds is None: + if self.embed_tokens is None: + raise ValueError("You have to initialize the model with valid token embeddings") + inputs_embeds = self.embed_tokens(input_ids) + + batch_size, seq_length = input_shape + + # required mask seq length can be calculated via length of past + mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + + if use_cache is True: + if not self.is_decoder: + raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder") + + if attention_mask is None: + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) + if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: + encoder_seq_length = encoder_hidden_states.shape[1] + encoder_attention_mask = torch.ones( + batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + ) + + # initialize past_key_values with `None` if past does not exist + if past_key_values is None: + past_key_values = [None] * len(self.block) + + # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] + # ourselves in which case we just need to make it broadcastable to all heads. + extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + + # If a 2D or 3D attention mask is provided for the cross-attention + # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] + if self.is_decoder and encoder_hidden_states is not None: + encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() + encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) + if encoder_attention_mask is None: + encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device) + encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) + else: + encoder_extended_attention_mask = None + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # Prepare head mask if needed + head_mask = self.get_head_mask(head_mask, self.config.num_layers) + cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) + present_key_value_states = () if use_cache else None + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + all_cross_attentions = () if (output_attentions and self.is_decoder) else None + position_bias = None + encoder_decoder_position_bias = None + + hidden_states = self.dropout(inputs_embeds) + if not self.is_decoder: + hidden_states = self.post_embedding_forward(hidden_states) + + for i, (layer_module, past_key_value) in enumerate(zip(self.block, past_key_values)): + layer_head_mask = head_mask[i] + cross_attn_layer_head_mask = cross_attn_head_mask[i] + # Model parallel + if self.model_parallel: + torch.cuda.set_device(hidden_states.device) + # Ensure that attention_mask is always on the same device as hidden_states + if attention_mask is not None: + attention_mask = attention_mask.to(hidden_states.device) + if position_bias is not None: + position_bias = position_bias.to(hidden_states.device) + if encoder_hidden_states is not None: + encoder_hidden_states = encoder_hidden_states.to(hidden_states.device) + if encoder_extended_attention_mask is not None: + encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device) + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device) + if layer_head_mask is not None: + layer_head_mask = layer_head_mask.to(hidden_states.device) + if cross_attn_layer_head_mask is not None: + cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device) + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + return tuple(module(*inputs, use_cache, output_attentions)) + + return custom_forward + + layer_outputs = checkpoint( + create_custom_forward(layer_module), + hidden_states, + extended_attention_mask, + position_bias, + encoder_hidden_states, + encoder_extended_attention_mask, + encoder_decoder_position_bias, + layer_head_mask, + cross_attn_layer_head_mask, + None, # past_key_value is always None with gradient checkpointing + ) + else: + layer_outputs = layer_module( + hidden_states, + attention_mask=extended_attention_mask, + position_bias=position_bias, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_extended_attention_mask, + encoder_decoder_position_bias=encoder_decoder_position_bias, + layer_head_mask=layer_head_mask, + cross_attn_layer_head_mask=cross_attn_layer_head_mask, + past_key_value=past_key_value, + use_cache=use_cache, + output_attentions=output_attentions, + ) + + # layer_outputs is a tuple with: + # hidden-states, key-value-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights) + if use_cache is False: + layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] + + hidden_states, present_key_value_state = layer_outputs[:2] + + attention_mask, extended_attention_mask = adjust_tensors_for_parallel( + hidden_states, attention_mask, extended_attention_mask + ) + + # We share the position biases between the layers - the first layer store them + # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), + # (cross-attention position bias), (cross-attention weights) + position_bias = layer_outputs[2] + if self.is_decoder and encoder_hidden_states is not None: + encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] + # append next layer key value states + if use_cache: + present_key_value_states = present_key_value_states + (present_key_value_state,) + + if position_bias is not None: + position_bias = adjust_tensors_for_parallel(hidden_states, position_bias)[0] + if encoder_decoder_position_bias is not None: + encoder_decoder_position_bias = adjust_tensors_for_parallel( + hidden_states, encoder_decoder_position_bias + )[0] + + if output_attentions: + all_attentions = all_attentions + (layer_outputs[3],) + if self.is_decoder: + all_cross_attentions = all_cross_attentions + (layer_outputs[5],) + + # Model Parallel: If it's the last layer for that device, put things on the next device + if self.model_parallel: + for k, v in self.device_map.items(): + if i == v[-1] and "cuda:" + str(k) != self.last_device: + hidden_states = hidden_states.to("cuda:" + str(k + 1)) + + hidden_states = self.final_layer_norm(hidden_states) + hidden_states = self.dropout(hidden_states) + + # Add last layer + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple( + v + for v in [ + hidden_states, + present_key_value_states, + all_hidden_states, + all_attentions, + all_cross_attentions, + ] + if v is not None + ) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=hidden_states, + past_key_values=present_key_value_states, + hidden_states=all_hidden_states, + attentions=all_attentions, + cross_attentions=all_cross_attentions, + ) diff --git a/tests/models/test_mt5.py b/tests/models/test_mt5.py new file mode 100644 index 0000000000..8d9f551e8b --- /dev/null +++ b/tests/models/test_mt5.py @@ -0,0 +1,12 @@ +# flake8: noqa: F403,F405 +from adapters import MT5AdapterModel +from hf_transformers.tests.models.mt5.test_modeling_mt5 import * +from transformers.testing_utils import require_torch + +from .base import AdapterModelTesterMixin + + +@require_torch +class MT5AdapterModelTest(AdapterModelTesterMixin, MT5IntegrationTest): + all_model_classes = (MT5AdapterModel,) + fx_compatible = False diff --git a/tests/test_mt5.py b/tests/test_mt5.py new file mode 100644 index 0000000000..67e56add5a --- /dev/null +++ b/tests/test_mt5.py @@ -0,0 +1,66 @@ +import unittest + +from transformers import MT5Config +from transformers.testing_utils import require_torch + +from .composition.test_parallel import ParallelAdapterInferenceTestMixin, ParallelTrainingMixin +from .methods import ( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, +) +from .test_adapter import AdapterTestBase, make_config +from .test_adapter_backward_compability import CompabilityTestMixin +from .test_adapter_conversion import ModelClassConversionTestMixin +from .test_adapter_embeddings import EmbeddingTestMixin +from .test_adapter_fusion_common import AdapterFusionModelTestMixin +from .test_adapter_heads import PredictionHeadModelTestMixin + + +@require_torch +class MT5AdapterTestBase(AdapterTestBase): + config_class = MT5Config + config = make_config( + MT5Config, + d_model=16, + num_layers=2, + num_decoder_layers=2, + num_heads=4, + d_ff=4, + d_kv=16 // 4, + tie_word_embeddings=False, + decoder_start_token_id=0, + ) + tokenizer_name = "google/mt5-base" + + +@require_torch +class MT5AdapterTest( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, + EmbeddingTestMixin, + CompabilityTestMixin, + ParallelAdapterInferenceTestMixin, + ParallelTrainingMixin, + AdapterFusionModelTestMixin, + PredictionHeadModelTestMixin, + MT5AdapterTestBase, + unittest.TestCase, +): + pass + + +@require_torch +class MT5ClassConversionTest( + ModelClassConversionTestMixin, + MT5AdapterTestBase, + unittest.TestCase, +): + pass