diff --git a/protein_lm/modeling/models/apt/model_pytorch.py b/protein_lm/modeling/models/apt/model_pytorch.py index 47c2398..f450d4a 100644 --- a/protein_lm/modeling/models/apt/model_pytorch.py +++ b/protein_lm/modeling/models/apt/model_pytorch.py @@ -8,10 +8,12 @@ from transformers.pytorch_utils import Conv1D from transformers.activations import ACT2FN from transformers.utils import logging + from protein_lm.modeling.utils.rotary_embedding import RotaryEmbedding from protein_lm.modeling.utils.rerope_embedding import RectifiedRotaryEmbedding from protein_lm.modeling.utils.alibi_embedding import create_alibi_tensor - +from protein_lm.modeling.utils.scaled_rope_embedding import LlamaLinearScalingRotaryEmbedding,LlamaDynamicNTKScalingRotaryEmbedding +from protein_lm.modeling.utils.modules import ContactPredictionHead logger = logging.get_logger(__name__) @@ -28,9 +30,12 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): ) self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) self.position_embedding = config.position_embedding + self.rope_scaling_factor=config.rope_scaling_factor + self.rope_theta=config.rope_theta self.max_sequence_length = config.max_sequence_length self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads + self.attn_type = config.attn_type self.head_dim = self.embed_dim // self.num_heads self.split_size = self.embed_dim if self.head_dim * self.num_heads != self.embed_dim: @@ -45,7 +50,15 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): # Layer-wise attention scaling, reordering, and upcasting self.scale_attn_by_inverse_layer_idx = config.scale_attn_by_inverse_layer_idx self.layer_idx = layer_idx - self.reorder_and_upcast_attn = config.reorder_and_upcast_attn + + if self.attn_type == "gqa": + self.gqa_attn = True + elif self.attn_type == "reorder_and_upcast_attn": + self.reorder_and_upcast_attn = True + elif self.attn_type == "standard": + self.standard_attn = True + + #self.reorder_and_upcast_attn = config.reorder_and_upcast_attn #comment out because config now states attn type if self.is_cross_attention: self.c_attn = Conv1D(2 * self.embed_dim, self.embed_dim) @@ -64,6 +77,10 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.rot_emb=RotaryEmbedding(dim=self.head_dim) elif self.position_embedding == "rerope": self.rot_emb = RectifiedRotaryEmbedding(dim=self.head_dim,max_position_embeddings = self.max_positions) + elif self.position_embedding=="linear_rope_scaling": + self.rot_emb=LlamaLinearScalingRotaryEmbedding(dim=self.head_dim,max_position_embeddings=self.max_positions,scaling_factor=self.rope_scaling_factor,base=self.rope_theta) + elif self.position_embedding=="dynamic_rope_scaling": + self.rot_emb=LlamaDynamicNTKScalingRotaryEmbedding(dim=self.head_dim,max_position_embeddings=self.max_positions,scaling_factor=self.rope_scaling_factor,base=self.rope_theta) @@ -109,6 +126,87 @@ def _attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bia return attn_output, attn_weights + def _gqa_attn(self, query, key, value, attention_mask=None, + alibi_bias =None, dropout=0.0): + """Group Query Attention implementation.""" + + # Check for potential issues before moving on + if not query.ndim == key.ndim == value.ndim == 4: + raise ValueError(f"Expected query, key, and value to be 4-dimensional, but got shapes " + f"{query.shape}, {key.shape}, and {value.shape}.") + + """ + Expected shapes: (batch_size, num_heads, query_len, query_dim) similar to _upcast_and_reordered_attn + """ + batch_size, num_heads, query_len, query_dim = query.shape + + + scale_factor = 1.0 + if self.scale_attn_weights: + scale_factor /= float(value.size(-1)) ** 0.5 + query = query / scale_factor + + ''' + Determine the number of groups + For example lets say we have 4 queries heads and 2 keys heads, then we have 2 groups + Lets say the number of group are 2 and head are 2, + then reshape the query tensor to (batch_size, (2, 2), query_len, query_dim) + query shape (batch_size, num_groups, num_heads, query_len, query_dim) + attention_weights_grouped shape (batch_size, num_groups, num_heads, query_len, key_len). + attention weights shape: (batch_size, num_heads, query_len, key_len) + ''' + + n_groups = query.size(1) // key.size(1) + + if n_groups > 1: + query_shape = query.shape + grouped_shape = (query_shape[0], n_groups, query_shape[1]//n_groups, query_shape[2], query_shape[3]) + query_grouped = query.reshape(grouped_shape) + attn_weights_grouped = torch.matmul(query_grouped, key.transpose(-2, -1)) + attn_weights = attn_weights_grouped.sum(dim=1) + #print("attn_weights:", attn_weights.shape) + + else: + ''' + If the number of groups is 1, then we can use the normal attention function + ''' + attn_weights = torch.matmul(query, key.transpose(-2, -1)) + + if self.scale_attn_by_inverse_layer_idx: + attn_weights = attn_weights / float(self.layer_idx + 1) + + if attention_mask is not None: + # Apply the attention mask + ''' + Input attention_mask shape: (batch_size, query_len, key_len) + ''' + attn_weights += attention_mask.unsqueeze(1) # Unsqueeze to Add head dimension + + # Causal masking ensures that the attention mechanism doesn't attend to "future" tokens in sequences. + ## Adapted to work with groups and ensure similarity with vanilla attention + if not self.is_cross_attention: + query_length, key_length = query.size(-2), key.size(-2) + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + mask_value = torch.finfo(attn_weights.dtype).min + mask_value = torch.full([], mask_value, dtype=attn_weights.dtype).to(attn_weights.device) + attn_weights = torch.where(causal_mask, attn_weights.to(attn_weights.dtype), mask_value) + + # print("attn_weights:", attn_weights) + # Softmax normalization to get the attention scores + attn_weights = nn.functional.softmax(attn_weights, dim=-1) + + if alibi_bias is not None: + attn_weights = attn_weights + alibi_bias[:,:,:attn_weights.size(-1)] + + # Apply dropout if specified + attn_weights = attn_weights.type(value.dtype) + attn_weights = self.attn_dropout(attn_weights) + + # Compute the output by multiplying the attention scores with the value tensor. + attn_output = torch.matmul(attn_weights, value) + + return attn_output, attn_weights + def _upcast_and_reordered_attn(self, query, key, value, attention_mask=None, head_mask=None,alibi_bias=None): # Use `torch.baddbmm` (a bit more efficient w/ alpha param for scaling -- from Megatron-LM) bsz, num_heads, q_seq_len, dk = query.size() @@ -196,26 +294,29 @@ def forward( key = self._split_heads(key, self.num_heads, self.head_dim) value = self._split_heads(value, self.num_heads, self.head_dim) - - - + kv_seq_len=key.shape[-2] + if layer_past is not None: + kv_seq_len+=layer_past[0].shape[-2] + # Apply rope embedding to query and key if self.rot_emb: + bsz, q_len, _ = hidden_states.size() if self.position_embedding == 'rope': query, key = self.rot_emb(query,key) elif self.position_embedding == 'rerope': - bsz, q_len, _ = hidden_states.size() query = query.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) query *= ((position_ids + 1)[:, None, :, None].log() / torch.log(torch.tensor(self.max_sequence_length)).item()).clip(1).to(query.dtype) query, key = self.rot_emb(query,key,seq_len = self.max_sequence_length,position_ids=position_ids) + elif self.position_embedding=="linear_rope_scaling" or self.position_embedding=="dynamic_rope_scaling": + query = query.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + query, key = self.rot_emb(query, key, seq_len=kv_seq_len,position_ids=position_ids) - - if layer_past is not None: past_key, past_value = layer_past key = torch.cat((past_key, key), dim=-2) value = torch.cat((past_value, value), dim=-2) + if use_cache is True: present = (key, value) else: @@ -223,9 +324,10 @@ def forward( if self.reorder_and_upcast_attn: attn_output, attn_weights = self._upcast_and_reordered_attn(query, key, value, attention_mask, head_mask,alibi_bias=alibi_bias) - else: + elif self.standard_attn: attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask,alibi_bias=alibi_bias) - + elif self.gqa_attn: + attn_output, attn_weights = self._gqa_attn(query, key, value, attention_mask,alibi_bias=alibi_bias) attn_output = self._merge_heads(attn_output, self.num_heads, self.head_dim) attn_output = self.c_proj(attn_output) attn_output = self.resid_dropout(attn_output) @@ -234,7 +336,7 @@ def forward( if output_attentions: outputs += (attn_weights,) - return outputs # a, present, (attentions) + return outputs # a, present, (attentions) class APTMLP(nn.Module): @@ -348,7 +450,7 @@ def __init__(self, config): self.wte = nn.Embedding(config.vocab_size, self.embed_dim) self.position_embedding = config.position_embedding if hasattr(config, "position_embedding") else "learned" - if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope': + if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling": self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.alibi = None elif self.position_embedding=="alibi": @@ -357,7 +459,7 @@ def __init__(self, config): alibi = create_alibi_tensor(attn_heads,maxpos) self.register_buffer('alibi',alibi) else: - raise Exception(f'position_embedding {self.position_embedding} not supported. Please select one of learned,rope,rerope or alibi') + raise Exception(f'position_embedding {self.position_embedding} not supported. Please select one of learned, rope, rerope, linear rope, dynamic rope or alibi') self.drop = nn.Dropout(config.embd_pdrop) self.h = nn.ModuleList([APTBlock(config, layer_idx=i) for i in range(config.num_hidden_layers)]) @@ -462,7 +564,7 @@ def forward( if inputs_embeds is None: inputs_embeds = self.wte(input_ids) - if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' : + if self.position_embedding=="learned" or self.position_embedding == 'rope' or self.position_embedding == 'rerope' or self.position_embedding=="linear_rope_scaling" or self.position_embedding =="dynamic_rope_scaling": position_embeds = self.wpe(position_ids) hidden_states = inputs_embeds + position_embeds else: @@ -588,6 +690,11 @@ def __init__(self, config): # Model parallel self.model_parallel = False self.device_map = None + + self.contact_head=ContactPredictionHead(config.num_hidden_layers * config.num_attention_heads, + prepend_bos=True, + append_eos=True, + eos_idx=2) # Initialize weights and apply final processing self.post_init() @@ -663,3 +770,22 @@ def forward( attentions=transformer_outputs.attentions, cross_attentions=transformer_outputs.cross_attentions, ) + + def predict_contacts(self, input_ids): + transformer_outputs = self.transformer( + input_ids, + return_dict=True, + output_attentions=True, + ) + # Convert attention tuples to list + attentions_list = list(transformer_outputs.attentions) + + # Stack the attention tensors + stacked_attentions = torch.stack( + [attn for attn in attentions_list], + dim=1 + ) + + contact_predictions = self.contact_head(input_ids, stacked_attentions) + + return contact_predictions \ No newline at end of file diff --git a/protein_lm/tests/test_attention.py b/protein_lm/tests/test_attention.py new file mode 100644 index 0000000..1e85ffe --- /dev/null +++ b/protein_lm/tests/test_attention.py @@ -0,0 +1,89 @@ +import pytest +import torch +from torch.nn import functional as F + +from model_pytorch import APTAttention + +class ParameterConfig: + def __init__(self): + self.max_position_embeddings = 512 + self.position_embedding = 'rope' + self.max_sequence_length = 512 + self.hidden_size = 768 + self.num_attention_heads = 12 + self.scale_attn_weights = True + self.scale_attn_by_inverse_layer_idx = True + self.reorder_and_upcast_attn = True + self.attn_pdrop = 0.1 + self.resid_pdrop = 0.1 + self.rope_scaling_factor = 1 + self.rope_theta = 1 + self.attn_type = 'gqa' + + +def test_vanilla_attn(): + # Initialize with mock config + config = ParameterConfig() + attention = APTAttention(config, is_cross_attention=False, layer_idx=0) + + # generate random input tensors + batch_size = 4 + seq_length = 100 + num_heads = config.num_attention_heads + query_dim = config.hidden_size // num_heads + query = torch.randn(batch_size, num_heads, seq_length, query_dim) + key = torch.randn(batch_size, num_heads, seq_length, query_dim) + value = torch.randn(batch_size, num_heads, seq_length, query_dim) + + # Create a random attention mask for testing + attention_mask = torch.ones(batch_size,seq_length, seq_length) + padding_positions = 10 + attention_mask[:, -padding_positions:, :] = float('-inf') + attention_mask[:, :, -padding_positions:] = float('-inf') + attention_mask = attention_mask.unsqueeze(1) + # Pass them through the _attn method + attn_output, attn_weights = attention._attn(query, key, value, attention_mask=attention_mask) + + # Check the shapes and types of the output + assert isinstance(attn_output, torch.Tensor) + assert attn_output.shape == (batch_size, num_heads, seq_length, query_dim) + assert isinstance(attn_weights, torch.Tensor) + assert attn_weights.shape == (batch_size, num_heads, seq_length, seq_length) + print("Test passed!") + +def test_gqa_attn(): + # Initialize with mock config + config = ParameterConfig() + attention = APTAttention(config, is_cross_attention=False, layer_idx=0) + + # generate random input tensors + batch_size = 4 + seq_length = 100 + num_heads = config.num_attention_heads + query_dim = config.hidden_size // num_heads + query = torch.randn(batch_size, num_heads, seq_length, query_dim) + key = torch.randn(batch_size, num_heads, seq_length, query_dim) + value = torch.randn(batch_size, num_heads, seq_length, query_dim) + + # Create a random attention mask for testing + attention_mask = torch.ones(batch_size,seq_length, seq_length) + padding_positions = 10 + attention_mask[:, -padding_positions:, :] = float('-inf') + attention_mask[:, :, -padding_positions:] = float('-inf') + + # Pass them through the _gqa_attn method + attn_output, attn_weights = attention._gqa_attn(query, key, value, attention_mask=attention_mask) + + # Check the shapes and types of the output + assert isinstance(attn_output, torch.Tensor) + assert attn_output.shape == (batch_size, num_heads, seq_length, query_dim) + assert isinstance(attn_weights, torch.Tensor) + assert attn_weights.shape == (batch_size, num_heads, seq_length, seq_length) + print("Test passed!") + + +test_gqa_attn() +test_vanilla_attn() + + +