-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathutils.py
executable file
·102 lines (84 loc) · 3.52 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# Copyright 2017 The BEGAN-tensorflow Authors(Taehoon Kim). All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import print_function
import os
import math
import json
import logging
import numpy as np
from PIL import Image
from datetime import datetime
def prepare_dirs_and_logger(config):
formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s")
logger = logging.getLogger()
for hdlr in logger.handlers:
logger.removeHandler(hdlr)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
if config.load_path:
if config.load_path.startswith(config.log_dir):
config.model_dir = config.load_path
else:
if config.load_path.startswith(config.dataset):
config.model_name = config.load_path
else:
config.model_name = "{}_{}".format(config.dataset, config.load_path)
else:
config.model_name = "{}_{}".format(config.dataset, get_time())
if not hasattr(config, 'model_dir'):
config.model_dir = os.path.join(config.log_dir, config.model_name)
config.data_path = os.path.join(config.data_dir, config.dataset)
for path in [config.log_dir, config.data_dir, config.model_dir]:
if not os.path.exists(path):
os.makedirs(path)
def get_time():
return datetime.now().strftime("%m%d_%H%M%S")
def save_config(config):
param_path = os.path.join(config.model_dir, "params.json")
print("[*] MODEL dir: %s" % config.model_dir)
print("[*] PARAM path: %s" % param_path)
with open(param_path, 'w') as fp:
json.dump(config.__dict__, fp, indent=4, sort_keys=True)
def rank(array):
return len(array.shape)
def make_grid(tensor, nrow=8, padding=2,
normalize=False, scale_each=False,is_train=True):
"""Code based on https://github.com/pytorch/vision/blob/master/torchvision/utils.py"""
if is_train:
scale = 1
else:
scale = 0
nmaps = tensor.shape[0]
xmaps = min(nrow, nmaps)
ymaps = int(math.ceil(float(nmaps) / xmaps))
height, width = int(tensor.shape[1] + padding), int(tensor.shape[2] + padding)
grid = np.zeros([height * ymaps + scale + padding // 2, width * xmaps + scale + padding // 2, 3], dtype=np.uint8)
k = 0
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
h, h_width = y * height + scale + padding // 2, height - padding
w, w_width = x * width + scale + padding // 2, width - padding
grid[h:h+h_width, w:w+w_width] = tensor[k]
k = k + 1
return grid
def save_image(tensor, filename, nrow=8, padding=2,
normalize=False, scale_each=False,is_train=True):
ndarr = make_grid(tensor, nrow=nrow, padding=padding,
normalize=normalize, scale_each=scale_each,is_train=is_train)
im = Image.fromarray(ndarr)
im.save(filename)