Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pluggable Model Integration Interface #738

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
"Seq2SeqLMHead",
"TaggingHead",
],
"interface": ["AdapterMethod", "AdapterModelInterface"],
"methods.adapter_layer_base": ["AdapterLayerBase", "ComposableAdapterLayerBase"],
"model_mixin": [
"EmbeddingAdaptersMixin",
Expand Down Expand Up @@ -198,6 +199,7 @@
Seq2SeqLMHead,
TaggingHead,
)
from .interface import AdapterMethod, AdapterModelInterface
from .methods.adapter_layer_base import AdapterLayerBase, ComposableAdapterLayerBase
from .model_mixin import (
EmbeddingAdaptersMixin,
Expand Down
85 changes: 85 additions & 0 deletions src/adapters/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from dataclasses import dataclass
from typing import List, Optional


class AdapterMethod:
"""
Enum of all supported adapter method types.
"""

bottleneck = "bottleneck"
prefix_tuning = "prefix_tuning"
lora = "lora"
prompt_tuning = "prompt_tuning"
reft = "reft"
invertible = "invertible"

@staticmethod
def get_from_config(config) -> List[str]:
"""
Get the adapter type from a given adapter config.

Args:
config: The adapter config.

Returns:
List[str]: The adapter type.
"""
methods = []
if getattr(config, "inv_adapter", False):
methods.append(AdapterMethod.invertible)
if config.architecture is None:
methods.append(AdapterMethod.bottleneck)
elif config.architecture == "union":
methods.extend([AdapterMethod.get_from_config(sub_config) for sub_config in config.configs])
else:
methods.append(config.architecture)
return methods


@dataclass
class AdapterModelInterface:
"""
Defines the main interface for integrating adapter methods into a model class.
This interface translates generic accessor names to model-specific attribute names.

Args:
adapter_types (List[str]): List of adapter types that are supported by the model.
model_embeddings (str): Name of the model's embedding layer.
model_layers (str): Name of the model's layer list.
layer_self_attn (str): Name of the self-attention layer in a transformer layer.
layer_cross_attn (str): Name of the cross-attention layer in a transformer layer.
attn_k_proj (str): Name of the key projection layer in an attention layer.
attn_q_proj (str): Name of the query projection layer in an attention layer.
attn_v_proj (str): Name of the value projection layer in an attention layer.
attn_o_proj (str): Name of the output projection layer in an attention layer.
layer_intermediate_proj (str): Name of the intermediate projection layer in a transformer layer.
layer_output_proj (str): Name of the output projection layer in a transformer layer.
layer_pre_self_attn (Optional[str]): Hook point directly before the self attention layer. Used for extended bottleneck adapter support.
layer_pre_cross_attn (Optional[str]): Hook point directly before the cross attention layer. Used for extended bottleneck adapter support.
layer_pre_ffn (Optional[str]): Hook point directly before the feed forward layer. Used for extended bottleneck adapter support.
layer_ln_1 (Optional[str]): Layer norm *after* the self-attention layer. Used for extended bottleneck adapter support.
layer_ln_2 (Optional[str]): Layer norm *after* the feed forward layer. Used for extended bottleneck adapter support.
"""

adapter_types: List[str]

model_embeddings: str
model_layers: str

layer_self_attn: str
layer_cross_attn: str
attn_k_proj: str
attn_q_proj: str
attn_v_proj: str
attn_o_proj: str

layer_intermediate_proj: str
layer_output_proj: str

# Optional attributes for extended bottleneck adapter support
layer_pre_self_attn: Optional[str] = None
layer_pre_cross_attn: Optional[str] = None
layer_pre_ffn: Optional[str] = None
layer_ln_1: Optional[str] = None
layer_ln_2: Optional[str] = None
14 changes: 14 additions & 0 deletions src/adapters/methods/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
from .bottleneck import init_bottleneck
from .invertible import init_invertible_adapters
from .lora import init_lora
from .prompt_tuning import init_prompt_tuning
from .reft import init_reft


METHOD_INIT_MAPPING = {
"bottleneck": init_bottleneck,
"lora": init_lora,
"prompt_tuning": init_prompt_tuning,
"reft": init_reft,
"invertible": init_invertible_adapters,
}
88 changes: 82 additions & 6 deletions src/adapters/methods/bottleneck.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
from typing import List, Mapping, NamedTuple, Optional, Union

import torch
Expand All @@ -15,10 +16,16 @@
)
from ..configuration import BnConfig
from ..context import ForwardContext
from ..utils import multigetattr
from .adapter_layer_base import ComposableAdapterLayerBase
from .modeling import Adapter, BertFusion, ParallelAdapter


LAYER_HOOK_UNSUPPORTED = [
("original_ln_after", False),
]


class BottleneckState(NamedTuple):
"""
Models the input and output states of a bottleneck adapter layer.
Expand All @@ -45,16 +52,19 @@ class BottleneckLayer(ComposableAdapterLayerBase, nn.Module):
adapter_modules_name = "adapters"
supported_compositions = [Stack, Fuse, Split, Parallel, BatchSplit, Average]

def __init__(self, location_key: str):
def __init__(self, location_key: str, is_layer_hooked: bool = False):
super().__init__()
self.location_key = location_key
self.is_layer_hooked = is_layer_hooked

def init_adapters(self, model_config, adapters_config):
self._init_mapping()
self.model_config = model_config
self.adapters_config = adapters_config
self.adapters = nn.ModuleDict(dict())
self.adapter_fusion_layer = nn.ModuleDict(dict())
if not hasattr(self, "is_layer_hooked"):
self.is_layer_hooked = False

def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
self.layer_idx = layer_idx
Expand All @@ -78,6 +88,15 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
'{"1": 16, "default": 16}'
)

# check unsupported configurations for layer hooking mode
if self.is_layer_hooked:
for key, value in LAYER_HOOK_UNSUPPORTED:
if adapter_config.get(key, None) == value:
raise ValueError(
f"Unsupported configuration for bottleneck layer hooking mode: {key}={value}. "
"Please set this configuration to a supported value."
)

if adapter_config.is_parallel:
adapter_class = ParallelAdapter
else:
Expand All @@ -88,6 +107,10 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
down_sample=int(self.model_config.hidden_size // reduction_factor),
config=adapter_config,
)
# for adapters hooked via interface:
# residual & LN are applied by model, so don't apply in adapters
if self.is_layer_hooked:
adapter.original_ln_after = False
adapter.train(self.training) # make sure training mode is consistent
self.adapters[adapter_name] = adapter
return True
Expand Down Expand Up @@ -321,9 +344,10 @@ def bottleneck_layer_forward(self, hidden_states, residual_input, layer_norm):
torch.Tensor: Output hidden states of the adapter layer.
"""
# Batch sizes might be different due to prefix tuning w. Parallel block
(residual_input,) = adjust_tensors_for_parallel(hidden_states, residual_input)
# Replicate in both directions as residual might be larger (e.g. GPT-J)
(hidden_states,) = adjust_tensors_for_parallel(residual_input, hidden_states)
if residual_input is not None:
(residual_input,) = adjust_tensors_for_parallel(hidden_states, residual_input)
# Replicate in both directions as residual might be larger (e.g. GPT-J)
(hidden_states,) = adjust_tensors_for_parallel(residual_input, hidden_states)
adapter_setup = self.get_active_setup()
if adapter_setup is not None:
input_hidden_states = hidden_states
Expand All @@ -335,9 +359,9 @@ def bottleneck_layer_forward(self, hidden_states, residual_input, layer_norm):
last_adapter = self.adapters[last]
hidden_states = last_adapter.post_forward(hidden_states, input_hidden_states, residual_input, layer_norm)

elif layer_norm:
elif layer_norm is not None and not self.is_layer_hooked:
hidden_states = layer_norm(hidden_states + residual_input)
else:
elif residual_input is not None and not self.is_layer_hooked:
hidden_states = hidden_states + residual_input

return hidden_states
Expand All @@ -354,3 +378,55 @@ def forward(self, hidden_states, residual_input, layer_norm):
torch.Tensor: Output hidden states of the adapter layer.
"""
return self.bottleneck_layer_forward(hidden_states, residual_input, layer_norm)


def hook_fn(adapter_layer, ln_get_fn, module, args, output):
# Retrieve residual from previous hook, if existing
context = ForwardContext.get_context()
residual_input = getattr(context, f"{adapter_layer.location_key}_residual_input", None)
# Retrieve layer norm from getter fn
if ln_get_fn is not None:
layer_norm = ln_get_fn()
else:
layer_norm = None
# Call adapter layer
if isinstance(output, torch.Tensor):
return adapter_layer(output, residual_input, layer_norm)
else:
return (adapter_layer(output[0], residual_input, layer_norm),) + output[1:]


def _residual_hook_fn(location_key, module, args):
context = ForwardContext.get_context()
if context is not None:
setattr(context, f"{location_key}_residual_input", args[0])


def init_bottleneck(model):
model = model.base_model
for _, layer in model.iter_layers():
if self_attn := multigetattr(layer, model.adapter_interface.layer_self_attn, None):
if o_proj := multigetattr(self_attn, model.adapter_interface.attn_o_proj, None):
if not hasattr(layer, "attention_adapters"):
layer.attention_adapters = BottleneckLayer("mh_adapter", is_layer_hooked=True)
ln_1_get_fn = lambda: multigetattr(layer, model.adapter_interface.layer_ln_1, None)
o_proj.register_forward_hook(partial(hook_fn, layer.attention_adapters, ln_1_get_fn))
if layer_output_proj := multigetattr(layer, model.adapter_interface.layer_output_proj, None):
if not hasattr(layer, "output_adapters"):
layer.output_adapters = BottleneckLayer("output_adapter", is_layer_hooked=True)
ln_2_get_fn = lambda: multigetattr(layer, model.adapter_interface.layer_ln_2, None)
layer_output_proj.register_forward_hook(partial(hook_fn, layer.output_adapters, ln_2_get_fn))
if cross_attn := multigetattr(layer, model.adapter_interface.layer_cross_attn, None):
if not hasattr(cross_attn, "cross_attention_adapters"):
layer.attention_adapters = BottleneckLayer("cross_adapter", is_layer_hooked=True)
cross_attn.register_forward_hook(partial(hook_fn, layer.attention_adapters, None))

if model.adapter_interface.layer_pre_self_attn is not None:
if pre_self_attn := multigetattr(layer, model.adapter_interface.layer_pre_self_attn, None):
pre_self_attn.register_forward_pre_hook(partial(_residual_hook_fn, "mh_adapter"))
if model.adapter_interface.layer_pre_cross_attn is not None:
if pre_cross_attn := multigetattr(layer, model.adapter_interface.layer_pre_cross_attn, None):
pre_cross_attn.register_forward_pre_hook(partial(_residual_hook_fn, "cross_adapter"))
if model.adapter_interface.layer_pre_ffn is not None:
if pre_ffn := multigetattr(layer, model.adapter_interface.layer_pre_ffn, None):
pre_ffn.register_forward_pre_hook(partial(_residual_hook_fn, "output_adapter"))
104 changes: 104 additions & 0 deletions src/adapters/methods/invertible.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import types
from functools import partial

import torch
import torch.nn as nn

from ..configuration.adapter_config import BnConfig
from ..utils import multigetattr
from .adapter_layer_base import AdapterLayerBase
from .modeling import Adapter, GLOWCouplingBlock, NICECouplingBlock


class InvertibleAdapterLayer(AdapterLayerBase, nn.ModuleDict):
adapter_modules_name = "_modules"

def __init__(self, model_config, adapters_config):
super().__init__()
self.location_key = "inv_adapter"
self.model_config = model_config
self.adapters_config = adapters_config

def add_adapter(self, adapter_name: str, layer_idx: int) -> bool:
self.layer_idx = layer_idx
embedding_size = getattr(self.model_config, "embedding_size", self.model_config.hidden_size)
adapter_config = self.adapters_config.match(
adapter_name,
config_type=BnConfig,
location_key="inv_adapter",
)
if adapter_config is not None and adapter_config["inv_adapter"]:
if adapter_config["inv_adapter"] == "nice":
inv_adap = NICECouplingBlock(
[[embedding_size]],
non_linearity=adapter_config["non_linearity"],
reduction_factor=adapter_config["inv_adapter_reduction_factor"],
)
elif adapter_config["inv_adapter"] == "glow":
inv_adap = GLOWCouplingBlock(
[[embedding_size]],
non_linearity=adapter_config["non_linearity"],
reduction_factor=adapter_config["inv_adapter_reduction_factor"],
)
else:
raise ValueError(f"Invalid invertible adapter type '{adapter_config['inv_adapter']}'.")
self[adapter_name] = inv_adap
self[adapter_name].apply(Adapter.init_bert_weights)
return True

return False

def get_invertible_adapter(self):
# HACK: returns the first adapter of the currently active setup. for backwards compatibility
adapter_setup = self.get_active_setup()
if adapter_setup is not None and len(adapter_setup) > 0:
first_adapter = adapter_setup.first()
if first_adapter in self:
return self[first_adapter]
return None

def forward(self, hidden_states: torch.Tensor, rev=False):
adapter_setup = self.get_active_setup()
if adapter_setup is not None and len(adapter_setup) > 0:
first_adapter = adapter_setup.first()
if first_adapter in self:
hidden_states = self[first_adapter](hidden_states, rev=rev)
return hidden_states


def hook_fn(model, module, args, embedding_output):
embedding_output = model.invertible_adapters(embedding_output)
return embedding_output


def inv_hook_fn(model, module, args):
inv_output = model.invertible_adapters(args[0], rev=True)
return (inv_output,) + args[1:]


def init_invertible_adapters(model):
base_model = model.base_model
if not hasattr(base_model, "invertible_adapters"):
base_model.invertible_adapters = InvertibleAdapterLayer(base_model.config, base_model.adapters_config)

embed_layer = multigetattr(base_model, base_model.adapter_interface.model_embeddings)
embed_layer.register_forward_hook(partial(hook_fn, base_model))

# Add methods from original invertible adapter mixin.
# This is primarily for backwards compatibility and internal use.
base_model.add_invertible_adapter = types.MethodType(
lambda self, *args, **kwargs: self.invertible_adapters.add_adapter(*args, **kwargs), base_model
)
base_model.delete_invertible_adapter = types.MethodType(
lambda self, *args, **kwargs: self.invertible_adapters.delete_adapter(*args, **kwargs), base_model
)
base_model.get_invertible_adapter = types.MethodType(
lambda self: self.invertible_adapters.get_invertible_adapter(), base_model
)
base_model.invertible_adapters_forward = types.MethodType(
lambda self, *args, **kwargs: self.invertible_adapters(*args, **kwargs), base_model
)

# Register reverse forward pass
if output_embedding := model.get_output_embeddings():
output_embedding.register_forward_pre_hook(partial(inv_hook_fn, base_model))
Loading
Loading