From fc65a0d4705370bd90581f1c6b0bdd8517c81226 Mon Sep 17 00:00:00 2001 From: Jim Lin Date: Wed, 22 May 2024 15:04:51 -0700 Subject: [PATCH] Supports `builder_kwargs` in `TfdsDataSource` PiperOrigin-RevId: 636306489 --- seqio/dataset_providers.py | 4 ++++ seqio/utils.py | 15 ++++++++++++++- seqio/utils_test.py | 2 ++ 3 files changed, 20 insertions(+), 1 deletion(-) diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index c71ad75f..8f84c340 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -493,6 +493,7 @@ def __init__( ] = None, caching_permitted: bool = True, decoders: Optional[tfds.typing.TreeDict[tfds.decode.Decoder]] = None, + tfds_builder_kwargs: Optional[dict[str, Any]] = None, ): """TfdsTask constructor. @@ -514,6 +515,8 @@ def __init__( Default True. decoders: dict (optional), mapping from features to tfds.decode.Decoders, such as tfds.decode.SkipDecoding() for skipping image byte decoding. + tfds_builder_kwargs: `dict` (optional), keyword arguments to be passed to + the `tfds.core.DatasetBuilder` constructor through `tfds.load()`. """ if splits and not isinstance(splits, dict): splits = {k: k for k in splits} @@ -523,6 +526,7 @@ def __init__( data_dir=tfds_data_dir, split_map=splits if isinstance(splits, dict) else None, decoders=decoders, + builder_kwargs=tfds_builder_kwargs, ) # If splits are not provided, we pass an empty tuple and use the lazy diff --git a/seqio/utils.py b/seqio/utils.py index 68d2cb15..625c85e9 100644 --- a/seqio/utils.py +++ b/seqio/utils.py @@ -127,6 +127,7 @@ def __init__( data_dir: Optional[str] = None, split_map: Union[Mapping[str, str], Mapping[str, TfdsSplit], None] = None, decoders=None, + builder_kwargs: Optional[dict[str, Any]] = None, ): """LazyTfdsLoader constructor. @@ -140,12 +141,16 @@ def __init__( split='train')`). If `TfdsSplit` are used then `name` must be empty. decoders: dict (optional), mapping from features to tfds.decode.Decoders, such as tfds.decode.SkipDecoding() for skipping image byte decoding. + builder_kwargs: `dict` (optional), keyword arguments to be passed to the + `tfds.core.DatasetBuilder` constructor through `tfds.load()` and + `tfds.builder()`. """ _validate_tfds_name(name) self._name = name self._data_dir = data_dir self._split_map = split_map self._decoders = decoders + self._builder_kwargs = builder_kwargs self._is_custom_split_map = False if split_map: @@ -302,8 +307,15 @@ def _get_builder(self, split: Optional[str] = None): builder_key = self._get_builder_key(dataset, data_dir) if builder_key not in LazyTfdsLoader._MEMOIZED_BUILDERS: if dataset: - builder = tfds.builder(dataset, data_dir=data_dir) + builder = tfds.builder( + dataset, data_dir=data_dir, **self._builder_kwargs + ) else: + if self._builder_kwargs: + raise ValueError( + "`builder_kwargs` should be empty when `dataset` value is not" + " present." + ) builder = tfds.builder_from_directory(data_dir) LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key] = builder return LazyTfdsLoader._MEMOIZED_BUILDERS[builder_key] @@ -374,6 +386,7 @@ def load( try_gcs=True, read_config=read_config, decoders=self._decoders, + builder_kwargs=self._builder_kwargs, ) def load_shard( diff --git a/seqio/utils_test.py b/seqio/utils_test.py index 5154e06f..a7ca73f9 100644 --- a/seqio/utils_test.py +++ b/seqio/utils_test.py @@ -178,6 +178,7 @@ def test_split_map(self, mock_tfds_load): try_gcs=True, read_config=AnyArg(), decoders=None, + builder_kwargs=None, ) # test .size() @@ -238,6 +239,7 @@ def test_tfds_split(self, mock_tfds_load): try_gcs=True, read_config=AnyArg(), decoders=None, + builder_kwargs=None, ) # test .size()