Skip to content

Commit

Permalink
my public commit msg
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 615267599
  • Loading branch information
SeqIO Team authored and SeqIO committed Mar 13, 2024
1 parent 11706e4 commit 220ba25
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 7 deletions.
107 changes: 107 additions & 0 deletions seqio/beam_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
import seqio
import tensorflow.compat.v2 as tf

from array_record.python import array_record_module # copybara:strip

PROVENANCE_PREFIX = "provenance/"
TASK_PROVENANCE_KEY = PROVENANCE_PREFIX + "task"
SOURCE_SHARD_PROVENANCE_KEY = PROVENANCE_PREFIX + "source_shard"
Expand Down Expand Up @@ -203,6 +205,111 @@ def expand(self, pcoll):
)


class _ArrayRecordSink(beam.io.filebasedsink.FileBasedSink):
"""Sink Class for use in Arrayrecord PTransform."""

def __init__(
self,
file_path_prefix,
file_name_suffix=None,
num_shards=0,
shard_name_template=None,
coder=beam.coders.coders.ToBytesCoder(),
compression_type=beam.io.filesystem.CompressionTypes.AUTO,
preserve_random_access: bool = False,
):

super().__init__(
file_path_prefix,
file_name_suffix=file_name_suffix,
num_shards=num_shards,
shard_name_template=shard_name_template,
coder=coder,
mime_type="application/octet-stream",
compression_type=compression_type,
)
self._preserve_random_access = preserve_random_access

def open(self, temp_path):
group_size = 1 if self._preserve_random_access else self.num_shards
array_writer = array_record_module.ArrayRecordWriter(
temp_path, f"group_size:{group_size}"
)
return array_writer

def close(self, file_handle):
file_handle.close()

def write_encoded_record(self, file_handle, value):
file_handle.write(value)


class WriteToArrayRecord(beam.transforms.PTransform):
"""PTransform for a disk-based write to ArrayRecord."""

def __init__(
self,
file_path_prefix,
file_name_suffix="",
num_shards=0,
shard_name_template=None,
coder=beam.coders.coders.ToBytesCoder(),
compression_type=beam.io.filesystem.CompressionTypes.AUTO,
preserve_random_access: bool = False,
):

self._sink = _ArrayRecordSink(
file_path_prefix,
file_name_suffix,
num_shards,
shard_name_template,
coder,
compression_type,
preserve_random_access,
)

def expand(self, pcoll):
return pcoll | beam.io.iobase.Write(self._sink)


class WriteExampleArrayRecord(beam.PTransform):
"""Writes examples (dicts) to an ArrayRecord of tf.Example protos."""

def __init__(
self,
output_path: str,
num_shards: Optional[int] = None,
preserve_random_access: bool = False,
):
"""WriteExampleArrayRecord constructor.
Args:
output_path: string, path to the output ArrayRecord file (w/o shard
suffix).
num_shards: (optional) int, number of shards to output or None to use
liquid sharding.
preserve_random_access: Whether to preserve the random access of the
written ArrayRecord. If true, set group_size=1, else, set to number of
shards.
"""
self._output_path = output_path
self._num_shards = num_shards
self._preserve_random_access = preserve_random_access

def expand(self, pcoll):
return (
pcoll
| beam.Map(seqio.dict_to_tfexample)
| beam.Reshuffle()
| WriteToArrayRecord(
self._output_path,
num_shards=self._num_shards,
coder=beam.coders.ProtoCoder(tf.train.Example),
preserve_random_access=self._preserve_random_access,
)
)


class WriteJson(beam.PTransform):
"""Writes datastructures to file as JSON(L)."""

Expand Down
43 changes: 36 additions & 7 deletions seqio/scripts/cache_tasks_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,21 @@
)


flags.DEFINE_enum(
"output_format",
"tfrecord",
["arrayrecord", "tfrecord"],
"Output format of the cached tasks.",
)
flags.DEFINE_boolean(
"preserve_random_access",
False,
"Used only if --output_format=arrayrecord. If true, preserve the random"
" access by setting group_size=1, else, set group_size to number of output"
" shards. Be aware that preserve_random_access will significantly slow down"
" the process of writing to the ArrayRecord.",
)


def _import_modules(modules):
for module in modules:
Expand Down Expand Up @@ -289,14 +304,28 @@ def run_pipeline(
| "%s_global_example_shuffle" % label >> beam.Reshuffle()
)

completion_values.append(
examples
| "%s_write_tfrecord" % label
>> beam_utils.WriteExampleTfRecord(
seqio.get_cached_tfrecord_prefix(output_dir, split),
num_shards=num_shards,
match FLAGS.output_format:
case "arrayrecord":
completion_values.append(
examples
| "%s_write_arrayrecord" % label
>> beam_utils.WriteExampleArrayRecord(
os.path.join(
output_dir, "{split}.array_record".format(split=split)
),
num_shards=num_shards,
preserve_random_access=FLAGS.preserve_random_access,
)
)
case "tfrecord":
completion_values.append(
examples
| "%s_write_tfrecord" % label
>> beam_utils.WriteExampleTfRecord(
seqio.get_cached_tfrecord_prefix(output_dir, split),
num_shards=num_shards,
)
)
)
completion_values.append(
examples
| "%s_info" % label >> beam_utils.GetInfo(num_shards)
Expand Down

0 comments on commit 220ba25

Please sign in to comment.