Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GQA Attention #59

Merged
merged 4 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 86 additions & 1 deletion protein_lm/modeling/models/apt/model_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,92 @@ def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, hea

return attn_output, attn_weights


def _gqa_attn(self, query, key, value, attention_mask=None,
alibi_bias =None, dropout=0.0):

"""Group Query Attention implementation."""

# Check for potential issues before moving on
if not query.ndim == key.ndim == value.ndim == 4:
raise ValueError(f"Expected query, key, and value to be 4-dimensional, but got shapes "
f"{query.shape}, {key.shape}, and {value.shape}.")

"""
Expected shapes: (batch_size, num_heads, query_len, query_dim) similar to _upcast_and_reordered_attn
"""
batch_size, num_heads, query_len, query_dim = query.shape


scale_factor = 1.0
if self.scale_attn_weights:
scale_factor /= float(value.size(-1)) ** 0.5

if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)

query = query / scale_factor

'''
Determine the number of groups
For example lets say we have 4 queries heads and 2 keys heads, then we have 2 groups
Lets say the number of group are 2 and head are 2,
then reshape the query tensor to (batch_size, (2, 2), query_len, query_dim)
query shape (batch_size, num_groups, num_heads, query_len, query_dim)
attention_weights_grouped shape (batch_size, num_groups, num_heads, query_len, key_len).
attention weights shape: (batch_size, num_heads, query_len, key_len)
'''

n_groups = query.size(1) // key.size(1)

if n_groups > 1:
query_shape = query.shape
grouped_shape = (query_shape[0], n_groups, query_shape[1]//n_groups, query_shape[2], query_shape[3])
query_grouped = query.reshape(grouped_shape)
attn_weights_grouped = torch.matmul(query_grouped, key.transpose(-2, -1))
attn_weights = attn_weights_grouped.sum(dim=1)
#print("attn_weights:", attn_weights.shape)

else:
'''
If the number of groups is 1, then we can use the normal attention function
'''
attn_weights = torch.matmul(query, key.transpose(-2, -1))


if attention_mask is not None:
# Apply the attention mask
'''
Input attention_mask shape: (batch_size, 1, query_len, key_len)
'''
attn_weights += attention_mask.unsqueeze(1) # Unsqueeze to Add head dimension

# Causal masking ensures that the attention mechanism doesn't attend to "future" tokens in sequences.

if not self.is_cross_attention:
causal_mask = torch.ones((query.size(0), query.size(2), key.size(2)), device=query.device, dtype=torch.bool).tril_()
# causal mask is lower traingular matrix with 1s on the lower triangle and 0s on the upper triangle
mask_value = torch.finfo(attn_weights.dtype).min
# Need to be a tensor, otherwise we get error: `RuntimeError: expected scalar type float but found double`.
# Need to be on the same device, otherwise `RuntimeError: ..., x and y to be on the same device`
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
# print("mask_value:", mask_value)
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)
# print("attn_weights:", attn_weights)
# Softmax normalization to get the attention scores
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

if alibi_bias is not None:
attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)]

# Apply dropout if specified
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)

# Compute the output by multiplying the attention scores with the value tensor.
attn_output = torch.matmul(attn_weights, value)

return attn_output, attn_weights

def forward(
self,
hidden_states: Optional[Tuple[torch.FloatTensor]],
Expand Down
53 changes: 53 additions & 0 deletions protein_lm/tests/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import pytest
import torch
from torch.nn import functional as F

from model_pytorch import APTAttention

class ParameterConfig:
def __init__(self):
self.max_position_embeddings = 512
self.position_embedding = 'rope'
self.max_sequence_length = 512
self.hidden_size = 768
self.num_attention_heads = 12
self.scale_attn_weights = True
self.scale_attn_by_inverse_layer_idx = True
self.reorder_and_upcast_attn = True
self.attn_pdrop = 0.1
self.resid_pdrop = 0.1


def test_gqa_attn():
# 1. Initialize with mock config
config = ParameterConfig()
attention = APTAttention(config, is_cross_attention=False, layer_idx=0)

# 2. Generate random input tensors
batch_size = 4
seq_length = 100
num_heads = config.num_attention_heads # Using the number of attention heads from config
query_dim = config.hidden_size // num_heads
query = torch.randn(batch_size, num_heads, seq_length, query_dim)
key = torch.randn(batch_size, num_heads, seq_length, query_dim)
value = torch.randn(batch_size, num_heads, seq_length, query_dim)

# Create a random attention mask (if required)
attention_mask = torch.ones(batch_size, 1, seq_length, seq_length)
padding_positions = 10
attention_mask[:, :, -padding_positions:, :] = float('-inf')
attention_mask[:, :, :, -padding_positions:] = float('-inf')

# 3. Pass them through the _gqa_attn method
attn_output, attn_weights = attention._gqa_attn(query, key, value, attention_mask=attention_mask)

# 4. Check the shapes and types of the output
assert isinstance(attn_output, torch.Tensor)
assert attn_output.shape == (batch_size, num_heads, seq_length, query_dim)
assert isinstance(attn_weights, torch.Tensor)
assert attn_weights.shape == (batch_size, num_heads, seq_length, seq_length)
print("Test passed!")

test_gqa_attn()