From e08792a48029e40be390348d544120f550e7300d Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 13 Dec 2024 10:11:47 -0800 Subject: [PATCH] Remove open_async usage in put raw data (#2998) Signed-off-by: Yee Hing Tong --- flytekit/core/data_persistence.py | 39 +++++++------------ .../flytekitplugins/polars/sd_transformers.py | 3 +- .../tests/test_polars_plugin_sd.py | 27 ++++++++++++- .../unit/core/test_data_persistence.py | 18 ++++++++- 4 files changed, 58 insertions(+), 29 deletions(-) diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 7035147016..0640bc2eb5 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -423,47 +423,34 @@ async def async_put_raw_data( r = await self._put(from_path, to_path, **kwargs) return r or to_path + # See https://github.com/fsspec/s3fs/issues/871 for more background and pending work on the fsspec side to + # support effectively async open(). For now these use-cases below will revert to sync calls. # raw bytes if isinstance(lpath, bytes): - fs = await self.get_async_filesystem_for_path(to_path) - if isinstance(fs, AsyncFileSystem): - async with fs.open_async(to_path, "wb", **kwargs) as s: - s.write(lpath) - else: - with fs.open(to_path, "wb", **kwargs) as s: - s.write(lpath) - + fs = self.get_filesystem_for_path(to_path) + with fs.open(to_path, "wb", **kwargs) as s: + s.write(lpath) return to_path # If lpath is a buffered reader of some kind if isinstance(lpath, io.BufferedReader) or isinstance(lpath, io.BytesIO): if not lpath.readable(): raise FlyteAssertion("Buffered reader must be readable") - fs = await self.get_async_filesystem_for_path(to_path) + fs = self.get_filesystem_for_path(to_path) lpath.seek(0) - if isinstance(fs, AsyncFileSystem): - async with fs.open_async(to_path, "wb", **kwargs) as s: - while data := lpath.read(read_chunk_size_bytes): - s.write(data) - else: - with fs.open(to_path, "wb", **kwargs) as s: - while data := lpath.read(read_chunk_size_bytes): - s.write(data) + with fs.open(to_path, "wb", **kwargs) as s: + while data := lpath.read(read_chunk_size_bytes): + s.write(data) return to_path if isinstance(lpath, io.StringIO): if not lpath.readable(): raise FlyteAssertion("Buffered reader must be readable") - fs = await self.get_async_filesystem_for_path(to_path) + fs = self.get_filesystem_for_path(to_path) lpath.seek(0) - if isinstance(fs, AsyncFileSystem): - async with fs.open_async(to_path, "wb", **kwargs) as s: - while data_str := lpath.read(read_chunk_size_bytes): - s.write(data_str.encode(encoding)) - else: - with fs.open(to_path, "wb", **kwargs) as s: - while data_str := lpath.read(read_chunk_size_bytes): - s.write(data_str.encode(encoding)) + with fs.open(to_path, "wb", **kwargs) as s: + while data_str := lpath.read(read_chunk_size_bytes): + s.write(data_str.encode(encoding)) return to_path raise FlyteAssertion(f"Unsupported lpath type {type(lpath)}") diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 474901544d..e6359641ca 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -69,9 +69,10 @@ def encode( df.to_parquet(output_bytes) if structured_dataset.uri is not None: + output_bytes.seek(0) fs = ctx.file_access.get_filesystem_for_path(path=structured_dataset.uri) with fs.open(structured_dataset.uri, "wb") as s: - s.write(output_bytes) + s.write(output_bytes.read()) output_uri = structured_dataset.uri else: remote_fn = "00000" # 00000 is our default unnamed parquet filename diff --git a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py index c2d4a39be7..9acae1c274 100644 --- a/plugins/flytekit-polars/tests/test_polars_plugin_sd.py +++ b/plugins/flytekit-polars/tests/test_polars_plugin_sd.py @@ -5,7 +5,7 @@ import pytest from flytekitplugins.polars.sd_transformers import PolarsDataFrameRenderer from typing_extensions import Annotated -from packaging import version +import numpy as np from polars.testing import assert_frame_equal from flytekit import kwtypes, task, workflow @@ -134,3 +134,28 @@ def consume_sd_return_sd(sd: StructuredDataset) -> StructuredDataset: opened_sd = opened_sd.collect() assert_frame_equal(opened_sd, polars_df) + + +def test_with_uri(): + temp_file = tempfile.mktemp() + + @task + def random_dataframe(num_rows: int) -> StructuredDataset: + feature_1_list = np.random.randint(low=100, high=999, size=(num_rows,)) + feature_2_list = np.random.normal(loc=0, scale=1, size=(num_rows, )) + pl_df = pl.DataFrame({'protein_length': feature_1_list, + 'protein_feature': feature_2_list}) + sd = StructuredDataset(dataframe=pl_df, uri=temp_file) + return sd + + @task + def consume(df: pd.DataFrame): + print(df.head(5)) + print(df.describe()) + + @workflow + def my_wf(num_rows: int): + pl = random_dataframe(num_rows=num_rows) + consume(pl) + + my_wf(num_rows=100) diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index d992ed1fa5..116717b92d 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -1,16 +1,17 @@ import io import os -import fsspec import pathlib import random import string import sys import tempfile +import fsspec import mock import pytest from azure.identity import ClientSecretCredential, DefaultAzureCredential +from flytekit.configuration import Config from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.local_fsspec import FlyteLocalFileSystem @@ -207,3 +208,18 @@ def __init__(self, *args, **kwargs): fp = FileAccessProvider("/tmp", "s3://my-bucket") fp.get_filesystem("testgetfs", test_arg="test_arg") + + +@pytest.mark.sandbox_test +def test_put_raw_data_bytes(): + dc = Config.for_sandbox().data_config + raw_output = f"s3://my-s3-bucket/" + provider = FileAccessProvider(local_sandbox_dir="/tmp/unittest", raw_output_prefix=raw_output, data_config=dc) + prefix = provider.get_random_string() + provider.put_raw_data(lpath=b"hello", upload_prefix=prefix, file_name="hello_bytes") + provider.put_raw_data(lpath=io.BytesIO(b"hello"), upload_prefix=prefix, file_name="hello_bytes_io") + provider.put_raw_data(lpath=io.StringIO("hello"), upload_prefix=prefix, file_name="hello_string_io") + + fs = provider.get_filesystem("s3") + listing = fs.ls(f"{raw_output}{prefix}/") + assert len(listing) == 3