-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathbenchmark.py
193 lines (153 loc) · 7.2 KB
/
benchmark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
'''
Usage:
benchmark --gold=GOLD_OIE --out=OUTPUT_FILE (--stanford=STANFORD_OIE | --ollie=OLLIE_OIE |--reverb=REVERB_OIE | --clausie=CLAUSIE_OIE | --openiefour=OPENIEFOUR_OIE | --props=PROPS_OIE)
Options:
--gold=GOLD_OIE The gold reference Open IE file (by default, it should be under ./oie_corpus/all.oie).
--out-OUTPUT_FILE The output file, into which the precision recall curve will be written.
--clausie=CLAUSIE_OIE Read ClausIE format from file CLAUSIE_OIE.
--ollie=OLLIE_OIE Read OLLIE format from file OLLIE_OIE.
--openiefour=OPENIEFOUR_OIE Read Open IE 4 format from file OPENIEFOUR_OIE.
--props=PROPS_OIE Read PropS format from file PROPS_OIE
--reverb=REVERB_OIE Read ReVerb format from file REVERB_OIE
--stanford=STANFORD_OIE Read Stanford format from file STANFORD_OIE
'''
import docopt
import string
import numpy as np
from sklearn.metrics import precision_recall_curve
import re
import logging
logging.basicConfig(level = logging.INFO)
from oie_readers.stanfordReader import StanfordReader
from oie_readers.ollieReader import OllieReader
from oie_readers.reVerbReader import ReVerbReader
from oie_readers.clausieReader import ClausieReader
from oie_readers.openieFourReader import OpenieFourReader
from oie_readers.propsReader import PropSReader
from oie_readers.goldReader import GoldReader
from matcher import Matcher
class Benchmark:
''' Compare the gold OIE dataset against a predicted equivalent '''
def __init__(self, gold_fn):
''' Load gold Open IE, this will serve to compare against using the compare function '''
gr = GoldReader()
gr.read(gold_fn)
self.gold = gr.oie
def compare(self, predicted, matchingFunc, output_fn):
''' Compare gold against predicted using a specified matching function.
Outputs PR curve to output_fn '''
y_true = []
y_scores = []
correctTotal = 0
unmatchedCount = 0
predicted = Benchmark.normalizeDict(predicted)
gold = Benchmark.normalizeDict(self.gold)
for sent, goldExtractions in list(gold.items()):
if sent not in predicted:
# The extractor didn't find any extractions for this sentence
unmatchedCount += len(goldExtractions)
correctTotal += len(goldExtractions)
continue
predictedExtractions = predicted[sent]
for goldEx in goldExtractions:
correctTotal += 1
found = False
for predictedEx in predictedExtractions:
if matchingFunc(goldEx,
predictedEx,
ignoreStopwords = True,
ignoreCase = True):
y_true.append(1)
y_scores.append(predictedEx.confidence)
predictedEx.matched.append(output_fn)
# Also mark any other predictions with the
# same exact predicate as matched.
# This is to support packages that do conjunction
# splitting, and doesn't affect the results for
# packages that don't.
if predictedEx.splits_conjunctions:
for otherPredictedEx in predictedExtractions:
if otherPredictedEx.pred == predictedEx.pred:
otherPredictedEx.matched.append(output_fn)
found = True
break
if not found:
unmatchedCount += 1
for predictedEx in [x for x in predictedExtractions if (output_fn not in x.matched)]:
# Add false positives
y_true.append(0)
y_scores.append(predictedEx.confidence)
y_true = y_true
y_scores = y_scores
# recall on y_true, y (r')_scores computes |covered by extractor| / |True in what's covered by extractor|
# to get to true recall we do r' * (|True in what's covered by extractor| / |True in gold|) = |true in what's covered| / |true in gold|
p, r = Benchmark.prCurve(np.array(y_true), np.array(y_scores),
recallMultiplier = ((correctTotal - unmatchedCount)/float(correctTotal)))
# write PR to file
with open(output_fn, 'w') as fout:
fout.write('{0}\t{1}\n'.format("Precision", "Recall"))
for cur_p, cur_r in sorted(zip(p, r), key = lambda cur_p_cur_r: cur_p_cur_r[1]):
fout.write('{0}\t{1}\n'.format(cur_p, cur_r))
@staticmethod
def prCurve(y_true, y_scores, recallMultiplier):
# Recall multiplier - accounts for the percentage examples unreached by
precision, recall, _ = precision_recall_curve(y_true, y_scores)
recall = recall * recallMultiplier
return precision, recall
# Helper functions:
@staticmethod
def normalizeDict(d):
return dict([(Benchmark.normalizeKey(k), v) for k, v in list(d.items())])
@staticmethod
def normalizeKey(k):
return Benchmark.removePunct(str(Benchmark.PTB_unescape(k.replace(' ',''))))
@staticmethod
def PTB_escape(s):
for u, e in Benchmark.PTB_ESCAPES:
s = s.replace(u, e)
return s
@staticmethod
def PTB_unescape(s):
for u, e in Benchmark.PTB_ESCAPES:
s = s.replace(e, u)
return s
@staticmethod
def removePunct(s):
return Benchmark.regex.sub('', s)
# CONSTANTS
regex = re.compile('[%s]' % re.escape(string.punctuation))
# Penn treebank bracket escapes
# Taken from: https://github.com/nlplab/brat/blob/master/server/src/gtbtokenize.py
PTB_ESCAPES = [('(', '-LRB-'),
(')', '-RRB-'),
('[', '-LSB-'),
(']', '-RSB-'),
('{', '-LCB-'),
('}', '-RCB-'),]
if __name__ == '__main__':
args = docopt.docopt(__doc__)
logging.debug(args)
if args['--stanford']:
predicted = StanfordReader()
predicted.read(args['--stanford'])
if args['--props']:
predicted = PropSReader()
predicted.read(args['--props'])
if args['--ollie']:
predicted = OllieReader()
predicted.read(args['--ollie'])
if args['--reverb']:
predicted = ReVerbReader()
predicted.read(args['--reverb'])
if args['--clausie']:
predicted = ClausieReader()
predicted.read(args['--clausie'])
if args['--openiefour']:
predicted = OpenieFourReader()
predicted.read(args['--openiefour'])
b = Benchmark(args['--gold'])
out_filename = args['--out']
logging.info("Writing PR curve of {} to {}".format(predicted.name, out_filename))
b.compare(predicted = predicted.oie,
matchingFunc = Matcher.lexicalMatch,
output_fn = out_filename)