diff --git a/src/adapters/composition.py b/src/adapters/composition.py index a44b9c5aa..6c17fb8eb 100644 --- a/src/adapters/composition.py +++ b/src/adapters/composition.py @@ -92,13 +92,17 @@ def __init__(self, *stack_layers: List[Union[AdapterCompositionBlock, str]]): class Fuse(AdapterCompositionBlock): - def __init__(self, *fuse_stacks: List[Union[AdapterCompositionBlock, str]]): + def __init__(self, *fuse_stacks: List[Union[AdapterCompositionBlock, str]], name: Optional[str] = None): super().__init__(*fuse_stacks) + self._name = name # TODO-V2 pull this up to all block classes? @property def name(self): - return ",".join([c if isinstance(c, str) else c.last() for c in self.children]) + if self._name: + return self._name + else: + return ",".join([c if isinstance(c, str) else c.last() for c in self.children]) class Split(AdapterCompositionBlock): diff --git a/src/adapters/configuration/model_adapters_config.py b/src/adapters/configuration/model_adapters_config.py index 3ae7dcf56..f742028b6 100644 --- a/src/adapters/configuration/model_adapters_config.py +++ b/src/adapters/configuration/model_adapters_config.py @@ -1,7 +1,7 @@ import copy import logging from collections.abc import Collection, Mapping -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union from .. import __version__ from ..composition import AdapterCompositionBlock @@ -27,6 +27,7 @@ def __init__(self, **kwargs): self.fusions: Mapping[str, str] = kwargs.pop("fusions", {}) self.fusion_config_map = kwargs.pop("fusion_config_map", {}) + self.fusion_name_map = kwargs.pop("fusion_name_map", {}) # TODO-V2 Save this with config? self.active_setup: Optional[AdapterCompositionBlock] = None @@ -131,7 +132,7 @@ def add(self, adapter_name: str, config: Optional[Union[str, dict]] = None): self.adapters[adapter_name] = config_name logger.info(f"Adding adapter '{adapter_name}'.") - def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]: + def get_fusion(self, fusion_name: Union[str, List[str]]) -> Tuple[Optional[dict], Optional[list]]: """ Gets the config dictionary for a given AdapterFusion. @@ -140,6 +141,7 @@ def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]: Returns: Optional[dict]: The AdapterFusion configuration. + Optional[list]: The names of the adapters to fuse. """ if isinstance(fusion_name, list): fusion_name = ",".join(fusion_name) @@ -149,20 +151,31 @@ def get_fusion(self, fusion_name: Union[str, List[str]]) -> Optional[dict]: config = self.fusion_config_map.get(config_name, None) else: config = ADAPTERFUSION_CONFIG_MAP.get(config_name, None) + + if fusion_name in self.fusion_name_map: + adapter_names = self.fusion_name_map[fusion_name] + else: + adapter_names = fusion_name.split(",") + + return config, adapter_names else: - config = None - return config + return None, None - def add_fusion(self, fusion_name: Union[str, List[str]], config: Optional[Union[str, dict]] = None): + def add_fusion( + self, adapter_names: List[str], config: Optional[Union[str, dict]] = None, fusion_name: Optional[str] = None + ): """ Adds a new AdapterFusion. Args: - fusion_name (Union[str, List[str]]): The name of the AdapterFusion or the adapters to fuse. + adapter_names (List[str]): The names of the adapters to fuse. config (Optional[Union[str, dict]], optional): AdapterFusion config. Defaults to None. + fusion_name (Optional[str], optional): The name of the AdapterFusion. If not specified, will default to comma-separated adapter names. """ - if isinstance(fusion_name, list): - fusion_name = ",".join(fusion_name) + if fusion_name is None: + fusion_name = ",".join(adapter_names) + else: + self.fusion_name_map[fusion_name] = adapter_names if fusion_name in self.fusions: raise ValueError(f"An AdapterFusion with the name '{fusion_name}' has already been added.") if config is None: @@ -218,6 +231,7 @@ def to_dict(self): output_dict["fusion_config_map"][k] = v.to_dict() else: output_dict["fusion_config_map"][k] = copy.deepcopy(v) + output_dict["fusion_name_map"] = copy.deepcopy(self.fusion_name_map) return output_dict def __eq__(self, other): diff --git a/src/adapters/loading.py b/src/adapters/loading.py index 69747e04c..55ba1db45 100644 --- a/src/adapters/loading.py +++ b/src/adapters/loading.py @@ -639,7 +639,7 @@ def save_to_state_dict(self, name: str): if name not in self.model.adapters_config.fusions: raise ValueError(f"No AdapterFusion with name '{name}' available.") - adapter_fusion_config = self.model.adapters_config.get_fusion(name) + adapter_fusion_config, _ = self.model.adapters_config.get_fusion(name) config_dict = build_full_config( adapter_fusion_config, @@ -676,13 +676,14 @@ def save(self, save_directory: str, name: str, meta_dict=None): else: assert isdir(save_directory), "Saving path should be a directory where the head can be saved." - adapter_fusion_config = self.model.adapters_config.get_fusion(name) + adapter_fusion_config, adapter_names = self.model.adapters_config.get_fusion(name) # Save the adapter fusion configuration config_dict = build_full_config( adapter_fusion_config, self.model.config, name=name, + adapter_names=adapter_names, model_name=self.model.model_name, model_class=self.model.__class__.__name__, ) @@ -746,9 +747,14 @@ def load(self, save_directory, load_as=None, loading_info=None, **kwargs): config = self.weights_helper.load_weights_config(save_directory) adapter_fusion_name = load_as or config["name"] + adapter_names = config.get("adapter_names", adapter_fusion_name) if adapter_fusion_name not in self.model.adapters_config.fusions: self.model.add_adapter_fusion( - adapter_fusion_name, config["config"], overwrite_ok=True, set_active=kwargs.pop("set_active", True) + adapter_names, + config["config"], + name=adapter_fusion_name, + overwrite_ok=True, + set_active=kwargs.pop("set_active", True), ) else: logger.warning("Overwriting existing adapter fusion module '{}'".format(adapter_fusion_name)) diff --git a/src/adapters/methods/bottleneck.py b/src/adapters/methods/bottleneck.py index ff12a91cd..889941d2b 100644 --- a/src/adapters/methods/bottleneck.py +++ b/src/adapters/methods/bottleneck.py @@ -96,9 +96,9 @@ def add_adapter(self, adapter_name: str, layer_idx: int) -> bool: def add_fusion_layer(self, adapter_names: Union[List, str]): """See BertModel.add_fusion_layer""" - adapter_names = adapter_names if isinstance(adapter_names, list) else adapter_names.split(",") + fusion_name = ",".join(adapter_names) if isinstance(adapter_names, list) else adapter_names + fusion_config, adapter_names = self.adapters_config.get_fusion(fusion_name) if self.adapters_config.common_config_value(adapter_names, self.location_key): - fusion_config = self.adapters_config.get_fusion(adapter_names) dropout_prob = fusion_config.dropout_prob or getattr(self.model_config, "attention_probs_dropout_prob", 0) fusion = BertFusion( fusion_config, @@ -106,7 +106,7 @@ def add_fusion_layer(self, adapter_names: Union[List, str]): dropout_prob, ) fusion.train(self.training) # make sure training mode is consistent - self.adapter_fusion_layer[",".join(adapter_names)] = fusion + self.adapter_fusion_layer[fusion_name] = fusion def delete_fusion_layer(self, adapter_names: Union[List, str]): adapter_names = adapter_names if isinstance(adapter_names, str) else ",".join(adapter_names) @@ -223,7 +223,7 @@ def compose_fuse(self, adapter_setup: Fuse, state: BottleneckState, lvl: int = 0 context = ForwardContext.get_context() # config of _last_ fused adapter is significant - fusion_config = self.adapters_config.get_fusion(adapter_setup.name) + fusion_config, _ = self.adapters_config.get_fusion(adapter_setup.name) last = adapter_setup.last() last_adapter = self.adapters[last] hidden_states, query, residual = last_adapter.pre_forward( diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 3154af5ac..62de6178a 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -638,6 +638,7 @@ def add_adapter_fusion( self, adapter_names: Union[Fuse, list, str], config=None, + name: str = None, overwrite_ok: bool = False, set_active: bool = False, ): @@ -655,6 +656,8 @@ def add_adapter_fusion( - a string identifying a pre-defined adapter fusion configuration - a dictionary representing the adapter fusion configuration - the path to a file containing the adapter fusion configuration + name (str, optional): + Name of the AdapterFusion layer. If not specified, the name is generated automatically from the fused adapter names. overwrite_ok (bool, optional): Overwrite an AdapterFusion layer with the same name if it exists. By default (False), an exception is thrown. @@ -662,22 +665,24 @@ def add_adapter_fusion( Activate the added AdapterFusion. By default (False), the AdapterFusion is added but not activated. """ if isinstance(adapter_names, Fuse): + if name is None: + name = adapter_names.name adapter_names = adapter_names.children elif isinstance(adapter_names, str): adapter_names = adapter_names.split(",") + if name is None: + name = ",".join(adapter_names) if isinstance(config, dict): config = AdapterFusionConfig.from_dict(config) # ensure config is ok and up-to-date # In case adapter already exists and we allow overwriting, explicitly delete the existing one first - if overwrite_ok and self.adapters_config.get_fusion(adapter_names) is not None: - self.delete_adapter_fusion(adapter_names) - self.adapters_config.add_fusion(adapter_names, config=config) - self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(adapter_names)) - self.apply_to_basemodel_childs(lambda i, child: child.add_fusion_layer(adapter_names)) + if overwrite_ok and self.adapters_config.get_fusion(name)[0] is not None: + self.delete_adapter_fusion(name) + self.adapters_config.add_fusion(adapter_names, config=config, fusion_name=name) + self.apply_to_adapter_layers(lambda i, layer: layer.add_fusion_layer(name)) + self.apply_to_basemodel_childs(lambda i, child: child.add_fusion_layer(name)) if set_active: - if not isinstance(adapter_names, list): - adapter_names = adapter_names.split(",") - self.set_active_adapters(Fuse(*adapter_names)) + self.set_active_adapters(Fuse(*adapter_names, name=name)) def delete_adapter(self, adapter_name: str): """ @@ -710,7 +715,7 @@ def delete_adapter_fusion(self, adapter_names: Union[Fuse, list, str]): adapter_names (Union[Fuse, list, str]): AdapterFusion layer to delete. """ if isinstance(adapter_names, Fuse): - adapter_fusion_name = ",".join(adapter_names.children) + adapter_fusion_name = adapter_names.name elif isinstance(adapter_names, list): adapter_fusion_name = ",".join(adapter_names) elif isinstance(adapter_names, str): @@ -776,7 +781,7 @@ def save_adapter_fusion( ValueError: If the given AdapterFusion name is invalid. """ if isinstance(adapter_names, Fuse): - adapter_fusion_name = ",".join(adapter_names.children) + adapter_fusion_name = adapter_names.name elif isinstance(adapter_names, list): adapter_fusion_name = ",".join(adapter_names) elif isinstance(adapter_names, str): @@ -1094,7 +1099,7 @@ def save_all_adapter_fusions( """ os.makedirs(save_directory, exist_ok=True) for name in self.adapters_config.fusions: - adapter_fusion_config = self.adapters_config.get_fusion(name) + adapter_fusion_config, _ = self.adapters_config.get_fusion(name) h = get_adapter_config_hash(adapter_fusion_config) save_path = join(save_directory, name) if meta_dict: diff --git a/tests/test_adapter_fusion_common.py b/tests/test_adapter_fusion_common.py index ccc860f66..695808eb2 100644 --- a/tests/test_adapter_fusion_common.py +++ b/tests/test_adapter_fusion_common.py @@ -214,3 +214,86 @@ def test_output_adapter_fusion_attentions(self): self.assertEqual(len(per_layer_scores), 1) for k, v in per_layer_scores.items(): self.assertEqual(self.default_input_samples_shape[0], v.shape[0], k) + + def test_add_adapter_fusion_custom_name(self): + config_name = "seq_bn" + model = self.get_model() + model.eval() + + name1 = f"{config_name}-1" + name2 = f"{config_name}-2" + model.add_adapter(name1, config=config_name) + model.add_adapter(name2, config=config_name) + + # adapter is correctly added to config + self.assertTrue(name1 in model.adapters_config) + self.assertTrue(name2 in model.adapters_config) + + # add fusion with default name + model.add_adapter_fusion([name1, name2]) + model.to(torch_device) + + # check forward pass + input_data = self.get_input_samples(config=model.config) + model.set_active_adapters(Fuse(name1, name2)) + fusion_default_ref_output = model(**input_data) + + # add fusion with custom name + model.add_adapter_fusion([name1, name2], name="custom_name_fusion") + model.to(torch_device) + + self.assertIn(f"{name1},{name2}", model.adapters_config.fusions) + self.assertIn("custom_name_fusion", model.adapters_config.fusions) + self.assertIn("custom_name_fusion", model.adapters_config.fusion_name_map) + + # check forward pass + model.set_active_adapters(Fuse(name1, name2, name="custom_name_fusion")) + fusion_custom_output = model(**input_data) + model.set_active_adapters(Fuse(name1, name2)) + fusion_default_output = model(**input_data) + model.set_active_adapters(None) + base_output = model(**input_data) + + self.assertFalse(torch.equal(fusion_default_ref_output[0], base_output[0])) + self.assertTrue(torch.equal(fusion_default_ref_output[0], fusion_default_output[0])) + self.assertFalse(torch.equal(fusion_custom_output[0], fusion_default_output[0])) + self.assertFalse(torch.equal(fusion_custom_output[0], base_output[0])) + + # delete only the custom fusion + model.delete_adapter_fusion(Fuse(name1, name2, name="custom_name_fusion")) + # model.delete_adapter_fusion("custom_name_fusion") + + self.assertIn(f"{name1},{name2}", model.adapters_config.fusions) + self.assertNotIn("custom_name_fusion", model.adapters_config.fusions) + + def test_load_adapter_fusion_custom_name(self): + model1 = self.get_model() + model1.eval() + + name1 = "name1" + name2 = "name2" + model1.add_adapter(name1) + model1.add_adapter(name2) + + model2 = copy.deepcopy(model1) + model2.eval() + + model1.add_adapter_fusion([name1, name2], name="custom_name_fusion") + model1.set_active_adapters(Fuse(name1, name2, name="custom_name_fusion")) + + with tempfile.TemporaryDirectory() as temp_dir: + model1.save_adapter_fusion(temp_dir, "custom_name_fusion") + # also tests that set_active works + model2.load_adapter_fusion(temp_dir, set_active=True) + + # check if adapter was correctly loaded + self.assertEqual(model1.adapters_config.fusions.keys(), model2.adapters_config.fusions.keys()) + + # check equal output + in_data = self.get_input_samples(config=model1.config) + model1.to(torch_device) + model2.to(torch_device) + output1 = model1(**in_data) + output2 = model2(**in_data) + self.assertEqual(len(output1), len(output2)) + self.assertTrue(torch.equal(output1[0], output2[0]))