Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GQA Attention #59

Merged
merged 4 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 140 additions & 14 deletions protein_lm/modeling/models/apt/model_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
from transformers.pytorch_utils import Conv1D
from transformers.activations import ACT2FN
from transformers.utils import logging

from protein_lm.modeling.utils.rotary_embedding import RotaryEmbedding
from protein_lm.modeling.utils.rerope_embedding import RectifiedRotaryEmbedding
from protein_lm.modeling.utils.alibi_embedding import create_alibi_tensor

from protein_lm.modeling.utils.scaled_rope_embedding import LlamaLinearScalingRotaryEmbedding,LlamaDynamicNTKScalingRotaryEmbedding
from protein_lm.modeling.utils.modules import ContactPredictionHead

logger = logging.get_logger(__name__)

Expand All @@ -28,9 +30,12 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
)
self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False)
self.position_embedding = config.position_embedding
self.rope_scaling_factor=config.rope_scaling_factor
self.rope_theta=config.rope_theta
self.max_sequence_length = config.max_sequence_length
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.attn_type = config.attn_type
self.head_dim = self.embed_dim // self.num_heads
self.split_size = self.embed_dim
if self.head_dim * self.num_heads != self.embed_dim:
Expand All @@ -45,7 +50,15 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
# Layer-wise attention scaling, reordering, and upcasting
self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx
self.layer_idx = layer_idx
self.reorder_and_upcast_attn = config.reorder_and_upcast_attn

if self.attn_type == "gqa":
self.gqa_attn = True
elif self.attn_type == "reorder_and_upcast_attn":
self.reorder_and_upcast_attn = True
elif self.attn_type == "standard":
self.standard_attn = True

#self.reorder_and_upcast_attn = config.reorder_and_upcast_attn #comment out because config now states attn type

if self.is_cross_attention:
self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim)
Expand All @@ -64,6 +77,10 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None):
self.rot_emb=RotaryEmbedding(dim=self.head_dim)
elif self.position_embedding == "rerope":
self.rot_emb = RectifiedRotaryEmbedding(dim=self.head_dim,max_position_embeddings = self.max_positions)
elif self.position_embedding=="linear_rope_scaling":
self.rot_emb=LlamaLinearScalingRotaryEmbedding(dim=self.head_dim,max_position_embeddings=self.max_positions,scaling_factor=self.rope_scaling_factor,base=self.rope_theta)
elif self.position_embedding=="dynamic_rope_scaling":
self.rot_emb=LlamaDynamicNTKScalingRotaryEmbedding(dim=self.head_dim,max_position_embeddings=self.max_positions,scaling_factor=self.rope_scaling_factor,base=self.rope_theta)



Expand Down Expand Up @@ -109,6 +126,87 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bia

return attn_output, attn_weights

def _gqa_attn(self, query, key, value, attention_mask=None,
alibi_bias =None, dropout=0.0):
"""Group Query Attention implementation."""

# Check for potential issues before moving on
if not query.ndim == key.ndim == value.ndim == 4:
raise ValueError(f"Expected query, key, and value to be 4-dimensional, but got shapes "
f"{query.shape}, {key.shape}, and {value.shape}.")

"""
Expected shapes: (batch_size, num_heads, query_len, query_dim) similar to _upcast_and_reordered_attn
"""
batch_size, num_heads, query_len, query_dim = query.shape


scale_factor = 1.0
if self.scale_attn_weights:
scale_factor /= float(value.size(-1)) ** 0.5
query = query / scale_factor

'''
Determine the number of groups
For example lets say we have 4 queries heads and 2 keys heads, then we have 2 groups
Lets say the number of group are 2 and head are 2,
then reshape the query tensor to (batch_size, (2, 2), query_len, query_dim)
query shape (batch_size, num_groups, num_heads, query_len, query_dim)
attention_weights_grouped shape (batch_size, num_groups, num_heads, query_len, key_len).
attention weights shape: (batch_size, num_heads, query_len, key_len)
'''

n_groups = query.size(1) // key.size(1)

if n_groups > 1:
query_shape = query.shape
grouped_shape = (query_shape[0], n_groups, query_shape[1]//n_groups, query_shape[2], query_shape[3])
query_grouped = query.reshape(grouped_shape)
attn_weights_grouped = torch.matmul(query_grouped, key.transpose(-2, -1))
attn_weights = attn_weights_grouped.sum(dim=1)
#print("attn_weights:", attn_weights.shape)

else:
'''
If the number of groups is 1, then we can use the normal attention function
'''
attn_weights = torch.matmul(query, key.transpose(-2, -1))

if self.scale_attn_by_inverse_layer_idx:
attn_weights = attn_weights / float(self.layer_idx + 1)

if attention_mask is not None:
# Apply the attention mask
'''
Input attention_mask shape: (batch_size, query_len, key_len)
'''
attn_weights += attention_mask.unsqueeze(1) # Unsqueeze to Add head dimension

# Causal masking ensures that the attention mechanism doesn't attend to "future" tokens in sequences.
## Adapted to work with groups and ensure similarity with vanilla attention
if not self.is_cross_attention:
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
mask_value = torch.finfo(attn_weights.dtype).min
mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device)
attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value)

# print("attn_weights:", attn_weights)
# Softmax normalization to get the attention scores
attn_weights = nn.functional.softmax(attn_weights, dim=-1)

if alibi_bias is not None:
attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)]

# Apply dropout if specified
attn_weights = attn_weights.type(value.dtype)
attn_weights = self.attn_dropout(attn_weights)

# Compute the output by multiplying the attention scores with the value tensor.
attn_output = torch.matmul(attn_weights, value)

return attn_output, attn_weights

def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None):
# Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM)
bsz, num_heads, q_seq_len, dk = query.size()
Expand Down Expand Up @@ -196,36 +294,40 @@ def forward(
key = self._split_heads(key, self.num_heads, self.head_dim)
value = self._split_heads(value, self.num_heads, self.head_dim)




kv_seq_len=key.shape[-2]
if layer_past is not None:
kv_seq_len+=layer_past[0].shape[-2]

# Apply rope embedding to query and key
if self.rot_emb:
bsz, q_len, _ = hidden_states.size()
if self.position_embedding == 'rope':
query, key = self.rot_emb(query,key)
elif self.position_embedding == 'rerope':
bsz, q_len, _ = hidden_states.size()
query = query.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query *= ((position_ids + 1)[:, None, :, None].log() / torch.log(torch.tensor(self.max_sequence_length)).item()).clip(1).to(query.dtype)
query, key = self.rot_emb(query,key,seq_len = self.max_sequence_length,position_ids=position_ids)
elif self.position_embedding=="linear_rope_scaling" or self.position_embedding=="dynamic_rope_scaling":
query = query.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
query, key = self.rot_emb(query, key, seq_len=kv_seq_len,position_ids=position_ids)



if layer_past is not None:
past_key, past_value = layer_past
key = torch.cat((past_key, key), dim=-2)
value = torch.cat((past_value, value), dim=-2)


if use_cache is True:
present = (key, value)
else:
present = None

if self.reorder_and_upcast_attn:
attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask,alibi_bias=alibi_bias)
else:
elif self.standard_attn:
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask,alibi_bias=alibi_bias)

elif self.gqa_attn:
attn_output, attn_weights = self._gqa_attn(query, key, value, attention_mask,alibi_bias=alibi_bias)
attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim)
attn_output = self.c_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
Expand All @@ -234,7 +336,7 @@ def forward(
if output_attentions:
outputs += (attn_weights,)

return outputs # a, present, (attentions)
return outputs # a, present, (attentions)


class APTMLP(nn.Module):
Expand Down Expand Up @@ -348,7 +450,7 @@ def __init__(self, config):
self.wte = nn.Embedding(config.vocab_size, self.embed_dim)
self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned"

if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope':
if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling":
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.alibi = None
elif self.position_embedding=="alibi":
Expand All @@ -357,7 +459,7 @@ def __init__(self, config):
alibi = create_alibi_tensor(attn_heads,maxpos)
self.register_buffer('alibi',alibi)
else:
raise Exception(f'position_embedding {self.position_embedding} not supported. Please select one of learned,rope,rerope or alibi')
raise Exception(f'position_embedding {self.position_embedding} not supported. Please select one of learned, rope, rerope, linear rope, dynamic rope or alibi')

self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([APTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)])
Expand Down Expand Up @@ -462,7 +564,7 @@ def forward(
if inputs_embeds is None:
inputs_embeds = self.wte(input_ids)

if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' :
if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling":
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
else:
Expand Down Expand Up @@ -588,6 +690,11 @@ def __init__(self, config):
# Model parallel
self.model_parallel = False
self.device_map = None

self.contact_head=ContactPredictionHead(config.num_hidden_layers * config.num_attention_heads,
prepend_bos=True,
append_eos=True,
eos_idx=2)

# Initialize weights and apply final processing
self.post_init()
Expand Down Expand Up @@ -663,3 +770,22 @@ def forward(
attentions=transformer_outputs.attentions,
cross_attentions=transformer_outputs.cross_attentions,
)

def predict_contacts(self, input_ids):
transformer_outputs = self.transformer(
input_ids,
return_dict=True,
output_attentions=True,
)
# Convert attention tuples to list
attentions_list = list(transformer_outputs.attentions)

# Stack the attention tensors
stacked_attentions = torch.stack(
[attn for attn in attentions_list],
dim=1
)

contact_predictions = self.contact_head(input_ids, stacked_attentions)

return contact_predictions
89 changes: 89 additions & 0 deletions protein_lm/tests/test_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import pytest
import torch
from torch.nn import functional as F

from model_pytorch import APTAttention

class ParameterConfig:
def __init__(self):
self.max_position_embeddings = 512
self.position_embedding = 'rope'
self.max_sequence_length = 512
self.hidden_size = 768
self.num_attention_heads = 12
self.scale_attn_weights = True
self.scale_attn_by_inverse_layer_idx = True
self.reorder_and_upcast_attn = True
self.attn_pdrop = 0.1
self.resid_pdrop = 0.1
self.rope_scaling_factor = 1
self.rope_theta = 1
self.attn_type = 'gqa'


def test_vanilla_attn():
# Initialize with mock config
config = ParameterConfig()
attention = APTAttention(config, is_cross_attention=False, layer_idx=0)

# generate random input tensors
batch_size = 4
seq_length = 100
num_heads = config.num_attention_heads
query_dim = config.hidden_size // num_heads
query = torch.randn(batch_size, num_heads, seq_length, query_dim)
key = torch.randn(batch_size, num_heads, seq_length, query_dim)
value = torch.randn(batch_size, num_heads, seq_length, query_dim)

# Create a random attention mask for testing
attention_mask = torch.ones(batch_size,seq_length, seq_length)
padding_positions = 10
attention_mask[:, -padding_positions:, :] = float('-inf')
attention_mask[:, :, -padding_positions:] = float('-inf')
attention_mask = attention_mask.unsqueeze(1)
# Pass them through the _attn method
attn_output, attn_weights = attention._attn(query, key, value, attention_mask=attention_mask)

# Check the shapes and types of the output
assert isinstance(attn_output, torch.Tensor)
assert attn_output.shape == (batch_size, num_heads, seq_length, query_dim)
assert isinstance(attn_weights, torch.Tensor)
assert attn_weights.shape == (batch_size, num_heads, seq_length, seq_length)
print("Test passed!")

def test_gqa_attn():
# Initialize with mock config
config = ParameterConfig()
attention = APTAttention(config, is_cross_attention=False, layer_idx=0)

# generate random input tensors
batch_size = 4
seq_length = 100
num_heads = config.num_attention_heads
query_dim = config.hidden_size // num_heads
query = torch.randn(batch_size, num_heads, seq_length, query_dim)
key = torch.randn(batch_size, num_heads, seq_length, query_dim)
value = torch.randn(batch_size, num_heads, seq_length, query_dim)

# Create a random attention mask for testing
attention_mask = torch.ones(batch_size,seq_length, seq_length)
padding_positions = 10
attention_mask[:, -padding_positions:, :] = float('-inf')
attention_mask[:, :, -padding_positions:] = float('-inf')

# Pass them through the _gqa_attn method
attn_output, attn_weights = attention._gqa_attn(query, key, value, attention_mask=attention_mask)

# Check the shapes and types of the output
assert isinstance(attn_output, torch.Tensor)
assert attn_output.shape == (batch_size, num_heads, seq_length, query_dim)
assert isinstance(attn_weights, torch.Tensor)
assert attn_weights.shape == (batch_size, num_heads, seq_length, seq_length)
print("Test passed!")


test_gqa_attn()
test_vanilla_attn()