Skip to content

Commit

Permalink
Supports builder_kwargs in TfdsDataSource
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636306489
  • Loading branch information
jimlinntu authored and SeqIO committed May 22, 2024
1 parent e1b6c86 commit bdb3f7f
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 0 deletions.
3 changes: 3 additions & 0 deletions seqio/dataset_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ def __init__(
] = None,
caching_permitted: bool = True,
decoders: Optional[tfds.typing.TreeDict[tfds.decode.Decoder]] = None,
tfds_builder_kwargs: Optional[dict[str, Any]] = None,
):
"""TfdsTask constructor.
Expand All @@ -514,6 +515,7 @@ def __init__(
Default True.
decoders: dict (optional), mapping from features to tfds.decode.Decoders,
such as tfds.decode.SkipDecoding() for skipping image byte decoding.
tfds_builder_kwargs: `builder_kwargs` later used in `tfds.load()`.
"""
if splits and not isinstance(splits, dict):
splits = {k: k for k in splits}
Expand All @@ -523,6 +525,7 @@ def __init__(
data_dir=tfds_data_dir,
split_map=splits if isinstance(splits, dict) else None,
decoders=decoders,
builder_kwargs=tfds_builder_kwargs,
)

# If splits are not provided, we pass an empty tuple and use the lazy
Expand Down
4 changes: 4 additions & 0 deletions seqio/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
data_dir: Optional[str] = None,
split_map: Union[Mapping[str, str], Mapping[str, TfdsSplit], None] = None,
decoders=None,
builder_kwargs: Optional[dict[str, Any]] = None,
):
"""LazyTfdsLoader constructor.
Expand All @@ -140,12 +141,14 @@ def __init__(
split='train')`). If `TfdsSplit` are used then `name` must be empty.
decoders: dict (optional), mapping from features to tfds.decode.Decoders,
such as tfds.decode.SkipDecoding() for skipping image byte decoding.
builder_kwargs: `builder_kwargs` in `tfds.load()`.
"""
_validate_tfds_name(name)
self._name = name
self._data_dir = data_dir
self._split_map = split_map
self._decoders = decoders
self._builder_kwargs = builder_kwargs

self._is_custom_split_map = False
if split_map:
Expand Down Expand Up @@ -374,6 +377,7 @@ def load(
try_gcs=True,
read_config=read_config,
decoders=self._decoders,
builder_kwargs=self._builder_kwargs,
)

def load_shard(
Expand Down

0 comments on commit bdb3f7f

Please sign in to comment.