diff --git a/seqio/beam_utils.py b/seqio/beam_utils.py index c2b6f4f8..96c0d123 100644 --- a/seqio/beam_utils.py +++ b/seqio/beam_utils.py @@ -28,6 +28,8 @@ import seqio import tensorflow.compat.v2 as tf +from array_record.python import array_record_module # copybara:strip + PROVENANCE_PREFIX = "provenance/" TASK_PROVENANCE_KEY = PROVENANCE_PREFIX + "task" SOURCE_SHARD_PROVENANCE_KEY = PROVENANCE_PREFIX + "source_shard" @@ -203,6 +205,111 @@ def expand(self, pcoll): ) +class _ArrayRecordSink(beam.io.filebasedsink.FileBasedSink): + """Sink Class for use in Arrayrecord PTransform.""" + + def __init__( + self, + file_path_prefix, + file_name_suffix=None, + num_shards=0, + shard_name_template=None, + coder=beam.coders.coders.ToBytesCoder(), + compression_type=beam.io.filesystem.CompressionTypes.AUTO, + preserve_random_access: bool = False, + ): + + super().__init__( + file_path_prefix, + file_name_suffix=file_name_suffix, + num_shards=num_shards, + shard_name_template=shard_name_template, + coder=coder, + mime_type="application/octet-stream", + compression_type=compression_type, + ) + self._preserve_random_access = preserve_random_access + + def open(self, temp_path): + group_size = 1 if self._preserve_random_access else self.num_shards + array_writer = array_record_module.ArrayRecordWriter( + temp_path, f"group_size:{group_size}" + ) + return array_writer + + def close(self, file_handle): + file_handle.close() + + def write_encoded_record(self, file_handle, value): + file_handle.write(value) + + +class WriteToArrayRecord(beam.transforms.PTransform): + """PTransform for a disk-based write to ArrayRecord.""" + + def __init__( + self, + file_path_prefix, + file_name_suffix="", + num_shards=0, + shard_name_template=None, + coder=beam.coders.coders.ToBytesCoder(), + compression_type=beam.io.filesystem.CompressionTypes.AUTO, + preserve_random_access: bool = False, + ): + + self._sink = _ArrayRecordSink( + file_path_prefix, + file_name_suffix, + num_shards, + shard_name_template, + coder, + compression_type, + preserve_random_access, + ) + + def expand(self, pcoll): + return pcoll | beam.io.iobase.Write(self._sink) + + +class WriteExampleArrayRecord(beam.PTransform): + """Writes examples (dicts) to an ArrayRecord of tf.Example protos.""" + + def __init__( + self, + output_path: str, + num_shards: Optional[int] = None, + preserve_random_access: bool = False, + ): + """WriteExampleArrayRecord constructor. + + Args: + output_path: string, path to the output ArrayRecord file (w/o shard + suffix). + num_shards: (optional) int, number of shards to output or None to use + liquid sharding. + preserve_random_access: Whether to preserve the random access of the + written ArrayRecord. If true, set group_size=1, else, set to number of + shards. + """ + self._output_path = output_path + self._num_shards = num_shards + self._preserve_random_access = preserve_random_access + + def expand(self, pcoll): + return ( + pcoll + | beam.Map(seqio.dict_to_tfexample) + | beam.Reshuffle() + | WriteToArrayRecord( + self._output_path, + num_shards=self._num_shards, + coder=beam.coders.ProtoCoder(tf.train.Example), + preserve_random_access=self._preserve_random_access, + ) + ) + + class WriteJson(beam.PTransform): """Writes datastructures to file as JSON(L).""" diff --git a/seqio/scripts/cache_tasks_main.py b/seqio/scripts/cache_tasks_main.py index 2cef90b2..1d235fc6 100644 --- a/seqio/scripts/cache_tasks_main.py +++ b/seqio/scripts/cache_tasks_main.py @@ -142,6 +142,21 @@ ) +flags.DEFINE_enum( + "output_format", + "tfrecord", + ["arrayrecord", "tfrecord"], + "Output format of the cached tasks.", +) +flags.DEFINE_boolean( + "preserve_random_access", + False, + "Used only if --output_format=arrayrecord. If true, preserve the random" + " access by setting group_size=1, else, set group_size to number of output" + " shards. Be aware that preserve_random_access will significantly slow down" + " the process of writing to the ArrayRecord.", +) + def _import_modules(modules): for module in modules: @@ -289,14 +304,28 @@ def run_pipeline( | "%s_global_example_shuffle" % label >> beam.Reshuffle() ) - completion_values.append( - examples - | "%s_write_tfrecord" % label - >> beam_utils.WriteExampleTfRecord( - seqio.get_cached_tfrecord_prefix(output_dir, split), - num_shards=num_shards, + match FLAGS.output_format: + case "arrayrecord": + completion_values.append( + examples + | "%s_write_arrayrecord" % label + >> beam_utils.WriteExampleArrayRecord( + os.path.join( + output_dir, "{split}.array_record".format(split=split) + ), + num_shards=num_shards, + preserve_random_access=FLAGS.preserve_random_access, + ) + ) + case "tfrecord": + completion_values.append( + examples + | "%s_write_tfrecord" % label + >> beam_utils.WriteExampleTfRecord( + seqio.get_cached_tfrecord_prefix(output_dir, split), + num_shards=num_shards, + ) ) - ) completion_values.append( examples | "%s_info" % label >> beam_utils.GetInfo(num_shards)