-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathminibatcher.py
22 lines (20 loc) · 858 Bytes
/
minibatcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import numpy as np
class MiniBatcher(object):
def __init__(self, batch_size, n_examples, shuffle=True):
assert batch_size <= n_examples, "Error: batch_size is larger than n_examples"
self.batch_size = batch_size
self.n_examples = n_examples
self.shuffle = shuffle
self.idxs = np.arange(self.n_examples)
if self.shuffle:
np.random.shuffle(self.idxs)
self.current_start = 0
def get_one_batch(self):
self.idxs = np.arange(self.n_examples)
if self.shuffle:
np.random.shuffle(self.idxs)
self.current_start = 0
while self.current_start < self.n_examples:
batch_idxs = self.idxs[self.current_start:self.current_start+self.batch_size]
self.current_start += self.batch_size
yield np.array(batch_idxs)