-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathevaluate.py
59 lines (45 loc) · 1.74 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import json
from pathlib import Path
import click
import pandas as pd
import rouge
import torch
from coop.util import load_tokenizer, load_data, build_model
def evaluate(model, data, num_beams=4, debug=False):
evaluator = rouge.Rouge(metrics=["rouge-n", "rouge-l"], max_n=2, limit_length=False, apply_avg=True,
stemming=True, ensure_compatibility=True)
hyp, ref = [], []
for x in data:
out = model(x["src"], do_generate=True)
summary_avg = model.generate(out.q.mean.mean(dim=0, keepdim=True), num_beams=num_beams)
summary_avg = data.tokenizer.decode(summary_avg)
hyp.extend(summary_avg)
ref.append(x["summary"])
sums = evaluator.get_scores(hyp, ref).items()
scores = {"_".join((metric, "sum", k)): v for metric, vs in sums for k, v in vs.items()}
if debug:
print("Generated examples:")
print("\n".join(hyp[:10]))
return scores
@click.command()
@click.argument("log_dir", type=click.Path(exists=True))
@click.option("--debug", is_flag=True)
def main(log_dir, debug):
log_dir = Path(log_dir)
checkpoint = log_dir / "best.th"
config = json.load(open(log_dir / "config.json"))
src_tokenizer, tgt_tokenizer = load_tokenizer(config)
_, dev, test = load_data(config, src_tokenizer, tgt_tokenizer)
model = build_model(config).eval()
model.load_state_dict(torch.load(checkpoint, map_location=lambda storage, loc: storage))
if torch.cuda.is_available():
model.cuda()
scores = {}
for data_type in ("dev", "test"):
data = eval(data_type)
scores[data_type] = evaluate(model, data, debug=debug)
df = pd.DataFrame(scores)
df.sort_index(inplace=True)
print(df)
if __name__ == '__main__':
main()