Skip to content

Commit

Permalink
Adds parquet writer (#103)
Browse files Browse the repository at this point in the history
* added parquet writer

* nit

* Update src/datatrove/pipeline/writers/parquet.py

Co-authored-by: Mario Šaško <[email protected]>

* updated test

* nit

---------

Co-authored-by: Mario Šaško <[email protected]>
  • Loading branch information
guipenedo and mariosasko authored Feb 22, 2024
1 parent 795e542 commit d4cf053
Show file tree
Hide file tree
Showing 6 changed files with 98 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/datatrove/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def __init__(self, fs, mode: str = "wt", compression: str | None = "infer"):

def get_file(self, filename):
"""
Opens file `filename` if it hasn't been opened yet. Otherwise just returns it from the file cache
Opens file `filename` if it hasn't been opened yet. Otherwise, just returns it from the file cache
Args:
filename: name of the file to open/get if previously opened
Expand Down
1 change: 1 addition & 0 deletions src/datatrove/pipeline/writers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .jsonl import JsonlWriter
from .parquet import ParquetWriter
11 changes: 6 additions & 5 deletions src/datatrove/pipeline/writers/disk_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import dataclasses
from abc import ABC, abstractmethod
from string import Template
from typing import Callable
from typing import IO, Callable

from datatrove.data import Document, DocumentsPipeline
from datatrove.io import DataFolderLike, get_datafolder
Expand Down Expand Up @@ -31,6 +31,7 @@ def __init__(
output_filename: str = None,
compression: str | None = "infer",
adapter: Callable = None,
mode: str = "wt",
):
"""
Base writer block to save data to disk.
Expand All @@ -47,7 +48,7 @@ def __init__(
if self.compression == "gzip" and not output_filename.endswith(".gz"):
output_filename += ".gz"
self.output_filename = Template(output_filename)
self.output_mg = self.output_folder.get_output_file_manager(mode="wt", compression=compression)
self.output_mg = self.output_folder.get_output_file_manager(mode=mode, compression=compression)
self.adapter = adapter if adapter else _default_adapter

def __enter__(self):
Expand Down Expand Up @@ -81,13 +82,13 @@ def _get_output_filename(self, document: Document, rank: int | str = 0, **kwargs
)

@abstractmethod
def _write(self, document: dict, file_handler):
def _write(self, document: dict, file_handler: IO, filename: str):
"""
Main method that subclasses should implement. Receives an adapted (after applying self.adapter) dictionary with data to save to `file_handler`
Args:
document: dictionary with the data to save
file_handler: file_handler where it should be saved
filename: to use as a key for writer helpers and other data
Returns:
"""
Expand All @@ -105,7 +106,7 @@ def write(self, document: Document, rank: int = 0, **kwargs):
"""
output_filename = self._get_output_filename(document, rank, **kwargs)
self._write(self.adapter(document), self.output_mg.get_file(output_filename))
self._write(self.adapter(document), self.output_mg.get_file(output_filename), output_filename)
self.stat_update(self._get_output_filename(document, "XXXXX", **kwargs))
self.stat_update(StatHints.total)
self.update_doc_stats(document)
Expand Down
4 changes: 2 additions & 2 deletions src/datatrove/pipeline/writers/jsonl.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ def __init__(
):
super().__init__(output_folder, output_filename=output_filename, compression=compression, adapter=adapter)

def _write(self, document: dict, file: IO):
file.write(json.dumps(document, ensure_ascii=False) + "\n")
def _write(self, document: dict, file_handler: IO, _filename: str):
file_handler.write(json.dumps(document, ensure_ascii=False) + "\n")
55 changes: 55 additions & 0 deletions src/datatrove/pipeline/writers/parquet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from collections import defaultdict
from typing import IO, Callable

from datatrove.io import DataFolderLike
from datatrove.pipeline.writers.disk_base import DiskWriter


class ParquetWriter(DiskWriter):
default_output_filename: str = "${rank}.parquet"
name = "📒 Parquet"
_requires_dependencies = ["pyarrow"]

def __init__(
self,
output_folder: DataFolderLike,
output_filename: str = None,
compression: str | None = None,
adapter: Callable = None,
batch_size: int = 1000,
):
super().__init__(output_folder, output_filename, compression, adapter, mode="wb")
self._writers = {}
self._batches = defaultdict(list)
self.batch_size = batch_size

def _write_batch(self, filename):
if not self._batches[filename]:
return
import pyarrow as pa

# prepare batch
batch = pa.RecordBatch.from_pylist(self._batches.pop(filename))
# write batch
self._writers[filename].write_batch(batch)

def _write(self, document: dict, file_handler: IO, filename: str):
import pyarrow as pa
import pyarrow.parquet as pq

if filename not in self._writers:
self._writers[filename] = pq.ParquetWriter(
file_handler, schema=pa.RecordBatch.from_pylist([document]).schema
)
self._batches[filename].append(document)
if len(self._batches[filename]) == self.batch_size:
self._write_batch(filename)

def close(self):
for filename in list(self._batches.keys()):
self._write_batch(filename)
for writer in self._writers.values():
writer.close()
self._batches.clear()
self._writers.clear()
super().close()
33 changes: 33 additions & 0 deletions tests/pipeline/test_parquet_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import shutil
import tempfile
import unittest

from datatrove.data import Document
from datatrove.pipeline.readers.parquet import ParquetReader
from datatrove.pipeline.writers.parquet import ParquetWriter

from ..utils import require_pyarrow


@require_pyarrow
class TestParquetWriter(unittest.TestCase):
def setUp(self):
# Create a temporary directory
self.tmp_dir = tempfile.mkdtemp()
self.addCleanup(shutil.rmtree, self.tmp_dir)

def test_write(self):
data = [
Document(text=text, id=str(i), metadata={"somedata": 2 * i, "somefloat": i * 0.4, "somestring": "hello"})
for i, text in enumerate(["hello", "text2", "more text"])
]
with ParquetWriter(output_folder=self.tmp_dir, batch_size=2) as w:
for doc in data:
w.write(doc)
reader = ParquetReader(self.tmp_dir)
c = 0
for read_doc, original in zip(reader(), data):
read_doc.metadata.pop("file_path", None)
assert read_doc == original
c += 1
assert c == len(data)

0 comments on commit d4cf053

Please sign in to comment.