diff --git a/src/transformers/adapters/mixins/m2m_100.py b/src/transformers/adapters/mixins/m2m_100.py new file mode 100644 index 0000000000..7aded68bb3 --- /dev/null +++ b/src/transformers/adapters/mixins/m2m_100.py @@ -0,0 +1,59 @@ +from typing import Iterable, Tuple + +import torch.nn as nn + +from ..layer import AdapterLayer +from ..model_mixin import ( + EmbeddingAdaptersMixin, + EmbeddingAdaptersWrapperMixin, + InvertibleAdaptersMixin, + ModelAdaptersMixin, + ModelWithHeadsAdaptersMixin, +) + +class M2M100EncoderLayerAdaptersMixin: + """Adds adapters to the M2M100EncoderLayer module of M2M100.""" + + def _init_adapter_modules(self): + self.attention_adapters = AdapterLayer("mh_adapter", self.config) + self.output_adapters = AdapterLayer("output_adapter", self.config) + self.attention_adapters._init_adapter_modules() + self.output_adapters._init_adapter_modules() + + +class M2M100DecoderLayerAdaptersMixin(M2M100EncoderLayerAdaptersMixin): + """Adds adapters to the M2M100DecoderLayer module of M2M100.""" + + def _init_adapter_modules(self): + super()._init_adapter_modules() + self.cross_attention_adapters = AdapterLayer("cross_adapter", self.config) + self.cross_attention_adapters._init_adapter_modules() + + +class M2M100ModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin): + """Adds adapters to the M2M100Model class.""" + + def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]: + if hasattr(self, "encoder"): + for i, layer in enumerate(self.encoder.layers): + yield i, layer + for i, layer in enumerate(self.decoder.layers, start=len(self.encoder.layers)): + yield i, layer + else: + for i, layer in enumerate(self.decoder.layers): + yield i, layer + + def _init_adapter_modules(self): + if hasattr(self, "encoder"): + # In M2M100, the invertible adapters are implemented by the encoder module. + # Therefore, relay mixin calls to the encoder here. + self.invertible_adapters = self.encoder.invertible_adapters + self.add_invertible_adapter = self.encoder.add_invertible_adapter + self.get_invertible_adapter = self.encoder.get_invertible_adapter + self.enable_invertible_adapters = self.encoder.enable_invertible_adapters + self.invertible_adapters_forward = self.encoder.invertible_adapters_forward + super()._init_adapter_modules() + + +class M2M100ModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin): + pass diff --git a/src/transformers/adapters/model_mixin.py b/src/transformers/adapters/model_mixin.py index 95b07ac849..6df0c2b60c 100644 --- a/src/transformers/adapters/model_mixin.py +++ b/src/transformers/adapters/model_mixin.py @@ -335,7 +335,9 @@ def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock], tra # use the adapters to be trained by default in every forward pass self.set_active_adapters(adapter_setup) if train_embeddings: - self.get_input_embeddings().train() + for param in self.get_input_embeddings().parameters(): + param.requires_grad = True + #self.get_input_embeddings().train() def train_fusion(self, adapter_setup: Union[list, AdapterCompositionBlock], unfreeze_adapters=False): """Sets the model into mode for training of adapter fusion determined by a list of adapter names.""" diff --git a/src/transformers/adapters/models/auto/adapter_model.py b/src/transformers/adapters/models/auto/adapter_model.py index 7d403f72cd..b9136f4249 100644 --- a/src/transformers/adapters/models/auto/adapter_model.py +++ b/src/transformers/adapters/models/auto/adapter_model.py @@ -19,6 +19,7 @@ ("gpt2", "GPT2AdapterModel"), ("t5", "T5AdapterModel"), ("vit", "ViTAdapterModel"), + ("M2M100", "M2M100AdapterModel"), ] ) MODEL_WITH_HEADS_MAPPING_NAMES = OrderedDict( diff --git a/src/transformers/adapters/models/m2m_100/__init__.py b/src/transformers/adapters/models/m2m_100/__init__.py new file mode 100644 index 0000000000..eda17f86b6 --- /dev/null +++ b/src/transformers/adapters/models/m2m_100/__init__.py @@ -0,0 +1,42 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + +# Copyright 2020 The Adapter-Hub Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ....utils import _LazyModule + + +_import_structure = { + "adapter_model": [ + "M2M100AdapterModel", + "M2M100ModelWithHeads", + ], +} + + +if TYPE_CHECKING: + from .adapter_model import M2M100AdapterModel, M2M100ModelWithHeads + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + ) diff --git a/src/transformers/adapters/models/m2m_100/adapter_model.py b/src/transformers/adapters/models/m2m_100/adapter_model.py new file mode 100644 index 0000000000..0448531158 --- /dev/null +++ b/src/transformers/adapters/models/m2m_100/adapter_model.py @@ -0,0 +1,245 @@ +import warnings + +import torch + +from ....models.m2m_100.modeling_m2m_100 import ( + M2M_100_INPUTS_DOCSTRING, + M2M_100_START_DOCSTRING, + M2M100Config, + M2M100Model, + M2M100PreTrainedModel, + shift_tokens_right, +) +from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward +from ...composition import adjust_tensors_for_parallel +from ...heads import ( + ClassificationHead, + ModelWithFlexibleHeadsAdaptersMixin, + MultiLabelClassificationHead, + QuestionAnsweringHead, + Seq2SeqLMHead, +) +from ...model_mixin import EmbeddingAdaptersWrapperMixin + + +@add_start_docstrings( + "M2M100 Model with the option to add multiple flexible prediction heads on top.", M2M_100_START_DOCSTRING +) +class M2M100AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, M2M100PreTrainedModel): + def __init__(self, config: M2M100Config, **kwargs): + super().__init__(config, **kwargs) + self.model = M2M100Model(config) + + self._init_head_modules() + + def get_encoder(self): + return self.model.get_encoder() + + def get_decoder(self): + return self.model.get_decoder() + + @add_start_docstrings_to_model_forward(M2M_100_INPUTS_DOCSTRING) + def forward( + self, + input_ids=None, + attention_mask=None, + decoder_input_ids=None, + decoder_attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + encoder_outputs=None, + inputs_embeds=None, + decoder_inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + past_key_values=None, + head=None, + output_adapter_gating_scores=False, + **kwargs + ): + r""" + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., + config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs: + use_cache = False + + outputs = self.model( + input_ids, + attention_mask=attention_mask, + decoder_input_ids=decoder_input_ids, + decoder_attention_mask=decoder_attention_mask, + head_mask=head_mask, + decoder_head_mask=decoder_head_mask, + cross_attn_head_mask=cross_attn_head_mask, + encoder_outputs=encoder_outputs, + inputs_embeds=inputs_embeds, + decoder_inputs_embeds=decoder_inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + past_key_values=past_key_values, + output_adapter_gating_scores=output_adapter_gating_scores, + ) + # sequence classification based on last token in sequence + x = outputs[0] # last hidden state + if input_ids is not None and x.shape[1] == input_ids.shape[1]: + eos_mask = input_ids.eq(self.config.eos_token_id) + (eos_mask,) = adjust_tensors_for_parallel(x, eos_mask) + if len(torch.unique(eos_mask.sum(1))) > 1: + raise ValueError("All examples must have the same number of tokens.") + cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :] + else: + cls_representation = x + + head_outputs = self.forward_head( + outputs, + head_name=head, + cls_output=cls_representation, + attention_mask=attention_mask, + return_dict=return_dict, + **kwargs, + ) + + return head_outputs + + # Copied from M2M100ForConditionalGeneration + def prepare_inputs_for_generation( + self, + decoder_input_ids, + past=None, + attention_mask=None, + head_mask=None, + decoder_head_mask=None, + cross_attn_head_mask=None, + use_cache=None, + encoder_outputs=None, + **kwargs, + ): + # cut decoder_input_ids if past is used + if past is not None: + decoder_input_ids = decoder_input_ids[:, -1:] + + return { + "input_ids": None, # encoder_outputs is defined. input_ids not needed + "encoder_outputs": encoder_outputs, + "past_key_values": past, + "decoder_input_ids": decoder_input_ids, + "attention_mask": attention_mask, + "head_mask": head_mask, + "decoder_head_mask": decoder_head_mask, + "cross_attn_head_mask": cross_attn_head_mask, + "use_cache": use_cache, # change this to avoid caching (presumably for debugging) + } + + + # Copied from M2M100ForConditionalGeneration + @staticmethod + def _reorder_cache(past, beam_idx): + reordered_past = () + for layer_past in past: + reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),) + return reordered_past + + head_types = { + "classification": ClassificationHead, + "multilabel_classification": MultiLabelClassificationHead, + "question_answering": QuestionAnsweringHead, + "seq2seq_lm": Seq2SeqLMHead, + } + + def add_classification_head( + self, + head_name, + num_labels=2, + layers=2, + activation_function="tanh", + overwrite_ok=False, + multilabel=False, + id2label=None, + ): + """ + Adds a sequence classification head on top of the model. + + Args: + head_name (str): The name of the head. + num_labels (int, optional): Number of classification labels. Defaults to 2. + layers (int, optional): Number of layers. Defaults to 2. + activation_function (str, optional): Activation function. Defaults to 'tanh'. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + multilabel (bool, optional): Enable multilabel classification setup. Defaults to False. + """ + + if multilabel: + head = MultiLabelClassificationHead(self, head_name, num_labels, layers, activation_function, id2label) + else: + head = ClassificationHead(self, head_name, num_labels, layers, activation_function, id2label) + self.add_prediction_head(head, overwrite_ok) + + def add_qa_head( + self, + head_name, + num_labels=2, + layers=1, + activation_function="tanh", + overwrite_ok=False, + id2label=None, + ): + head = QuestionAnsweringHead(self, head_name, num_labels, layers, activation_function, id2label) + self.add_prediction_head(head, overwrite_ok) + + def add_seq2seq_lm_head( + self, + head_name, + overwrite_ok=False, + ): + """ + Adds a sequence-to-sequence language modeling head on top of the model. + + Args: + head_name (str): The name of the head. + overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False. + """ + head = Seq2SeqLMHead(self, head_name) + self.add_prediction_head(head, overwrite_ok=overwrite_ok) + + +class M2M100ModelWithHeads(M2M100AdapterModel): + def __init__(self, *args, **kwargs): + warnings.warn( + "This class has been renamed to `{}` in v3. " + "Please use the new class instead as this class might be removed in a future version.".format( + self.__class__.__bases__[0].__name__ + ), + FutureWarning, + ) + super().__init__(*args, **kwargs) + + @classmethod + def from_config(cls, config): + warnings.warn( + "This class has been renamed to `{}` in v3. " + "Please use the new class instead as this class might be removed in a future version.".format( + cls.__bases__[0].__name__ + ), + FutureWarning, + ) + return super().from_config(config) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): + warnings.warn( + "This class has been renamed to `{}` in v3. " + "Please use the new class instead as this class might be removed in a future version.".format( + cls.__bases__[0].__name__ + ), + FutureWarning, + ) + return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs) diff --git a/src/transformers/models/m2m_100/modeling_m2m_100.py b/src/transformers/models/m2m_100/modeling_m2m_100.py index b6d97180ee..f14216cf85 100755 --- a/src/transformers/models/m2m_100/modeling_m2m_100.py +++ b/src/transformers/models/m2m_100/modeling_m2m_100.py @@ -24,6 +24,17 @@ from torch.nn import CrossEntropyLoss from ...activations import ACT2FN +from ...adapters.composition import adjust_tensors_for_parallel +from ...adapters.context import ForwardContext +from ...adapters.lora import Linear as LoRALinear +from ...adapters.mixins.m2m_100 import ( + M2M100DecoderLayerAdaptersMixin, + M2M100EncoderLayerAdaptersMixin, + M2M100ModelAdaptersMixin, + M2M100ModelWithHeadsAdaptersMixin, +) +from ...adapters.model_mixin import InvertibleAdaptersMixin +from ...adapters.prefix_tuning import PrefixTuningShim from ...deepspeed import is_deepspeed_zero3_enabled from ...modeling_outputs import ( BaseModelOutput, @@ -199,11 +210,13 @@ class M2M100Attention(nn.Module): def __init__( self, + config: M2M100Config, embed_dim: int, num_heads: int, dropout: float = 0.0, is_decoder: bool = False, bias: bool = True, + location_key: Optional[str] = None, ): super().__init__() self.embed_dim = embed_dim @@ -219,11 +232,13 @@ def __init__( self.scaling = self.head_dim**-0.5 self.is_decoder = is_decoder - self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) - self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.k_proj = LoRALinear(embed_dim, embed_dim, "selfattn", config, attn_key="k", bias=bias) + self.v_proj = LoRALinear(embed_dim, embed_dim, "selfattn", config, attn_key="v", bias=bias) + self.q_proj = LoRALinear(embed_dim, embed_dim, "selfattn", config, attn_key="q", bias=bias) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + self.prefix_tuning = PrefixTuningShim(location_key + "_prefix" if location_key else None, config) + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() @@ -342,23 +357,28 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartEncoderLayer with MBart->M2M100 -class M2M100EncoderLayer(nn.Module): +class M2M100EncoderLayer(M2M100EncoderLayerAdaptersMixin, nn.Module): def __init__(self, config: M2M100Config): super().__init__() + self.config = config self.embed_dim = config.d_model self.self_attn = M2M100Attention( + config, embed_dim=self.embed_dim, num_heads=config.encoder_attention_heads, dropout=config.attention_dropout, + location_key="encoder", ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] self.activation_dropout = config.activation_dropout - self.fc1 = nn.Linear(self.embed_dim, config.encoder_ffn_dim) - self.fc2 = nn.Linear(config.encoder_ffn_dim, self.embed_dim) + self.fc1 = LoRALinear(self.embed_dim, config.encoder_ffn_dim, "intermediate", config) + self.fc2 = LoRALinear(config.encoder_ffn_dim, self.embed_dim, "output", config) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self._init_adapter_modules() + def forward( self, hidden_states: torch.Tensor, @@ -386,7 +406,8 @@ def forward( output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states + #hidden_states = residual + hidden_states + hidden_states = self.attention_adapters(hidden_states, residual, None) residual = hidden_states hidden_states = self.final_layer_norm(hidden_states) @@ -394,7 +415,8 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states + #hidden_states = residual + hidden_states + hidden_states = self.output_adapters(hidden_states, residual, None) if hidden_states.dtype == torch.float16 and ( torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any() @@ -411,16 +433,20 @@ def forward( # Copied from transformers.models.mbart.modeling_mbart.MBartDecoderLayer with MBart->M2M100 -class M2M100DecoderLayer(nn.Module): +class M2M100DecoderLayer(M2M100DecoderLayerAdaptersMixin, nn.Module): def __init__(self, config: M2M100Config): super().__init__() + self.config = config + self.embed_dim = config.d_model self.self_attn = M2M100Attention( + config, embed_dim=self.embed_dim, num_heads=config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + location_key="self", ) self.dropout = config.dropout self.activation_fn = ACT2FN[config.activation_function] @@ -428,16 +454,20 @@ def __init__(self, config: M2M100Config): self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.encoder_attn = M2M100Attention( + config, self.embed_dim, config.decoder_attention_heads, dropout=config.attention_dropout, is_decoder=True, + location_key="cross", ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.fc1 = nn.Linear(self.embed_dim, config.decoder_ffn_dim) self.fc2 = nn.Linear(config.decoder_ffn_dim, self.embed_dim) self.final_layer_norm = nn.LayerNorm(self.embed_dim) + self._init_adapter_modules() + def forward( self, hidden_states: torch.Tensor, @@ -483,7 +513,8 @@ def forward( output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states + #hidden_states = residual + hidden_states + hidden_states = self.attention_adapters(hidden_states, residual, None) # Cross-Attention Block cross_attn_present_key_value = None @@ -503,7 +534,8 @@ def forward( output_attentions=output_attentions, ) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states + #hidden_states = residual + hidden_states + hidden_states = self.cross_attention_adapters(hidden_states, residual, None) # add cross-attn to positions 3,4 of present_key_value tuple present_key_value = present_key_value + cross_attn_present_key_value @@ -515,7 +547,8 @@ def forward( hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training) hidden_states = self.fc2(hidden_states) hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) - hidden_states = residual + hidden_states + #hidden_states = residual + hidden_states + hidden_states = self.output_adapters(hidden_states, residual, None) outputs = (hidden_states,) @@ -673,7 +706,7 @@ def _set_gradient_checkpointing(self, module, value=False): """ -class M2M100Encoder(M2M100PreTrainedModel): +class M2M100Encoder(InvertibleAdaptersMixin, M2M100PreTrainedModel): """ Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a [`M2M100EncoderLayer`]. @@ -685,6 +718,7 @@ class M2M100Encoder(M2M100PreTrainedModel): def __init__(self, config: M2M100Config, embed_tokens: Optional[nn.Embedding] = None): super().__init__(config) + self.config = config self.dropout = config.dropout self.layerdrop = config.encoder_layerdrop @@ -782,6 +816,8 @@ def forward( hidden_states = inputs_embeds + embed_pos hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training) + hidden_states = self.invertible_adapters_forward(hidden_states) + # expand attention_mask if attention_mask is not None: # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] @@ -1125,7 +1161,7 @@ def custom_forward(*inputs): "The bare M2M100 Model outputting raw hidden-states without any specific head on top.", M2M_100_START_DOCSTRING, ) -class M2M100Model(M2M100PreTrainedModel): +class M2M100Model(M2M100ModelAdaptersMixin, M2M100PreTrainedModel): def __init__(self, config: M2M100Config): super().__init__(config) @@ -1135,6 +1171,8 @@ def __init__(self, config: M2M100Config): self.encoder = M2M100Encoder(config, self.shared) self.decoder = M2M100Decoder(config, self.shared) + self._init_adapter_modules() + # Initialize weights and apply final processing self.post_init() @@ -1159,6 +1197,7 @@ def get_decoder(self): output_type=Seq2SeqModelOutput, config_class=_CONFIG_FOR_DOC, ) + @ForwardContext.wrap def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -1236,7 +1275,7 @@ def forward( @add_start_docstrings( "The M2M100 Model with a language modeling head. Can be used for summarization.", M2M_100_START_DOCSTRING ) -class M2M100ForConditionalGeneration(M2M100PreTrainedModel): +class M2M100ForConditionalGeneration(M2M100ModelWithHeadsAdaptersMixin, M2M100PreTrainedModel): base_model_prefix = "model" _keys_to_ignore_on_load_missing = [ r"encoder.version", @@ -1329,7 +1368,8 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, ) - lm_logits = self.lm_head(outputs[0]) + lm_logits = self.model.encoder.invertible_adapters_forward(outputs[0], rev=True) + lm_logits = self.lm_head(lm_logits) masked_lm_loss = None if labels is not None: diff --git a/tests_adapters/test_m2m_100.py b/tests_adapters/test_m2m_100.py new file mode 100644 index 0000000000..d437fa8fab --- /dev/null +++ b/tests_adapters/test_m2m_100.py @@ -0,0 +1,63 @@ +import unittest + +from tests.models.m2m_100.test_modeling_m2m_100 import * +from transformers import M2M100AdapterModel +from transformers.testing_utils import require_torch + +from .methods import BottleneckAdapterTestMixin, UniPELTTestMixin, CompacterTestMixin, IA3TestMixin, LoRATestMixin, PrefixTuningTestMixin +from .test_adapter import AdapterTestBase, make_config +from .test_adapter_composition import ParallelAdapterInferenceTestMixin +from .test_adapter_conversion import ModelClassConversionTestMixin +from .test_adapter_fusion_common import AdapterFusionModelTestMixin +from .test_adapter_heads import PredictionHeadModelTestMixin +from .test_common import AdapterModelTesterMixin + + +@require_torch +class M2M100AdapterModelTest(AdapterModelTesterMixin, M2M100ModelTest): + all_model_classes = ( + M2M100AdapterModel, + ) + fx_compatible = False + + +class M2M100AdapterTestBase(AdapterTestBase): + config_class = M2M100Config + config = make_config( + M2M100Config, + d_model=16, + encoder_layers=2, + decoder_layers=2, + encoder_attention_heads=4, + decoder_attention_heads=4, + encoder_ffn_dim=4, + decoder_ffn_dim=4, + vocab_size=128112, + ) + tokenizer_name = "facebook/m2m100_418M" + + +@require_torch +class M2M100AdapterTest( + BottleneckAdapterTestMixin, + CompacterTestMixin, + IA3TestMixin, + LoRATestMixin, + PrefixTuningTestMixin, + UniPELTTestMixin, + AdapterFusionModelTestMixin, + PredictionHeadModelTestMixin, + ParallelAdapterInferenceTestMixin, + M2M100AdapterTestBase, + unittest.TestCase, +): + pass + + +@require_torch +class M2M100ClassConversionTest( + ModelClassConversionTestMixin, + M2M100AdapterTestBase, + unittest.TestCase, +): + pass diff --git a/utils/check_adapters.py b/utils/check_adapters.py index 3c34fb39cc..9c0a820435 100644 --- a/utils/check_adapters.py +++ b/utils/check_adapters.py @@ -16,6 +16,7 @@ "deberta", "deberta-v2", "vit", + "m2m_100", ] IGNORE_NOT_IMPLEMENTING_MIXIN = [ @@ -23,6 +24,8 @@ "BartDecoder", "MBartEncoder", "MBartDecoder", + "M2M100Encoder", + "M2M100Decoder", "T5Stack", ]