-
Notifications
You must be signed in to change notification settings - Fork 164
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* added parquet writer * nit * Update src/datatrove/pipeline/writers/parquet.py Co-authored-by: Mario Šaško <[email protected]> * updated test * nit --------- Co-authored-by: Mario Šaško <[email protected]>
- Loading branch information
1 parent
795e542
commit d4cf053
Showing
6 changed files
with
98 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .jsonl import JsonlWriter | ||
from .parquet import ParquetWriter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from collections import defaultdict | ||
from typing import IO, Callable | ||
|
||
from datatrove.io import DataFolderLike | ||
from datatrove.pipeline.writers.disk_base import DiskWriter | ||
|
||
|
||
class ParquetWriter(DiskWriter): | ||
default_output_filename: str = "${rank}.parquet" | ||
name = "📒 Parquet" | ||
_requires_dependencies = ["pyarrow"] | ||
|
||
def __init__( | ||
self, | ||
output_folder: DataFolderLike, | ||
output_filename: str = None, | ||
compression: str | None = None, | ||
adapter: Callable = None, | ||
batch_size: int = 1000, | ||
): | ||
super().__init__(output_folder, output_filename, compression, adapter, mode="wb") | ||
self._writers = {} | ||
self._batches = defaultdict(list) | ||
self.batch_size = batch_size | ||
|
||
def _write_batch(self, filename): | ||
if not self._batches[filename]: | ||
return | ||
import pyarrow as pa | ||
|
||
# prepare batch | ||
batch = pa.RecordBatch.from_pylist(self._batches.pop(filename)) | ||
# write batch | ||
self._writers[filename].write_batch(batch) | ||
|
||
def _write(self, document: dict, file_handler: IO, filename: str): | ||
import pyarrow as pa | ||
import pyarrow.parquet as pq | ||
|
||
if filename not in self._writers: | ||
self._writers[filename] = pq.ParquetWriter( | ||
file_handler, schema=pa.RecordBatch.from_pylist([document]).schema | ||
) | ||
self._batches[filename].append(document) | ||
if len(self._batches[filename]) == self.batch_size: | ||
self._write_batch(filename) | ||
|
||
def close(self): | ||
for filename in list(self._batches.keys()): | ||
self._write_batch(filename) | ||
for writer in self._writers.values(): | ||
writer.close() | ||
self._batches.clear() | ||
self._writers.clear() | ||
super().close() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import shutil | ||
import tempfile | ||
import unittest | ||
|
||
from datatrove.data import Document | ||
from datatrove.pipeline.readers.parquet import ParquetReader | ||
from datatrove.pipeline.writers.parquet import ParquetWriter | ||
|
||
from ..utils import require_pyarrow | ||
|
||
|
||
@require_pyarrow | ||
class TestParquetWriter(unittest.TestCase): | ||
def setUp(self): | ||
# Create a temporary directory | ||
self.tmp_dir = tempfile.mkdtemp() | ||
self.addCleanup(shutil.rmtree, self.tmp_dir) | ||
|
||
def test_write(self): | ||
data = [ | ||
Document(text=text, id=str(i), metadata={"somedata": 2 * i, "somefloat": i * 0.4, "somestring": "hello"}) | ||
for i, text in enumerate(["hello", "text2", "more text"]) | ||
] | ||
with ParquetWriter(output_folder=self.tmp_dir, batch_size=2) as w: | ||
for doc in data: | ||
w.write(doc) | ||
reader = ParquetReader(self.tmp_dir) | ||
c = 0 | ||
for read_doc, original in zip(reader(), data): | ||
read_doc.metadata.pop("file_path", None) | ||
assert read_doc == original | ||
c += 1 | ||
assert c == len(data) |