Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Binary Classification Confusion Matrix and AUC Aggregators #633

Open
wants to merge 3 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/*
Copyright 2012 Twitter, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package com.twitter.algebird

/**
* Curve is a list of Confusion Matrices with different
* thresholds
*
* @param matrices List of Matrices
*/
case class Curve(matrices: List[ConfusionMatrix])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you add an object with an implicit def monoid: Monoid[Curve] = CurveMonoid you won't need to declare the same in the tests.


/**
* Given a List of (x,y) this functions computes the
* Area Under the Curve
*/
object AreaUnderCurve {
private def trapezoid(points: Seq[(Double, Double)]): Double = {
require(points.length == 2)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to take the two points directly instead of using this assertion, head, last combo?

val x = points.head
val y = points.last
(y._1 - x._1) * (y._2 + x._2) / 2.0
}

def of(curve: List[(Double, Double)]): Double = {
curve.toIterator.sliding(2).withPartial(false).aggregate(0.0)(
seqop = (auc: Double, points: Seq[(Double, Double)]) => auc + trapezoid(points),
combop = _ + _)
}
}

sealed trait AUCMetric
case object ROC extends AUCMetric
case object PR extends AUCMetric

/**
* Sums Curves which are a series of Confusion Matrices
* with different thresholds
*/
case object CurveMonoid extends Monoid[Curve] {
def zero = Curve(Nil)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looked like we don't have a test yet that exercises the zero.

override def plus(left: Curve, right: Curve): Curve = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a sumOption implementation that would be more efficient that we should add here as well, for adding multiple curves at once?

val sg = BinaryClassificationConfusionMatrixMonoid
Curve(
left.matrices.zip(right.matrices)
.map{ case (cl, cr) => sg.plus(cl, cr) })
}
}

/**
* AUCAggregator computes the Area Under the Curve
* for a given metric by sampling along that curve.
*
* The number of samples is taken and is used to compute
* the thresholds to use. A confusion matrix is then computed
* for each threshold and finally that is used to compute the
* Area Under the Curve.
*
* Note this is for Binary Classifications Tasks
*
* @param metric Which Metric to compute
* @param samples Number of samples, defaults to 100
*/
case class BinaryClassificationAUCAggregator(metric: AUCMetric, samples: Int = 100)
extends Aggregator[BinaryPrediction, Curve, Double]
with Serializable {

private def linspace(a: Double, b: Double, length: Int = 100): Array[Double] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we never call this without default args - should we remove the default 100?

val increment = (b - a) / (length - 1)
Array.tabulate(length)(i => a + increment * i)
}

private lazy val thresholds = linspace(0.0, 1.0, samples)
private lazy val aggregators = thresholds.map(BinaryClassificationConfusionMatrixAggregator(_)).toList

def prepare(input: BinaryPrediction): Curve = Curve(aggregators.map(_.prepare(input)))

def semigroup: Semigroup[Curve] = CurveMonoid

def present(c: Curve): Double = {
val total = c.matrices.map { matrix =>
val scores = BinaryClassificationConfusionMatrixAggregator().present(matrix)
metric match {
case ROC => (scores.falsePositiveRate, scores.recall)
case PR => (scores.recall, scores.precision)
}
}.reverse

val combined = metric match {
case ROC => total ++ List((1.0, 1.0))
case PR => List((0.0, 1.0)) ++ total
}

AreaUnderCurve.of(combined)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
/*
Copyright 2012 Twitter, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package com.twitter.algebird

/**
* A BinaryPrediction is a label with a score
*
* @param score Score of the classifier
* @param label Is this in the positive or negative class.
*/
case class BinaryPrediction(score: Double, label: Boolean) extends Serializable {
override def toString: String = s"$label,$score"
}

/**
* Confusion Matrix itself with the statistics to be aggregated
*/
case class ConfusionMatrix(
truePositive: Int = 0,
falsePositive: Int = 0,
falseNegative: Int = 0,
trueNegative: Int = 0)
extends Serializable

/**
* After the aggregation this generates some common statistics
*
* @param fscore F Score based on the alpha given to the Aggregator
* @param precision Precision Score
* @param recall Recall Score
* @param falsePositiveRate False Positive Rate
* @param matrix Confusion Matrix
*/
case class Scores(
fscore: Double,
precision: Double,
recall: Double,
falsePositiveRate: Double,
matrix: ConfusionMatrix)
extends Serializable

case object BinaryClassificationConfusionMatrixMonoid extends Monoid[ConfusionMatrix] {
def zero: ConfusionMatrix = ConfusionMatrix()
override def plus(left: ConfusionMatrix, right: ConfusionMatrix): ConfusionMatrix = {
val tp = left.truePositive + right.truePositive
val fp = left.falsePositive + right.falsePositive
val fn = left.falseNegative + right.falseNegative
val tn = left.trueNegative + right.trueNegative

ConfusionMatrix(tp, fp, fn, tn)
}
}

/**
* A Confusion Matrix Aggregator creates a Confusion Matrix and
* relevant scores for a given threshold given predictions from
* a binary classifier.
*
* @param threshold Threshold to use for the predictions
* @param beta Beta used in the FScore Calculation.
*/
case class BinaryClassificationConfusionMatrixAggregator(threshold: Double = 0.5, beta: Double = 1.0)
extends Aggregator[BinaryPrediction, ConfusionMatrix, Scores]
with Serializable {

def prepare(input: BinaryPrediction): ConfusionMatrix =
(input.label, input.score) match {
case (true, score) if score > threshold =>
ConfusionMatrix(truePositive = 1)
case (true, score) if score < threshold =>
ConfusionMatrix(falseNegative = 1)
case (false, score) if score < threshold =>
ConfusionMatrix(trueNegative = 1)
case (false, score) if score > threshold =>
ConfusionMatrix(falsePositive = 1)
}

def semigroup: Semigroup[ConfusionMatrix] =
BinaryClassificationConfusionMatrixMonoid

def present(m: ConfusionMatrix): Scores = {
val precDenom = m.truePositive.toDouble + m.falsePositive.toDouble
val precision = if (precDenom > 0.0) m.truePositive.toDouble / precDenom else 1.0

val recallDenom = m.truePositive.toDouble + m.falseNegative.toDouble
val recall = if (recallDenom > 0.0) m.truePositive.toDouble / recallDenom else 1.0

val fpDenom = m.falsePositive.toDouble + m.trueNegative.toDouble
val fpr = if (fpDenom > 0.0) m.falsePositive.toDouble / fpDenom else 0.0

val betaSqr = Math.pow(beta, 2.0)

val fScoreDenom = (betaSqr * precision) + recall

val fscore = if (fScoreDenom > 0.0) {
(1 + betaSqr) * ((precision * recall) / fScoreDenom)
} else { 1.0 }

Scores(fscore, precision, recall, fpr, m)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
Copyright 2012 Twitter, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package com.twitter.algebird

import org.scalacheck.Arbitrary
import org.scalacheck.Gen.choose
import org.scalactic.TolerantNumerics
import org.scalatest.{ Matchers, _ }

class CurveMonoidLaws extends CheckProperties {
import BaseProperties._

implicit val semigroup = CurveMonoid
implicit val gen = Arbitrary {
for (
v <- choose(0, 10000)
) yield Curve(List(ConfusionMatrix(truePositive = v)))
}

property("Curve is associative") {
isAssociative[Curve]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we use the monoid laws?

}
}

class BinaryClassificationAUCTest extends WordSpec with Matchers {
lazy val data =
List(
BinaryPrediction(0.1, false),
BinaryPrediction(0.1, true),
BinaryPrediction(0.4, false),
BinaryPrediction(0.6, false),
BinaryPrediction(0.6, true),
BinaryPrediction(0.6, true),
BinaryPrediction(0.8, true))

"BinaryClassificationAUC" should {
implicit val doubleEq = TolerantNumerics.tolerantDoubleEquality(0.1)

"return roc auc" in {
val aggregator = BinaryClassificationAUCAggregator(ROC, samples = 50)
assert(aggregator(data) === 0.708)
}

"return pr auc" in {
val aggregator = BinaryClassificationAUCAggregator(PR, samples = 50)
assert(aggregator(data) === 0.833)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
Copyright 2012 Twitter, Inc.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

package com.twitter.algebird

import org.scalacheck.Arbitrary
import org.scalacheck.Gen.choose
import org.scalactic.TolerantNumerics
import org.scalatest.{ Matchers, WordSpec }

class BinaryClassificationConfusionMatrixMonoidLaws extends CheckProperties {
import BaseProperties._

implicit val semigroup = BinaryClassificationConfusionMatrixMonoid
implicit val gen = Arbitrary {
for (
v <- choose(0, 10000)
) yield ConfusionMatrix(truePositive = v)
}

property("ConfusionMatrix is associative") {
isAssociative[ConfusionMatrix]
}
}

class BinaryClassificationConfusionMatrixTest extends WordSpec with Matchers {
lazy val data =
List(
BinaryPrediction(0.1, false),
BinaryPrediction(0.1, true),
BinaryPrediction(0.4, false),
BinaryPrediction(0.6, false),
BinaryPrediction(0.6, true),
BinaryPrediction(0.6, true),
BinaryPrediction(0.8, true))

"BinaryClassificationConfusionMatrix" should {
implicit val doubleEq = TolerantNumerics.tolerantDoubleEquality(0.1)

"return a correct confusion matrix" in {
val aggregator = BinaryClassificationConfusionMatrixAggregator()
val scored = aggregator(data)

assert(scored.recall === 0.75)
assert(scored.precision === 0.75)
assert(scored.fscore === 0.75)
assert(scored.falsePositiveRate === 0.333)
}
}
}