Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisschopp authored Dec 10, 2024
2 parents 4f195dc + 7baa826 commit 9d2061e
Show file tree
Hide file tree
Showing 13 changed files with 115 additions and 59 deletions.
19 changes: 11 additions & 8 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
# Upcoming Release 6.0.0

## Major features and improvements
- Added functionality to save Pandas DataFrame directly to Snowflake, facilitating seemless `.csv` ingestion
- Added Python 3.9, 3.10 and 3.11 support for SnowflakeTableDataset

- Added functionality to save pandas DataFrames directly to Snowflake, facilitating seamless `.csv` ingestion.
- Added Python 3.9, 3.10 and 3.11 support for `snowflake.SnowflakeTableDataset`.
- Enabled connection sharing between `ibis.FileDataset` and `ibis.TableDataset` instances, thereby allowing nodes to save data loaded by one to the other (as long as they share the same connection configuration).
- Added the following new **experimental** datasets:

| Type | Description | Location |
| --------------------------------- | ------------------------------------------------------ | ---------------------------------------- |
| `databricks.ExternalTableDataset` | A dataset for accessing external tables in Databricks. | `kedro_datasets_experimental.databricks` |
| Type | Description | Location |
| --------------------------------- | -------------------------------------------------------------------------- | ----------------------------------------- |
| `databricks.ExternalTableDataset` | A dataset for accessing external tables in Databricks. | `kedro_datasets_experimental.databricks` |
| `safetensors.SafetensorsDataset` | A dataset for securely saving and loading files in the SafeTensors format. | `kedro_datasets_experimental.safetensors` |


## Bug fixes and other changes
- Implemented Snowflake's (local testing framework)[https://docs.snowflake.com/en/developer-guide/snowpark/python/testing-locally] for testing purposes

- Implemented Snowflake's [local testing framework](https://docs.snowflake.com/en/developer-guide/snowpark/python/testing-locally) for testing purposes.
- Improved the dependency management for Spark-based datasets by refactoring the Spark and Databricks utility functions used across the datasets.
- Add deprecation warning for `tracking.MetricsDataset` and `tracking.JSONDataset`.
- Added deprecation warning for `tracking.MetricsDataset` and `tracking.JSONDataset`.

## Breaking Changes

- Demoted `video.VideoDataset` from core to experimental dataset.

## Community contributions
Expand Down
1 change: 1 addition & 0 deletions kedro-datasets/kedro_datasets/_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .connection_mixin import ConnectionMixin # noqa: F401
32 changes: 32 additions & 0 deletions kedro-datasets/kedro_datasets/_utils/connection_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from abc import ABC, abstractmethod
from collections.abc import Hashable
from typing import Any, ClassVar


class ConnectionMixin(ABC):
_CONNECTION_GROUP: ClassVar[str]

_connection_config: dict[str, Any]

_connections: ClassVar[dict[Hashable, Any]] = {}

@abstractmethod
def _connect(self) -> Any:
... # pragma: no cover

@property
def _connection(self) -> Any:
def hashable(value: Any) -> Hashable:
"""Return a hashable key for a potentially-nested object."""
if isinstance(value, dict):
return tuple((k, hashable(v)) for k, v in sorted(value.items()))
if isinstance(value, list):
return tuple(hashable(x) for x in value)
return value

cls = type(self)
key = self._CONNECTION_GROUP, hashable(self._connection_config)
if key not in cls._connections:
cls._connections[key] = self._connect()

return cls._connections[key]
33 changes: 12 additions & 21 deletions kedro-datasets/kedro_datasets/ibis/file_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
import ibis.expr.types as ir
from kedro.io import AbstractVersionedDataset, DatasetError, Version

from kedro_datasets._utils import ConnectionMixin

if TYPE_CHECKING:
from ibis import BaseBackend


class FileDataset(AbstractVersionedDataset[ir.Table, ir.Table]):
class FileDataset(ConnectionMixin, AbstractVersionedDataset[ir.Table, ir.Table]):
"""``FileDataset`` loads/saves data from/to a specified file format.
Example usage for the
Expand Down Expand Up @@ -73,7 +75,7 @@ class FileDataset(AbstractVersionedDataset[ir.Table, ir.Table]):
DEFAULT_LOAD_ARGS: ClassVar[dict[str, Any]] = {}
DEFAULT_SAVE_ARGS: ClassVar[dict[str, Any]] = {}

_connections: ClassVar[dict[tuple[tuple[str, str]], BaseBackend]] = {}
_CONNECTION_GROUP: ClassVar[str] = "ibis"

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -143,28 +145,17 @@ def __init__( # noqa: PLR0913
if save_args is not None:
self._save_args.update(save_args)

def _connect(self) -> BaseBackend:
import ibis

config = deepcopy(self._connection_config)
backend = getattr(ibis, config.pop("backend"))
return backend.connect(**config)

@property
def connection(self) -> BaseBackend:
"""The ``Backend`` instance for the connection configuration."""

def hashable(value):
"""Return a hashable key for a potentially-nested object."""
if isinstance(value, dict):
return tuple((k, hashable(v)) for k, v in sorted(value.items()))
if isinstance(value, list):
return tuple(hashable(x) for x in value)
return value

cls = type(self)
key = hashable(self._connection_config)
if key not in cls._connections:
import ibis

config = deepcopy(self._connection_config)
backend = getattr(ibis, config.pop("backend"))
cls._connections[key] = backend.connect(**config)

return cls._connections[key]
return self._connection

def load(self) -> ir.Table:
load_path = self._get_load_path()
Expand Down
32 changes: 11 additions & 21 deletions kedro-datasets/kedro_datasets/ibis/table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
from kedro.io import AbstractDataset, DatasetError

from kedro_datasets import KedroDeprecationWarning
from kedro_datasets._utils import ConnectionMixin

if TYPE_CHECKING:
from ibis import BaseBackend


class TableDataset(AbstractDataset[ir.Table, ir.Table]):
class TableDataset(ConnectionMixin, AbstractDataset[ir.Table, ir.Table]):
"""``TableDataset`` loads/saves data from/to Ibis table expressions.
Example usage for the
Expand Down Expand Up @@ -70,7 +71,7 @@ class TableDataset(AbstractDataset[ir.Table, ir.Table]):
"overwrite": True,
}

_connections: ClassVar[dict[tuple[tuple[str, str]], BaseBackend]] = {}
_CONNECTION_GROUP: ClassVar[str] = "ibis"

def __init__( # noqa: PLR0913
self,
Expand Down Expand Up @@ -145,28 +146,17 @@ def __init__( # noqa: PLR0913

self._materialized = self._save_args.pop("materialized")

def _connect(self) -> BaseBackend:
import ibis

config = deepcopy(self._connection_config)
backend = getattr(ibis, config.pop("backend"))
return backend.connect(**config)

@property
def connection(self) -> BaseBackend:
"""The ``Backend`` instance for the connection configuration."""

def hashable(value):
"""Return a hashable key for a potentially-nested object."""
if isinstance(value, dict):
return tuple((k, hashable(v)) for k, v in sorted(value.items()))
if isinstance(value, list):
return tuple(hashable(x) for x in value)
return value

cls = type(self)
key = hashable(self._connection_config)
if key not in cls._connections:
import ibis

config = deepcopy(self._connection_config)
backend = getattr(ibis, config.pop("backend"))
cls._connections[key] = backend.connect(**config)

return cls._connections[key]
return self._connection

def load(self) -> ir.Table:
if self._filepath is not None:
Expand Down
10 changes: 8 additions & 2 deletions kedro-datasets/kedro_datasets/polars/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,21 @@ class CSVDataset(AbstractVersionedDataset[pl.DataFrame, pl.DataFrame]):
.. code-block:: pycon
>>> from kedro_datasets.polars import CSVDataset
>>> import sys
>>>
>>> import polars as pl
>>> import pytest
>>> from kedro_datasets.polars import CSVDataset
>>>
>>> if sys.platform.startswith("win"):
... pytest.skip("this doctest fails on Windows CI runner")
...
>>> data = pl.DataFrame({"col1": [1, 2], "col2": [4, 5], "col3": [5, 6]})
>>>
>>> dataset = CSVDataset(filepath=tmp_path / "test.csv")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.frame_equal(reloaded)
>>> assert data.equals(reloaded)
"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class EagerPolarsDataset(AbstractVersionedDataset[pl.DataFrame, pl.DataFrame]):
>>> dataset = EagerPolarsDataset(filepath=tmp_path / "test.parquet", file_format="parquet")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.frame_equal(reloaded)
>>> assert data.equals(reloaded)
"""

Expand Down
6 changes: 5 additions & 1 deletion kedro-datasets/kedro_datasets/polars/lazy_polars_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
"""
from __future__ import annotations

import errno
import logging
import os
from copy import deepcopy
from pathlib import PurePosixPath
from typing import Any, ClassVar
Expand Down Expand Up @@ -69,7 +71,7 @@ class LazyPolarsDataset(
>>> dataset = LazyPolarsDataset(filepath=tmp_path / "test.csv", file_format="csv")
>>> dataset.save(data)
>>> reloaded = dataset.load()
>>> assert data.frame_equal(reloaded.collect())
>>> assert data.equals(reloaded.collect())
"""

Expand Down Expand Up @@ -199,6 +201,8 @@ def _describe(self) -> dict[str, Any]:

def load(self) -> pl.LazyFrame:
load_path = str(self._get_load_path())
if not self._exists():
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), load_path)

if self._protocol == "file":
# With local filesystems, we can use Polar's build-in I/O method:
Expand Down
2 changes: 1 addition & 1 deletion kedro-datasets/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ test = [
"pandas>=2.0",
"Pillow~=10.0",
"plotly>=4.8.0, <6.0",
"polars[xlsx2csv, deltalake]~=0.18.0",
"polars[deltalake,xlsx2csv]>=1.0",
"pyarrow>=1.0; python_version < '3.11'",
"pyarrow>=7.0; python_version >= '3.11'", # Adding to avoid numpy build errors
"pyodbc~=5.0",
Expand Down
4 changes: 2 additions & 2 deletions kedro-datasets/tests/ibis/test_file_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def dummy_table():


class TestFileDataset:
def test_save_and_load(self, file_dataset, dummy_table, database):
def test_save_and_load(self, file_dataset, dummy_table):
"""Test saving and reloading the data set."""
file_dataset.save(dummy_table)
reloaded = file_dataset.load()
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_connection_config(self, mocker, file_dataset, connection_config, key):
)
mocker.patch(f"ibis.{backend}")
file_dataset.load()
assert key in file_dataset._connections
assert ("ibis", key) in file_dataset._connections


class TestFileDatasetVersioned:
Expand Down
22 changes: 20 additions & 2 deletions kedro-datasets/tests/ibis/test_table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from kedro.io import DatasetError
from pandas.testing import assert_frame_equal

from kedro_datasets.ibis import TableDataset
from kedro_datasets.ibis import FileDataset, TableDataset

_SENTINEL = object()

Expand Down Expand Up @@ -56,6 +56,17 @@ def dummy_table(table_dataset_from_csv):
return table_dataset_from_csv.load()


@pytest.fixture
def file_dataset(filepath_csv, connection_config, load_args, save_args):
return FileDataset(
filepath=filepath_csv,
file_format="csv",
connection=connection_config,
load_args=load_args,
save_args=save_args,
)


class TestTableDataset:
def test_save_and_load(self, table_dataset, dummy_table, database):
"""Test saving and reloading the dataset."""
Expand Down Expand Up @@ -146,4 +157,11 @@ def test_connection_config(self, mocker, table_dataset, connection_config, key):
)
mocker.patch(f"ibis.{backend}")
table_dataset.load()
assert key in table_dataset._connections
assert ("ibis", key) in table_dataset._connections

def test_save_data_loaded_using_file_dataset(self, file_dataset, table_dataset):
"""Test interoperability of Ibis datasets sharing a database."""
dummy_table = file_dataset.load()
assert not table_dataset.exists()
table_dataset.save(dummy_table)
assert table_dataset.exists()
10 changes: 10 additions & 0 deletions kedro-datasets/tests/polars/test_csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,14 @@ def mocked_csv_in_s3(mocked_s3_bucket, mocked_dataframe: pl.DataFrame):


class TestCSVDataset:
@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_save_and_load(self, csv_dataset, dummy_dataframe):
"""Test saving and reloading the dataset."""
csv_dataset.save(dummy_dataframe)
reloaded = csv_dataset.load()
assert_frame_equal(dummy_dataframe, reloaded)

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_exists(self, csv_dataset, dummy_dataframe):
"""Test `exists` method invocation for both existing and
nonexistent dataset."""
Expand Down Expand Up @@ -202,13 +204,15 @@ def test_version_str_repr(self, load_version, save_version):
assert "load_args={'rechunk': True}" in str(ds)
assert "load_args={'rechunk': True}" in str(ds_versioned)

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_save_and_load(self, versioned_csv_dataset, dummy_dataframe):
"""Test that saved and reloaded data matches the original one for
the versioned dataset."""
versioned_csv_dataset.save(dummy_dataframe)
reloaded_df = versioned_csv_dataset.load()
assert_frame_equal(dummy_dataframe, reloaded_df)

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_multiple_loads(self, versioned_csv_dataset, dummy_dataframe, filepath_csv):
"""Test that if a new version is created mid-run, by an
external system, it won't be loaded in the current run."""
Expand All @@ -232,6 +236,7 @@ def test_multiple_loads(self, versioned_csv_dataset, dummy_dataframe, filepath_c
ds_new.resolve_load_version() == v_new
) # new version is discoverable by a new instance

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_multiple_saves(self, dummy_dataframe, filepath_csv):
"""Test multiple cycles of save followed by load for the same dataset"""
ds_versioned = CSVDataset(filepath=filepath_csv, version=Version(None, None))
Expand All @@ -254,6 +259,7 @@ def test_multiple_saves(self, dummy_dataframe, filepath_csv):
ds_new = CSVDataset(filepath=filepath_csv, version=Version(None, None))
assert ds_new.resolve_load_version() == second_load_version

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_release_instance_cache(self, dummy_dataframe, filepath_csv):
"""Test that cache invalidation does not affect other instances"""
ds_a = CSVDataset(filepath=filepath_csv, version=Version(None, None))
Expand Down Expand Up @@ -282,12 +288,14 @@ def test_no_versions(self, versioned_csv_dataset):
with pytest.raises(DatasetError, match=pattern):
versioned_csv_dataset.load()

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_exists(self, versioned_csv_dataset, dummy_dataframe):
"""Test `exists` method invocation for versioned dataset."""
assert not versioned_csv_dataset.exists()
versioned_csv_dataset.save(dummy_dataframe)
assert versioned_csv_dataset.exists()

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_prevent_overwrite(self, versioned_csv_dataset, dummy_dataframe):
"""Check the error when attempting to override the dataset if the
corresponding CSV file for a given save version already exists."""
Expand All @@ -299,6 +307,7 @@ def test_prevent_overwrite(self, versioned_csv_dataset, dummy_dataframe):
with pytest.raises(DatasetError, match=pattern):
versioned_csv_dataset.save(dummy_dataframe)

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
@pytest.mark.parametrize(
"load_version", ["2019-01-01T23.59.59.999Z"], indirect=True
)
Expand All @@ -325,6 +334,7 @@ def test_http_filesystem_no_versioning(self):
filepath="https://example.com/file.csv", version=Version(None, None)
)

@pytest.mark.xfail(sys.platform == "win32", reason="file encoding is not UTF-8")
def test_versioning_existing_dataset(
self, csv_dataset, versioned_csv_dataset, dummy_dataframe
):
Expand Down
Loading

0 comments on commit 9d2061e

Please sign in to comment.