-
Notifications
You must be signed in to change notification settings - Fork 347
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Binary Classification Confusion Matrix and AUC Aggregators #633
base: develop
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
/* | ||
Copyright 2012 Twitter, Inc. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package com.twitter.algebird | ||
|
||
/** | ||
* Curve is a list of Confusion Matrices with different | ||
* thresholds | ||
* | ||
* @param matrices List of Matrices | ||
*/ | ||
case class Curve(matrices: List[ConfusionMatrix]) | ||
|
||
/** | ||
* Given a List of (x,y) this functions computes the | ||
* Area Under the Curve | ||
*/ | ||
object AreaUnderCurve { | ||
private def trapezoid(points: Seq[(Double, Double)]): Double = { | ||
require(points.length == 2) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do you want to take the two points directly instead of using this assertion, head, last combo? |
||
val x = points.head | ||
val y = points.last | ||
(y._1 - x._1) * (y._2 + x._2) / 2.0 | ||
} | ||
|
||
def of(curve: List[(Double, Double)]): Double = { | ||
curve.toIterator.sliding(2).withPartial(false).aggregate(0.0)( | ||
seqop = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points), | ||
combop = _ + _) | ||
} | ||
} | ||
|
||
sealed trait AUCMetric | ||
case object ROC extends AUCMetric | ||
case object PR extends AUCMetric | ||
|
||
/** | ||
* Sums Curves which are a series of Confusion Matrices | ||
* with different thresholds | ||
*/ | ||
case object CurveMonoid extends Monoid[Curve] { | ||
def zero = Curve(Nil) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looked like we don't have a test yet that exercises the |
||
override def plus(left: Curve, right: Curve): Curve = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a sumOption implementation that would be more efficient that we should add here as well, for adding multiple curves at once? |
||
val sg = BinaryClassificationConfusionMatrixMonoid | ||
Curve( | ||
left.matrices.zip(right.matrices) | ||
.map{ case (cl, cr) => sg.plus(cl, cr) }) | ||
} | ||
} | ||
|
||
/** | ||
* AUCAggregator computes the Area Under the Curve | ||
* for a given metric by sampling along that curve. | ||
* | ||
* The number of samples is taken and is used to compute | ||
* the thresholds to use. A confusion matrix is then computed | ||
* for each threshold and finally that is used to compute the | ||
* Area Under the Curve. | ||
* | ||
* Note this is for Binary Classifications Tasks | ||
* | ||
* @param metric Which Metric to compute | ||
* @param samples Number of samples, defaults to 100 | ||
*/ | ||
case class BinaryClassificationAUCAggregator(metric: AUCMetric, samples: Int = 100) | ||
extends Aggregator[BinaryPrediction, Curve, Double] | ||
with Serializable { | ||
|
||
private def linspace(a: Double, b: Double, length: Int = 100): Array[Double] = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we never call this without default args - should we remove the default 100? |
||
val increment = (b - a) / (length - 1) | ||
Array.tabulate(length)(i => a + increment * i) | ||
} | ||
|
||
private lazy val thresholds = linspace(0.0, 1.0, samples) | ||
private lazy val aggregators = thresholds.map(BinaryClassificationConfusionMatrixAggregator(_)).toList | ||
|
||
def prepare(input: BinaryPrediction): Curve = Curve(aggregators.map(_.prepare(input))) | ||
|
||
def semigroup: Semigroup[Curve] = CurveMonoid | ||
|
||
def present(c: Curve): Double = { | ||
val total = c.matrices.map { matrix => | ||
val scores = BinaryClassificationConfusionMatrixAggregator().present(matrix) | ||
metric match { | ||
case ROC => (scores.falsePositiveRate, scores.recall) | ||
case PR => (scores.recall, scores.precision) | ||
} | ||
}.reverse | ||
|
||
val combined = metric match { | ||
case ROC => total ++ List((1.0, 1.0)) | ||
case PR => List((0.0, 1.0)) ++ total | ||
} | ||
|
||
AreaUnderCurve.of(combined) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
/* | ||
Copyright 2012 Twitter, Inc. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package com.twitter.algebird | ||
|
||
/** | ||
* A BinaryPrediction is a label with a score | ||
* | ||
* @param score Score of the classifier | ||
* @param label Is this in the positive or negative class. | ||
*/ | ||
case class BinaryPrediction(score: Double, label: Boolean) extends Serializable { | ||
override def toString: String = s"$label,$score" | ||
} | ||
|
||
/** | ||
* Confusion Matrix itself with the statistics to be aggregated | ||
*/ | ||
case class ConfusionMatrix( | ||
truePositive: Int = 0, | ||
falsePositive: Int = 0, | ||
falseNegative: Int = 0, | ||
trueNegative: Int = 0) | ||
extends Serializable | ||
|
||
/** | ||
* After the aggregation this generates some common statistics | ||
* | ||
* @param fscore F Score based on the alpha given to the Aggregator | ||
* @param precision Precision Score | ||
* @param recall Recall Score | ||
* @param falsePositiveRate False Positive Rate | ||
* @param matrix Confusion Matrix | ||
*/ | ||
case class Scores( | ||
fscore: Double, | ||
precision: Double, | ||
recall: Double, | ||
falsePositiveRate: Double, | ||
matrix: ConfusionMatrix) | ||
extends Serializable | ||
|
||
case object BinaryClassificationConfusionMatrixMonoid extends Monoid[ConfusionMatrix] { | ||
def zero: ConfusionMatrix = ConfusionMatrix() | ||
override def plus(left: ConfusionMatrix, right: ConfusionMatrix): ConfusionMatrix = { | ||
val tp = left.truePositive + right.truePositive | ||
val fp = left.falsePositive + right.falsePositive | ||
val fn = left.falseNegative + right.falseNegative | ||
val tn = left.trueNegative + right.trueNegative | ||
|
||
ConfusionMatrix(tp, fp, fn, tn) | ||
} | ||
} | ||
|
||
/** | ||
* A Confusion Matrix Aggregator creates a Confusion Matrix and | ||
* relevant scores for a given threshold given predictions from | ||
* a binary classifier. | ||
* | ||
* @param threshold Threshold to use for the predictions | ||
* @param beta Beta used in the FScore Calculation. | ||
*/ | ||
case class BinaryClassificationConfusionMatrixAggregator(threshold: Double = 0.5, beta: Double = 1.0) | ||
extends Aggregator[BinaryPrediction, ConfusionMatrix, Scores] | ||
with Serializable { | ||
|
||
def prepare(input: BinaryPrediction): ConfusionMatrix = | ||
(input.label, input.score) match { | ||
case (true, score) if score > threshold => | ||
ConfusionMatrix(truePositive = 1) | ||
case (true, score) if score < threshold => | ||
ConfusionMatrix(falseNegative = 1) | ||
case (false, score) if score < threshold => | ||
ConfusionMatrix(trueNegative = 1) | ||
case (false, score) if score > threshold => | ||
ConfusionMatrix(falsePositive = 1) | ||
} | ||
|
||
def semigroup: Semigroup[ConfusionMatrix] = | ||
BinaryClassificationConfusionMatrixMonoid | ||
|
||
def present(m: ConfusionMatrix): Scores = { | ||
val precDenom = m.truePositive.toDouble + m.falsePositive.toDouble | ||
val precision = if (precDenom > 0.0) m.truePositive.toDouble / precDenom else 1.0 | ||
|
||
val recallDenom = m.truePositive.toDouble + m.falseNegative.toDouble | ||
val recall = if (recallDenom > 0.0) m.truePositive.toDouble / recallDenom else 1.0 | ||
|
||
val fpDenom = m.falsePositive.toDouble + m.trueNegative.toDouble | ||
val fpr = if (fpDenom > 0.0) m.falsePositive.toDouble / fpDenom else 0.0 | ||
|
||
val betaSqr = Math.pow(beta, 2.0) | ||
|
||
val fScoreDenom = (betaSqr * precision) + recall | ||
|
||
val fscore = if (fScoreDenom > 0.0) { | ||
(1 + betaSqr) * ((precision * recall) / fScoreDenom) | ||
} else { 1.0 } | ||
|
||
Scores(fscore, precision, recall, fpr, m) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* | ||
Copyright 2012 Twitter, Inc. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package com.twitter.algebird | ||
|
||
import org.scalacheck.Arbitrary | ||
import org.scalacheck.Gen.choose | ||
import org.scalactic.TolerantNumerics | ||
import org.scalatest.{ Matchers, _ } | ||
|
||
class CurveMonoidLaws extends CheckProperties { | ||
import BaseProperties._ | ||
|
||
implicit val semigroup = CurveMonoid | ||
implicit val gen = Arbitrary { | ||
for ( | ||
v <- choose(0, 10000) | ||
) yield Curve(List(ConfusionMatrix(truePositive = v))) | ||
} | ||
|
||
property("Curve is associative") { | ||
isAssociative[Curve] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we use the monoid laws? |
||
} | ||
} | ||
|
||
class BinaryClassificationAUCTest extends WordSpec with Matchers { | ||
lazy val data = | ||
List( | ||
BinaryPrediction(0.1, false), | ||
BinaryPrediction(0.1, true), | ||
BinaryPrediction(0.4, false), | ||
BinaryPrediction(0.6, false), | ||
BinaryPrediction(0.6, true), | ||
BinaryPrediction(0.6, true), | ||
BinaryPrediction(0.8, true)) | ||
|
||
"BinaryClassificationAUC" should { | ||
implicit val doubleEq = TolerantNumerics.tolerantDoubleEquality(0.1) | ||
|
||
"return roc auc" in { | ||
val aggregator = BinaryClassificationAUCAggregator(ROC, samples = 50) | ||
assert(aggregator(data) === 0.708) | ||
} | ||
|
||
"return pr auc" in { | ||
val aggregator = BinaryClassificationAUCAggregator(PR, samples = 50) | ||
assert(aggregator(data) === 0.833) | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
/* | ||
Copyright 2012 Twitter, Inc. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
|
||
package com.twitter.algebird | ||
|
||
import org.scalacheck.Arbitrary | ||
import org.scalacheck.Gen.choose | ||
import org.scalactic.TolerantNumerics | ||
import org.scalatest.{ Matchers, WordSpec } | ||
|
||
class BinaryClassificationConfusionMatrixMonoidLaws extends CheckProperties { | ||
import BaseProperties._ | ||
|
||
implicit val semigroup = BinaryClassificationConfusionMatrixMonoid | ||
implicit val gen = Arbitrary { | ||
for ( | ||
v <- choose(0, 10000) | ||
) yield ConfusionMatrix(truePositive = v) | ||
} | ||
|
||
property("ConfusionMatrix is associative") { | ||
isAssociative[ConfusionMatrix] | ||
} | ||
} | ||
|
||
class BinaryClassificationConfusionMatrixTest extends WordSpec with Matchers { | ||
lazy val data = | ||
List( | ||
BinaryPrediction(0.1, false), | ||
BinaryPrediction(0.1, true), | ||
BinaryPrediction(0.4, false), | ||
BinaryPrediction(0.6, false), | ||
BinaryPrediction(0.6, true), | ||
BinaryPrediction(0.6, true), | ||
BinaryPrediction(0.8, true)) | ||
|
||
"BinaryClassificationConfusionMatrix" should { | ||
implicit val doubleEq = TolerantNumerics.tolerantDoubleEquality(0.1) | ||
|
||
"return a correct confusion matrix" in { | ||
val aggregator = BinaryClassificationConfusionMatrixAggregator() | ||
val scored = aggregator(data) | ||
|
||
assert(scored.recall === 0.75) | ||
assert(scored.precision === 0.75) | ||
assert(scored.fscore === 0.75) | ||
assert(scored.falsePositiveRate === 0.333) | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you add an object with an
implicit def monoid: Monoid[Curve] = CurveMonoid
you won't need to declare the same in the tests.