diff --git a/protein_lm/configs/train/toy_hf.yaml b/protein_lm/configs/train/toy_hf.yaml index 8ee01d4..f1c922e 100644 --- a/protein_lm/configs/train/toy_hf.yaml +++ b/protein_lm/configs/train/toy_hf.yaml @@ -9,14 +9,6 @@ dataset: sequence_column_name: "sequence" max_sequence_length: 10 do_curriculum_learning: false - curriculum_learning_strategy: - - 'sequence_length' - - 'ppl' - - 'plddt' - curriculum_learning_column_name: - - 'sequence_length' - - 'ppl' - - 'plddt' # corresponds to HuggingFace's TrainingArguments training_arguments: diff --git a/protein_lm/configs/train/toy_localcsv.yaml b/protein_lm/configs/train/toy_localcsv.yaml index a2f9d8b..116a1f4 100644 --- a/protein_lm/configs/train/toy_localcsv.yaml +++ b/protein_lm/configs/train/toy_localcsv.yaml @@ -9,14 +9,6 @@ dataset: sequence_column_name: "sequence" max_sequence_length: 10 do_curriculum_learning: false - curriculum_learning_strategy: - - 'sequence_length' - - 'ppl' - - 'plddt' - curriculum_learning_column_name: - - 'sequence_length' - - 'ppl' - - 'plddt' # corresponds to HuggingFace's TrainingArguments training_arguments: diff --git a/protein_lm/modeling/getters/dataset.py b/protein_lm/modeling/getters/dataset.py index 04c1cbe..57cc6ba 100644 --- a/protein_lm/modeling/getters/dataset.py +++ b/protein_lm/modeling/getters/dataset.py @@ -30,8 +30,8 @@ class DatasetConfig(BaseModel): max_sequence_length: int do_curriculum_learning: bool - curriculum_learning_strategy: str - curriculum_learning_column_name: str + curriculum_learning_strategy: Optional[str] = None + curriculum_learning_column_name: Optional[str] = None def set_input_ids( diff --git a/protein_lm/modeling/models/apt/config.py b/protein_lm/modeling/models/apt/config.py index 36f2c04..4c9d337 100644 --- a/protein_lm/modeling/models/apt/config.py +++ b/protein_lm/modeling/models/apt/config.py @@ -1,5 +1,5 @@ from transformers import GPT2Config - +from typing import Literal class APTConfig(GPT2Config): """ @@ -8,9 +8,10 @@ class APTConfig(GPT2Config): def __init__( self, - position_embedding="learned", + position_embedding: Literal["alibi", "learned", "rope", "rerope", "linear_rope_scaling", "dynamic_rope_scaling"]="learned", tokenizer=None, max_sequence_length = 1024, + attn_type="standard", **kwargs ): super().__init__(**kwargs) @@ -18,4 +19,5 @@ def __init__( self.position_embedding = position_embedding self.tokenizer = tokenizer self.max_sequence_length = max_sequence_length + self.attn_type = attn_type diff --git a/protein_lm/modeling/models/apt/model_pytorch.py b/protein_lm/modeling/models/apt/model_pytorch.py index f450d4a..11de2cc 100644 --- a/protein_lm/modeling/models/apt/model_pytorch.py +++ b/protein_lm/modeling/models/apt/model_pytorch.py @@ -30,8 +30,7 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): ) self.register_buffer("masked_bias", torch.tensor(-1e4), persistent=False) self.position_embedding = config.position_embedding - self.rope_scaling_factor=config.rope_scaling_factor - self.rope_theta=config.rope_theta + self.max_sequence_length = config.max_sequence_length self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads @@ -72,15 +71,18 @@ def __init__(self, config, is_cross_attention=False, layer_idx=None): self.pruned_heads = set() - self.rot_emb=None - if self.position_embedding == "rope": - self.rot_emb=RotaryEmbedding(dim=self.head_dim) - elif self.position_embedding == "rerope": - self.rot_emb = RectifiedRotaryEmbedding(dim=self.head_dim,max_position_embeddings = self.max_positions) - elif self.position_embedding=="linear_rope_scaling": - self.rot_emb=LlamaLinearScalingRotaryEmbedding(dim=self.head_dim,max_position_embeddings=self.max_positions,scaling_factor=self.rope_scaling_factor,base=self.rope_theta) - elif self.position_embedding=="dynamic_rope_scaling": - self.rot_emb=LlamaDynamicNTKScalingRotaryEmbedding(dim=self.head_dim,max_position_embeddings=self.max_positions,scaling_factor=self.rope_scaling_factor,base=self.rope_theta) + self.rot_emb = None + if self.position_embedding in ["rope", "rerope", "linear_rope_scaling", "dynamic_rope_scaling"]: + self.rope_scaling_factor = config.rope_scaling_factor + self.rope_theta = config.rope_theta + if self.position_embedding == "rope": + self.rot_emb=RotaryEmbedding(dim=self.head_dim) + elif self.position_embedding == "rerope": + self.rot_emb = RectifiedRotaryEmbedding(dim=self.head_dim,max_position_embeddings = self.max_positions) + elif self.position_embedding=="linear_rope_scaling": + self.rot_emb=LlamaLinearScalingRotaryEmbedding(dim=self.head_dim,max_position_embeddings=self.max_positions,scaling_factor=self.rope_scaling_factor,base=self.rope_theta) + elif self.position_embedding=="dynamic_rope_scaling": + self.rot_emb=LlamaDynamicNTKScalingRotaryEmbedding(dim=self.head_dim,max_position_embeddings=self.max_positions,scaling_factor=self.rope_scaling_factor,base=self.rope_theta)