Skip to content

Commit

Permalink
add next_d_short_term back
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed May 9, 2024
1 parent 639614c commit 3e7d5ea
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
2.8755,
0,
0,
0,
]

S_MIN = 0.01
Expand Down Expand Up @@ -111,9 +112,9 @@ def next_d(self, state: Tensor, rating: Tensor) -> Tensor:
new_d = self.mean_reversion(self.w[4], new_d)
return new_d

# def next_d_short_term(self, state: Tensor, rating: Tensor) -> Tensor:
# new_d = state[:, 1] - self.w[18] * (rating - 3)
# return new_d
def next_d_short_term(self, state: Tensor, rating: Tensor) -> Tensor:
new_d = state[:, 1] - self.w[19] * (rating - 3)
return new_d

def step(self, X: Tensor, state: Tensor) -> Tensor:
"""
Expand Down Expand Up @@ -145,7 +146,7 @@ def step(self, X: Tensor, state: Tensor) -> Tensor:
)
new_d = torch.where(
short_term,
state[:, 1],
self.next_d_short_term(state, X[:, 1]),
self.next_d(state, X[:, 1]),
)
new_d = new_d.clamp(1, 10)
Expand Down Expand Up @@ -190,6 +191,7 @@ def __call__(self, module):
w[16] = w[16].clamp(1, 6)
w[17] = w[17].clamp(0, 1)
w[18] = w[18].clamp(0, 1)
w[19] = w[19].clamp(0, 1)
module.w.data = w


Expand Down

0 comments on commit 3e7d5ea

Please sign in to comment.