-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
84 lines (68 loc) · 2.7 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright (C) 2021. Huawei Technologies Co., Ltd. All rights reserved.
# This program is free software; you can redistribute it and/or modify
# it under the terms of the MIT License.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# MIT License for more details.
import os
import glob
import numpy as np
import matplotlib.pyplot as plt
import torch
def intersperse(lst, item):
# Adds blank symbol
result = [item] * (len(lst) * 2 + 1)
result[1::2] = lst
return result
def parse_filelist(filelist_path, split_char="|"):
with open(filelist_path, encoding='utf-8') as f:
filepaths_and_text = [line.strip().split(split_char) for line in f]
return filepaths_and_text
def latest_checkpoint_path(dir_path, regex="grad_*.pt"):
f_list = glob.glob(os.path.join(dir_path, regex))
f_list.sort(key=lambda f: int("".join(filter(str.isdigit, f))))
x = f_list[-1]
return x
def load_checkpoint(logdir, model, num=None):
if num is None:
model_path = latest_checkpoint_path(logdir, regex="grad_*.pt")
else:
model_path = os.path.join(logdir, f"grad_{num}.pt")
print(f'Loading checkpoint {model_path}...')
model_dict = torch.load(model_path, map_location=lambda loc, storage: loc)
model.load_state_dict(model_dict, strict=False)
return model
def save_figure_to_numpy(fig):
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
return data
def pt_to_pdf(pt,pdf,vmin=-12.5,vmax=0.0, extension='png'):
""" plot spectrogram """
spec=pt.t()
fig=plt.figure(figsize=(12,3),tight_layout=True) #(20,4)
subfig=fig.add_subplot()
image=subfig.imshow(spec,cmap="jet",origin="lower",aspect="equal",interpolation="none",vmax=vmax,vmin=vmin)
fig.colorbar(mappable=image,orientation='vertical',ax=subfig,shrink=0.5)
plt.savefig(f'{pdf}.{extension}',format=extension)
plt.close()
def plot_tensor(tensor):
plt.style.use('default')
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(tensor.T, cmap='jet', aspect="equal", origin="lower", interpolation='none')
plt.colorbar(im, ax=ax)
plt.tight_layout()
fig.canvas.draw()
data = save_figure_to_numpy(fig)
plt.close()
return data
def save_plot(tensor, savepath):
plt.style.use('default')
fig, ax = plt.subplots(figsize=(12, 3))
im = ax.imshow(tensor, cmap='jet', aspect="equal", origin="lower", interpolation='none')
plt.colorbar(im, ax=ax)
plt.tight_layout()
fig.canvas.draw()
plt.savefig(savepath)
plt.close()
return