forked from hushell/pmf_cvpr22
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_bscdfsl.py
116 lines (93 loc) · 3.86 KB
/
test_bscdfsl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import numpy as np
import time
import random
import torch
import torch.backends.cudnn as cudnn
import json
from pathlib import Path
from tabulate import tabulate
from engine import evaluate
import utils.deit_util as utils
from utils.args import get_args_parser
from models import get_model
from datasets import get_bscd_loader
def main(args):
args.distributed = False # CDFSL dataloader doesn't support DDP
args.eval = True
print(args)
device = torch.device(args.device)
cudnn.benchmark = True
##############################################
# Model
print(f"Creating model: {args.deploy} {args.arch}")
model = get_model(args)
model.to(device)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'], strict=True)
print(f'Load ckpt from {args.resume}')
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
##############################################
# Test
criterion = torch.nn.CrossEntropyLoss()
#datasets = ["EuroSAT", "ISIC", "CropDisease", "ChestX"]
datasets = args.cdfsl_domains
var_accs = {}
for domain in datasets:
print(f'\n# Testing {domain} starts...\n')
data_loader_val = get_bscd_loader(domain, args.test_n_way, args.n_shot, args.image_size)
# validate lr
best_lr = args.ada_lr
if args.deploy == 'finetune':
print("Start selecting the best lr...")
best_acc = 0
for lr in [0, 0.0001, 0.0005, 0.001]:
model.lr = lr
test_stats = evaluate(data_loader_val, model, criterion, device, seed=1234, ep=5)
acc = test_stats['acc1']
print(f"*lr = {lr}: acc1 = {acc}")
if acc > best_acc:
best_acc = acc
best_lr = lr
model.lr = best_lr
print(f"### Selected lr = {best_lr}")
# final classification
data_loader_val.generator.manual_seed(args.seed + 10000)
test_stats = evaluate(data_loader_val, model, criterion, device)
var_accs[domain] = (test_stats['acc1'], test_stats['acc_std'], best_lr)
print(f"{domain}: acc1 on {len(data_loader_val.dataset)} test images: {test_stats['acc1']:.1f}%")
if args.output_dir and utils.is_main_process():
test_stats['domain'] = domain
test_stats['lr'] = best_lr
with (output_dir / f"log_test_{args.deploy}_{args.train_tag}.txt").open("a") as f:
f.write(json.dumps(test_stats) + "\n")
# print results as a table
if utils.is_main_process():
rows = []
for dataset_name in datasets:
row = [dataset_name]
acc, std, lr = var_accs[dataset_name]
conf = (1.96 * std) / np.sqrt(len(data_loader_val.dataset))
row.append(f"{acc:0.2f} +- {conf:0.2f}")
row.append(f"{lr}")
rows.append(row)
np.save(os.path.join(output_dir, f'test_results_{args.deploy}_{args.train_tag}.npy'), {'rows': rows})
table = tabulate(rows, headers=['Domain', args.arch, 'lr'], floatfmt=".2f")
print(table)
print("\n")
if args.output_dir:
with (output_dir / f"log_test_{args.deploy}_{args.train_tag}.txt").open("a") as f:
f.write(table)
if __name__ == '__main__':
parser = get_args_parser()
args = parser.parse_args()
args.train_tag = 'pt' if args.resume == '' else 'ep'
args.train_tag += f'_step{args.ada_steps}_lr{args.ada_lr}_prob{args.aug_prob}'
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
import sys
with (output_dir / f"log_test_{args.deploy}_{args.train_tag}.txt").open("a") as f:
f.write(" ".join(sys.argv) + "\n")
main(args)