Skip to content

Commit

Permalink
improve stability_short_term
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed May 9, 2024
1 parent f862b52 commit 5ab4c84
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch.nn.utils.rnn import pad_sequence
from sklearn.model_selection import TimeSeriesSplit
from sklearn.metrics import root_mean_squared_error, mean_absolute_error, r2_score
from scipy.optimize import minimize, curve_fit
from scipy.optimize import minimize
from itertools import accumulate
from tqdm.auto import tqdm
import warnings
Expand Down Expand Up @@ -66,6 +66,7 @@
0.2272,
2.8755,
0,
0,
]

S_MIN = 0.01
Expand Down Expand Up @@ -102,7 +103,7 @@ def stability_after_failure(self, state: Tensor, r: Tensor) -> Tensor:
return torch.minimum(new_s, state[:, 0])

def stability_short_term(self, state: Tensor, rating: Tensor) -> Tensor:
new_s = state[:, 0] * torch.exp(self.w[17] * (rating - 3))
new_s = state[:, 0] * torch.exp(self.w[17] * (rating - 3 + self.w[18]))
return new_s

def next_d(self, state: Tensor, rating: Tensor) -> Tensor:
Expand Down Expand Up @@ -188,6 +189,7 @@ def __call__(self, module):
w[15] = w[15].clamp(0, 1)
w[16] = w[16].clamp(1, 6)
w[17] = w[17].clamp(0, 1)
w[18] = w[18].clamp(0, 1)
module.w.data = w


Expand Down Expand Up @@ -342,8 +344,6 @@ def train(self, verbose: bool = True):
retentions = power_forgetting_curve(delta_ts, stabilities)
loss = self.loss_fn(retentions, labels).sum()
loss.backward()
for param in self.model.parameters():
param.grad[:4] = torch.zeros(4)
self.optimizer.step()
self.scheduler.step()
self.model.apply(self.clipper)
Expand Down

0 comments on commit 5ab4c84

Please sign in to comment.