From bdb3f7f3b99b31a2cffa9a7b2550538cd06460ba Mon Sep 17 00:00:00 2001 From: Jim Lin Date: Wed, 22 May 2024 15:04:51 -0700 Subject: [PATCH] Supports `builder_kwargs` in `TfdsDataSource` PiperOrigin-RevId: 636306489 --- seqio/dataset_providers.py | 3 +++ seqio/utils.py | 4 ++++ 2 files changed, 7 insertions(+) diff --git a/seqio/dataset_providers.py b/seqio/dataset_providers.py index c71ad75f..b13c1c8c 100644 --- a/seqio/dataset_providers.py +++ b/seqio/dataset_providers.py @@ -493,6 +493,7 @@ def __init__( ] = None, caching_permitted: bool = True, decoders: Optional[tfds.typing.TreeDict[tfds.decode.Decoder]] = None, + tfds_builder_kwargs: Optional[dict[str, Any]] = None, ): """TfdsTask constructor. @@ -514,6 +515,7 @@ def __init__( Default True. decoders: dict (optional), mapping from features to tfds.decode.Decoders, such as tfds.decode.SkipDecoding() for skipping image byte decoding. + tfds_builder_kwargs: `builder_kwargs` later used in `tfds.load()`. """ if splits and not isinstance(splits, dict): splits = {k: k for k in splits} @@ -523,6 +525,7 @@ def __init__( data_dir=tfds_data_dir, split_map=splits if isinstance(splits, dict) else None, decoders=decoders, + builder_kwargs=tfds_builder_kwargs, ) # If splits are not provided, we pass an empty tuple and use the lazy diff --git a/seqio/utils.py b/seqio/utils.py index 68d2cb15..d371450e 100644 --- a/seqio/utils.py +++ b/seqio/utils.py @@ -127,6 +127,7 @@ def __init__( data_dir: Optional[str] = None, split_map: Union[Mapping[str, str], Mapping[str, TfdsSplit], None] = None, decoders=None, + builder_kwargs: Optional[dict[str, Any]] = None, ): """LazyTfdsLoader constructor. @@ -140,12 +141,14 @@ def __init__( split='train')`). If `TfdsSplit` are used then `name` must be empty. decoders: dict (optional), mapping from features to tfds.decode.Decoders, such as tfds.decode.SkipDecoding() for skipping image byte decoding. + builder_kwargs: `builder_kwargs` in `tfds.load()`. """ _validate_tfds_name(name) self._name = name self._data_dir = data_dir self._split_map = split_map self._decoders = decoders + self._builder_kwargs = builder_kwargs self._is_custom_split_map = False if split_map: @@ -374,6 +377,7 @@ def load( try_gcs=True, read_config=read_config, decoders=self._decoders, + builder_kwargs=self._builder_kwargs, ) def load_shard(