Skip to content

Commit

Permalink
Add support for dropout in sparse attention (#1312)
Browse files Browse the repository at this point in the history
  • Loading branch information
michaelc-yu authored Nov 16, 2024
1 parent 797a4ab commit 50e74cd
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,9 +750,13 @@ def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask):
rpe = self.rpe(query_layer.size(0), key_layer.size(0))
else:
rpe = None
return self.sparse_attn(
attn_scores = self.sparse_attn(
query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe
)
# apply dropout
if self.training:
attn_scores = self.attention_dropout(attn_scores)
return attn_scores

def gqa_project(self, hidden_states, attention_mask, layer_past=None):
# QKV projection and separation into separate Q/K/V layers for GQA,
Expand Down

0 comments on commit 50e74cd

Please sign in to comment.