Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix confusion matrix using only predictions as source for labels #249

Open
wants to merge 24 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
fb7a1fa
Add serialization for LogisticRegression
levkk Oct 6, 2022
0930322
Merge pull request #1 from postgresml/levkk-add-ser-for-logistic
levkk Oct 6, 2022
378c3b4
Serialization for multi-class
levkk Oct 6, 2022
ecd48d0
Merge pull request #2 from postgresml/levkk-fix-missing-ser
levkk Oct 6, 2022
fe3ae53
Float type restriction with handwritten bounds
gkobeaga Oct 6, 2022
9392fe6
Merge pull request #3 from gkobeaga/serde-logistic
levkk Oct 6, 2022
3fa43a8
Merge branch 'rust-ml:master' into master
levkk Oct 16, 2022
c44940b
Confusion matrix should use labels from predictions and ground truth
levkk Oct 17, 2022
4057c2d
Merge pull request #4 from postgresml/levkk-f1-division-by-zero
levkk Oct 17, 2022
d91de55
Clippy fixes
levkk Oct 17, 2022
3356d42
Merge pull request #5 from postgresml/levkk-fix-f1-metric
levkk Oct 17, 2022
4ac3ec8
This is the correct test
levkk Oct 17, 2022
3dd71b1
Merge pull request #6 from postgresml/levkk-fix-test-not-sure
levkk Oct 17, 2022
1e8ac38
Merge branch 'rust-ml:master' into master
montanalow Jun 6, 2023
ef0a23a
Merge branch 'rust-ml:master' into master
montanalow Jul 18, 2023
01c8224
Merge branch 'rust-ml:master' into master
levkk Nov 2, 2023
4004fec
Merge branch 'rust-ml:master' into master
montanalow Jan 11, 2025
7dee254
fix warnings
montanalow Jan 11, 2025
e9904a8
remove lifetimes
montanalow Jan 11, 2025
4f8ccef
clippy lints
montanalow Jan 11, 2025
5ec7b2f
fix ownership
montanalow Jan 11, 2025
97d52e7
fix ownership
montanalow Jan 11, 2025
d4a5744
cleanup lints
montanalow Jan 11, 2025
9d615fc
Merge pull request #7 from postgresml/montana/a
montanalow Jan 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/dataset/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,18 @@ pub trait Labels {
fn labels(&self) -> Vec<Self::Elem> {
self.label_set().into_iter().flatten().collect()
Copy link
Collaborator

@YuhanLiin YuhanLiin Oct 19, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason this method doesn't dedup the final vector. It should do something like union all HashSet together. Or we can just change the return type to HashSet, but that might be too invasive.

}

fn combined_labels(&self, other: Vec<Self::Elem>) -> Vec<Self::Elem> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better to have this method take &impl Labels or &Self as input. Then you can call label_set on both self and the input and union all the hashsets before converting it into a Vec.

let mut combined = self.labels();
combined.extend(other);

combined
.iter()
.collect::<HashSet<_>>()
.into_iter()
.cloned()
.collect()
}
}

#[cfg(test)]
Expand Down
13 changes: 12 additions & 1 deletion src/metrics_classification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ where
return Err(Error::MismatchedShapes(targets.len(), ground_truth.len()));
}

let classes = self.labels();
let classes = self.combined_labels(ground_truth.labels());

let indices = map_prediction_to_idx(
targets.as_slice().unwrap(),
Expand Down Expand Up @@ -636,6 +636,17 @@ mod tests {
);
}

#[test]
fn test_division_by_zero_cm() {
let ground_truth = Array1::from(vec![1, 1, 0, 1, 0, 1]);
let predicted = Array1::from(vec![0, 0, 0, 0, 0, 0]);

let x = ground_truth.confusion_matrix(predicted).unwrap();
let f1 = x.f1_score();

assert_eq!(f1, 0.5);
}

#[test]
fn test_roc_curve() {
let predicted = ArrayView1::from(&[0.1, 0.3, 0.5, 0.7, 0.8, 0.9]).mapv(Pr::new);
Expand Down