-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathevaluate.py
75 lines (55 loc) · 2.89 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import logging
from pathlib import Path
import numpy as np
import torch
from helpers import init_helper, data_helper, vsumm_helper, bbox_helper
from modules.model_zoo import get_model
logger = logging.getLogger()
def evaluate(model, val_loader, nms_thresh, device):
model.eval()
stats = data_helper.AverageMeter('fscore', 'diversity')
with torch.no_grad():
for test_key, seq,seqdiff, gt_fscore, cps, n_frames, nfps, picks, user_summary in val_loader:
seq_len = len(seq)
seq_torch = torch.from_numpy(seq).unsqueeze(0).to(device)
seqdiff_torch = torch.from_numpy(seqdiff).unsqueeze(0).to(device)
pred_cls, pred_bboxes = model.predict(seq_torch,seqdiff_torch)
pred_bboxes = np.clip(pred_bboxes, 0, seq_len).round().astype(np.int32)
pred_cls, pred_bboxes = bbox_helper.nms(pred_cls, pred_bboxes, nms_thresh)
pred_summ = vsumm_helper.bbox2summary(
seq_len, pred_cls, pred_bboxes, cps, n_frames, nfps, picks)
eval_metric = 'avg' if 'tvsum' in test_key else 'max'#gt_score[298];fscore
fscore = vsumm_helper.get_summ_f1score(
pred_summ, user_summary, eval_metric)
# print(fscore)
pred_summ = vsumm_helper.downsample_summ(pred_summ)
diversity = vsumm_helper.get_summ_diversity(pred_summ, seq)
stats.update(fscore=fscore, diversity=diversity)
return stats.fscore, stats.diversity
def main():
args = init_helper.get_arguments()
init_helper.init_logger(args.model_dir, args.log_file)
init_helper.set_random_seed(args.seed)
logger.info(args)
model = get_model(args.model, **vars(args))
model = model.eval().to(args.device)
print('Total params: %.6fM' % (sum(p.numel() for p in model.parameters()) / 1000000.0))
for split_path in args.splits:
split_path = Path(split_path)
splits = data_helper.load_yaml(split_path)
stats = data_helper.AverageMeter('fscore', 'diversity')
for split_idx, split in enumerate(splits):
ckpt_path = data_helper.get_ckpt_path(args.model_dir, split_path, split_idx)
state_dict = torch.load(str(ckpt_path),
map_location=lambda storage, loc: storage)
model.load_state_dict(state_dict)
val_set = data_helper.VideoDataset(split['test_keys'])
val_loader = data_helper.DataLoader(val_set, shuffle=False)
fscore, diversity = evaluate(model, val_loader, args.nms_thresh, args.device)
stats.update(fscore=fscore, diversity=diversity)
logger.info(f'{split_path.stem} split {split_idx}: diversity: '
f'{diversity:.4f}, F-score: {fscore:.4f}')
logger.info(f'{split_path.stem}: diversity: {stats.diversity:.4f}, '
f'F-score: {stats.fscore:.4f}')
if __name__ == '__main__':
main()