Skip to content

Commit

Permalink
Adding MT5 support (#629)
Browse files Browse the repository at this point in the history
  • Loading branch information
sotwi authored Jan 28, 2024
1 parent 86a3024 commit 5f91178
Show file tree
Hide file tree
Showing 12 changed files with 930 additions and 0 deletions.
24 changes: 24 additions & 0 deletions docs/classes/models/mt5.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
MT5
=====

The mT5 model was presented in `mT5: A massively multilingual pre-trained text-to-text transformer
<https://arxiv.org/pdf/2010.11934.pdf>`__ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou,
Aditya Siddhant, Aditya Barua, Colin Raffel.

The abstract from the paper is the following,


- The recent "Text-to-Text Transfer Transformer" (T5) leveraged a unified text-to-text format and scale to attain
state-of-the-art results on a wide variety of English-language NLP tasks. In this paper, we introduce mT5, a
multilingual variant of T5 that was pre-trained on a new Common Crawl-based dataset covering 101 languages. We detail
the design and modified training of mT5 and demonstrate its state-of-the-art performance on many multilingual
benchmarks. We also describe a simple technique to prevent "accidental translation" in the zero-shot setting, where a
generative model chooses to (partially) translate its prediction into the wrong language. All of the code and model
checkpoints used in this work are publicly available.

MT5AdapterModel
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: adapters.MT5AdapterModel
:members:
:inherited-members: MT5PreTrainedModel
1 change: 1 addition & 0 deletions docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The table below further shows which model architectures support which adaptation
| [GPT-J](classes/models/gptj.html) |||||||| |
| [Llama](classes/models/llama.html) |||||||| |
| [MBart](classes/models/mbart.html) |||||||| |
| [MT5](classes/models/mt5.html) |||||||| |
| [RoBERTa](classes/models/roberta.html) |||||||||
| [T5](classes/models/t5.html) |||||||| |
| [ViT](classes/models/vit.html) |||||||||
Expand Down
2 changes: 2 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@
"models.gptj": ["GPTJAdapterModel"],
"models.llama": ["LlamaAdapterModel"],
"models.mbart": ["MBartAdapterModel"],
"models.mt5": ["MT5AdapterModel"],
"models.roberta": ["RobertaAdapterModel"],
"models.t5": ["T5AdapterModel"],
"models.vit": ["ViTAdapterModel"],
Expand Down Expand Up @@ -207,6 +208,7 @@
from .models.gptj import GPTJAdapterModel
from .models.llama import LlamaAdapterModel
from .models.mbart import MBartAdapterModel
from .models.mt5 import MT5AdapterModel
from .models.roberta import RobertaAdapterModel
from .models.t5 import T5AdapterModel
from .models.vit import ViTAdapterModel
Expand Down
1 change: 1 addition & 0 deletions src/adapters/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
"deberta",
"bart",
"mbart",
"mt5",
"gpt2",
"gptj",
"t5",
Expand Down
29 changes: 29 additions & 0 deletions src/adapters/head_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,35 @@
},
"layers": ["lm_head"],
},
# MT5
"MT5ForConditionalGeneration": {
"config": {
"head_type": "seq2seq_lm",
},
"layers": ["lm_head"],
},
"MT5ForQuestionAnswering": {
"config": {
"head_type": "question_answering",
"activation_function": None,
"layers": 1,
},
"layers": [None, "qa_outputs"],
},
"MT5ForSequenceClassification": {
"config": {
"head_type": "classification",
"layers": 2,
"activation_function": "tanh",
},
"layers": [
None,
"classification_head.dense",
None,
None,
"classification_head.out_proj",
],
},
# DistilBERT
"DistilBertForSequenceClassification": {
"config": {
Expand Down
5 changes: 5 additions & 0 deletions src/adapters/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@
"MBartDecoder": BartDecoderAdaptersMixin,
"MBartDecoderWrapper": BartDecoderWrapperAdaptersMixin,
"MBartModel": BartModelAdaptersMixin,
"MT5Block": T5BlockAdaptersMixin,
"MT5Model": T5ModelAdaptersMixin,
"MT5ForConditionalGeneration": T5ForCondiditionalGenerationWithHeadsMixin,
"MT5ForQuestionAnswering": T5ForQuestionAnsweringWithHeadsMixin,
"MT5EncoderModel": T5ModelAdaptersMixin,
"GPT2Model": GPT2ModelAdapterMixin,
"GPTJMLP": GPTJMLPAdaptersMixin,
"GPTJModel": GPTJModelAdapterMixin,
Expand Down
1 change: 1 addition & 0 deletions src/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
("gptj", "GPTJAdapterModel"),
("llama", "LlamaAdapterModel"),
("mbart", "MBartAdapterModel"),
("mt5", "MT5AdapterModel"),
("roberta", "RobertaAdapterModel"),
("t5", "T5AdapterModel"),
("vit", "ViTAdapterModel"),
Expand Down
39 changes: 39 additions & 0 deletions src/adapters/models/mt5/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2020 The Adapter-Hub Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from transformers.utils import _LazyModule


_import_structure = {
"adapter_model": ["MT5AdapterModel"],
}


if TYPE_CHECKING:
from .adapter_model import MT5AdapterModel

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
)
Loading

0 comments on commit 5f91178

Please sign in to comment.