Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for M2M100 #445

Open
wants to merge 2 commits into
base: legacy
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions src/transformers/adapters/mixins/m2m_100.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import Iterable, Tuple

import torch.nn as nn

from ..layer import AdapterLayer
from ..model_mixin import (
EmbeddingAdaptersMixin,
EmbeddingAdaptersWrapperMixin,
InvertibleAdaptersMixin,
ModelAdaptersMixin,
ModelWithHeadsAdaptersMixin,
)

class M2M100EncoderLayerAdaptersMixin:
"""Adds adapters to the M2M100EncoderLayer module of M2M100."""

def _init_adapter_modules(self):
self.attention_adapters = AdapterLayer("mh_adapter", self.config)
self.output_adapters = AdapterLayer("output_adapter", self.config)
self.attention_adapters._init_adapter_modules()
self.output_adapters._init_adapter_modules()


class M2M100DecoderLayerAdaptersMixin(M2M100EncoderLayerAdaptersMixin):
"""Adds adapters to the M2M100DecoderLayer module of M2M100."""

def _init_adapter_modules(self):
super()._init_adapter_modules()
self.cross_attention_adapters = AdapterLayer("cross_adapter", self.config)
self.cross_attention_adapters._init_adapter_modules()


class M2M100ModelAdaptersMixin(EmbeddingAdaptersMixin, InvertibleAdaptersMixin, ModelAdaptersMixin):
"""Adds adapters to the M2M100Model class."""

def iter_layers(self) -> Iterable[Tuple[int, nn.Module]]:
if hasattr(self, "encoder"):
for i, layer in enumerate(self.encoder.layers):
yield i, layer
for i, layer in enumerate(self.decoder.layers, start=len(self.encoder.layers)):
yield i, layer
else:
for i, layer in enumerate(self.decoder.layers):
yield i, layer

def _init_adapter_modules(self):
if hasattr(self, "encoder"):
# In M2M100, the invertible adapters are implemented by the encoder module.
# Therefore, relay mixin calls to the encoder here.
self.invertible_adapters = self.encoder.invertible_adapters
self.add_invertible_adapter = self.encoder.add_invertible_adapter
self.get_invertible_adapter = self.encoder.get_invertible_adapter
self.enable_invertible_adapters = self.encoder.enable_invertible_adapters
self.invertible_adapters_forward = self.encoder.invertible_adapters_forward
super()._init_adapter_modules()


class M2M100ModelWithHeadsAdaptersMixin(EmbeddingAdaptersWrapperMixin, ModelWithHeadsAdaptersMixin):
pass
4 changes: 3 additions & 1 deletion src/transformers/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,9 @@ def train_adapter(self, adapter_setup: Union[list, AdapterCompositionBlock], tra
# use the adapters to be trained by default in every forward pass
self.set_active_adapters(adapter_setup)
if train_embeddings:
self.get_input_embeddings().train()
for param in self.get_input_embeddings().parameters():
param.requires_grad = True
#self.get_input_embeddings().train()

def train_fusion(self, adapter_setup: Union[list, AdapterCompositionBlock], unfreeze_adapters=False):
"""Sets the model into mode for training of adapter fusion determined by a list of adapter names."""
Expand Down
1 change: 1 addition & 0 deletions src/transformers/adapters/models/auto/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
("gpt2", "GPT2AdapterModel"),
("t5", "T5AdapterModel"),
("vit", "ViTAdapterModel"),
("M2M100", "M2M100AdapterModel"),
]
)
MODEL_WITH_HEADS_MAPPING_NAMES = OrderedDict(
Expand Down
42 changes: 42 additions & 0 deletions src/transformers/adapters/models/m2m_100/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.

# Copyright 2020 The Adapter-Hub Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from ....utils import _LazyModule


_import_structure = {
"adapter_model": [
"M2M100AdapterModel",
"M2M100ModelWithHeads",
],
}


if TYPE_CHECKING:
from .adapter_model import M2M100AdapterModel, M2M100ModelWithHeads

else:
import sys

sys.modules[__name__] = _LazyModule(
__name__,
globals()["__file__"],
_import_structure,
)
245 changes: 245 additions & 0 deletions src/transformers/adapters/models/m2m_100/adapter_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import warnings

import torch

from ....models.m2m_100.modeling_m2m_100 import (
M2M_100_INPUTS_DOCSTRING,
M2M_100_START_DOCSTRING,
M2M100Config,
M2M100Model,
M2M100PreTrainedModel,
shift_tokens_right,
)
from ....utils import add_start_docstrings, add_start_docstrings_to_model_forward
from ...composition import adjust_tensors_for_parallel
from ...heads import (
ClassificationHead,
ModelWithFlexibleHeadsAdaptersMixin,
MultiLabelClassificationHead,
QuestionAnsweringHead,
Seq2SeqLMHead,
)
from ...model_mixin import EmbeddingAdaptersWrapperMixin


@add_start_docstrings(
"M2M100 Model with the option to add multiple flexible prediction heads on top.", M2M_100_START_DOCSTRING
)
class M2M100AdapterModel(EmbeddingAdaptersWrapperMixin, ModelWithFlexibleHeadsAdaptersMixin, M2M100PreTrainedModel):
def __init__(self, config: M2M100Config, **kwargs):
super().__init__(config, **kwargs)
self.model = M2M100Model(config)

self._init_head_modules()

def get_encoder(self):
return self.model.get_encoder()

def get_decoder(self):
return self.model.get_decoder()

@add_start_docstrings_to_model_forward(M2M_100_INPUTS_DOCSTRING)
def forward(
self,
input_ids=None,
attention_mask=None,
decoder_input_ids=None,
decoder_attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
encoder_outputs=None,
inputs_embeds=None,
decoder_inputs_embeds=None,
use_cache=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
past_key_values=None,
head=None,
output_adapter_gating_scores=False,
**kwargs
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
config.num_labels - 1]`. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs:
use_cache = False

outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
encoder_outputs=encoder_outputs,
inputs_embeds=inputs_embeds,
decoder_inputs_embeds=decoder_inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
past_key_values=past_key_values,
output_adapter_gating_scores=output_adapter_gating_scores,
)
# sequence classification based on last token in sequence
x = outputs[0] # last hidden state
if input_ids is not None and x.shape[1] == input_ids.shape[1]:
eos_mask = input_ids.eq(self.config.eos_token_id)
(eos_mask,) = adjust_tensors_for_parallel(x, eos_mask)
if len(torch.unique(eos_mask.sum(1))) > 1:
raise ValueError("All examples must have the same number of <eos> tokens.")
cls_representation = x[eos_mask, :].view(x.size(0), -1, x.size(-1))[:, -1, :]
else:
cls_representation = x

head_outputs = self.forward_head(
outputs,
head_name=head,
cls_output=cls_representation,
attention_mask=attention_mask,
return_dict=return_dict,
**kwargs,
)

return head_outputs

# Copied from M2M100ForConditionalGeneration
def prepare_inputs_for_generation(
self,
decoder_input_ids,
past=None,
attention_mask=None,
head_mask=None,
decoder_head_mask=None,
cross_attn_head_mask=None,
use_cache=None,
encoder_outputs=None,
**kwargs,
):
# cut decoder_input_ids if past is used
if past is not None:
decoder_input_ids = decoder_input_ids[:, -1:]

return {
"input_ids": None, # encoder_outputs is defined. input_ids not needed
"encoder_outputs": encoder_outputs,
"past_key_values": past,
"decoder_input_ids": decoder_input_ids,
"attention_mask": attention_mask,
"head_mask": head_mask,
"decoder_head_mask": decoder_head_mask,
"cross_attn_head_mask": cross_attn_head_mask,
"use_cache": use_cache, # change this to avoid caching (presumably for debugging)
}


# Copied from M2M100ForConditionalGeneration
@staticmethod
def _reorder_cache(past, beam_idx):
reordered_past = ()
for layer_past in past:
reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
return reordered_past

head_types = {
"classification": ClassificationHead,
"multilabel_classification": MultiLabelClassificationHead,
"question_answering": QuestionAnsweringHead,
"seq2seq_lm": Seq2SeqLMHead,
}

def add_classification_head(
self,
head_name,
num_labels=2,
layers=2,
activation_function="tanh",
overwrite_ok=False,
multilabel=False,
id2label=None,
):
"""
Adds a sequence classification head on top of the model.

Args:
head_name (str): The name of the head.
num_labels (int, optional): Number of classification labels. Defaults to 2.
layers (int, optional): Number of layers. Defaults to 2.
activation_function (str, optional): Activation function. Defaults to 'tanh'.
overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False.
multilabel (bool, optional): Enable multilabel classification setup. Defaults to False.
"""

if multilabel:
head = MultiLabelClassificationHead(self, head_name, num_labels, layers, activation_function, id2label)
else:
head = ClassificationHead(self, head_name, num_labels, layers, activation_function, id2label)
self.add_prediction_head(head, overwrite_ok)

def add_qa_head(
self,
head_name,
num_labels=2,
layers=1,
activation_function="tanh",
overwrite_ok=False,
id2label=None,
):
head = QuestionAnsweringHead(self, head_name, num_labels, layers, activation_function, id2label)
self.add_prediction_head(head, overwrite_ok)

def add_seq2seq_lm_head(
self,
head_name,
overwrite_ok=False,
):
"""
Adds a sequence-to-sequence language modeling head on top of the model.

Args:
head_name (str): The name of the head.
overwrite_ok (bool, optional): Force overwrite if a head with the same name exists. Defaults to False.
"""
head = Seq2SeqLMHead(self, head_name)
self.add_prediction_head(head, overwrite_ok=overwrite_ok)


class M2M100ModelWithHeads(M2M100AdapterModel):
def __init__(self, *args, **kwargs):
warnings.warn(
"This class has been renamed to `{}` in v3. "
"Please use the new class instead as this class might be removed in a future version.".format(
self.__class__.__bases__[0].__name__
),
FutureWarning,
)
super().__init__(*args, **kwargs)

@classmethod
def from_config(cls, config):
warnings.warn(
"This class has been renamed to `{}` in v3. "
"Please use the new class instead as this class might be removed in a future version.".format(
cls.__bases__[0].__name__
),
FutureWarning,
)
return super().from_config(config)

@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
warnings.warn(
"This class has been renamed to `{}` in v3. "
"Please use the new class instead as this class might be removed in a future version.".format(
cls.__bases__[0].__name__
),
FutureWarning,
)
return super().from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
Loading