From 6d0eaf9856001712abbe24bf4aafca900dde76fa Mon Sep 17 00:00:00 2001 From: sdelcore Date: Thu, 6 May 2021 20:25:03 -0400 Subject: [PATCH] added wrapper function in utils.py for getting accuracy of bins --- calibration/util_test.py | 14 ++++++++++++++ calibration/utils.py | 25 +++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/calibration/util_test.py b/calibration/util_test.py index 42c6d8f..cf79769 100644 --- a/calibration/util_test.py +++ b/calibration/util_test.py @@ -249,5 +249,19 @@ def test_missing_classes_ece(self): true_ece = 0.15 self.assertAlmostEqual(pred_ece, true_ece) + def test_get_bin_accuracies(self): + prob_labels = np.array([[0.2, 0], + [0.2, 1], + [0.6, 1], + [0.6, 0], + [0.7, 0], + [0.7, 1], + [0.7, 1], + [0.7, 1]]) + accuracies = get_bin_accuracies(prob_labels) + true_accuracies = [0.0, 0.5, 0.0, 0.0, 0.0, 0.5, 0.75, 0.0, 0.0, 0.0] + self.assertEqual(accuracies, true_accuracies) + + if __name__ == '__main__': unittest.main() diff --git a/calibration/utils.py b/calibration/utils.py index 1cfeb59..0a5f1c6 100644 --- a/calibration/utils.py +++ b/calibration/utils.py @@ -314,6 +314,31 @@ def get_bin_probs(binned_data: BinnedData) -> List[float]: assert(abs(sum(bin_probs) - 1.0) < eps) return list(bin_probs) +def get_bin_accuracies(prob_labels: List[float], num_bins: int=10): + """Seperate binary classification probabilities into + bins and calculate the accuracy for each bin + + Args: + prob_labels: An array of shape (n,2), where each element is a pair of + (probability,label) where label is 1 the prediction was correct, 0 if incorrect + num_bins: the number of bins to seperate the probabilities + + Returns: + An array of accuracies for each bin + """ + prob_labels = np.array(prob_labels) + prob_bins = get_equal_prob_bins(prob_labels, num_bins) + binned_data = bin(prob_labels, prob_bins) + assert(len(binned_data) == num_bins) + accuracies = [] + for b in binned_data: + b = np.array(b) + if len(b) == 0: + accuracies.append(0.0) + else: + accuracies.append(np.mean(b[:,1])) + return accuracies + def plugin_ce(binned_data: BinnedData, power=2) -> float: def bin_error(data: Data):