From 46f2945ab95e348a845870bb1ca1e07cdbbb7932 Mon Sep 17 00:00:00 2001 From: SeqIO Team Date: Tue, 14 May 2024 16:23:11 -0700 Subject: [PATCH] Allow data sources to specify that they can be shuffled without a buffer. PiperOrigin-RevId: 633740168 --- seqio/dataset_providers.py | 34 +++++++++++++++++++++++++-------- seqio/dataset_providers_test.py | 7 +++++++ 2 files changed, 33 insertions(+), 8 deletions(-) diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index 79b53410..c71ad75f 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -288,12 +288,14 @@ def __init__( splits: Iterable[str], num_input_examples: Optional[Mapping[str, int]] = None, caching_permitted: bool = True, + performs_internal_shuffling: bool = False, ): self._splits = tuple(splits) self._num_input_examples = ( dict(num_input_examples) if num_input_examples is not None else None ) self._caching_permitted = caching_permitted + self._performs_internal_shuffling = performs_internal_shuffling @property def caching_permitted(self) -> bool: @@ -319,6 +321,15 @@ def output_features(self) -> Mapping[str, Feature]: """Override unused property of `DatasetProviderBase`.""" raise NotImplementedError + @property + def performs_internal_shuffling(self) -> bool: + """Indicates whether this data source performs internal shuffling. + + Some datasets may provide internal shuffling mechanisms that could allow + the dataset to be shuffled without calling ds.shuffle(). + """ + return self._performs_internal_shuffling + @abc.abstractmethod def list_shards(self, split: str) -> Sequence[str]: """Returns string identifiers of input shards.""" @@ -590,6 +601,7 @@ def __init__( file_shuffle_buffer_size: Optional[int] = None, cycle_length: int = 16, block_length: int = 16, + performs_internal_shuffling: bool = False, ): """FileDataSource constructor. @@ -609,6 +621,9 @@ def __init__( replicate earlier behavior. cycle_length: The cycle_length to pass to tf.data.Dataset.interleave. block_length: The block_length to pass to tf.data.Dataset.interleave. + performs_internal_shuffling: Allow enclosing task to call get_dataset with + shuffle_buffer_size=None. In this case, only filename shuffling will be + performed when shuffle==True. """ self._split_to_filepattern = split_to_filepattern self._reader = read_file_fn @@ -619,6 +634,7 @@ def __init__( splits=split_to_filepattern.keys(), num_input_examples=num_input_examples, caching_permitted=caching_permitted, + performs_internal_shuffling=performs_internal_shuffling, ) @property @@ -1663,14 +1679,16 @@ def get_dataset( ds = self._trim_output_features(ds, sequence_length=sequence_length) if shuffle: if self._shuffle_buffer_size is None: - raise ValueError( - f"Shuffling is disallowed for Task '{self.name}' since its " - "`shuffle_buffer_size` was set to `None` on construction." - ) - shuffle_buffer_size = shuffle_buffer_size or self._shuffle_buffer_size - # Shuffle before mixing since preprocessor can output multiple - # (correlated) examples per input. - ds = ds.shuffle(shuffle_buffer_size, seed=seed) + if not self.source.performs_internal_shuffling: + raise ValueError( + f"Shuffling is disallowed for Task '{self.name}' since its " + "`shuffle_buffer_size` was set to `None` on construction." + ) + else: + shuffle_buffer_size = shuffle_buffer_size or self._shuffle_buffer_size + # Shuffle before mixing since preprocessor can output multiple + # (correlated) examples per input. + ds = ds.shuffle(shuffle_buffer_size, seed=seed) return ds.prefetch(tf.data.experimental.AUTOTUNE) diff --git a/seqio/dataset_providers_test.py b/seqio/dataset_providers_test.py index 5fbd247d..35074cec 100644 --- a/seqio/dataset_providers_test.py +++ b/seqio/dataset_providers_test.py @@ -302,6 +302,13 @@ def test_disallow_shuffle(self): task.get_dataset(None, shuffle=False) + # When the source specifies performs_internal_shuffling, it should be + # possible to call get_dataset with shuffle=True and + # shuffle_buffer_size=None. In this case, only the source's internal + # shuffling mechanism will be active. + self.function_source._performs_internal_shuffling = True + task.get_dataset(None, shuffle=True) + def test_supports_caching(self): self.assertFalse( dataset_providers.Task(