diff --git a/megatron/model/transformer.py b/megatron/model/transformer.py index 7627e13b6..5a4586309 100644 --- a/megatron/model/transformer.py +++ b/megatron/model/transformer.py @@ -750,9 +750,13 @@ def sparse_attention(self, query_layer, key_layer, value_layer, attention_mask): rpe = self.rpe(query_layer.size(0), key_layer.size(0)) else: rpe = None - return self.sparse_attn( + attn_scores = self.sparse_attn( query_layer, key_layer, value_layer, attn_mask=attn_mask, rpe=rpe ) + # apply dropout + if self.training: + attn_scores = self.attention_dropout(attn_scores) + return attn_scores def gqa_project(self, hidden_states, attention_mask, layer_past=None): # QKV projection and separation into separate Q/K/V layers for GQA,