Skip to content

Commit

Permalink
Add Prompt Tuning (#595)
Browse files Browse the repository at this point in the history
This PR adds support for Prompt Tuning
(https://aclanthology.org/2021.emnlp-main.243/)

---------

Co-authored-by: calpt <[email protected]>
  • Loading branch information
lenglaender and calpt authored Nov 19, 2023
1 parent d45d951 commit 84f4f04
Show file tree
Hide file tree
Showing 77 changed files with 799 additions and 114 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ instance/

# Sphinx documentation
docs/_build/
docs/_build/
adapter_docs/_build/

# PyBuilder
target/
Expand Down
7 changes: 7 additions & 0 deletions docs/classes/adapter_config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ IA3Config
:members:
:inherited-members: Mapping

PromptTuningConfig
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: adapters.PromptTuningConfig
:members:
:inherited-members: Mapping

Combined configurations
~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
28 changes: 28 additions & 0 deletions docs/methods.md
Original file line number Diff line number Diff line change
Expand Up @@ -267,3 +267,31 @@ model.reset_adapter()

_Papers:_
- [Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning](https://arxiv.org/pdf/2205.05638.pdf) (Liu et al., 2022)

## Prompt Tuning
Prompt Tuning is an efficient fine-tuning technique proposed by Lester et al. (2021). Prompt tuning adds tunable tokens, called soft-prompts, that are prepended to the input text.
First, the input sequence ${x_1, x_2, \dots, x_n }$ gets embedded, resulting in the matrix $X_e \in \mathbb{R}^{n \times e}$ where $e$ is the dimension of
the embedding space. The soft-prompts with length $p$ are represented as $P_e \in \mathbb{R}^{p \times e}$.
$P_e$ and $X_e$ get concatenated, forming the input of the following encoder or decoder:

$$
\left[P_e; X_e\right] \in \mathbb{R}^{\left(p + n\right) \times e}
$$

The `PromptTuningConfig` has the properties:
- `prompt_length`: to set the soft-prompts length $p$
- `prompt_init`: to set the weight initialisation method, which is either "random_uniform" or "from_string" to initialize each prompt token with an embedding drawn from the model’s vocabulary.
- `prompt_init_text` as the text use for initialisation if `prompt_init="from_string"`
- `combine`: To define if the prefix should be added before the embedded input sequence or after the BOS token

To add Prompt Tuning to your model, you can use the predefined configs:
```python
from adapters import PromptTuningConfig

config = PromptTuningConfig(prompt_length=10)
model.add_adapter("dummy", config=config)
```

_Papers:_
- [The Power of Scale for Parameter-Efficient Prompt Tuning](https://aclanthology.org/2021.emnlp-main.243/) (Lester et al., 2021)

44 changes: 22 additions & 22 deletions docs/model_overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,28 @@ The table below further shows which model architectures support which adaptation
E.g., for BERT, this means adapters provides a ``BertAdapterModel`` class, but you can also use ``BertModel``, ``BertForSequenceClassification`` etc. together with adapters.
```

| Model | (Bottleneck)<br> Adapters | Prefix<br> Tuning | LoRA | Compacter | Adapter<br> Fusion | Invertible<br> Adapters | Parallel<br> block |
| --------------------------------------- | -| - | - | - | - | - | - |
| [ALBERT](classes/models/albert.html) ||||||||
| [BART](classes/models/bart.html) ||||||||
| [BEIT](classes/models/beit.html) |||||| | |
| [BERT-Generation](classes/models/bert-generation.html) ||||||||
| [BERT](classes/models/bert.html) ||||||||
| [CLIP](classes/models/clip.html) ||||||| |
| [DeBERTa](classes/models/deberta.html) ||||||||
| [DeBERTa-v2](classes/models/debertaV2.html) ||||||||
| [DistilBERT](classes/models/distilbert.html) ||||||||
| [Electra](classes/models/electra.html) ||||||||
| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | |
| [GPT-2](classes/models/gpt2.html) ||||||||
| [GPT-J](classes/models/gptj.html) ||||||||
| [Llama](classes/models/llama.html) ||||||||
| [MBart](classes/models/mbart.html) ||||||||
| [RoBERTa](classes/models/roberta.html) ||||||||
| [T5](classes/models/t5.html) ||||||||
| [ViT](classes/models/vit.html) ||||||||
| [XLM-RoBERTa](classes/models/xlmroberta.html) ||||||||
| [X-MOD](classes/models/xmod.html) ||||||||
| Model | (Bottleneck)<br> Adapters | Prefix<br> Tuning | LoRA | Compacter | Adapter<br> Fusion | Invertible<br> Adapters | Parallel<br> block | Prompt<br> Tuning |
| --------------------------------------- | -| - | - | - | - | - | - |- |
| [ALBERT](classes/models/albert.html) |||||||||
| [BART](classes/models/bart.html) |||||||| |
| [BEIT](classes/models/beit.html) |||||| | ||
| [BERT-Generation](classes/models/bert-generation.html) |||||||||
| [BERT](classes/models/bert.html) |||||||||
| [CLIP](classes/models/clip.html) ||||||| | |
| [DeBERTa](classes/models/deberta.html) |||||||||
| [DeBERTa-v2](classes/models/debertaV2.html) |||||||||
| [DistilBERT](classes/models/distilbert.html) |||||||||
| [Electra](classes/models/electra.html) |||||||||
| [Encoder Decoder](classes/models/encoderdecoder.html) | (*) | (*) | (*) | (*) | (*) | (*) | | |
| [GPT-2](classes/models/gpt2.html) |||||||| |
| [GPT-J](classes/models/gptj.html) |||||||| |
| [Llama](classes/models/llama.html) |||||||| |
| [MBart](classes/models/mbart.html) |||||||| |
| [RoBERTa](classes/models/roberta.html) |||||||||
| [T5](classes/models/t5.html) |||||||| |
| [ViT](classes/models/vit.html) |||||||||
| [XLM-RoBERTa](classes/models/xlmroberta.html) |||||||||
| [X-MOD](classes/models/xmod.html) |||||||||

(*) If the used encoder and decoder model class are supported.

Expand Down
2 changes: 2 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"ModelAdaptersConfig",
"ParBnConfig",
"PrefixTuningConfig",
"PromptTuningConfig",
"SeqBnConfig",
"SeqBnInvConfig",
"StaticAdapterFusionConfig",
Expand Down Expand Up @@ -161,6 +162,7 @@
ModelAdaptersConfig,
ParBnConfig,
PrefixTuningConfig,
PromptTuningConfig,
SeqBnConfig,
SeqBnInvConfig,
StaticAdapterFusionConfig,
Expand Down
30 changes: 30 additions & 0 deletions src/adapters/configuration/adapter_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def _get_config_class(config_dict):
cls_new = LoRAConfig
elif architecture == "union":
cls_new = ConfigUnion
elif architecture == "prompt_tuning":
cls_new = PromptTuningConfig
else:
cls_new = BnConfig

Expand Down Expand Up @@ -395,6 +397,33 @@ class PrefixTuningConfig(AdapterConfig):
shared_gating: bool = True


@dataclass(eq=False)
class PromptTuningConfig(AdapterConfig):
"""
The Prompt Tuning architecture proposed by Lester et al. (2021). See https://arxiv.org/pdf/2104.08691.pdf
Args:
prompt_length (int): The number of tokens in the prompt.
Defaults to 10.
prompt_init (str): The initialization method for the prompt. Can be either "random_uniform" or "from_string".
Defaults to "random_uniform".
prompt_init_text (str): The text to use for prompt initialization if prompt_init="from_string".
random_uniform_scale (float): The scale of the random uniform initialization if prompt_init="random_uniform".
Defaults to 0.5 as in the paper.
combine (str):
The method used to combine the prompt with the input. Can be either "prefix" or "prefix_after_bos".
Defaults to "prefix".
"""

architecture: str = "prompt_tuning"

prompt_length: int = 10
prompt_init: str = "random_uniform"
prompt_init_text: Optional[str] = None
random_uniform_scale = 0.5
combine: str = "prefix"


@dataclass(eq=False)
class LoRAConfig(AdapterConfig):
"""
Expand Down Expand Up @@ -612,6 +641,7 @@ def __init__(
"compacter": CompacterConfig(),
"prefix_tuning": PrefixTuningConfig(),
"prefix_tuning_flat": PrefixTuningConfig(flat=True),
"prompt_tuning": PromptTuningConfig(),
"lora": LoRAConfig(),
"ia3": IA3Config(),
"mam": MAMConfig(),
Expand Down
19 changes: 17 additions & 2 deletions src/adapters/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,13 @@ class ForwardContext:
# thread-local storage that holds a stack of active contexts
storage = threading.local()

context_attributes = ["adapter_gating_scores", "adapter_fusion_attentions", "adapter_input_parallelized"]
context_attributes = [
"adapter_gating_scores",
"adapter_fusion_attentions",
"adapter_input_parallelized",
]
# Additional used attributes not exposed to the user
# - prompt_tokens_length: length of the prompt tokens

def __init__(self, model, *args, **kwargs):
# If the model has a method ``forward_context()``, use it to create the context.
Expand All @@ -102,6 +108,8 @@ def wrap(cls, f):
def wrapper_func(self, *args, **kwargs):
if self.adapters_config is not None:
with cls(self, *args, **kwargs) as ctx:
# whether to output the context attributes
output_context = kwargs.pop("output_context", False)
kwargs = {
k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes
}
Expand All @@ -116,7 +124,14 @@ def wrapper_func(self, *args, **kwargs):
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results[attr] = dict(getattr(ctx, attr))
return results

if output_context:
context_dict = ctx.__dict__

if output_context:
return results, context_dict
else:
return results
else:
return f(self, *args, **kwargs)

Expand Down
28 changes: 27 additions & 1 deletion src/adapters/heads/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,19 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal
labels = kwargs.pop("labels", None)
if labels is not None:
loss_fct = CrossEntropyLoss()
# adjust labels for prompt tuning
if kwargs.get("prompt_tokens_length", 0) > 0:
prompt_length = kwargs.get("prompt_tokens_length")
prompt_labels = torch.full(
(labels.shape[0], prompt_length), loss_fct.ignore_index, dtype=torch.long, device=labels.device
)
labels = torch.cat((prompt_labels, labels), dim=-1)
if attention_mask is not None:
attention_mask = torch.cat(
(torch.ones_like(prompt_labels, dtype=torch.long, device=labels.device), attention_mask),
dim=-1,
)

# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
Expand Down Expand Up @@ -763,7 +776,14 @@ def _get_used_heads(self, head_name: str = None):
return head_modules

def forward_head(
self, all_outputs, head_name=None, cls_output=None, attention_mask=None, return_dict=False, **kwargs
self,
all_outputs,
head_name=None,
cls_output=None,
attention_mask=None,
return_dict=False,
context=None,
**kwargs
):
"""
The forward pass through a prediction head configuration. There are three ways to specify the used prediction
Expand Down Expand Up @@ -811,6 +831,12 @@ def _get_head_input(outputs, cls_out, batch):
if inv_adapter:
kwargs["invertible_adapter"] = inv_adapter

# Set prompt tokens length
if context is not None:
prompt_tokens_length = context.get("prompt_tokens_length", None)
if prompt_tokens_length is not None:
kwargs["prompt_tokens_length"] = prompt_tokens_length

if isinstance(self.active_head, BatchSplit):
if sum(self.active_head.batch_sizes) != all_outputs[0].size()[0]:
raise ValueError(
Expand Down
10 changes: 10 additions & 0 deletions src/adapters/heads/language_modeling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import torch
import torch.nn as nn

from transformers.modeling_outputs import CausalLMOutput, CausalLMOutputWithPast, MaskedLMOutput, Seq2SeqLMOutput
Expand Down Expand Up @@ -118,6 +119,15 @@ def forward(self, outputs, cls_output=None, attention_mask=None, return_dict=Fal
labels = labels[..., 1:].contiguous()
else:
logits_for_loss = lm_logits

# adjust labels for prompt tuning
if kwargs.get("prompt_tokens_length", 0) > 0:
prompt_length = kwargs.get("prompt_tokens_length")
prompt_labels = torch.full(
(labels.shape[0], prompt_length), loss_fct.ignore_index, dtype=torch.long, device=labels.device
)
labels = torch.cat((prompt_labels, labels), dim=-1)

loss = loss_fct(logits_for_loss.view(-1, self.config["vocab_size"]), labels.view(-1))

if return_dict:
Expand Down
1 change: 1 addition & 0 deletions src/adapters/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@ def filter_func(self, adapter_name):
or ".prefix_tunings.{}.".format(adapter_name) in x
or ".prefix_gates.{}.".format(adapter_name) in x
or ".loras.{}.".format(adapter_name) in x
or ".prompt_tunings.{}.".format(adapter_name) in x
)

# This dict maps the original weight names to the currently used equivalents.
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/methods/adapter_layer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def compose_batch_split(self, adapter_setup: BatchSplit, state: NamedTuple, lvl:
# sequentially feed different parts of the blown-up batch into different adapters
children_states = []
for i, child in enumerate(adapter_setup):
# compute ids of sequences thet should be passed to the ith adapter
# compute ids of sequences that should be passed to the ith adapter
batch_idx = (
sum(adapter_setup.batch_sizes[:i]),
sum(adapter_setup.batch_sizes[: i + 1]),
Expand Down
Loading

0 comments on commit 84f4f04

Please sign in to comment.