diff --git a/clu/deterministic_data.py b/clu/deterministic_data.py index 6fb4671..27419d7 100644 --- a/clu/deterministic_data.py +++ b/clu/deterministic_data.py @@ -372,6 +372,7 @@ def create_dataset(dataset_builder: DatasetBuilder, num_epochs: Optional[int] = None, shuffle: bool = True, shuffle_buffer_size: int = 10_000, + reshuffle_each_iteration: Optional[bool] = None, prefetch_size: int = 4, pad_up_to_batches: Optional[Union[int, str]] = None, cardinality: Optional[int] = None, @@ -402,6 +403,9 @@ def create_dataset(dataset_builder: DatasetBuilder, forever. shuffle: Whether to shuffle the dataset (both on file and example level). shuffle_buffer_size: Number of examples in the shuffle buffer. + reshuffle_each_iteration: A boolean, which if true indicates that the + dataset should be pseudorandomly reshuffled each time it is iterated over. + (Defaults to `True`.) prefetch_size: The number of elements in the final dataset to prefetch in the background. This should be a small (say <10) positive integer or tf.data.experimental.AUTOTUNE. @@ -453,7 +457,11 @@ def create_dataset(dataset_builder: DatasetBuilder, ds = ds.cache() if shuffle: - ds = ds.shuffle(shuffle_buffer_size, seed=rngs.pop()[0]) + ds = ds.shuffle( + shuffle_buffer_size, + seed=rngs.pop()[0], + reshuffle_each_iteration=reshuffle_each_iteration, + ) ds = ds.repeat(num_epochs) if preprocess_fn is not None: