From 5ab4c84f659d41448f2acf1a77f6ad9ecc42cb46 Mon Sep 17 00:00:00 2001 From: Jarrett Ye Date: Thu, 9 May 2024 14:05:51 +0800 Subject: [PATCH] improve stability_short_term --- src/fsrs_optimizer/fsrs_optimizer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/fsrs_optimizer/fsrs_optimizer.py b/src/fsrs_optimizer/fsrs_optimizer.py index 9506574..f8be465 100644 --- a/src/fsrs_optimizer/fsrs_optimizer.py +++ b/src/fsrs_optimizer/fsrs_optimizer.py @@ -18,7 +18,7 @@ from torch.nn.utils.rnn import pad_sequence from sklearn.model_selection import TimeSeriesSplit from sklearn.metrics import root_mean_squared_error, mean_absolute_error, r2_score -from scipy.optimize import minimize, curve_fit +from scipy.optimize import minimize from itertools import accumulate from tqdm.auto import tqdm import warnings @@ -66,6 +66,7 @@ 0.2272, 2.8755, 0, + 0, ] S_MIN = 0.01 @@ -102,7 +103,7 @@ def stability_after_failure(self, state: Tensor, r: Tensor) -> Tensor: return torch.minimum(new_s, state[:, 0]) def stability_short_term(self, state: Tensor, rating: Tensor) -> Tensor: - new_s = state[:, 0] * torch.exp(self.w[17] * (rating - 3)) + new_s = state[:, 0] * torch.exp(self.w[17] * (rating - 3 + self.w[18])) return new_s def next_d(self, state: Tensor, rating: Tensor) -> Tensor: @@ -188,6 +189,7 @@ def __call__(self, module): w[15] = w[15].clamp(0, 1) w[16] = w[16].clamp(1, 6) w[17] = w[17].clamp(0, 1) + w[18] = w[18].clamp(0, 1) module.w.data = w @@ -342,8 +344,6 @@ def train(self, verbose: bool = True): retentions = power_forgetting_curve(delta_ts, stabilities) loss = self.loss_fn(retentions, labels).sum() loss.backward() - for param in self.model.parameters(): - param.grad[:4] = torch.zeros(4) self.optimizer.step() self.scheduler.step() self.model.apply(self.clipper)