Skip to content

Commit

Permalink
Fix deepseek coder with linear rope type support on GPU (#12709)
Browse files Browse the repository at this point in the history
* Fix deepseek coder with linear rope type

* Style fix

* Move to optimize_pre

* Small fix

* Small fix

* Small fix to not affect other cases

* Style fixes

* Update function name

* Small fix

* Small fix

* Small fix

* Fix for low transformers version first

* Style fix

* Small fix
  • Loading branch information
Oscilloscope98 authored Jan 15, 2025
1 parent 36bf3d8 commit 9d65dcd
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
3 changes: 2 additions & 1 deletion python/llm/src/ipex_llm/transformers/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,8 +995,9 @@ def _optimize_pre(model, qtype=None):
from ipex_llm.transformers.models.gemma2 import merge_qkv
model.apply(merge_qkv)
elif model.config.model_type == "llama":
from ipex_llm.transformers.models.llama import merge_qkv
from ipex_llm.transformers.models.llama import merge_qkv, pre_compute_inv_freq
model.apply(merge_qkv)
model.apply(pre_compute_inv_freq)
elif model.config.model_type == "mllama":
from ipex_llm.transformers.models.mllama import merge_qkv
model.apply(merge_qkv)
Expand Down
15 changes: 13 additions & 2 deletions python/llm/src/ipex_llm/transformers/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,13 @@ def merge_qkv(module: torch.nn.Module):
merge_qkv_base(module, LlamaAttention)


def pre_compute_inv_freq(module: torch.nn.Module):
if module.__class__.__name__ == "LlamaLinearScalingRotaryEmbedding":
if hasattr(module, "scaling_factor"):
module.register_buffer("inv_freq_scaled", None, persistent=False)
module.inv_freq_scaled = module.inv_freq / module.scaling_factor


def llama_attention_forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -147,8 +154,12 @@ def llama_attention_forward(
import xe_addons
if hasattr(self, "rotary_emb"):
# transformers < 4.46
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
if hasattr(self.rotary_emb, "inv_freq_scaled"):
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq_scaled, position_ids,
query_states, key_states)
else:
xe_addons.rotary_half_inplaced(self.rotary_emb.inv_freq, position_ids,
query_states, key_states)
else:
# transformers >= 4.46
cos, sin = position_embeddings
Expand Down

0 comments on commit 9d65dcd

Please sign in to comment.