Skip to content

Commit

Permalink
Merge branch 'feature/save_and_load_for_label_encoder' into 'main'
Browse files Browse the repository at this point in the history
Feature/save and load for label encoder

See merge request ai-lab-pmo/mltools/recsys/RePlay!234
  • Loading branch information
OnlyDeniko committed Nov 5, 2024
2 parents 3604485 + 7a173eb commit e841510
Show file tree
Hide file tree
Showing 5 changed files with 193 additions and 53 deletions.
8 changes: 4 additions & 4 deletions examples/01_replay_basics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"\n",
"from replay.metrics import Coverage, HitRate, NDCG, MAP, Experiment, OfflineMetrics\n",
"from replay.utils.model_handler import save, load, save_encoder, load_encoder\n",
"from replay.utils.session_handler import get_spark_session, State \n",
"from replay.utils.session_handler import get_spark_session, State\n",
"from replay.splitters import TwoStageSplitter\n",
"from replay.utils.spark_utils import convert2spark, get_log_info\n",
"\n",
Expand Down Expand Up @@ -998,9 +998,9 @@
"source": [
"metrics = Experiment(\n",
" [\n",
" NDCG(K), \n",
" MAP(K), \n",
" HitRate([1, K]), \n",
" NDCG(K),\n",
" MAP(K),\n",
" HitRate([1, K]),\n",
" Coverage(K)\n",
" ],\n",
" test_dataset.interactions,\n",
Expand Down
56 changes: 43 additions & 13 deletions examples/04_splitters.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,11 @@
"outputs": [],
"source": [
"import pandas as pd\n",
"from matplotlib import pyplot as plt\n",
"import seaborn as sns\n",
"from matplotlib import pyplot as plt\n",
"\n",
"from replay.preprocessing import LabelEncoder, LabelEncodingRule\n",
"from replay.utils.common import load_from_replay, save_to_replay\n",
"\n",
"sns.set_theme(style=\"whitegrid\", palette=\"muted\")"
]
Expand All @@ -57,17 +58,24 @@
"source": [
"## Get started\n",
"\n",
"Download the dataset **MovieLens** and preprocess it with `LabelEncoder`"
"Download the dataset **MovieLens** and preprocess it with `LabelEncoder`.\n",
"\n",
"`LabelEncoder` is similar to `DatasetLabelEncoder` except that `LabelEncoder` processes Pandas, Polars or Spark datasets when `DatasetLabelEncoder` processes `replay.data.Dataset`. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ratings = pd.read_csv('./data/ml1m_ratings.dat',sep=\"\\t\",names=[\"user_id\", \"item_id\", \"rating\", \"timestamp\"], engine='python')\n",
"ratings.timestamp = pd.to_datetime(ratings.timestamp, unit='s')"
"ratings = pd.read_csv(\n",
" \"./data/ml1m_ratings.dat\",\n",
" sep=\"\\t\",\n",
" names=[\"user_id\", \"item_id\", \"rating\", \"timestamp\"],\n",
" engine=\"python\",\n",
")\n",
"ratings.timestamp = pd.to_datetime(ratings.timestamp, unit=\"s\")"
]
},
{
Expand Down Expand Up @@ -175,6 +183,25 @@
"ratings = encoder.fit_transform(ratings)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`LabelEncoder` can be saved in JSON format with `save_to_replay()` or `.save()` methods. Than it can be loaded with `load_from_replay()` or `.load()` methods.\n",
"\n",
"Note that `save_to_replay()`and `load_from_replay()` are functions from utils while `.save()`, `.load()` are methods of the class."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"save_to_replay(encoder, \"./encoder\") # or encoder.save(\"./encoder\")\n",
"loaded_encoder = load_from_replay(\"./encoder\") # or LabelEncoder.load(\"./encoder\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -189,7 +216,8 @@
"outputs": [],
"source": [
"top_users = (\n",
" ratings.groupby(\"user_id\")[[\"item_id\"]].count()\n",
" ratings.groupby(\"user_id\")[[\"item_id\"]]\n",
" .count()\n",
" .nlargest(n=20, columns=[\"item_id\"])\n",
" .index\n",
")"
Expand Down Expand Up @@ -238,15 +266,15 @@
"outputs": [],
"source": [
"def show_train_test(train, test):\n",
" plt.figure(figsize=(25,12))\n",
" plt.figure(figsize=(25, 12))\n",
" train_plt = train\n",
" train_plt[\"split\"] = 'train'\n",
" train_plt[\"split\"] = \"train\"\n",
" test_plt = test\n",
" test_plt[\"split\"] = 'test'\n",
" test_plt[\"split\"] = \"test\"\n",
" pd_for_print = pd.concat((train_plt, test_plt), axis=0)\n",
" pd_for_print[\"user_id\"] = pd_for_print[\"user_id\"].astype(str)\n",
" sns.scatterplot(data=pd_for_print, x=\"timestamp\", y=\"user_id\", hue=\"split\",s=8*8)\n",
" plt.autoscale(enable=True, axis='x')\n",
" sns.scatterplot(data=pd_for_print, x=\"timestamp\", y=\"user_id\", hue=\"split\", s=8 * 8)\n",
" plt.autoscale(enable=True, axis=\"x\")\n",
" plt.grid(False)\n",
" plt.show()"
]
Expand All @@ -265,7 +293,9 @@
"outputs": [],
"source": [
"def get_df_info(df: pd.DataFrame):\n",
" print(f\"Total rows {len(df)}, unique users: {df.user_id.nunique()}, unique items: {df.item_id.nunique()}\")"
" print(\n",
" f\"Total rows {len(df)}, unique users: {df.user_id.nunique()}, unique items: {df.item_id.nunique()}\"\n",
" )"
]
},
{
Expand Down Expand Up @@ -1237,7 +1267,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.9.19"
}
},
"nbformat": 4,
Expand Down
102 changes: 102 additions & 0 deletions replay/preprocessing/label_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
"""

import abc
import json
import os
import warnings
from pathlib import Path
from typing import Dict, List, Literal, Mapping, Optional, Sequence, Union

import polars as pl
Expand Down Expand Up @@ -484,6 +487,65 @@ def set_handle_unknown(self, handle_unknown: HandleUnknownStrategies) -> None:
raise ValueError(msg)
self._handle_unknown = handle_unknown

def save(
self,
path: str,
) -> None:
encoder_rule_dict = {}
encoder_rule_dict["_class_name"] = self.__class__.__name__
encoder_rule_dict["init_args"] = {
"column": self._col,
"mapping": self._mapping,
"handle_unknown": self._handle_unknown,
"default_value": self._default_value,
}

column_type = str(type(next(iter(self._mapping))))

if not isinstance(column_type, (str, int, float)): # pragma: no cover
msg = f"LabelEncodingRule.save() is not implemented for column type {column_type}. \
Convert type to string, integer, or float."
raise NotImplementedError(msg)

encoder_rule_dict["fitted_args"] = {
"target_col": self._target_col,
"is_fitted": self._is_fitted,
"column_type": column_type,
}

base_path = Path(path).with_suffix(".replay").resolve()
if os.path.exists(base_path): # pragma: no cover
msg = "There is already LabelEncodingRule object saved at the given path. File will be overwrited."
warnings.warn(msg)
else: # pragma: no cover
base_path.mkdir(parents=True, exist_ok=True)

with open(base_path / "init_args.json", "w+") as file:
json.dump(encoder_rule_dict, file)

@classmethod
def load(cls, path: str) -> "LabelEncodingRule":
base_path = Path(path).with_suffix(".replay").resolve()
with open(base_path / "init_args.json", "r") as file:
encoder_rule_dict = json.loads(file.read())

string_column_type = encoder_rule_dict["fitted_args"]["column_type"]
if "str" in string_column_type:
column_type = str
elif "int" in string_column_type:
column_type = int
elif "float" in string_column_type:
column_type = float

encoder_rule_dict["init_args"]["mapping"] = {
column_type(key): int(value) for key, value in encoder_rule_dict["init_args"]["mapping"].items()
}

encoding_rule = cls(**encoder_rule_dict["init_args"])
encoding_rule._target_col = encoder_rule_dict["fitted_args"]["target_col"]
encoding_rule._is_fitted = encoder_rule_dict["fitted_args"]["is_fitted"]
return encoding_rule


class LabelEncoder:
"""
Expand Down Expand Up @@ -650,3 +712,43 @@ def set_default_values(self, default_value_rules: Dict[str, Optional[Union[int,
raise ValueError(msg)
rule = list(filter(lambda x: x.column == column, self.rules))
rule[0].set_default_value(default_value)

def save(
self,
path: str,
) -> None:
encoder_dict = {}
encoder_dict["_class_name"] = self.__class__.__name__

base_path = Path(path).with_suffix(".replay").resolve()
if os.path.exists(base_path): # pragma: no cover
msg = "There is already LabelEncoder object saved at the given path. File will be overwrited."
warnings.warn(msg)
else: # pragma: no cover
base_path.mkdir(parents=True, exist_ok=True)

encoder_dict["rule_names"] = []

for rule in self.rules:
path_suffix = f"{rule.__class__.__name__}_{rule.column}"
rule.save(str(base_path) + f"/rules/{path_suffix}")
encoder_dict["rule_names"].append(path_suffix)

with open(base_path / "init_args.json", "w+") as file:
json.dump(encoder_dict, file)

@classmethod
def load(cls, path: str) -> "LabelEncoder":
base_path = Path(path).with_suffix(".replay").resolve()
with open(base_path / "init_args.json", "r") as file:
encoder_dict = json.loads(file.read())
rules = []
for root, dirs, files in os.walk(str(base_path) + "/rules/"):
for d in dirs:
if d.split(".")[0] in encoder_dict["rule_names"]:
with open(root + d + "/init_args.json", "r") as file:
encoder_rule_dict = json.loads(file.read())
rules.append(globals()[encoder_rule_dict["_class_name"]].load(root + d))

encoder = cls(rules=rules)
return encoder
15 changes: 7 additions & 8 deletions replay/utils/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@
from polars import from_pandas as pl_from_pandas

from replay.data.dataset import Dataset
from replay.preprocessing import (
LabelEncoder,
LabelEncodingRule,
)
from replay.splitters import (
ColdUserRandomSplitter,
KFolds,
Expand Down Expand Up @@ -38,20 +42,15 @@
TimeSplitter,
TwoStageSplitter,
Dataset,
LabelEncoder,
LabelEncodingRule,
]

if TORCH_AVAILABLE:
from replay.data.nn import PandasSequentialDataset, PolarsSequentialDataset, SequenceTokenizer

SavableObject = Union[
ColdUserRandomSplitter,
KFolds,
LastNSplitter,
NewUsersSplitter,
RandomSplitter,
RatioSplitter,
TimeSplitter,
TwoStageSplitter,
SavableObject,
SequenceTokenizer,
PandasSequentialDataset,
PolarsSequentialDataset,
Expand Down
Loading

0 comments on commit e841510

Please sign in to comment.