-
Notifications
You must be signed in to change notification settings - Fork 156
/
Copy pathtps_res32_attn.py
79 lines (64 loc) · 2.13 KB
/
tps_res32_attn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""
##################################################################################################
# Copyright Info : Copyright (c) Davar Lab @ Hikvision Research Institute. All rights reserved.
# Filename : spin_res32_attn.py
# Abstract : SPIN transformation recognition Model
# Current Version: 1.0.0
# Date : 2021-06-11
##################################################################################################
"""
# encoding=utf-8
_base_ = [
'./baseline.py'
]
"""
1. Model Settings
include model-related setting, such as model type, user-selected modules and parameters.
"""
# model parameters for changing the TPS transformation
model = dict(
transformation=dict(
type='TPS_SpatialTransformer',
F=20,
I_size=(32, 100),
I_r_size=(32, 100),
I_channel_num=1,
_delete_=True,
),
)
data = dict(
samples_per_gpu=64)
# checkpoint setting
checkpoint_config = dict(type="DavarCheckpointHook",
interval=1,
iter_interval=5000,
by_epoch=True,
by_iter=True,
filename_tmpl='ckpt/res32_ace_e{}.pth',
metric="accuracy",
rule="greater",
save_mode="lightweight",
init_metric=-1,
model_milestone=0.5
)
# logger setting
log_config = dict(
interval=50,
hooks=[dict(type='TextLoggerHook'), ])
evaluation = dict(start=3,
start_iter=0.5,
save_best="accuracy",
iter_interval=5000,
model_type="RECOGNIZOR",
eval_mode="lightweight",
by_epoch=True,
by_iter=True,
rule="greater",
metric=['accuracy', 'NED'],
)
# runner setting
runner = dict(type='EpochBasedRunner', max_epochs=6)
# work directory
work_dir = '/data1/workdir/davar_opensource/tps/'
# distributed training setting
dist_params = dict(backend='nccl')