Skip to content

Commit

Permalink
Add functionality to override the decoders in a TFDS data source
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 691001368
  • Loading branch information
tomvdw authored and SeqIO committed Oct 29, 2024
1 parent 14cbb0c commit 247bb52
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 0 deletions.
4 changes: 4 additions & 0 deletions seqio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ def __init__(
def name(self) -> Optional[str]:
return self._name

def set_decoders(self, decoders) -> None:
"""Override the decoders for this dataset."""
self._decoders = decoders

@property
def tfds_splits(self) -> Optional[Mapping[str, TfdsSplit]]:
return self._split_map if self._is_custom_split_map else None
Expand Down
27 changes: 27 additions & 0 deletions seqio/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,33 @@ def test_read_only(self, mock_tfds_builder, mock_tfds_load):
decoders=None,
)

@mock.patch("tensorflow_datasets.builder")
def test_get_dataset_decoders(self, mock_tfds_builder):
mock_builder = mock.create_autospec(tfds.core.DatasetBuilder)
mock_tfds_builder.return_value = mock_builder
init_decoders = mock.MagicMock()
get_dataset_decoders = mock.MagicMock()
loader = utils.LazyTfdsLoader(
"ds/cfg:1.2.3", data_dir="/data", read_only=True, decoders=init_decoders
)
_ = loader.load(split="train", shuffle_files=False)
mock_builder.as_dataset.assert_called_once_with(
split="train",
shuffle_files=False,
read_config=AnyArg(),
decoders=init_decoders,
)

mock_builder.reset_mock()
loader.set_decoders(get_dataset_decoders)
_ = loader.load(split="train", shuffle_files=False)
mock_builder.as_dataset.assert_called_once_with(
split="train",
shuffle_files=False,
read_config=AnyArg(),
decoders=get_dataset_decoders,
)

@mock.patch("tensorflow_datasets.load")
def test_split_map(self, mock_tfds_load):
seed = 0
Expand Down

0 comments on commit 247bb52

Please sign in to comment.