Skip to content

Commit

Permalink
Fix regularization_loss() (#135)
Browse files Browse the repository at this point in the history
Optimize tokenizer (#134)
  • Loading branch information
xpai committed Dec 29, 2024
1 parent 556160e commit fb162f5
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 19 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
[Doing] Add support for saving pb file, exporting embeddings
[Doing] Add support of multi-gpu training

**FuxiCTR v2.3.7, 2024-12-29**
+ [Fix] Fix regularization_loss() when feature_encoders exist ([#135](https://github.com/reczoo/FuxiCTR/issues/135))

**FuxiCTR v2.3.6, 2024-12-28**
+ [Fix] Fix init_weights() for PretrainedEmbedding by modifying embedding_initializer ([#126](https://github.com/reczoo/FuxiCTR/issues/126))
+ [Fix] Fix get_mask issue when num_heads > 1 ([#130](https://github.com/reczoo/FuxiCTR/issues/130))
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Click-through rate (CTR) prediction is a critical task for various industrial ap
| 38 | SIGIR'23 | [EulerNet](./model_zoo/EulerNet) | [EulerNet: Adaptive Feature Interaction Learning via Euler's Formula for CTR Prediction](https://dl.acm.org/doi/10.1145/3539618.3591681) :triangular_flag_on_post:**Huawei** | [:arrow_upper_right:](https://github.com/Ethan-TZ/EulerNet/tree/main/%23Code4FuxiCTR%23) | `torch` |
| 39 | CIKM'23 | [GDCN](./model_zoo/GDCN) | [Towards Deeper, Lighter and Interpretable Cross Network for CTR Prediction](https://dl.acm.org/doi/pdf/10.1145/3583780.3615089) :triangular_flag_on_post:**Microsoft** | | `torch` |
| 40 | ICML'24 | [WuKong](./model_zoo/WuKong) | [Wukong: Towards a Scaling Law for Large-Scale Recommendation](https://arxiv.org/abs/2403.02545) :triangular_flag_on_post:**Meta** | | `torch` |
| 41 | Arxiv'24 | [DCNv3](./model_zoo/DCNv3) | [DCNv3: Towards Next Generation Deep Cross Network for Click-Through Rate Prediction](https://arxiv.org/abs/2407.13349) | [:arrow_upper_right:](https://github.com/salmon1802/DCNv3/tree/master/checkpoints) | `torch` |
<!-- | 41 | Arxiv'24 | [DCNv3](./model_zoo/DCNv3) | [DCNv3: Towards Next Generation Deep Cross Network for Click-Through Rate Prediction](https://arxiv.org/abs/2407.13349) | [:arrow_upper_right:](https://github.com/salmon1802/DCNv3/tree/master/checkpoints) | `torch` | -->
|<tr><th colspan=6 align="center">:open_file_folder: **Behavior Sequence Modeling**</th></tr>|
| 42 | KDD'18 | [DIN](./model_zoo/DIN) | [Deep Interest Network for Click-Through Rate Prediction](https://www.kdd.org/kdd2018/accepted-papers/view/deep-interest-network-for-click-through-rate-prediction) :triangular_flag_on_post:**Alibaba** | [:arrow_upper_right:](https://github.com/reczoo/BARS/tree/main/ranking/ctr/DIN) | `torch` |
| 43 | AAAI'19 | [DIEN](./model_zoo/DIEN) | [Deep Interest Evolution Network for Click-Through Rate Prediction](https://arxiv.org/abs/1809.03672) :triangular_flag_on_post:**Alibaba** | [:arrow_upper_right:](https://github.com/reczoo/BARS/tree/main/ranking/ctr/DIEN) | `torch` |
Expand Down
3 changes: 1 addition & 2 deletions fuxictr/preprocess/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,8 @@ def fit_on_texts(self, series):
self.build_vocab(word_counts)

def build_vocab(self, word_counts):
word_counts = word_counts.items()
# sort to guarantee the determinism of index order
word_counts = sorted(word_counts, key=lambda x: (-x[1], x[0]))
word_counts = word_counts.most_common()
if self._max_features: # keep the most frequent features
word_counts = word_counts[0:self._max_features]
words = []
Expand Down
30 changes: 16 additions & 14 deletions fuxictr/pytorch/models/rank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import torch
import os, sys
import logging
from fuxictr.pytorch.layers import FeatureEmbeddingDict
from fuxictr.metrics import evaluate_metrics
from fuxictr.pytorch.torch_utils import get_device, get_optimizer, get_loss, get_regularizer
from fuxictr.utils import Monitor, not_in_whitelist
Expand Down Expand Up @@ -65,23 +66,24 @@ def compile(self, optimizer, loss, lr):
self.loss_fn = get_loss(loss)

def regularization_loss(self):
reg_loss = 0
reg_term = 0
if self._embedding_regularizer or self._net_regularizer:
emb_reg = get_regularizer(self._embedding_regularizer)
net_reg = get_regularizer(self._net_regularizer)
for _, module in self.named_modules():
for p_name, param in module.named_parameters():
if param.requires_grad:
if p_name in ["weight", "bias"]:
if type(module) == nn.Embedding:
if self._embedding_regularizer:
for emb_p, emb_lambda in emb_reg:
reg_loss += (emb_lambda / emb_p) * torch.norm(param, emb_p) ** emb_p
else:
if self._net_regularizer:
for net_p, net_lambda in net_reg:
reg_loss += (net_lambda / net_p) * torch.norm(param, net_p) ** net_p
return reg_loss
emb_params = set()
for m_name, module in self.named_modules():
if type(module) == FeatureEmbeddingDict:
for p_name, param in module.named_parameters():
if param.requires_grad:
emb_params.add(".".join([m_name, p_name]))
for emb_p, emb_lambda in emb_reg:
reg_term += (emb_lambda / emb_p) * torch.norm(param, emb_p) ** emb_p
for name, param in self.named_parameters():
if param.requires_grad:
if name not in emb_params:
for net_p, net_lambda in net_reg:
reg_term += (net_lambda / net_p) * torch.norm(param, net_p) ** net_p
return reg_term

def add_loss(self, return_dict, y_true):
loss = self.loss_fn(return_dict["y_pred"], y_true, reduction='mean')
Expand Down
2 changes: 1 addition & 1 deletion fuxictr/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__="2.3.6"
__version__="2.3.7"
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="fuxictr",
version="2.3.6",
version="2.3.7",
author="RECZOO",
author_email="[email protected]",
description="A configurable, tunable, and reproducible library for CTR prediction",
Expand Down

0 comments on commit fb162f5

Please sign in to comment.