Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/L2 Regularization #270

Merged
merged 9 commits into from
Jan 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/check.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,5 @@ unzip *.zip

RUSTDOCFLAGS="-D warnings" cargo doc --release

cargo install cargo-llvm-cov --locked
cargo install cargo-llvm-cov@0.6.15 --locked
SKIP_TRAINING=1 cargo llvm-cov --release
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "fsrs"
version = "2.0.1"
version = "2.0.2"
authors = ["Open Spaced Repetition"]
categories = ["algorithms", "science"]
edition = "2021"
Expand Down
84 changes: 80 additions & 4 deletions src/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@ use log::info;

use std::sync::{Arc, Mutex};

static PARAMS_STDDEV: [f32; 19] = [
6.61, 9.52, 17.69, 27.74, 0.55, 0.28, 0.67, 0.12, 0.4, 0.18, 0.34, 0.27, 0.08, 0.14, 0.57,
0.25, 1.03, 0.27, 0.39,
];

pub struct BCELoss<B: Backend> {
backend: PhantomData<B>,
}
Expand Down Expand Up @@ -70,6 +75,21 @@ impl<B: Backend> Model<B> {
let retention = self.power_forgetting_curve(delta_ts, state.stability);
BCELoss::new().forward(retention, labels.float(), weights, reduce)
}

pub(crate) fn l2_regularization(
&self,
init_w: Tensor<B, 1>,
params_stddev: Tensor<B, 1>,
batch_size: usize,
total_size: usize,
gamma: f64,
) -> Tensor<B, 1> {
(self.w.val() - init_w)
.powi_scalar(2)
.div(params_stddev.powi_scalar(2))
.sum()
.mul_scalar(gamma * batch_size as f64 / total_size as f64)
}
}

impl<B: AutodiffBackend> Model<B> {
Expand Down Expand Up @@ -189,6 +209,8 @@ pub(crate) struct TrainingConfig {
pub learning_rate: f64,
#[config(default = 64)]
pub max_seq_len: usize,
#[config(default = 1.0)]
pub gamma: f64,
}

pub fn calculate_average_recall(items: &[FSRSItem]) -> f32 {
Expand Down Expand Up @@ -345,7 +367,8 @@ fn train<B: AutodiffBackend>(
B::seed(config.seed);

// Training data
let iterations = (train_set.len() / config.batch_size + 1) * config.num_epochs;
let total_size = train_set.len();
let iterations = (total_size / config.batch_size + 1) * config.num_epochs;
let batch_dataset = BatchTensorDataset::<B>::new(
FSRSDataset::from(train_set),
config.batch_size,
Expand All @@ -356,7 +379,7 @@ fn train<B: AutodiffBackend>(
let batch_dataset = BatchTensorDataset::<B::InnerBackend>::new(
FSRSDataset::from(test_set.clone()),
config.batch_size,
device,
device.clone(),
);
let dataloader_valid = ShuffleDataLoader::new(batch_dataset, config.seed);

Expand All @@ -371,6 +394,8 @@ fn train<B: AutodiffBackend>(
};

let mut model: Model<B> = config.model.init();
let init_w = model.w.val();
let params_stddev = Tensor::from_floats(PARAMS_STDDEV, &device);
let mut optim = config.optimizer.init::<B, Model<B>>();

let mut best_loss = f64::INFINITY;
Expand All @@ -380,8 +405,16 @@ fn train<B: AutodiffBackend>(
let mut iteration = 0;
while let Some(item) = iterator.next() {
iteration += 1;
let real_batch_size = item.delta_ts.shape().dims[0];
let lr = LrScheduler::<B>::step(&mut lr_scheduler);
let progress = iterator.progress();
let penalty = model.l2_regularization(
init_w.clone(),
params_stddev.clone(),
real_batch_size,
total_size,
config.gamma,
);
let loss = model.forward_classification(
item.t_historys,
item.r_historys,
Expand All @@ -390,7 +423,7 @@ fn train<B: AutodiffBackend>(
item.weights,
Reduction::Sum,
);
let mut gradients = loss.backward();
let mut gradients = (loss + penalty).backward();
if model.config.freeze_initial_stability {
gradients = model.freeze_initial_stability(gradients);
}
Expand Down Expand Up @@ -420,6 +453,14 @@ fn train<B: AutodiffBackend>(
let model_valid = model.valid();
let mut loss_valid = 0.0;
for batch in dataloader_valid.iter() {
let real_batch_size = batch.delta_ts.shape().dims[0];
let penalty = model_valid.l2_regularization(
init_w.valid(),
params_stddev.valid(),
real_batch_size,
total_size,
config.gamma,
);
let loss = model_valid.forward_classification(
batch.t_historys,
batch.r_historys,
Expand All @@ -429,7 +470,8 @@ fn train<B: AutodiffBackend>(
Reduction::Sum,
);
let loss = loss.into_data().convert::<f64>().value[0];
loss_valid += loss;
let penalty = penalty.into_data().convert::<f64>().value[0];
loss_valid += loss + penalty;

if interrupter.should_stop() {
break;
Expand Down Expand Up @@ -494,6 +536,8 @@ mod tests {
let device = NdArrayDevice::Cpu;
type B = Autodiff<NdArray<f32>>;
let mut model: Model<B> = config.init();
let init_w = model.w.val();
let params_stddev = Tensor::from_floats(PARAMS_STDDEV, &device);

let item = FSRSBatch {
t_historys: Tensor::from_floats(
Expand Down Expand Up @@ -563,6 +607,38 @@ mod tests {
])
);

let penalty =
model.l2_regularization(init_w.clone(), params_stddev.clone(), 512, 1000, 2.0);
assert_eq!(
penalty.clone().into_data().convert::<f32>().value[0],
0.64689976
);

let gradients = penalty.backward();
let w_grad = model.w.grad(&gradients).unwrap();
Data::from([
0.0018749383,
0.00090389,
0.00026177685,
-0.00010645759,
0.27080965,
-1.0448978,
-0.18249036,
5.688889,
-0.5119995,
2.528395,
-0.7086509,
1.1237301,
-12.799997,
4.179591,
0.25213587,
1.3107198,
-0.07721739,
-1.1237309,
-0.5385926,
])
.assert_approx_eq(&w_grad.clone().into_data(), 5);

let item = FSRSBatch {
t_historys: Tensor::from_floats(
Data::from([
Expand Down
Loading