From f52f9583cce7d62ca54b403cbb68fd00e4850ab9 Mon Sep 17 00:00:00 2001 From: Sadeep Jayasumana Date: Fri, 9 Aug 2024 12:08:49 -0700 Subject: [PATCH] Adds `reshuffle_each_iteration` argument to `deterministic_data.create_dataset()`. This argument is passed to `tf.data.Dataset.shuffle()` and controls whether the dataset is reshuffled each time it is iterated over. The default value is `None`, which is the same as the default value of `reshuffle_each_iteration` in `tf.data.Dataset.shuffle()`. This change is being made to support the use of `deterministic_data.create_dataset()` in evaluation loops that need to access the same evaluation data batches in each iteration of the dataset without reshuffling before each iteration/epoch over the dataset. This is useful, for example, in visualizing the progress of image generation models at different model checkpoints. Visualizing the model progress on the same evaluation data makes Tensorboard qualitative evaluation easier. This change is backwards compatible. If the `reshuffle_each_iteration` argument is not specified, the default value of `None` will be used. PiperOrigin-RevId: 661355447 --- clu/deterministic_data.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/clu/deterministic_data.py b/clu/deterministic_data.py index 6fb4671..27419d7 100644 --- a/clu/deterministic_data.py +++ b/clu/deterministic_data.py @@ -372,6 +372,7 @@ def create_dataset(dataset_builder: DatasetBuilder, num_epochs: Optional[int] = None, shuffle: bool = True, shuffle_buffer_size: int = 10_000, + reshuffle_each_iteration: Optional[bool] = None, prefetch_size: int = 4, pad_up_to_batches: Optional[Union[int, str]] = None, cardinality: Optional[int] = None, @@ -402,6 +403,9 @@ def create_dataset(dataset_builder: DatasetBuilder, forever. shuffle: Whether to shuffle the dataset (both on file and example level). shuffle_buffer_size: Number of examples in the shuffle buffer. + reshuffle_each_iteration: A boolean, which if true indicates that the + dataset should be pseudorandomly reshuffled each time it is iterated over. + (Defaults to `True`.) prefetch_size: The number of elements in the final dataset to prefetch in the background. This should be a small (say <10) positive integer or tf.data.experimental.AUTOTUNE. @@ -453,7 +457,11 @@ def create_dataset(dataset_builder: DatasetBuilder, ds = ds.cache() if shuffle: - ds = ds.shuffle(shuffle_buffer_size, seed=rngs.pop()[0]) + ds = ds.shuffle( + shuffle_buffer_size, + seed=rngs.pop()[0], + reshuffle_each_iteration=reshuffle_each_iteration, + ) ds = ds.repeat(num_epochs) if preprocess_fn is not None: