forked from MichiganCOG/TASED-Net
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataset.py
80 lines (66 loc) · 2.77 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import os
import csv
import cv2
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
def transform(snippet):
''' stack & noralization '''
snippet = np.concatenate(snippet, axis=-1)
snippet = torch.from_numpy(snippet).permute(2, 0, 1).contiguous().float()
snippet = snippet.mul_(2.).sub_(255).div(255)
snippet = snippet.view(-1,3,snippet.size(1),snippet.size(2)).permute(1,0,2,3)
return snippet
class DHF1KDataset(Dataset):
def __init__(self, path_data, len_snippet):
self.path_data = path_data
self.len_snippet = len_snippet
self.list_num_frame = [int(row[0]) for row in csv.reader(open('DHF1K_num_frame_train.csv', 'r'))]
def __len__(self):
return len(self.list_num_frame)
def __getitem__(self, idx):
file_name = '%04d'%(idx+1)
path_clip = os.path.join(self.path_data, 'video', file_name)
path_annt = os.path.join(self.path_data, 'annotation', file_name, 'maps')
start_idx = np.random.randint(0, self.list_num_frame[idx]-self.len_snippet+1)
v = np.random.random()
clip = []
for i in range(self.len_snippet):
img = cv2.imread(os.path.join(path_clip, '%04d.png'%(start_idx+i+1)))
img = cv2.resize(img, (384, 224))
img = img[...,::-1]
if v < 0.5:
img = img[:, ::-1, ...]
clip.append(img)
annt = cv2.imread(os.path.join(path_annt, '%04d.png'%(start_idx+self.len_snippet)), 0)
annt = cv2.resize(annt, (384, 224))
if v < 0.5:
annt = annt[:, ::-1]
return transform(clip), torch.from_numpy(annt.copy()).contiguous().float()
# Reference: gist.github.com/MFreidank/821cc87b012c53fade03b0c7aba13958
class InfiniteDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_iterator = super().__iter__()
def __iter__(self):
return self
def __next__(self):
try:
batch = next(self.dataset_iterator)
except StopIteration:
self.dataset_iterator = super().__iter__()
batch = next(self.dataset_iterator)
return batch
# ** Update **
# Please consider using the following sampler during training instead of the above InfiniteDataLoader.
# You can simply refer to: https://github.com/MichiganCOG/Gaze-Attention
# Reference: https://github.com/facebookresearch/detectron2
class trainingSampler(torch.utils.data.sampler.Sampler):
def __init__(self, size):
self.size = size
def _infinite_indices(self):
g = torch.Generator()
while True:
yield from torch.randperm(self.size, generator=g)
def __iter__(self):
yield from itertools.islice(self._infinite_indices(), 0, None, 1)