Skip to content

Commit

Permalink
Merge branch 'main' into vinamarora8-pro-158
Browse files Browse the repository at this point in the history
  • Loading branch information
mazabou committed Nov 29, 2024
2 parents 366dee0 + 0141f20 commit b6f864a
Show file tree
Hide file tree
Showing 3 changed files with 180 additions and 51 deletions.
74 changes: 40 additions & 34 deletions examples/poyo_plus/finetune.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import logging
from typing import List, Optional

import hydra
import lightning as L
import torch
import torch.nn as nn
from torch_optimizer import Lamb
from lightning.pytorch.callbacks import (
LearningRateMonitor,
ModelCheckpoint,
ModelSummary,
)
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig

from torch_brain.nn import compute_loss_or_metric
from torch_brain.registry import MODALITIY_REGISTRY
from torch_brain.utils import callbacks as tbrain_callbacks
from torch_brain.utils import seed_everything, DataModule
from torch_brain.utils import seed_everything
from torch_brain.utils.datamodules import DataModule
from torch_brain.utils.stitcher import StitchEvaluator

from train import POYOTrainWrapper
Expand All @@ -24,15 +24,19 @@
torch.set_float32_matmul_precision("medium")


class FreezeUnfreezePOYO(L.Callback):
class GradualUnfreezing(L.Callback):
r"""A Lightning callback to handle freezing and unfreezing of the model for the
purpose of finetuning the model to new sessions. If this callback is used,
most of the model weights will be frozen initially.
The only parts of the model that will be left unforzen are the unit, and session embeddings.
One we reach the specified epoch (`unfreeze_at_epoch`), the entire model will be unfrozen.
"""

_has_been_frozen: bool = False
frozen_params: Optional[List[nn.Parameter]] = None

def __init__(self, unfreeze_at_epoch: int):
self.enabled = unfreeze_at_epoch != 0
self.unfreeze_at_epoch = unfreeze_at_epoch
self.cli_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -62,32 +66,47 @@ def freeze(cls, model):
return frozen_params

def on_train_start(self, trainer, pl_module):
model = pl_module.model
self.frozen_params = self.freeze(model)
self.cli_log.info(f"POYO Perceiver frozen at epoch 0")
if self.enabled:
self.frozen_params = self.freeze(pl_module.model)
self._has_been_frozen = True
self.cli_log.info(
f"POYO+ Perceiver frozen at epoch 0. "
f"Will stay frozen until epoch {self.unfreeze_at_epoch}."
)

def on_train_epoch_start(self, trainer, pl_module):
if trainer.current_epoch == self.unfreeze_at_epoch:
if not hasattr(self, "frozen_params"):
raise RuntimeError(
"Model has not been frozen yet. Missing `frozen_params` attribute."
)
if self.enabled and (trainer.current_epoch == self.unfreeze_at_epoch):
if not self._has_been_frozen:
raise RuntimeError("Model has not been frozen yet.")

for param in self.frozen_params:
param.requires_grad = True

del self.frozen_params
self.cli_log.info(f"POYO unfrozen at epoch {trainer.current_epoch}")
self.frozen_params = None
self.cli_log.info(
f"POYO+ Perceiver unfrozen at epoch {trainer.current_epoch}"
)


def load_model_from_ckpt(model: nn.Module, ckpt_path: str) -> None:
if ckpt_path is None:
raise ValueError("Must provide a checkpoint path to finetune the model.")

ckpt = torch.load(ckpt_path, map_location="cpu")
state_dict = ckpt["state_dict"]
state_dict = {
k.replace("model.", ""): v
for k, v in state_dict.items()
if k.startswith("model.")
}
model.load_state_dict(state_dict)


@hydra.main(version_base="1.3", config_path="./configs", config_name="train.yaml")
def main(cfg: DictConfig):
# fix random seed, skipped if cfg.seed is None
seed_everything(cfg.seed)

if cfg.fast_dev_run:
cfg.wandb.enable = False

# setup loggers
log = logging.getLogger(__name__)
wandb_logger = None
Expand All @@ -102,17 +121,8 @@ def main(cfg: DictConfig):

# make model
model = hydra.utils.instantiate(cfg.model, readout_specs=MODALITIY_REGISTRY)

# load weights from checkpoint
if cfg.ckpt_path is None:
raise ValueError("Must provide a checkpoint path to finetune the model.")

ckpt = torch.load(cfg.ckpt_path, map_location="cpu")
state_dict = ckpt["state_dict"]
state_dict = {
k.replace("model.", ""): v for k, v in state_dict.items() if "model." in k
}
model.load_state_dict(state_dict)
load_model_from_ckpt(model, cfg.ckpt_path)
log.info(f"Loaded model weights from {cfg.ckpt_path}")

# setup data module
data_module = DataModule(cfg, model.unit_emb.tokenizer, model.session_emb.tokenizer)
Expand Down Expand Up @@ -153,12 +163,9 @@ def main(cfg: DictConfig):
tbrain_callbacks.MemInfo(),
tbrain_callbacks.EpochTimeLogger(),
tbrain_callbacks.ModelWeightStatsLogger(),
GradualUnfreezing(cfg.freeze_perceiver_until_epoch),
]

if cfg.freeze_perceiver_until_epoch != 0:
log.info(f"Freezing model until epoch {cfg.freeze_perceiver_until_epoch}")
callbacks.append(FreezeUnfreezePOYO(cfg.freeze_perceiver_until_epoch))

trainer = L.Trainer(
logger=wandb_logger,
default_root_dir=cfg.log_dir,
Expand All @@ -175,7 +182,6 @@ def main(cfg: DictConfig):
num_nodes=cfg.nodes,
num_sanity_val_steps=0,
limit_val_batches=None, # Ensure no limit on validation batches
fast_dev_run=cfg.fast_dev_run,
)

log.info(
Expand Down
136 changes: 120 additions & 16 deletions tests/test_infinite_vocab_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,19 +122,123 @@ def test_checkpointing():
"word3": 3,
}

### UPDATE: The below test is no longer valid, since words are always sorted.
# load checkpoint after vocab is initialized but the order of the words is different
# emb = InfiniteVocabEmbedding(embedding_dim=128)
# emb.initialize_vocab(["word3", "word1", "word2"])
# state_dict = torch.load("checkpoint.pth")
# emb.load_state_dict(state_dict)

# assert emb.vocab == {
# "NA": 0,
# "word3": 1,
# "word1": 2,
# "word2": 3,
# }

# assert torch.allclose(emb.weight, state_dict["weight"][[0, 3, 1, 2]])
###
# Test extend_vocab with exist_ok=True
emb = InfiniteVocabEmbedding(embedding_dim=128)
emb.initialize_vocab(["word1", "word2", "word3"])
original_weights = emb.weight.clone()

# Try extending with mix of new and existing words
emb.extend_vocab(["word2", "word4", "word1", "word5"], exist_ok=True)

# Check vocab was extended correctly
assert emb.vocab == {
"NA": 0,
"word1": 1,
"word2": 2,
"word3": 3,
"word4": 4,
"word5": 5,
}

# Check original embeddings were preserved
assert torch.allclose(emb.weight[:4], original_weights)

# Test extend_vocab with exist_ok=False raises error
emb = InfiniteVocabEmbedding(embedding_dim=128)
emb.initialize_vocab(["word1", "word2", "word3"])

with pytest.raises(ValueError):
emb.extend_vocab(["word2", "word4"])

# Test subset_vocab with inplace=True
emb = InfiniteVocabEmbedding(embedding_dim=128)
emb.initialize_vocab(["word1", "word2", "word3", "word4"])
original_weights = emb.weight.clone()

emb.subset_vocab(["word2", "word4"], inplace=True)

assert emb.vocab == {
"NA": 0,
"word2": 1,
"word4": 2,
}
# Check embeddings were preserved for kept words
assert torch.allclose(emb.weight[1], original_weights[2]) # word2
assert torch.allclose(emb.weight[2], original_weights[4]) # word4

# Test subset_vocab with inplace=False
emb = InfiniteVocabEmbedding(embedding_dim=128)
emb.initialize_vocab(["word1", "word2", "word3", "word4"])
original_weights = emb.weight.clone()

new_emb = emb.subset_vocab(["word2", "word4"], inplace=False)

# Original embedding should be unchanged
assert emb.vocab == {
"NA": 0,
"word1": 1,
"word2": 2,
"word3": 3,
"word4": 4,
}
assert torch.allclose(emb.weight, original_weights)

# New embedding should have subset
assert new_emb.vocab == {
"NA": 0,
"word2": 1,
"word4": 2,
}
assert torch.allclose(new_emb.weight[1], original_weights[2]) # word2
assert torch.allclose(new_emb.weight[2], original_weights[4]) # word4

# Test subset_vocab with invalid words
with pytest.raises(ValueError):
emb.subset_vocab(["word2", "nonexistent"])

# Test subset_vocab with duplicate words
with pytest.raises(ValueError):
emb.subset_vocab(["word2", "word2"])

# Test subset_vocab with empty list
with pytest.raises(AssertionError):
emb.subset_vocab([])


def test_vocab_ordering():
"""Test that vocabulary ordering behavior works consistently across all operations"""

# Test initial vocab creation maintains order
emb = InfiniteVocabEmbedding(embedding_dim=128)
emb.initialize_vocab(["word3", "word1", "word2"])
assert list(emb.vocab.keys()) == ["NA", "word3", "word1", "word2"]

# Test extend_vocab maintains existing order and appends new words
emb.extend_vocab(["word5", "word4"])
assert list(emb.vocab.keys()) == ["NA", "word3", "word1", "word2", "word5", "word4"]

# Test subset_vocab maintains relative order of selected words
subset_emb = emb.subset_vocab(["word2", "word5", "word1"], inplace=False)
assert list(subset_emb.vocab.keys()) == ["NA", "word2", "word5", "word1"]


# TODO: fix InfiniteVocabEmbedding.load_state_dict()
# The below test is currently failing.
# def test_state_dict_loading():
# # Test state dict loading preserves embeddings while allowing different order
# emb1 = InfiniteVocabEmbedding(embedding_dim=128)
# emb1.initialize_vocab(["word1", "word2", "word3"])
# original_weights = emb1.weight.clone()

# emb2 = InfiniteVocabEmbedding(embedding_dim=128)
# emb2.initialize_vocab(["word3", "word1", "word2"])

# # Load state dict and verify embeddings are correctly remapped
# emb2.load_state_dict(emb1.state_dict())
# # Need to use tokenizer() since vocab dict order may be different
# assert torch.allclose(emb2.weight[emb2.tokenizer("word1")],
# original_weights[emb1.tokenizer("word1")])
# assert torch.allclose(emb2.weight[emb2.tokenizer("word2")],
# original_weights[emb1.tokenizer("word2")])
# assert torch.allclose(emb2.weight[emb2.tokenizer("word3")],
# original_weights[emb1.tokenizer("word3")])
21 changes: 20 additions & 1 deletion torch_brain/nn/infinite_vocab_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,22 @@ def extend_vocab(self, vocab: List[str], exist_ok=False):
# update tokenizer
self.vocab.update(
OrderedDict(
zip(vocab, range(len(self.vocab), len(self.vocab) + len(vocab)))
zip(new_words, range(len(self.vocab), len(self.vocab) + len(new_words)))
)
)

# check that the dictionary update was done correctly
assert len(self.vocab) == len(self.weight) + len(new_words), (
f"Expected vocabulary length to be {len(self.weight) + len(new_words)}, "
f"but got {len(self.vocab)}"
)

# check that the largest value in the vocab matches its size
assert max(self.vocab.values()) == len(self.vocab) - 1, (
f"Expected maximum vocabulary index to be {len(self.vocab) - 1}, "
f"but got {max(self.vocab.values())}"
)

# make a copy of existing embeddings
embeddings_for_existing_words = self.weight.clone().detach()

Expand Down Expand Up @@ -350,6 +362,13 @@ def _hook_vocab_on_load_state_dict(
error_msgs,
):
if not self.is_lazy():
# pop the vocabulary from the state_dict
# popping is done because by default pytorch does not know how to load
# vocab which is an OrderedDict, so we cannot keep it inside of the state_dict.
# however, if we pop it, it will also be removed in the original state_dict
# which can have unintended side effects, noteably this will alter the actual
# state_dict in the checkpoint or the other module being used.
# TODO: find a way to keep the vocab in the state_dict
incoming_vocab = state_dict.pop(prefix + "vocab")

# incoming_vocab and self.vocab might have the same keys but in a different order
Expand Down

0 comments on commit b6f864a

Please sign in to comment.