Skip to content

Commit

Permalink
improve display of analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
L-M-Sherlock committed May 10, 2024
1 parent d334f97 commit 37fce1b
Showing 1 changed file with 41 additions and 11 deletions.
52 changes: 41 additions & 11 deletions src/fsrs_optimizer/fsrs_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -773,12 +773,17 @@ def cum_concat(x):
if not analysis:
return

df["retention"] = df.groupby(
by=["i", "r_history", "delta_t"], group_keys=False
)["y"].transform("mean")
df["total_cnt"] = df.groupby(
by=["i", "r_history", "delta_t"], group_keys=False
)["review_time"].transform("count")
df["r_history"] = df.apply(
lambda row: wrap_short_term_ratings(row["r_history"], row["t_history"]),
axis=1,
)

df["retention"] = df.groupby(by=["r_history", "delta_t"], group_keys=False)[
"y"
].transform("mean")
df["total_cnt"] = df.groupby(by=["r_history", "delta_t"], group_keys=False)[
"review_time"
].transform("count")
tqdm.write("Retention calculated.")

df.drop(
Expand Down Expand Up @@ -826,15 +831,15 @@ def cal_stability(group: pd.DataFrame) -> pd.DataFrame:
del group["delta_t"]
return group

df = df.groupby(by=["i", "r_history"], group_keys=False).progress_apply(
df = df.groupby(by=["r_history"], group_keys=False).progress_apply(
cal_stability
)
if df.empty:
return "No enough data for stability calculation."
tqdm.write("Stability calculated.")
df.reset_index(drop=True, inplace=True)
df.drop_duplicates(inplace=True)
df.sort_values(by=["i", "r_history"], inplace=True, ignore_index=True)
df.sort_values(by=["r_history"], inplace=True, ignore_index=True)

if df.shape[0] > 0:
for idx in tqdm(df.index, desc="analysis"):
Expand All @@ -848,17 +853,17 @@ def cal_stability(group: pd.DataFrame) -> pd.DataFrame:
df = df[(df["i"] >= 2) & (df["group_cnt"] >= 100)].copy()
df["last_recall"] = df["r_history"].map(lambda x: x[-1])
df = df[
df.groupby(["i", "r_history"], group_keys=False)["group_cnt"].transform(
df.groupby(["r_history"], group_keys=False)["group_cnt"].transform(
"max"
)
== df["group_cnt"]
]
df.to_csv("./stability_for_analysis.tsv", sep="\t", index=None)
tqdm.write("Analysis saved!")
caption = "1:again, 2:hard, 3:good, 4:easy\n"
df["first_rating"] = df["r_history"].map(lambda x: x[0])
df["first_rating"] = df["r_history"].map(lambda x: x[1])
analysis = (
df[df["r_history"].str.contains(r"^[1-4][^124]*$", regex=True)][
df[df["r_history"].str.contains(r"^\([1-4][^124]*$", regex=True)][
[
"first_rating",
"i",
Expand Down Expand Up @@ -1931,6 +1936,31 @@ def count_lapse(r_history, t_history):
return root_mean_squared_error(tmp["y"], tmp["p"], sample_weight=tmp["card_id"])


def wrap_short_term_ratings(r_history, t_history):
result = []
in_zero_sequence = False

for t, r in zip(t_history.split(","), r_history.split(",")):
if t in ("-1", "0"):
if not in_zero_sequence:
result.append("(")
in_zero_sequence = True
result.append(r)
result.append(",")
else:
if in_zero_sequence:
result[-1] = "),"
in_zero_sequence = False
result.append(r)
result.append(",")

if in_zero_sequence:
result[-1] = ")"
else:
result.pop()
return "".join(result)


if __name__ == "__main__":
model = FSRS(DEFAULT_WEIGHT)
stability = torch.tensor([5.0] * 4)
Expand Down

0 comments on commit 37fce1b

Please sign in to comment.