diff --git a/seqio/utils.py b/seqio/utils.py index 700d1b77..ca94aa91 100644 --- a/seqio/utils.py +++ b/seqio/utils.py @@ -180,6 +180,10 @@ def __init__( def name(self) -> Optional[str]: return self._name + def set_decoders(self, decoders) -> None: + """Override the decoders for this dataset.""" + self._decoders = decoders + @property def tfds_splits(self) -> Optional[Mapping[str, TfdsSplit]]: return self._split_map if self._is_custom_split_map else None diff --git a/seqio/utils_test.py b/seqio/utils_test.py index b10593ae..e8bd46db 100644 --- a/seqio/utils_test.py +++ b/seqio/utils_test.py @@ -162,6 +162,33 @@ def test_read_only(self, mock_tfds_builder, mock_tfds_load): decoders=None, ) + @mock.patch("tensorflow_datasets.builder") + def test_get_dataset_decoders(self, mock_tfds_builder): + mock_builder = mock.create_autospec(tfds.core.DatasetBuilder) + mock_tfds_builder.return_value = mock_builder + init_decoders = mock.MagicMock() + get_dataset_decoders = mock.MagicMock() + loader = utils.LazyTfdsLoader( + "ds/cfg:1.2.3", data_dir="/data", read_only=True, decoders=init_decoders + ) + _ = loader.load(split="train", shuffle_files=False) + mock_builder.as_dataset.assert_called_once_with( + split="train", + shuffle_files=False, + read_config=AnyArg(), + decoders=init_decoders, + ) + + mock_builder.reset_mock() + loader.set_decoders(get_dataset_decoders) + _ = loader.load(split="train", shuffle_files=False) + mock_builder.as_dataset.assert_called_once_with( + split="train", + shuffle_files=False, + read_config=AnyArg(), + decoders=get_dataset_decoders, + ) + @mock.patch("tensorflow_datasets.load") def test_split_map(self, mock_tfds_load): seed = 0