-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathedit_metrics.py
102 lines (68 loc) · 2.69 KB
/
edit_metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
from abc import ABC
import numpy as np
from stringcase import snakecase
from util import SubclassRegistry
def diff_signs(values):
return np.sign(values.diff()[1:].values)
class EditMetric(SubclassRegistry):
def __call__(self, weights, outputs):
raise NotImplementedError
@classmethod
def get_subclass_names(cls, snakecase=False):
subclass_names = super().get_subclass_names(snakecase=snakecase)
subclass_names.remove('scipy_stats_metric')
return subclass_names
class SignConsistency(EditMetric):
def __call__(self, *args, **kwargs):
return self._calculate(*args, **kwargs)
def _calculate(self, weights, outputs, strict=False, tolerance=0):
diff_w = diff_signs(weights)
diff_o = diff_signs(outputs)
same = diff_w == diff_o
zero_mask = np.abs(diff_o) <= tolerance
n_zero_matches = zero_mask.sum()
n_non_zero_matches = same[~zero_mask].sum()
n_diffs = len(same)
n_possible_non_zero_matches = n_diffs - n_zero_matches
# maybe flip sign
n_non_zero_matches = max(
n_non_zero_matches,
n_possible_non_zero_matches - n_non_zero_matches)
n_matches = n_non_zero_matches
if not strict:
n_matches += n_zero_matches
score = n_matches / n_diffs
return score
class SignConsistencyStrict(SignConsistency):
def __call__(self, weights, outputs, tolerance=0):
return self._calculate(
weights, outputs, strict=True, tolerance=tolerance)
class Monotonicity(EditMetric):
def __call__(self, weights, outputs, tolerance=0):
diff_o = diff_signs(outputs)
diff_o = diff_o[np.absolute(diff_o) > tolerance]
return np.absolute(diff_o.mean())
class MonotonicityStrict(EditMetric):
def __call__(self, weights, outputs):
diff_o = diff_signs(outputs)
return np.absolute(diff_o.mean())
class ScipyStatsMetric(EditMetric, ABC):
def __call__(self, weights, outputs):
import scipy.stats
metric_fn = getattr(scipy.stats, self.fn_name)
return metric_fn(weights, outputs).statistic
class SpearmanCorrelation(ScipyStatsMetric):
fn_name = 'spearmanr'
class PearsonCorrelation(ScipyStatsMetric):
fn_name = 'pearsonr'
class KendallTau(ScipyStatsMetric):
fn_name = 'kendalltau'
def calculate_edit_metrics(weights, outputs):
classes = set(EditMetric.get_subclasses())
classes.remove(ScipyStatsMetric)
def calculate_metric(metric_cls):
metric = metric_cls()
name = snakecase(metric.__class__.__name__)
score = metric(weights, outputs)
return name, score
return dict(map(calculate_metric, classes))