Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for LoRA with GQAQKVColumnParallelLinear #690

Draft
wants to merge 14 commits into
base: main
Choose a base branch
from
6 changes: 4 additions & 2 deletions docs/source/training_tutorials/sft_lora_finetune_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,12 @@ def training_function(script_args, training_args):
model = AutoModelForCausalLM.from_pretrained(script_args.model_id)

config = LoraConfig(
r=16,
r=64,
lora_alpha=16,
lora_dropout=0.05,
target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"],
# target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"],
target_modules=["q_proj", "k_proj", "v_proj"],
# target_modules=["o_proj"],
bias="none",
task_type="CAUSAL_LM",
)
Expand Down
18 changes: 18 additions & 0 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import torch
from transformers import PreTrainedModel
from transformers.utils import is_peft_available

from ...utils import logging
from ..utils import is_neuronx_distributed_available, is_torch_xla_available
Expand Down Expand Up @@ -607,6 +608,21 @@ def parallelize(
skip_linear_weight_load = hasattr(model, "_weight_map")

requires_grad_information = {n: p.requires_grad for n, p in model.named_parameters()}
if is_peft_available():
from peft.tuners.tuners_utils import BaseTunerLayer

peft_parameters = set()
for mod in model.modules():
if isinstance(mod, BaseTunerLayer):
base_layer = mod.get_base_layer()
for m in mod.modules():
if m is base_layer:
continue
for p in m.parameters():
peft_parameters.add(p)
peft_parameter_names = {n for n, p in model.named_parameters() if p in peft_parameters}
else:
peft_parameter_names = set()

def should_parallelize_layer_predicate_func(layer):
if pp_size == 1:
Expand Down Expand Up @@ -757,6 +773,8 @@ def should_parallelize_layer_predicate_func(layer):
elif gqa_qkv_names_to_original_names.get(name, None) in requires_grad_information:
gqa_qkv_name = gqa_qkv_names_to_original_names[name]
parameter.requires_grad = requires_grad_information[gqa_qkv_name]
elif name in peft_parameter_names:
continue
else:
raise ValueError(
f"Could not find information for the parameter {name} to set its `requires_grad` attribute."
Expand Down
139 changes: 17 additions & 122 deletions optimum/neuron/distributed/parallel_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,12 @@
from ..utils.misc import is_main_worker
from ..utils.require_utils import requires_neuronx_distributed
from .utils import (
FakeProj,
OptimumGQAQKVColumnParallelLinear,
WeightInformation,
embedding_to_parallel_embedding,
get_linear_weight_info,
inplace_linears_to_gqa_qkv_column_parallel_linear,
linear_to_parallel_linear,
mark_parameter_init_status_during_parallelization,
maybe_load_weights_to_gqa_qkv_column_parallel_linear,
maybe_load_weights_to_output_projection_when_using_gqa_qkv_column_parallel_linear,
)

Expand Down Expand Up @@ -327,124 +325,6 @@ class ParallelSelfAttention(ParallelLayer):

GQA_QKV_PROJ_NAME: str = "qkv_proj"

@classmethod
def get_layer_qualified_name(cls, model: torch.nn.Module, layer: torch.nn.Module) -> str:
layer_to_fully_qualified_name = {id(module): name for name, module in model.named_modules()}
return layer_to_fully_qualified_name[id(layer)]

@classmethod
def patch_proj_to_use_gqa_qkv_column_parallel_linear(
cls,
attention_layer: torch.nn.Module,
attention_layer_qualified_name: str,
proj_qualified_name: str,
proj_name: str,
output_index: int,
):
fake_proj = FakeProj(
proj_qualified_name,
proj_name,
output_index,
lambda: attention_layer,
attention_layer_qualified_name,
cls.GQA_QKV_PROJ_NAME,
)

setattr(attention_layer, proj_name, fake_proj)

@classmethod
@requires_neuronx_distributed
def replace_qkv_by_gqa_qkv_column_parallel_linear(
cls,
model: "torch.nn.Module",
attention_layer: "torch.nn.Module",
sequence_parallel_enabled: bool = False,
kv_size_multiplier: Optional[int] = None,
skip_linear_weight_load: bool = False,
):
from neuronx_distributed.parallel_layers.parallel_state import get_tensor_model_parallel_size

if cls.NUM_KEY_VALUE_HEADS_NAME is None:
raise ValueError(f"{cls} does not defined the name of the number of key value heads.")
tp_size = get_tensor_model_parallel_size()
num_key_value_heads = getattr(attention_layer, cls.NUM_KEY_VALUE_HEADS_NAME)
if tp_size < num_key_value_heads:
raise ValueError(
f"The TP size ({tp_size}) is lower than the number of key value heads, using "
"GQAQKVColumnParallelLinear is not needed."
)

num_attention_heads = getattr(attention_layer, cls.NUM_ATTENTION_HEADS_NAME)
query_linear = getattr(attention_layer, cls.QUERIES_NAME)
key_linear = getattr(attention_layer, cls.KEYS_NAME)

hidden_size = query_linear.weight.size(1)
query_in_features = query_linear.weight.size(0)
key_value_in_features = key_linear.weight.size(0)

if kv_size_multiplier is None:
kv_size_multiplier = get_tensor_model_parallel_size() // num_key_value_heads

device = query_linear.weight.device
if device == torch.device("meta"):
device = None

gqa_qkv_column_parallel_linear = OptimumGQAQKVColumnParallelLinear(
cls.QUERIES_NAME,
cls.KEYS_NAME,
cls.VALUES_NAME,
cls.OUTPUT_PROJECTION_NAME,
num_attention_heads,
num_key_value_heads,
hidden_size,
[query_in_features, key_value_in_features],
gather_output=False,
bias=query_linear.bias is not None,
sequence_parallel_enabled=sequence_parallel_enabled,
device=device,
kv_size_multiplier=kv_size_multiplier,
)

setattr(attention_layer, cls.GQA_QKV_PROJ_NAME, gqa_qkv_column_parallel_linear)

maybe_load_weights_to_gqa_qkv_column_parallel_linear(
model,
gqa_qkv_column_parallel_linear,
try_from_checkpoint=not skip_linear_weight_load,
try_from_original_layer=not skip_linear_weight_load,
)

attention_layer_qualified_name = cls.get_layer_qualified_name(model, attention_layer)
fake_q_proj = FakeProj(
f"{attention_layer_qualified_name}.{cls.QUERIES_NAME}",
"q",
0,
lambda: attention_layer,
attention_layer_qualified_name,
cls.GQA_QKV_PROJ_NAME,
)
setattr(attention_layer, cls.QUERIES_NAME, fake_q_proj)

fake_k_proj = FakeProj(
f"{attention_layer_qualified_name}.{cls.KEYS_NAME}",
"k",
1,
lambda: attention_layer,
attention_layer_qualified_name,
cls.GQA_QKV_PROJ_NAME,
)
setattr(attention_layer, cls.KEYS_NAME, fake_k_proj)

fake_v_proj = FakeProj(
f"{attention_layer_qualified_name}.{cls.VALUES_NAME}",
"v",
2,
lambda: attention_layer,
attention_layer_qualified_name,
cls.GQA_QKV_PROJ_NAME,
)
setattr(attention_layer, cls.VALUES_NAME, fake_v_proj)

@classmethod
@requires_neuronx_distributed
def _transform(
Expand Down Expand Up @@ -504,9 +384,24 @@ def _transform(
needs_gqa_qkv_column_parallel_linear = False

if needs_gqa_qkv_column_parallel_linear:
cls.replace_qkv_by_gqa_qkv_column_parallel_linear(
tp_size = get_tensor_model_parallel_size()
if cls.NUM_KEY_VALUE_HEADS_NAME is None:
raise ValueError(f"{cls} does not defined the name of the number of key value heads.")
if tp_size < num_key_value_heads:
raise ValueError(
f"The TP size ({tp_size}) is lower than the number of key value heads, using "
"GQAQKVColumnParallelLinear is not needed."
)
inplace_linears_to_gqa_qkv_column_parallel_linear(
model,
layer,
cls.GQA_QKV_PROJ_NAME,
cls.QUERIES_NAME,
cls.KEYS_NAME,
cls.VALUES_NAME,
cls.OUTPUT_PROJECTION_NAME,
num_attention_heads,
num_key_value_heads,
sequence_parallel_enabled=sequence_parallel_enabled,
kv_size_multiplier=kv_size_multiplier,
skip_linear_weight_load=skip_linear_weight_load,
Expand Down
Loading
Loading