-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdataset.py
82 lines (60 loc) · 2.23 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import os
import scipy.io as sio
import random
import tensorflow as tf
from tensorflow.data import AUTOTUNE
from ldsr.utils import coded2DTO3D
def get_list_imgs(data_path):
list_imgs = os.listdir(data_path)
list_imgs = [ os.path.join(data_path, img) for img in list_imgs ]
random.shuffle(list_imgs)
return list_imgs
def generate_H(coded_size, transmittance):
H = tf.random.uniform(coded_size, dtype=tf.float32)
H = tf.cast( H > transmittance, dtype=tf.float32)*1
H = coded2DTO3D(H)
return H
def csi_mapping(x, coded_size, transmittance=0.5):
batch = x.shape[0]
coded_size = (batch,) + coded_size
H = generate_H(coded_size, transmittance)
return (x, H), x
class DataGen(tf.data.Dataset):
def _generator(self, data_path):
list_imgs = get_list_imgs(data_path)
for img_path in list_imgs:
x = sio.loadmat(img_path)['img']
yield x
def __new__(cls, input_size=(512, 512, 31), data_path="../data/kaist/train"):
output_signature = tf.TensorSpec(shape = input_size, dtype = tf.float32)
return tf.data.Dataset.from_generator(
cls._generator,
output_signature = output_signature,
args=(data_path,)
)
def get_csi_pipeline(data_path, input_size=(512,512,31), batch_size=32, buffer_size=3, cache_dir=''):
M, N, L = input_size
coded_size = (N , M + L - 1 , 1)
map_fun = lambda x: csi_mapping(x, coded_size)
dataset = DataGen(input_size=input_size, data_path = data_path)
pipeline_data = (
dataset
.batch(batch_size, drop_remainder=True)
.cache(cache_dir) # cache_dir='' guarda el cache en RAM
.shuffle(buffer_size)
.map(map_fun, num_parallel_calls=AUTOTUNE)
.prefetch(AUTOTUNE)
)
return pipeline_data
def get_pipeline(data_path, input_size=(512,512,31), batch_size=32, buffer_size=3, cache_dir=''):
dataset = DataGen(input_size=input_size, data_path = data_path)
map_fun = lambda x: (x, x)
pipeline_data = (
dataset
.cache(cache_dir)
# .shuffle(batch_size) # cache_dir='' guarda el cache en RAM
.batch(batch_size, drop_remainder=True)
.map(map_fun, num_parallel_calls=AUTOTUNE)
.prefetch(buffer_size)
)
return pipeline_data