diff --git a/.gitignore b/.gitignore index 1ebbc2512..2ce2e5ce8 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ lightning_logs/ .logs/ .checkpoints/ +*.ckpt # Poetry (since they are generated from template automatically) /poetry.lock diff --git a/examples/09_sasrec_example.ipynb b/examples/09_sasrec_example.ipynb index 288c88f17..c9774e031 100644 --- a/examples/09_sasrec_example.ipynb +++ b/examples/09_sasrec_example.ipynb @@ -861,12 +861,11 @@ " postprocessors=[RemoveSeenItems(sequential_validation_dataset)]\n", ")\n", "\n", - "csv_logger = CSVLogger(save_dir=\".logs/train\", name=\"SASRec_example\")\n", - "\n", "trainer = L.Trainer(\n", " max_epochs=100,\n", " callbacks=[checkpoint_callback, validation_metrics_callback],\n", - " logger=csv_logger,\n", + " logger=False,\n", + " log_every_n_steps=1000,\n", ")\n", "\n", "train_dataloader = DataLoader(\n", diff --git a/replay/models/nn/sequential/bert4rec/lightning.py b/replay/models/nn/sequential/bert4rec/lightning.py index 80656bc56..d1d9f972d 100644 --- a/replay/models/nn/sequential/bert4rec/lightning.py +++ b/replay/models/nn/sequential/bert4rec/lightning.py @@ -506,7 +506,6 @@ def append_item_embeddings(self, item_embeddings: torch.Tensor): def _set_new_item_embedder_to_model(self, weights_new: torch.nn.Embedding, new_vocab_size: int): self._model.item_embedder.cat_embeddings[self._model.schema.item_id_feature_name] = weights_new - if self._model.enable_embedding_tying is True: self._model._head._item_embedder = self._model.item_embedder new_bias = torch.Tensor(new_vocab_size) @@ -521,3 +520,4 @@ def _set_new_item_embedder_to_model(self, weights_new: torch.nn.Embedding, new_v self._vocab_size = new_vocab_size self._model.item_count = new_vocab_size + self._schema.item_id_features[self._schema.item_id_feature_name]._set_cardinality(new_vocab_size) diff --git a/replay/models/nn/sequential/sasrec/lightning.py b/replay/models/nn/sequential/sasrec/lightning.py index 7c1febf50..cc36593f3 100644 --- a/replay/models/nn/sequential/sasrec/lightning.py +++ b/replay/models/nn/sequential/sasrec/lightning.py @@ -486,3 +486,11 @@ def _set_new_item_embedder_to_model(self, new_embedding: torch.nn.Embedding, new self._model.item_count = new_vocab_size self._model.padding_idx = new_vocab_size self._model.masking.padding_idx = new_vocab_size + self._model.candidates_to_score = torch.tensor( + list(range(new_embedding.weight.data.shape[0] - 1)), + device=self._model.candidates_to_score.device, + dtype=torch.long, + ) + self._schema.item_id_features[self._schema.item_id_feature_name]._set_cardinality( + new_embedding.weight.data.shape[0] - 1 + ) diff --git a/tests/models/nn/sequential/bert4rec/conftest.py b/tests/models/nn/sequential/bert4rec/conftest.py index 5fe354abf..40ce1f17d 100644 --- a/tests/models/nn/sequential/bert4rec/conftest.py +++ b/tests/models/nn/sequential/bert4rec/conftest.py @@ -44,7 +44,7 @@ def fitted_bert4rec(feature_schema_for_bert4rec): train_dataset = Dataset(feature_schema=feature_schema_for_bert4rec, interactions=data) tensor_schema = TensorSchema( TensorFeatureInfo( - name="item_id_seq", + name="item_id", is_seq=True, cardinality=train_dataset.item_count, feature_type=FeatureType.CATEGORICAL, diff --git a/tests/models/nn/sequential/bert4rec/test_bert4rec_lightning.py b/tests/models/nn/sequential/bert4rec/test_bert4rec_lightning.py index 67b19dada..983a6b3f1 100644 --- a/tests/models/nn/sequential/bert4rec/test_bert4rec_lightning.py +++ b/tests/models/nn/sequential/bert4rec/test_bert4rec_lightning.py @@ -265,6 +265,20 @@ def test_bert4rec_fine_tuning_on_new_items_by_appending(request, fitted_bert4rec assert torch.eq(only_new_items_tensor, new_items_data[4]).all() +@pytest.mark.torch +def test_bert4rec_fine_tuning_save_load(fitted_bert4rec, new_items_dataset, train_loader): + model, tokenizer = fitted_bert4rec + trainer = L.Trainer(max_epochs=1) + tokenizer.item_id_encoder.partial_fit(new_items_dataset) + new_vocab_size = len(tokenizer.item_id_encoder.mapping["item_id"]) + model.set_item_embeddings_by_size(new_vocab_size) + trainer.fit(model, train_loader) + trainer.save_checkpoint("bert_test.ckpt") + best_model = Bert4Rec.load_from_checkpoint("bert_test.ckpt") + + assert best_model.get_all_embeddings()["item_embedding"].shape[0] == new_vocab_size + + @pytest.mark.torch def test_bert4rec_fine_tuning_errors(fitted_bert4rec): model, _ = fitted_bert4rec diff --git a/tests/models/nn/sequential/sasrec/conftest.py b/tests/models/nn/sequential/sasrec/conftest.py index 96d663f6d..c12490b5d 100644 --- a/tests/models/nn/sequential/sasrec/conftest.py +++ b/tests/models/nn/sequential/sasrec/conftest.py @@ -44,7 +44,7 @@ def fitted_sasrec(feature_schema_for_sasrec): train_dataset = Dataset(feature_schema=feature_schema_for_sasrec, interactions=data) tensor_schema = TensorSchema( TensorFeatureInfo( - name="item_id_seq", + name="item_id", is_seq=True, cardinality=train_dataset.item_count, feature_type=FeatureType.CATEGORICAL, @@ -59,9 +59,9 @@ def fitted_sasrec(feature_schema_for_sasrec): tokenizer.fit(train_dataset) sequential_train_dataset = tokenizer.transform(train_dataset) - model = SasRec(tensor_schema) + model = SasRec(tensor_schema, max_seq_len=5) trainer = L.Trainer(max_epochs=1) - train_loader = torch.utils.data.DataLoader(SasRecTrainingDataset(sequential_train_dataset, 200)) + train_loader = torch.utils.data.DataLoader(SasRecTrainingDataset(sequential_train_dataset, 5)) trainer.fit(model, train_dataloaders=train_loader) diff --git a/tests/models/nn/sequential/sasrec/test_sasrec_lightning.py b/tests/models/nn/sequential/sasrec/test_sasrec_lightning.py index 1ec029a25..0a68fc68a 100644 --- a/tests/models/nn/sequential/sasrec/test_sasrec_lightning.py +++ b/tests/models/nn/sequential/sasrec/test_sasrec_lightning.py @@ -6,6 +6,7 @@ from replay.models.nn.optimizer_utils import FatLRSchedulerFactory, FatOptimizerFactory from replay.models.nn.sequential.sasrec import SasRec, SasRecPredictionBatch, SasRecPredictionDataset + torch = pytest.importorskip("torch") L = pytest.importorskip("lightning") @@ -236,6 +237,20 @@ def test_sasrec_fine_tuning_on_new_items_by_appending(fitted_sasrec, new_items_d assert torch.eq(only_new_items_tensor, new_items_data[4]).all() +@pytest.mark.torch +def test_sasrec_fine_tuning_save_load(fitted_sasrec, new_items_dataset, train_sasrec_loader): + model, tokenizer = fitted_sasrec + trainer = L.Trainer(max_epochs=1) + tokenizer.item_id_encoder.partial_fit(new_items_dataset) + new_vocab_size = len(tokenizer.item_id_encoder.mapping["item_id"]) + model.set_item_embeddings_by_size(new_vocab_size) + trainer.fit(model, train_sasrec_loader) + trainer.save_checkpoint("test.ckpt") + best_model = SasRec.load_from_checkpoint("test.ckpt") + + assert best_model.get_all_embeddings()["item_embedding"].shape[0] == new_vocab_size + + @pytest.mark.torch def test_sasrec_fine_tuning_errors(fitted_sasrec): model, _ = fitted_sasrec @@ -260,7 +275,7 @@ def test_sasrec_get_init_parameters(fitted_sasrec): params = model.hparams assert params["tensor_schema"].item().cardinality == 4 - assert params["max_seq_len"] == 200 + assert params["max_seq_len"] == 5 assert params["hidden_size"] == 50