Skip to content

Commit

Permalink
Improved docs
Browse files Browse the repository at this point in the history
  • Loading branch information
fuzhanrahmanian committed Mar 16, 2024
1 parent 34b9bd5 commit 9600180
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions arcana/procedures/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,20 +25,20 @@ def __init__(self, tl_strategy="decoder") -> None:


def load_model(self):
"""Load the model from the path"""
"""Load the pre-trained model from the path"""
return torch.load(self.general_config.pretrained_model)


def unfreeze_decoder(self):
"""Unfreeze the encoder"""
"""Unfreeze the decoder layer by making the autograd true for the decoder layer parameters."""
for param in self.pretrained_model.parameters():
param.requires_grad = False
for _, param in self.pretrained_model.decoder.named_parameters():
param.requires_grad = True


def unfreeze_fully_connected(self):
"""Unfreeze the fully connected layer"""
"""Unfreeze the fully connected layer by making the autograd true for the parameters."""
for param in self.pretrained_model.parameters():
param.requires_grad = False
for param in self.pretrained_model.decoder.fc_layer_pred_1.parameters():
Expand All @@ -50,7 +50,7 @@ def unfreeze_fully_connected(self):


def unfreeze_fc_and_attention(self):
"""Freeze the fully connected layer and the attention layer in the decoder"""
"""Unfreeze the fully connected layer and the attention layer in the decoder by making the autograd true for the parameters."""
# FIXME: fix the attention for multihead
self.unfreeze_fully_connected()
for name, param in self.pretrained_model.decoder.named_parameters():
Expand Down

0 comments on commit 9600180

Please sign in to comment.