-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathutils.py
256 lines (206 loc) · 8.75 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
#!/usr/bin/env python
# -*- coding:utf-8 -*-
# date: 2023/06
# author:Dingyi Hu
# emai:[email protected]
import os
import pickle
import math
# The definition of magnification of our gastic dataset.
# 'Large':40X, 'Medium':20X, 'Small':10X, 'Overview':5X
scales = ['Large', 'Medium', 'Small', 'Overview']
# The default number of patches for a kernel
PATCH_NUMBER_PER_ANCHOR = [36, 64, 100, 144, 256, 400]
def merge_config_to_args(args, cfg):
# dirs
args.data_conf_dir = cfg.DATA.DATASET_CONFIG_DIR
args.feat_dir = os.path.join(cfg.DATA.DATA_SAVE_DIR, 'cnn_feat')
args.graph_dir = os.path.join(cfg.DATA.DATA_SAVE_DIR, 'graph')
args.graph_list_dir = os.path.join(cfg.DATA.DATA_SAVE_DIR, 'graph_list')
args.kat_dir = os.path.join(cfg.DATA.DATA_SAVE_DIR, 'fcgr_model')
# data
args.slide_list = os.path.join(args.data_conf_dir, 'slide_list.pkl')
args.label_id = cfg.DATA.LABEL_ID
args.test_ratio = cfg.DATA.TEST_RATIO
args.fold_num = cfg.DATA.FOLD_NUM
# image
if 'IMAGE' in cfg:
args.level = cfg.IMAGE.LEVEL
args.mask_level = cfg.IMAGE.MASK_LEVEL
args.imsize = cfg.IMAGE.PATCH_SIZE
args.tile_size = cfg.IMAGE.LOCAL_TILE_SIZE
args.rl = args.mask_level-args.level
args.msize = args.imsize >> args.rl
args.mhalfsize = args.msize >> 1
# sampling
if 'SAMPLE' in cfg:
args.positive_ratio = cfg.SAMPLE.POS_RAT
args.negative_ratio = cfg.SAMPLE.NEG_RAT
args.intensity_thred = cfg.SAMPLE.INTENSITY_THRED
args.sample_step = cfg.SAMPLE.STEP
args.max_per_class = cfg.SAMPLE.MAX_PER_CLASS
args.save_mask = cfg.SAMPLE.SAVE_MASK
args.srstep = args.sample_step>>args.rl
args.filter_size = (args.imsize>>args.rl, args.imsize>>args.rl)
# CNN
if 'CNN' in cfg:
args.arch = cfg.CNN.ARCH
args.pretrained = cfg.CNN.PRETRAINED
args.cl = cfg.CNN.CONTRASTIVE
args.cdc = cfg.CNN.CDC_FINETUE
args.freeze_feat = cfg.CNN.FREEZE_FEAT
if args.cl:
args.hidden_dim = cfg.CNN.BYOL.HIDDEN_DIM
args.pred_dim = cfg.CNN.BYOL.PRE_DIM
args.momentum_decay = cfg.CNN.BYOL.M_DECAY
args.fix_pred_lr = cfg.CNN.BYOL.FIX_PRED_LR
if args.cdc:
args.cdc_t_neg = cfg.CNN.CDC.NEG_THRED
args.cdc_t_pos = cfg.CNN.CDC.POS_THRED
args.cdc_top_k = cfg.CNN.CDC.TOP_K
# WSI feature
if 'FEATURE' in cfg:
args.step = cfg.FEATURE.STEP
args.frstep = args.step>>args.rl
args.max_nodes = cfg.FEATURE.MAX_NODES
if 'VIT' in cfg:
args.trfm_depth = args.trfm_depth if ('trfm_depth' in args and args.trfm_depth) else cfg.VIT.DEPTH
args.trfm_heads = args.trfm_heads if ('trfm_heads' in args and args.trfm_heads) else cfg.VIT.HEADS
args.trfm_dim = cfg.VIT.DIM
args.trfm_mlp_dim = cfg.VIT.MLP_DIM
args.trfm_dim_head = cfg.VIT.HEAD_DIM
args.trfm_pool = cfg.VIT.POOL
if 'KAT' in cfg:
args.npk = args.npk if ('npk' in args and args.npk) else cfg.KAT.PATCH_PER_KERNEL
args.kn = int(args.max_nodes/args.npk) + 1
args.p_dim = args.p_dim if ('p_dim' in args and args.p_dim) else cfg.KAT.BYOL.PROJECTOR_DIM
args.aug_rate = args.aug_rate if ('aug_rate' in args and args.aug_rate) else cfg.KAT.BYOL.NODE_AUG
args.sl_weight = args.sl_weight if ('sl_weight' in args and args.sl_weight) else cfg.KAT.BYOL.SL_WEIGHT
return args
def get_sampling_path(args):
prefix = '[l{}t{}s{}m{}][p{}n{}i{}]'.format(args.level, args.imsize,
args.sample_step, args.max_per_class,
int(args.positive_ratio * 100),
int(args.negative_ratio * 100),
args.intensity_thred)
return os.path.join(args.patch_dir, prefix)
def get_data_list_path(args):
prefix = get_sampling_path(args)
prefix = '{}[f{}_t{}]'.format(prefix[prefix.find('['):], args.fold_num,
int(args.test_ratio * 100))
return os.path.join(args.list_dir, prefix)
def get_cnn_path(args):
prefix = get_data_list_path(args)
args.fold_name = 'list_fold_all' if args.fold == -1 else 'list_fold_{}'.format(
args.fold)
prefix = '{}[{}_td_{}_{}]'.format(prefix[prefix.find('['):], args.arch,
args.label_id, args.fold_name)
if args.freeze_feat:
prefix += '[frz]'
return os.path.join(args.cnn_dir, prefix)
def get_cdc_path(args):
prefix = get_data_list_path(args)
args.fold_name = 'list_fold_all' if args.fold == -1 else 'list_fold_{}'.format(
args.fold)
prefix = '{}[{}_td_{}][cdc_tp{}_tn{}_k{}]{}'.format(
prefix[prefix.find('['):],
args.arch,
args.label_id if not args.cl else 'cl',
int(args.cdc_t_pos * 100),
int(args.cdc_t_neg * 100),
args.cdc_top_k,
args.fold_name
)
if args.freeze_feat:
prefix += '[frz]'
return os.path.join(args.contrst_dir, prefix)
def get_feature_path(args):
if args.pretrained:
prefix = '[{}_pre][fs{}]'.format(args.arch, args.step)
else:
prefix = get_data_list_path(args)
args.fold_name = 'list_fold_all' if args.fold == -1 else 'list_fold_{}'.format(
args.fold)
if args.cdc:
prefix = '{}[{}_td_{}][fs{}][cdc_tp{}_tn{}_k{}][{}]'.format(
prefix[prefix.find('['):],
args.arch,
args.label_id if not args.cl else 'cl',
args.step,
int(args.cdc_t_pos * 100),
int(args.cdc_t_neg * 100),
args.cdc_top_k,
args.fold_name
)
else:
prefix = '{}[{}_td_{}][fs{}][{}]'.format(prefix[prefix.find('['):],
args.arch, args.label_id if not args.cl else 'cl', args.step, args.fold_name)
if args.freeze_feat:
prefix += '[frz]'
return os.path.join(args.feat_dir, prefix)
def get_graph_path(args):
prefix = get_feature_path(args)
prefix = '{}[m{}]'.format(prefix[prefix.find('['):],
args.max_nodes)
return os.path.join(args.graph_dir, prefix)
def get_graph_list_path(args):
prefix = get_feature_path(args)
prefix = '{}[m{}]'.format(prefix[prefix.find('['):],
args.max_nodes)
return os.path.join(args.graph_list_dir,prefix)
def get_slide_config(config_path):
with open(config_path, 'rb') as f:
data = pickle.load(f)
return data['tasks'], data['lesions']
def get_kat_path(args, prefix_name=''):
prefix = get_graph_list_path(args)
prefix = '{}'.format(prefix_name+prefix)
return os.path.join(args.kat_dir, prefix)
def get_kat_byol_path(args, prefix_name=''):
prefix = get_graph_list_path(args)
prefix = '{}[d{}_h_{}_de{}dm{}dh{}_{}][npk_{}][ar_{}_pd_{}_slw{}][t{}]'.format(prefix_name+prefix[prefix.find('['):],
args.trfm_depth, args.trfm_heads, args.trfm_dim, args.trfm_mlp_dim, args.trfm_dim_head, args.trfm_pool,
args.npk, args.aug_rate, args.p_dim, args.sl_weight,
args.label_id
)
return os.path.join(args.kat_dir, prefix)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, *meters, prefix=""):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def print(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
print('\t'.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
def adjust_learning_rate(optimizer, init_lr, epoch, args):
"""Decay the learning rate based on schedule"""
cur_lr = init_lr * 0.5 * (1. + math.cos(math.pi * epoch / args.epochs))
for param_group in optimizer.param_groups:
if 'fix_lr' in param_group and param_group['fix_lr']:
param_group['lr'] = init_lr
else:
param_group['lr'] = cur_lr