Skip to content

Commit

Permalink
fixed bug in otfinmemorydataset when nepochs is 1
Browse files Browse the repository at this point in the history
  • Loading branch information
M-R-Schaefer committed Apr 5, 2024
1 parent cb2f9d0 commit 0be92aa
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions apax/data/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,18 +251,19 @@ def cleanup(self):

class OTFInMemoryDataset(InMemoryDataset):
def __iter__(self):
epoch = 0
while epoch < self.n_epochs or len(self.buffer) > 0:
outer_count = 0
max_iter = self.n_data * self.n_epochs
while outer_count < max_iter:
yield self.buffer.popleft()

space = self.buffer_size - len(self.buffer)
if self.count + space > self.n_data:
space = self.n_data - self.count

if self.count >= self.n_data and epoch < self.n_epochs:
epoch += 1
if self.count >= self.n_data:
self.count = 0
self.enqueue(space)
outer_count += 1

def shuffle_and_batch(self):
"""Shuffles and batches the inputs/labels. This function prepares the
Expand Down

0 comments on commit 0be92aa

Please sign in to comment.