-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathutils.py
99 lines (79 loc) · 3.38 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import torch
import numpy as np
from preprocess import load_sparse
def load_adj(path, device=torch.device('cpu')):
filename = os.path.join(path, 'code_adj.npz')
adj = torch.from_numpy(load_sparse(filename)).to(device=device, dtype=torch.float32)
return adj
class EHRDataset:
def __init__(self, data_path, label='m', batch_size=32, shuffle=True, device=torch.device('cpu')):
super().__init__()
self.path = data_path
self.code_x, self.visit_lens, self.y, self.divided, self.neighbors = self._load(label)
self._size = self.code_x.shape[0]
self.idx = np.arange(self._size)
self.batch_size = batch_size
self.shuffle = shuffle
self.device = device
def _load(self, label):
code_x = load_sparse(os.path.join(self.path, 'code_x.npz'))
visit_lens = np.load(os.path.join(self.path, 'visit_lens.npz'))['lens']
if label == 'm':
y = load_sparse(os.path.join(self.path, 'code_y.npz'))
elif label == 'h':
y = np.load(os.path.join(self.path, 'hf_y.npz'))['hf_y']
else:
raise KeyError('Unsupported label type')
divided = load_sparse(os.path.join(self.path, 'divided.npz'))
neighbors = load_sparse(os.path.join(self.path, 'neighbors.npz'))
return code_x, visit_lens, y, divided, neighbors
def on_epoch_end(self):
if self.shuffle:
np.random.shuffle(self.idx)
def size(self):
return self._size
def label(self):
return self.y
def __len__(self):
len_ = self._size // self.batch_size
return len_ if self._size % self.batch_size == 0 else len_ + 1
def __getitem__(self, index):
device = self.device
start = index * self.batch_size
end = start + self.batch_size
slices = self.idx[start:end]
code_x = torch.from_numpy(self.code_x[slices]).to(device)
visit_lens = torch.from_numpy(self.visit_lens[slices]).to(device=device, dtype=torch.long)
y = torch.from_numpy(self.y[slices]).to(device=device, dtype=torch.float32)
divided = torch.from_numpy(self.divided[slices]).to(device)
neighbors = torch.from_numpy(self.neighbors[slices]).to(device)
return code_x, visit_lens, divided, y, neighbors
class MultiStepLRScheduler:
def __init__(self, optimizer, epochs, init_lr, milestones, lrs):
self.optimizer = optimizer
self.epochs = epochs
self.init_lr = init_lr
self.lrs = self._generate_lr(milestones, lrs)
self.current_epoch = 0
def _generate_lr(self, milestones, lrs):
milestones = [1] + milestones + [self.epochs + 1]
lrs = [self.init_lr] + lrs
lr_grouped = np.concatenate([np.ones((milestones[i + 1] - milestones[i], )) * lrs[i]
for i in range(len(milestones) - 1)])
return lr_grouped
def step(self):
lr = self.lrs[self.current_epoch]
for group in self.optimizer.param_groups:
group['lr'] = lr
self.current_epoch += 1
def reset(self):
self.current_epoch = 0
def format_time(seconds):
if seconds <= 60:
time_str = '%.1fs' % seconds
elif seconds <= 3600:
time_str = '%dm%.1fs' % (seconds // 60, seconds % 60)
else:
time_str = '%dh%dm%.1fs' % (seconds // 3600, (seconds % 3600) // 60, seconds % 60)
return time_str