diff --git a/seqio/beam_utils.py b/seqio/beam_utils.py index ccf1ff23..fa4fb9a6 100644 --- a/seqio/beam_utils.py +++ b/seqio/beam_utils.py @@ -196,16 +196,12 @@ def __init__(self, output_path: str, num_shards: Optional[int] = None): self._num_shards = num_shards def expand(self, pcoll): - return ( - pcoll - | beam.Map(seqio.dict_to_tfexample) - | beam.Reshuffle() - | beam.io.tfrecordio.WriteToTFRecord( - self._output_path, - num_shards=self._num_shards, - coder=beam.coders.ProtoCoder(tf.train.Example), - ) + sink = beam.io.tfrecordio.WriteToTFRecord( + self._output_path, + num_shards=self._num_shards, + coder=beam.coders.ProtoCoder(tf.train.Example), ) + return pcoll | beam.Map(seqio.dict_to_tfexample) | beam.Reshuffle() | sink @@ -302,17 +298,13 @@ def __init__( self._preserve_random_access = preserve_random_access def expand(self, pcoll): - return ( - pcoll - | beam.Map(seqio.dict_to_tfexample) - | beam.Reshuffle() - | WriteToArrayRecord( - self._output_path, - num_shards=self._num_shards, - coder=beam.coders.ProtoCoder(tf.train.Example), - preserve_random_access=self._preserve_random_access, - ) + sink = WriteToArrayRecord( + self._output_path, + num_shards=self._num_shards, + coder=beam.coders.ProtoCoder(tf.train.Example), + preserve_random_access=self._preserve_random_access, ) + return pcoll | beam.Map(seqio.dict_to_tfexample) | beam.Reshuffle() | sink class WriteJson(beam.PTransform): @@ -337,14 +329,10 @@ def _jsonify(self, el): return json.dumps(el) def expand(self, pcoll): - return ( - pcoll - | beam.Map(self._jsonify) - | "write_info" - >> beam.io.WriteToText( - self._output_path, num_shards=1, shard_name_template="" - ) + sink = beam.io.WriteToText( + self._output_path, num_shards=1, shard_name_template="" ) + return pcoll | beam.Map(self._jsonify) | "write_info" >> sink class GetInfo(beam.PTransform):