forked from woodywff/deepcorr
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpipeline.py
145 lines (136 loc) · 5.45 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import h5py
import numpy as np
from random import shuffle
import pdb
import yaml
import pickle
'''
The pipeline provides ndarray rather than framework-specific formats.
'''
class Generator():
def __init__(self, ids, h5_path, batch_size, is_test=False, channel_last=True, x=None, y=None, gan_sim=False):
'''
ids: id list
is_test: if True, it is the test dataset, otherwise training dataset.
channel_last: if True, corresponds to inputs with shape [batch, height, width, channels] (for tensorflow),
otherwise, [batch, channels, height, width] (for pytorch and paddlepaddle).
x,y: if None, read from .h5 file.
gan_sim: If True, simulate GAN on ingress traffic of Tor.
'''
self.ids = ids
self.h5_path = h5_path
self.batch_size = batch_size
self.is_test = is_test
self.channel_last = channel_last
self.spe = int(np.ceil(len(self.ids)/self.batch_size)) # steps per epoch
self.x = x
self.y = y
self.gan_sim = gan_sim
def epoch(self):
x = []
y = []
ids = self.ids.copy()
if not self.is_test:
shuffle(ids)
while ids:
i = ids.pop()
self.append(x, y, i)
if len(x) == self.batch_size or not ids:
yield self.feed(x, y)
x = []
y = []
return
def append(self, x, y, i):
'''
Dataset specific.
This is for (deepcorr)[http://traces.cs.umass.edu/index.php/Network/Network]
notice that x, y are list.
x,y: list to be feeded
i: index
'''
if self.x is not None:
x.append(self.x[i])
y.append(self.y[i])
else:
with h5py.File(self.h5_path, 'r') as f:
key_x = 'x_gan' if self.gan_sim else 'x'
if self.channel_last:
x.append(f['data'][key_x][i][...,np.newaxis])
else:
x.append(f['data'][key_x][i][np.newaxis,...])
y.append(f['data']['y'][i])
return
def feed(self, x, y):
return np.asarray(x), np.asarray(y)
class Dataset():
def __init__(self, cf='config.yml', cv_i=0, test_only=False, channel_last=True, h5_path=None, in_mem=True, gan_sim=False):
'''
cf: config.yml path
cv_i: which fold in the cross validation.
if cv_i >= n_fold: use all the training dataset.
test_only: if True, only for test dataset.
channel_last: if True, corresponds to inputs with shape (batch, height, width, channels) (for tensorflow),
otherwise, (batch, channels, height, width) (for pytorch and paddlepaddle).
h5_path: if None, use default .h5 file in config.yml, otherwise, use the given path.
in_mem: if True, read .h5 once and save x,y in memory.
gan_sim: If True, simulate GAN on ingress traffic of Tor.
'''
with open(cf) as f:
self.config = yaml.load(f,Loader=yaml.FullLoader)
self.h5_path = h5_path or self.config['data']['h5_path']
self.channel_last = channel_last
if in_mem:
with h5py.File(self.h5_path, 'r') as f:
if gan_sim:
self.x = np.asarray(f['data']['x_gan'])
else:
self.x = np.asarray(f['data']['x'])
self.y = np.asarray(f['data']['y'])
else:
self.x = self.y = None
self.gan_sim = gan_sim
if test_only:
return
crossval_file = self.config['data']['crossval_indices_path']
self.n_fold = self.config['data']['n_fold']
with open(crossval_file,'rb') as f:
self.crossval_dict = pickle.load(f)
self.cv_i = cv_i
@property
def _train_ids(self):
if self.cv_i >= self.n_fold:
return self.crossval_dict['train_0'] + self.crossval_dict['val_0']
else:
return self.crossval_dict[f'train_{self.cv_i}']
@property
def _val_ids(self):
if self.cv_i >= self.n_fold:
return self.crossval_dict['train_0'] + self.crossval_dict['val_0']
else:
return self.crossval_dict[f'val_{self.cv_i}']
@property
def _test_ids(self):
with h5py.File(self.h5_path, 'r') as f:
return list(f['indices']['test'])
@property
def train_generator(self):
return Generator(ids = self._train_ids,
h5_path = self.h5_path,
batch_size = self.config['train']['batch_size'],
channel_last = self.channel_last,
x = self.x, y = self.y, gan_sim=self.gan_sim)
@property
def val_generator(self):
return Generator(ids = self._val_ids,
h5_path = self.h5_path,
batch_size = self.config['train']['batch_size'],
channel_last = self.channel_last,
x = self.x, y = self.y, gan_sim=self.gan_sim)
@property
def test_generator(self):
return Generator(ids = self._test_ids,
h5_path = self.h5_path,
batch_size = self.config['test']['batch_size'],
is_test = True,
channel_last = self.channel_last,
x = self.x, y = self.y, gan_sim=self.gan_sim)