Skip to content

Commit

Permalink
Merge branch 'fix/embeddings_load_shape_mismatch' into 'main'
Browse files Browse the repository at this point in the history
Fix loading checkpoint with changed item embeddings

See merge request ai-lab-pmo/mltools/recsys/RePlay!201
  • Loading branch information
OnlyDeniko committed Apr 10, 2024
2 parents 065b81a + 90a33ed commit 42a7c0c
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 9 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
lightning_logs/
.logs/
.checkpoints/
*.ckpt

# Poetry (since they are generated from template automatically)
/poetry.lock
Expand Down
5 changes: 2 additions & 3 deletions examples/09_sasrec_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -861,12 +861,11 @@
" postprocessors=[RemoveSeenItems(sequential_validation_dataset)]\n",
")\n",
"\n",
"csv_logger = CSVLogger(save_dir=\".logs/train\", name=\"SASRec_example\")\n",
"\n",
"trainer = L.Trainer(\n",
" max_epochs=100,\n",
" callbacks=[checkpoint_callback, validation_metrics_callback],\n",
" logger=csv_logger,\n",
" logger=False,\n",
" log_every_n_steps=1000,\n",
")\n",
"\n",
"train_dataloader = DataLoader(\n",
Expand Down
2 changes: 1 addition & 1 deletion replay/models/nn/sequential/bert4rec/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,6 @@ def append_item_embeddings(self, item_embeddings: torch.Tensor):

def _set_new_item_embedder_to_model(self, weights_new: torch.nn.Embedding, new_vocab_size: int):
self._model.item_embedder.cat_embeddings[self._model.schema.item_id_feature_name] = weights_new

if self._model.enable_embedding_tying is True:
self._model._head._item_embedder = self._model.item_embedder
new_bias = torch.Tensor(new_vocab_size)
Expand All @@ -521,3 +520,4 @@ def _set_new_item_embedder_to_model(self, weights_new: torch.nn.Embedding, new_v

self._vocab_size = new_vocab_size
self._model.item_count = new_vocab_size
self._schema.item_id_features[self._schema.item_id_feature_name]._set_cardinality(new_vocab_size)
8 changes: 8 additions & 0 deletions replay/models/nn/sequential/sasrec/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,3 +486,11 @@ def _set_new_item_embedder_to_model(self, new_embedding: torch.nn.Embedding, new
self._model.item_count = new_vocab_size
self._model.padding_idx = new_vocab_size
self._model.masking.padding_idx = new_vocab_size
self._model.candidates_to_score = torch.tensor(
list(range(new_embedding.weight.data.shape[0] - 1)),
device=self._model.candidates_to_score.device,
dtype=torch.long,
)
self._schema.item_id_features[self._schema.item_id_feature_name]._set_cardinality(
new_embedding.weight.data.shape[0] - 1
)
2 changes: 1 addition & 1 deletion tests/models/nn/sequential/bert4rec/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def fitted_bert4rec(feature_schema_for_bert4rec):
train_dataset = Dataset(feature_schema=feature_schema_for_bert4rec, interactions=data)
tensor_schema = TensorSchema(
TensorFeatureInfo(
name="item_id_seq",
name="item_id",
is_seq=True,
cardinality=train_dataset.item_count,
feature_type=FeatureType.CATEGORICAL,
Expand Down
14 changes: 14 additions & 0 deletions tests/models/nn/sequential/bert4rec/test_bert4rec_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,20 @@ def test_bert4rec_fine_tuning_on_new_items_by_appending(request, fitted_bert4rec
assert torch.eq(only_new_items_tensor, new_items_data[4]).all()


@pytest.mark.torch
def test_bert4rec_fine_tuning_save_load(fitted_bert4rec, new_items_dataset, train_loader):
model, tokenizer = fitted_bert4rec
trainer = L.Trainer(max_epochs=1)
tokenizer.item_id_encoder.partial_fit(new_items_dataset)
new_vocab_size = len(tokenizer.item_id_encoder.mapping["item_id"])
model.set_item_embeddings_by_size(new_vocab_size)
trainer.fit(model, train_loader)
trainer.save_checkpoint("bert_test.ckpt")
best_model = Bert4Rec.load_from_checkpoint("bert_test.ckpt")

assert best_model.get_all_embeddings()["item_embedding"].shape[0] == new_vocab_size


@pytest.mark.torch
def test_bert4rec_fine_tuning_errors(fitted_bert4rec):
model, _ = fitted_bert4rec
Expand Down
6 changes: 3 additions & 3 deletions tests/models/nn/sequential/sasrec/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def fitted_sasrec(feature_schema_for_sasrec):
train_dataset = Dataset(feature_schema=feature_schema_for_sasrec, interactions=data)
tensor_schema = TensorSchema(
TensorFeatureInfo(
name="item_id_seq",
name="item_id",
is_seq=True,
cardinality=train_dataset.item_count,
feature_type=FeatureType.CATEGORICAL,
Expand All @@ -59,9 +59,9 @@ def fitted_sasrec(feature_schema_for_sasrec):
tokenizer.fit(train_dataset)
sequential_train_dataset = tokenizer.transform(train_dataset)

model = SasRec(tensor_schema)
model = SasRec(tensor_schema, max_seq_len=5)
trainer = L.Trainer(max_epochs=1)
train_loader = torch.utils.data.DataLoader(SasRecTrainingDataset(sequential_train_dataset, 200))
train_loader = torch.utils.data.DataLoader(SasRecTrainingDataset(sequential_train_dataset, 5))

trainer.fit(model, train_dataloaders=train_loader)

Expand Down
17 changes: 16 additions & 1 deletion tests/models/nn/sequential/sasrec/test_sasrec_lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from replay.models.nn.optimizer_utils import FatLRSchedulerFactory, FatOptimizerFactory
from replay.models.nn.sequential.sasrec import SasRec, SasRecPredictionBatch, SasRecPredictionDataset


torch = pytest.importorskip("torch")
L = pytest.importorskip("lightning")

Expand Down Expand Up @@ -236,6 +237,20 @@ def test_sasrec_fine_tuning_on_new_items_by_appending(fitted_sasrec, new_items_d
assert torch.eq(only_new_items_tensor, new_items_data[4]).all()


@pytest.mark.torch
def test_sasrec_fine_tuning_save_load(fitted_sasrec, new_items_dataset, train_sasrec_loader):
model, tokenizer = fitted_sasrec
trainer = L.Trainer(max_epochs=1)
tokenizer.item_id_encoder.partial_fit(new_items_dataset)
new_vocab_size = len(tokenizer.item_id_encoder.mapping["item_id"])
model.set_item_embeddings_by_size(new_vocab_size)
trainer.fit(model, train_sasrec_loader)
trainer.save_checkpoint("test.ckpt")
best_model = SasRec.load_from_checkpoint("test.ckpt")

assert best_model.get_all_embeddings()["item_embedding"].shape[0] == new_vocab_size


@pytest.mark.torch
def test_sasrec_fine_tuning_errors(fitted_sasrec):
model, _ = fitted_sasrec
Expand All @@ -260,7 +275,7 @@ def test_sasrec_get_init_parameters(fitted_sasrec):
params = model.hparams

assert params["tensor_schema"].item().cardinality == 4
assert params["max_seq_len"] == 200
assert params["max_seq_len"] == 5
assert params["hidden_size"] == 50


Expand Down

0 comments on commit 42a7c0c

Please sign in to comment.