Skip to content

Commit

Permalink
Merge branch 'feature/decision_transformer' into 'main'
Browse files Browse the repository at this point in the history
Decision Transformer

See merge request ai-lab-pmo/mltools/recsys/RePlay!138
  • Loading branch information
shashist committed Jan 12, 2024
2 parents 976f779 + f1ec698 commit 842abdd
Show file tree
Hide file tree
Showing 9 changed files with 1,272 additions and 2 deletions.
7 changes: 5 additions & 2 deletions docs/pages/modules/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ ___________________
"Neural Matrix Factorization (Experimental)", "Python CPU/GPU"
"MultVAE (Experimental)", "Python CPU/GPU"
"DDPG (Experimental)", "Python CPU"
"DT4Rec (Experimental)", "Python CPU/GPU"
"ADMM SLIM (Experimental)", "Python CPU"
"Wrapper for implicit (Experimental)", "Python CPU"
"Wrapper for LightFM (Experimental)", "Python CPU"
Expand Down Expand Up @@ -276,14 +277,16 @@ DDPG (Experimental)
.. autoclass:: replay.experimental.models.DDPG
:special-members: __init__

DT4Rec (Experimental)
```````````````````````````
.. autoclass:: replay.experimental.models.dt4rec.dt4rec.DT4Rec
:special-members: __init__

CQL Recommender (Experimental)
```````````````````````````````````
Conservative Q-Learning (CQL) algorithm is a SAC-based data-driven deep reinforcement learning algorithm,
which achieves state-of-the-art performance in offline RL problems.

\* incompatible with python 3.10

.. image:: /images/cql_comparison.png

.. autoclass:: replay.experimental.models.cql.CQL
Expand Down
73 changes: 73 additions & 0 deletions examples/train_dt4rec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from rs_datasets import MovieLens

from replay.metrics import MAP, MRR, NDCG, Coverage, HitRate, Surprisal
from replay.metrics.experiment import Experiment
from replay.experimental.models.dt4rec.dt4rec import DT4Rec
from replay.experimental.preprocessing.data_preparator import DataPreparator, Indexer
from replay.splitters import TimeSplitter
from replay.utils import PYSPARK_AVAILABLE
if PYSPARK_AVAILABLE:
from pyspark.sql import functions as sf


K = 10
K_list_metrics = [1, 5, 15]


df = MovieLens("1m").ratings

preparator = DataPreparator()
log = preparator.transform(
columns_mapping={
"user_id": "user_id",
"item_id": "item_id",
"relevance": "rating",
"timestamp": "timestamp",
},
data=df,
)
indexer = Indexer(user_col="user_id", item_col="item_id")
indexer.fit(users=log.select("user_id"), items=log.select("item_id"))

# will consider ratings >= 3 as positive feedback.
# A positive feedback is treated with relevance = 1
only_positives_log = log.filter(sf.col('relevance') >= 3).withColumn('relevance', sf.lit(1.))

indexed_log = indexer.transform(only_positives_log)

date_splitter = TimeSplitter(
time_threshold=0.2,
drop_cold_items=True,
drop_cold_users=True,
query_column="user_idx",
item_column="item_idx"
)
train, test = date_splitter.split(indexed_log)


item_num = train.toPandas()["item_idx"].max() + 1
user_num = train.toPandas()["user_idx"].max() + 1

experiment = Experiment(
{
MAP(K),
NDCG(K),
HitRate(K_list_metrics),
Coverage(K),
Surprisal(K),
MRR(K)
},
test,
train,
query_column="user_idx",
item_column="item_idx",
rating_column="relevance"
)

rec_sys = DT4Rec(item_num, user_num, use_cuda=True)
rec_sys.fit(train)
pred = rec_sys.predict(log=train, k=K, users=test.select("user_idx").distinct())

name = "DT4Rec"
experiment.add_result(name, pred)
experiment.results.sort_values(f"NDCG@{K}", ascending=False).to_csv("results.csv")
1 change: 1 addition & 0 deletions replay/experimental/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from replay.experimental.models.admm_slim import ADMMSLIM
from replay.experimental.models.base_torch_rec import TorchRecommender
from replay.experimental.models.ddpg import DDPG
from replay.experimental.models.dt4rec.dt4rec import DT4Rec
from replay.experimental.models.implicit_wrap import ImplicitWrap
from replay.experimental.models.lightfm_wrap import LightFMWrap
from replay.experimental.models.mult_vae import MultVAE
Expand Down
Empty file.
193 changes: 193 additions & 0 deletions replay/experimental/models/dt4rec/dt4rec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
from typing import List, Optional

import pandas as pd
from tqdm import tqdm

from ..base_rec import Recommender
from replay.utils import SparkDataFrame, PYSPARK_AVAILABLE, TORCH_AVAILABLE

if PYSPARK_AVAILABLE:
from replay.utils.spark_utils import convert2spark

if TORCH_AVAILABLE:
import torch
from torch.utils.data.dataloader import DataLoader

from .gpt1 import GPT, GPTConfig
from .trainer import Trainer, TrainerConfig
from .utils import (
Collator,
StateActionReturnDataset,
ValidateDataset,
WarmUpScheduler,
create_dataset,
matrix2df,
set_seed,
)


# pylint: disable=too-many-instance-attributes
class DT4Rec(Recommender):
"""
Decision Transformer for Recommendations
General Idea:
`Decision Transformer: Reinforcement Learning
via Sequence Modeling <https://arxiv.org/pdf/2106.01345.pdf>`_.
Ideas for improvements:
`User Retention-oriented Recommendation with Decision
Transformer <https://arxiv.org/pdf/2303.06347.pdf>`_.
Also, some sources are listed in their respective classes
"""

optimizer = None
train_batch_size = 128
val_batch_size = 128
lr_scheduler = None

# pylint: disable=too-many-arguments
def __init__(
self,
item_num,
user_num,
seed=123,
trajectory_len=30,
epochs=1,
batch_size=64,
use_cuda=True,
):
self.item_num = item_num
self.user_num = user_num
self.seed = seed
self.trajectory_len = trajectory_len
self.epochs = epochs
self.batch_size = batch_size
self.tconf: TrainerConfig = TrainerConfig(epochs=epochs)
self.mconf: GPTConfig = GPTConfig(
user_num=user_num,
item_num=item_num,
vocab_size=self.item_num + 1,
block_size=self.trajectory_len * 3,
max_timestep=self.item_num,
)
self.model: GPT
self.user_trajectory: List
self.trainer: Trainer
self.use_cuda = use_cuda
set_seed(self.seed)

# pylint: disable=invalid-overridden-method
def _init_args(self):
pass

def _update_mconf(self, **kwargs):
self.mconf.update(**kwargs)

def _update_tconf(self, **kwargs):
self.tconf.update(**kwargs)

def _make_prediction_dataloader(self, users, items, max_context_len=30):
val_dataset = ValidateDataset(
self.user_trajectory,
max_context_len=max_context_len - 1,
val_items=items,
val_users=users,
)

val_dataloader = DataLoader(
val_dataset,
pin_memory=True,
batch_size=self.val_batch_size,
collate_fn=Collator(self.item_num),
)

return val_dataloader

def train(
self,
log,
val_users=None,
val_items=None,
experiment=None,
):
"""
Run training loop
"""
assert (val_users is None) == (val_items is None) == (experiment is None)
with_validate = experiment is not None
df = log.toPandas()[["user_idx", "item_idx", "relevance", "timestamp"]]
self.user_trajectory = create_dataset(df, user_num=self.user_num, item_pad=self.item_num)

train_dataset = StateActionReturnDataset(self.user_trajectory, self.trajectory_len)

train_dataloader = DataLoader(
train_dataset,
shuffle=True,
pin_memory=True,
batch_size=self.train_batch_size,
collate_fn=Collator(self.item_num),
)

if with_validate:
val_dataloader = self._make_prediction_dataloader(val_users, val_items, max_context_len=self.trajectory_len)
else:
val_dataloader = None

self.model = GPT(self.mconf)

optimizer = torch.optim.AdamW(
self.model.configure_optimizers(),
lr=3e-4,
betas=(0.9, 0.95),
)
lr_scheduler = WarmUpScheduler(optimizer, dim_embed=768, warmup_steps=4000)

self.tconf.update(optimizer=optimizer, lr_scheduler=lr_scheduler)
self.trainer = Trainer(
self.model,
train_dataloader,
self.tconf,
val_dataloader,
experiment,
self.use_cuda,
)
self.trainer.train()

def _fit(
self,
log: SparkDataFrame,
user_features: Optional[SparkDataFrame] = None,
item_features: Optional[SparkDataFrame] = None,
) -> None:
self.train(log)

# pylint: disable=too-many-arguments
def _predict(
self,
log: SparkDataFrame,
k: int,
users: SparkDataFrame,
items: SparkDataFrame,
user_features: Optional[SparkDataFrame] = None,
item_features: Optional[SparkDataFrame] = None,
filter_seen_items: bool = True,
) -> SparkDataFrame:
items_consider_in_pred = items.toPandas()["item_idx"].values
users_consider_in_pred = users.toPandas()["user_idx"].values
ans = self._predict_helper(users_consider_in_pred, items_consider_in_pred)
return convert2spark(ans)

def _predict_helper(self, users, items, max_context_len=30):
predict_dataloader = self._make_prediction_dataloader(users, items, max_context_len)
self.model.eval()
ans_df = pd.DataFrame(columns=["user_idx", "item_idx", "relevance"])
with torch.no_grad():
for batch in tqdm(predict_dataloader):
states, actions, rtgs, timesteps, users = self.trainer._move_batch(batch)
logits = self.model(states, actions, rtgs, timesteps, users)
items_relevances = logits[:, -1, :][:, items]
ans_df = ans_df.append(matrix2df(items_relevances, users.squeeze(), items))

return ans_df
Loading

0 comments on commit 842abdd

Please sign in to comment.