Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding MT5 support #629

Merged
merged 11 commits into from
Jan 28, 2024
24 changes: 24 additions & 0 deletions docs/classes/models/mt5.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
MT5
=====

The mT5 model was presented in `mT5: A massively multilingual pre-trained text-to-text transformer
<https://arxiv.org/pdf/2010.11934.pdf>`__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou,
Aditya Siddhant, Aditya Barua, Colin Raffel.

The abstract from the paper is the following,


- The recent "Text-to-Text Transfer Transformer" (T5) leveraged a unified text-to-text format and scale to attain
state-of-the-art results on a wide variety of English-language NLP tasks. In this paper, we introduce mT5, a
multilingual variant of T5 that was pre-trained on a new Common Crawl-based dataset covering 101 languages. We detail
the design and modified training of mT5 and demonstrate its state-of-the-art performance on many multilingual
benchmarks. We also describe a simple technique to prevent "accidental translation" in the zero-shot setting, where a
generative model chooses to (partially) translate its prediction into the wrong language. All of the code and model
checkpoints used in this work are publicly available.

MT5AdapterModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: adapters.MT5AdapterModel
:members:
:inherited-members: MT5PreTrainedModel
1 change: 1 addition & 0 deletions docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The table below further shows which model architectures support which adaptation
| [GPT-J](classes/models/gptj.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [Llama](classes/models/llama.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [MBart](classes/models/mbart.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [MT5](classes/models/mt5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [RoBERTa](classes/models/roberta.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
| [T5](classes/models/t5.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | |
| [ViT](classes/models/vit.html) | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
Expand Down
2 changes: 2 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
"models.gptj": ["GPTJAdapterModel"],
"models.llama": ["LlamaAdapterModel"],
"models.mbart": ["MBartAdapterModel"],
"models.mt5": ["MT5AdapterModel"],
"models.roberta": ["RobertaAdapterModel"],
"models.t5": ["T5AdapterModel"],
"models.vit": ["ViTAdapterModel"],
Expand Down Expand Up @@ -207,6 +208,7 @@
from .models.gptj import GPTJAdapterModel
from .models.llama import LlamaAdapterModel
from .models.mbart import MBartAdapterModel
from .models.mt5 import MT5AdapterModel
from .models.roberta import RobertaAdapterModel
from .models.t5 import T5AdapterModel
from .models.vit import ViTAdapterModel
Expand Down
1 change: 1 addition & 0 deletions src/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
"deberta",
"bart",
"mbart",
"mt5",
"gpt2",
"gptj",
"t5",
Expand Down
29 changes: 29 additions & 0 deletions src/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,35 @@
},
"layers": ["lm_head"],
},
# MT5
"MT5ForConditionalGeneration": {
"config": {
"head_type": "seq2seq_lm",
},
"layers": ["lm_head"],
},
"MT5ForQuestionAnswering": {
"config": {
"head_type": "question_answering",
"activation_function": None,
"layers": 1,
},
"layers": [None, "qa_outputs"],
},
"MT5ForSequenceClassification": {
"config": {
"head_type": "classification",
"layers": 2,
"activation_function": "tanh",
},
"layers": [
None,
"classification_head.dense",
None,
None,
"classification_head.out_proj",
],
},
# DistilBERT
"DistilBertForSequenceClassification": {
"config": {
Expand Down
5 changes: 5 additions & 0 deletions src/adapters/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@
"MBartDecoder": BartDecoderAdaptersMixin,
"MBartDecoderWrapper": BartDecoderWrapperAdaptersMixin,
"MBartModel": BartModelAdaptersMixin,
"MT5Block": T5BlockAdaptersMixin,
"MT5Model": T5ModelAdaptersMixin,
"MT5ForConditionalGeneration": T5ForCondiditionalGenerationWithHeadsMixin,
"MT5ForQuestionAnswering": T5ForQuestionAnsweringWithHeadsMixin,
"MT5EncoderModel": T5ModelAdaptersMixin,
"GPT2Model": GPT2ModelAdapterMixin,
"GPTJMLP": GPTJMLPAdaptersMixin,
"GPTJModel": GPTJModelAdapterMixin,
Expand Down
1 change: 1 addition & 0 deletions src/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
("gptj", "GPTJAdapterModel"),
("llama", "LlamaAdapterModel"),
("mbart", "MBartAdapterModel"),
("mt5", "MT5AdapterModel"),
("roberta", "RobertaAdapterModel"),
("t5", "T5AdapterModel"),
("vit", "ViTAdapterModel"),
Expand Down
39 changes: 39 additions & 0 deletions src/adapters/models/mt5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2020 The Adapter-Hub Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from transformers.utils import _LazyModule


_import_structure = {
"adapter_model": ["MT5AdapterModel"],
}


if TYPE_CHECKING:
from .adapter_model import MT5AdapterModel

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
)
Loading
Loading