Skip to content

Commit

Permalink
Adds reshuffle_each_iteration argument to `deterministic_data.creat…
Browse files Browse the repository at this point in the history
…e_dataset()`.

This argument is passed to `tf.data.Dataset.shuffle()` and controls whether the dataset is reshuffled each time it is iterated over. The default value is `None`, which is the same as the default value of `reshuffle_each_iteration` in `tf.data.Dataset.shuffle()`.

This change is being made to support the use of `deterministic_data.create_dataset()` in evaluation loops that need to access the same evaluation data batches in each iteration of the dataset without reshuffling before each iteration/epoch over the dataset. This is useful, for example, in visualizing the progress of image generation models at different model checkpoints. Visualizing the model progress on the same evaluation data makes Tensorboard qualitative evaluation easier.

This change is backwards compatible. If the `reshuffle_each_iteration` argument is not specified, the default value of `None` will be used.

PiperOrigin-RevId: 661355447
  • Loading branch information
sadeepj authored and copybara-github committed Aug 9, 2024
1 parent b64aa29 commit f52f958
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion clu/deterministic_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def create_dataset(dataset_builder: DatasetBuilder,
num_epochs: Optional[int] = None,
shuffle: bool = True,
shuffle_buffer_size: int = 10_000,
reshuffle_each_iteration: Optional[bool] = None,
prefetch_size: int = 4,
pad_up_to_batches: Optional[Union[int, str]] = None,
cardinality: Optional[int] = None,
Expand Down Expand Up @@ -402,6 +403,9 @@ def create_dataset(dataset_builder: DatasetBuilder,
forever.
shuffle: Whether to shuffle the dataset (both on file and example level).
shuffle_buffer_size: Number of examples in the shuffle buffer.
reshuffle_each_iteration: A boolean, which if true indicates that the
dataset should be pseudorandomly reshuffled each time it is iterated over.
(Defaults to `True`.)
prefetch_size: The number of elements in the final dataset to prefetch in
the background. This should be a small (say <10) positive integer or
tf.data.experimental.AUTOTUNE.
Expand Down Expand Up @@ -453,7 +457,11 @@ def create_dataset(dataset_builder: DatasetBuilder,
ds = ds.cache()

if shuffle:
ds = ds.shuffle(shuffle_buffer_size, seed=rngs.pop()[0])
ds = ds.shuffle(
shuffle_buffer_size,
seed=rngs.pop()[0],
reshuffle_each_iteration=reshuffle_each_iteration,
)
ds = ds.repeat(num_epochs)

if preprocess_fn is not None:
Expand Down

0 comments on commit f52f958

Please sign in to comment.