Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix default Lora/ (IA)^3 scaling in forward #770

Merged
merged 9 commits into from
Dec 29, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"isort>=5.5.4",
"Jinja2==2.11.3",
"nltk",
"packaging",
"parameterized",
"pillow",
"protobuf",
Expand Down Expand Up @@ -136,11 +137,12 @@ def deps_list(*pkgs):
# when modifying the following list, make sure to update src/transformers/dependency_versions_check.py
install_requires = [
deps["transformers"],
deps["packaging"],
]

setup(
name="adapters",
version="1.0.1",
version="1.1.0.dev0",
author="The AdapterHub team and community contributors",
author_email="[email protected]",
description="A Unified Library for Parameter-Efficient and Modular Transfer Learning",
Expand Down
2 changes: 1 addition & 1 deletion src/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "1.0.1"
__version__ = "1.1.0.dev0"

from typing import TYPE_CHECKING

Expand Down
16 changes: 16 additions & 0 deletions src/adapters/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Callable, Mapping, Optional, Sequence, Tuple

import torch
from packaging.version import Version


try:
Expand Down Expand Up @@ -368,6 +369,19 @@ def _rename_legacy_weights(self, k):
k = k.replace(old, new)
return k

def _fix_backward_compat(self, config):
# Fix error in previous versions for LoRA/ (IA)^3
if config.get("version", "").startswith("adapters.") and Version(config["version"][9:]) < Version("1.1.0"):
calpt marked this conversation as resolved.
Show resolved Hide resolved
if (
config["config"].get("architecture", None) == "lora"
and config["config"]["r"] != config["config"]["alpha"]
):
logger.warning(
"Loading a LoRA trained using a faulty library version. Editing the configuration to make sure the adapter works as trained."
"See https://github.com/adapter-hub/adapters/pull/770 for more."
calpt marked this conversation as resolved.
Show resolved Hide resolved
)
config["config"]["alpha"] = config["config"]["r"]

# This method is used to remove unnecessary invertible adapters from task adapters using the old format.
# In the old format, task adapters e.g. using seq_bn config specify inv. adapters but don't use them.
# As inv. adapters would be incorrectly used in the new implementation,
Expand Down Expand Up @@ -560,6 +574,8 @@ def load(
# The conversion to a set and then back to a list removes all duplicates
leave_out = list(set(leave_out + config["config"]["leave_out"]))
config["config"]["leave_out"] = leave_out
# Fix issues
self._fix_backward_compat(config)

adapter_name = load_as or config["name"]
# If the adapter is not part of the model, add it
Expand Down
2 changes: 2 additions & 0 deletions src/adapters/methods/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens
hidden_states = hidden_states * gate
else:
gate = None
hidden_states = hidden_states * self.scaling

return hidden_states, gate

Expand Down Expand Up @@ -170,6 +171,7 @@ def forward(self, hidden_states: Optional[torch.Tensor], layer_input: torch.Tens
hidden_states = hidden_states * gate
else:
gate = None
hidden_states = hidden_states * self.scaling

return hidden_states, gate

Expand Down
Loading