-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtools.py
executable file
·499 lines (401 loc) · 15.7 KB
/
tools.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
import sys
from os.path import dirname, abspath, join, basename, normpath
import os
import subprocess
import glob
import shutil
import subprocess
import datetime as dt
import math
from collections.abc import Iterable
import torch
import yaml
from sklearn.metrics import auc
import numpy as np
root_dir = dirname(dirname(dirname(abspath(__file__))))
sys.path.append(root_dir)
import dlib.dllogger as DLLogger
from dlib.utils.shared import fmsg
from dlib.configure import constants
from dlib.utils.shared import is_cc
from dlib.utils.shared import get_tag_device
def get_cpu_device():
"""
Return CPU device.
:return:
"""
return torch.device("cpu")
def log_device(args):
assert torch.cuda.is_available()
tag = get_tag_device(args=args)
DLLogger.log(message=tag)
def chunks_into_n(l: Iterable, n: int) -> Iterable:
"""
Split iterable l into n chunks (iterables) with the same size.
:param l: iterable.
:param n: number of chunks.
:return: iterable of length n.
"""
chunksize = int(math.ceil(len(l) / n))
return (l[i * chunksize:i * chunksize + chunksize] for i in range(n))
def chunk_it(l, n):
"""
Create chunks with the same size (n) from the iterable l.
:param l: iterable.
:param n: int, size of the chunk.
"""
for i in range(0, len(l), n):
yield l[i:i + n]
class Dict2Obj(object):
"""
Convert a dictionary into a class where its attributes are the keys of
the dictionary, and
the values of the attributes are the values of the keys.
"""
def __init__(self, dictionary):
for key in dictionary.keys():
setattr(self, key, dictionary[key])
def __repr__(self):
attrs = str([x for x in self.__dict__])
return "<Dict2Obj: %s" % attrs
def count_nb_params(model):
"""
Count the number of parameters within a model.
:param model: nn.Module or None.
:return: int, number of learnable parameters.
"""
if model is None:
return 0
else:
return sum([p.numel() for p in model.parameters()])
def create_folders_for_exp(exp_folder, name):
"""
Create a set of folder for the current exp.
:param exp_folder: str, the path to the current exp.
:param name: str, name of the dataset (train, validation, test)
:return: object, where each attribute is a folder.
There is the following attributes:
. folder: the name of the folder that will contain everything about
this dataset.
. prediction: for the image prediction.
"""
l_dirs = dict()
l_dirs["folder"] = join(exp_folder, name)
l_dirs["prediction"] = join(exp_folder, "{}/prediction".format(name))
for k in l_dirs:
if not os.path.exists(l_dirs[k]):
os.makedirs(l_dirs[k], exist_ok=True)
return Dict2Obj(l_dirs)
def copy_code(dest,
compress=False,
verbose=False
):
"""Copy code to the exp folder for reproducibility.
Input:
dest: path to the destination folder (the exp folder).
compress: bool. if true, we compress the destination folder and
delete it.
verbose: bool. if true, we show what is going on.
"""
# extensions to copy.
exts = tuple(["py", "sh", "yaml"])
flds_files = ['.']
for fld in flds_files:
files = glob.iglob(os.path.join(root_dir, fld, "*"))
subfd = join(dest, fld) if fld != "." else dest
if not os.path.exists(subfd):
os.makedirs(subfd, exist_ok=True)
for file in files:
if file.endswith(exts):
if os.path.isfile(file):
shutil.copy(file, subfd)
# cp dlib
dirs = ["dlib", "cmds"]
for dirx in dirs:
cmds = [
"cd {} && ".format(root_dir),
"cp -r {} {} ".format(dirx, dest)
]
cmd = "".join(cmds)
if verbose:
print("Running bash-cmds: \n{}".format(cmd.replace("&& ", "\n")))
subprocess.run(cmd, shell=True, check=True)
if compress:
head = dest.split(os.sep)[-1]
if head == '': # dest ends with '/'
head = dest.split(os.sep)[-2]
cmds = [
"cd {} && ".format(dest),
"cd .. && ",
"tar -cf {}.tar.gz {} && ".format(head, head),
"rm -rf {}".format(head)
]
cmd = "".join(cmds)
if verbose:
print("Running bash-cmds: \n{}".format(cmd.replace("&& ", "\n")))
subprocess.run(cmd, shell=True, check=True)
def log_args(args_dict):
DLLogger.log(fmsg("Configuration"))
# todo
def save_model(model, args, outfd):
model.eval()
cpu_device = get_cpu_device()
model.to(cpu_device)
torch.save(model.state_dict(), join(outfd, "best_model.pt"))
if args.task == constants.STD_CL:
tag = "{}-{}-{}".format(
args.dataset, args.model['encoder_name'], args.spatial_pooling)
path = join(outfd, tag)
if not os.path.isdir(path):
os.makedirs(path, exist_ok=True)
torch.save(model.encoder.state_dict(), join(path, 'encoder.pt'))
torch.save(model.classification_head.state_dict(),
join(path, 'head.pt'))
DLLogger.log(message="Stored classifier. TAG: {}".format(tag))
def save_config(config_dict, outfd):
with open(join(outfd, 'config.yaml'), 'w') as fout:
yaml.dump(config_dict, fout)
def get_best_epoch(fyaml):
with open(fyaml, 'r') as f:
config = yaml.safe_load(f)
return config['best_epoch']
def compute_auc(vec: np.ndarray, nbr_p: int):
"""
Compute the area under a curve.
:param vec: vector contains values in [0, 100.].
:param nbr_p: int. number of points in the x-axis. it is expected to be
the same as the number of values in `vec`.
:return: float in [0, 100]. percentage of the area from the perfect area.
"""
if vec.size == 1:
return float(vec[0])
else:
area_under_c = auc(x=np.array(list(range(vec.size))), y=vec)
area_under_c /= (100. * (nbr_p - 1))
area_under_c *= 100. # (%)
return area_under_c
# WSOL
def check_box_convention(boxes, convention, tolerate_neg=False):
"""
Args:
boxes: numpy.ndarray(dtype=np.int or np.float, shape=(num_boxes, 4))
convention: string. One of ['x0y0x1y1', 'xywh'].
tolerate_neg: bool. if true, we dont mind negative values.
Raises:
RuntimeError if box does not meet the convention.
"""
if not tolerate_neg:
if (boxes < 0).any():
raise RuntimeError("Box coordinates must be non-negative.")
if len(boxes.shape) == 1:
boxes = np.expand_dims(boxes, 0)
elif len(boxes.shape) != 2:
raise RuntimeError(f"Box array must have dimension (4) or "
f"(num_boxes, 4): {len(boxes.shape)}.")
if boxes.shape[1] != 4:
raise RuntimeError(f"Box array must have dimension (4) or "
f"(num_boxes, 4): {boxes.shape[1]}.")
if convention == 'x0y0x1y1':
widths = boxes[:, 2] - boxes[:, 0]
heights = boxes[:, 3] - boxes[:, 1]
elif convention == 'xywh':
widths = boxes[:, 2]
heights = boxes[:, 3]
else:
raise ValueError("Unknown convention {}.".format(convention))
if (widths < 0).any() or (heights < 0).any():
raise RuntimeError("Boxes do not follow the {} convention."
.format(convention))
def t2n(t):
return t.detach().cpu().numpy().astype(float)
def check_scoremap_validity(scoremap):
if not isinstance(scoremap, np.ndarray):
raise TypeError("Scoremap must be a numpy array; it is {}."
.format(type(scoremap)))
if scoremap.dtype != float:
raise TypeError("Scoremap must be of np.float type; it is of {} type."
.format(scoremap.dtype))
if len(scoremap.shape) != 2:
raise ValueError("Scoremap must be a 2D array; it is {}D."
.format(len(scoremap.shape)))
if np.isnan(scoremap).any():
raise ValueError("Scoremap must not contain nans.")
if (scoremap > 1).any() or (scoremap < 0).any():
raise ValueError("Scoremap must be in range [0, 1]."
"scoremap.min()={}, scoremap.max()={}."
.format(scoremap.min(), scoremap.max()))
def get_tag(args: object, checkpoint_type: str = None) -> str:
if checkpoint_type is None:
checkpoint_type = args.eval_checkpoint_type
tag = "{}-{}-{}-{}-cp_{}-boxv2_{}".format(
args.dataset, args.model['encoder_name'], args.method,
args.spatial_pooling, checkpoint_type,
args.box_v2_metric)
return tag
def get_heatmap_tag(args: object, key: str) -> str:
assert key in constants.HEATMAP_KEYS, key
if key == constants.ALIGN_ATTEN_HEATMAP:
type_heatmap = args.align_atten_to_heatmap_type_heatmap
lndmk_variance = args.align_atten_to_heatmap_lndmk_variance
jaw = args.align_atten_to_heatmap_jaw
normalize = args.align_atten_to_heatmap_normalize
aus_seg_full = args.align_atten_to_heatmap_aus_seg_full
elif key == constants.TRAIN_HEATMAP:
type_heatmap = args.train_daug_mask_img_heatmap_type
lndmk_variance = args.train_daug_mask_img_heatmap_lndmk_variance
jaw = args.train_daug_mask_img_heatmap_jaw
normalize = args.train_daug_mask_img_heatmap_normalize
aus_seg_full = args.train_daug_mask_img_heatmap_aus_seg_full
elif key == constants.EVAL_HEATMAP:
type_heatmap = args.eval_daug_mask_img_heatmap_type
lndmk_variance = args.eval_daug_mask_img_heatmap_lndmk_variance
jaw = args.eval_daug_mask_img_heatmap_jaw
normalize = args.eval_daug_mask_img_heatmap_normalize
aus_seg_full = args.eval_daug_mask_img_heatmap_aus_seg_full
elif key == constants.AUS_SEGM:
type_heatmap = args.aus_seg_heatmap_type
lndmk_variance = args.aus_seg_lndmk_variance
jaw = args.aus_seg_jaw
normalize = args.aus_seg_normalize
aus_seg_full = args.aus_seg_aus_seg_full
else:
raise NotImplementedError(key)
tag = f"{args.dataset}-{type_heatmap}"
if type_heatmap == constants.HEATMAP_LNDMKS:
tag = f"{tag}-{normalize}-{lndmk_variance}-{jaw}"
elif type_heatmap in [constants.HEATMAP_AUNITS_LNMKS,
constants.HEATMAP_GENERIC_AUNITS_LNMKS,
constants.HEATMAP_PER_CLASS_AUNITS_LNMKS
]:
tag = f"{tag}-{normalize}"
elif type_heatmap == constants.HEATMAP_AUNITS_LEARNED_SEG:
tag = f"{tag}-{normalize}-{aus_seg_full}"
else:
raise NotImplementedError(type_heatmap)
return tag
def bye(args):
DLLogger.log(fmsg("End time: {}".format(args.tend)))
DLLogger.log(fmsg("Total time: {}".format(args.tend - args.t0)))
with open(join(root_dir, 'LOG.txt'), 'a') as f:
m = "{}: \t " \
"Dataset: {} \t " \
"Method: {} \t " \
"Spatial pooling: {} \t " \
"Encoder: {} \t " \
"Check point: {} \t " \
"Box_v2_metric: {} \t " \
"SL: {} \t " \
"CRF: {} \t " \
"... Passed in [{}]. \n".format(
dt.datetime.now(),
args.dataset,
args.method,
args.spatial_pooling,
args.model['encoder_name'],
args.eval_checkpoint_type,
args.box_v2_metric,
args.sl_fc,
args.crf_fc,
args.tend - args.t0
)
f.write(m)
with open(join(args.outd, 'passed.txt'), 'w') as fout:
fout.write('Passed.')
DLLogger.log(fmsg('bye.'))
# clean cc
if is_cc():
scratch_exp_fd = join(os.environ["SCRATCH"], constants.SCRATCH_FOLDER,
args.subpath)
scratch_tmp = dirname(normpath(scratch_exp_fd)) # parent
_tag = basename(normpath(args.outd))
cmdx = [
"cd {} ".format(args.outd),
"cd .. ",
"tar -cf {}.tar.gz {}".format(_tag, _tag),
'cp {}.tar.gz {}'.format(_tag, scratch_tmp),
'cd {}'.format(scratch_tmp),
'tar -xf {}.tar.gz -C {} --strip-components=1'.format(
_tag, basename(normpath(scratch_exp_fd))),
"rm {}.tar.gz".format(_tag)
]
cmdx = " && ".join(cmdx)
print("Running bash-cmds: \n{}".format(cmdx.replace("&& ", "\n")))
subprocess.run(cmdx, shell=True, check=True)
def get_root_wsol_dataset():
baseurl = None
if "HOST_XXX" in os.environ.keys():
if os.environ['HOST_XXX'] == 'laptop':
baseurl = f"{os.environ['EXDRIVE']}/datasets"
elif os.environ['HOST_XXX'] == 'lab':
baseurl = f"{os.environ['DATASETSH']}/wsol-done-right"
elif os.environ['HOST_XXX'] == 'gsys':
baseurl = f"{os.environ['DATASETSH']}/wsol-done-right"
elif os.environ['HOST_XXX'] == 'tay':
baseurl = f"{os.environ['DATASETSH']}/wsol-done-right"
elif os.environ['HOST_XXX'] == 'ESON':
baseurl = f"{os.environ['DATASETSH']}/datasets"
else:
raise NotImplementedError(os.environ['HOST_XXX'])
elif "CC_CLUSTER" in os.environ.keys():
if "SLURM_TMPDIR" in os.environ.keys():
# if we are running within a job use the node disc: $SLURM_TMPDIR
baseurl = "{}/datasets/wsol-done-right".format(
os.environ["SLURM_TMPDIR"])
else:
# if we are not running within a job, use the scratch.
# this cate my happen if someone calls this function outside a job.
baseurl = "{}/datasets/wsol-done-right".format(os.environ["SCRATCH"])
msg_unknown_host = "Sorry, it seems we are unable to recognize the " \
"host. You seem to be new to this code. " \
"We recommend you to add your baseurl on your own."
if baseurl is None:
raise ValueError(msg_unknown_host)
return baseurl
def build_heatmap_folder(args: dict,
key: str,
force_it: bool = False
) -> str:
"""
Operates for:
- train_daug_mask_img_heatmap
- eval_daug_mask_img_heatmap
- align_atten_to_heatmap
- aus_seg
:param args:
:return:
"""
assert key in constants.HEATMAP_KEYS, key
pre_computed = args[constants.PRECOMPUTED[key]]
folder = args[constants.FOLDER_HEATMAP[key]]
_cnd_ = pre_computed
_cnd_ &= ((folder == '') or force_it)
if _cnd_:
tag = get_heatmap_tag(Dict2Obj(args), key=key)
if is_cc():
baseurl_sc = "{}/datasets/wsol-done-right".format(
os.environ["SCRATCH"])
scratch_path = join(baseurl_sc, '{}.tar.gz'.format(tag))
if os.path.isfile(scratch_path):
slurm_dir = get_root_wsol_dataset()
cmds = [
'cp {} {} '.format(scratch_path, slurm_dir),
'cd {} '.format(slurm_dir),
'tar -xf {}'.format('{}.tar.gz'.format(tag))
]
cmdx = " && ".join(cmds)
print("Running bash-cmds: \n{}".format(
cmdx.replace("&& ", "\n")))
subprocess.run(cmdx, shell=True, check=True)
assert os.path.isdir(join(slurm_dir, tag))
path_heatmaps = join(slurm_dir, tag)
else:
raise ValueError(f"{scratch_path} not found.")
else:
baseurl = get_root_wsol_dataset()
path_heatmaps = join(baseurl, tag)
assert os.path.isdir(path_heatmaps), path_heatmaps
return path_heatmaps
else:
return folder
if __name__ == '__main__':
print(root_dir)