Skip to content

Commit

Permalink
Merge branch 'main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
chrisschopp authored Dec 11, 2024
2 parents 9d2061e + 50fa3c0 commit cd694ae
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 5 deletions.
2 changes: 2 additions & 0 deletions kedro-datasets/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

## Major features and improvements

- Supported passing `database` to `ibis.TableDataset` for load and save operations.
- Added functionality to save pandas DataFrames directly to Snowflake, facilitating seamless `.csv` ingestion.
- Added Python 3.9, 3.10 and 3.11 support for `snowflake.SnowflakeTableDataset`.
- Enabled connection sharing between `ibis.FileDataset` and `ibis.TableDataset` instances, thereby allowing nodes to save data loaded by one to the other (as long as they share the same connection configuration).
Expand All @@ -28,6 +29,7 @@ Many thanks to the following Kedroids for contributing PRs to this release:

- [Thomas d'Hooghe](https://github.com/tdhooghe)
- [Minura Punchihewa](https://github.com/MinuraPunchihewa)
- [Mark Druffel](https://github.com/mark-druffel)

# Release 5.1.0

Expand Down
23 changes: 20 additions & 3 deletions kedro-datasets/kedro_datasets/ibis/table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def __init__( # noqa: PLR0913
filepath: str | None = None,
file_format: str | None = None,
table_name: str | None = None,
database: str | None = None,
connection: dict[str, Any] | None = None,
load_args: dict[str, Any] | None = None,
save_args: dict[str, Any] | None = None,
Expand All @@ -103,6 +104,12 @@ def __init__( # noqa: PLR0913
Args:
table_name: The name of the table or view to read or create.
database: The name of the database to read the table or view
from or create the table or view in. If not passed, then
the current database is used. Provide a tuple of strings
(e.g. `("catalog", "database")`) or a dotted string path
(e.g. `"catalog.database"`) to reference a table or view
in a multi-level table hierarchy.
connection: Configuration for connecting to an Ibis backend.
If not provided, connect to DuckDB in in-memory mode.
load_args: Additional arguments passed to the Ibis backend's
Expand Down Expand Up @@ -132,17 +139,22 @@ def __init__( # noqa: PLR0913
self._filepath = filepath
self._file_format = file_format
self._table_name = table_name
self._database = database
self._connection_config = connection or self.DEFAULT_CONNECTION_CONFIG
self.metadata = metadata

# Set load and save arguments, overwriting defaults if provided.
self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS)
if load_args is not None:
self._load_args.update(load_args)
if database is not None:
self._load_args["database"] = database

self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS)
if save_args is not None:
self._save_args.update(save_args)
if database is not None:
self._save_args["database"] = database

self._materialized = self._save_args.pop("materialized")

Expand All @@ -166,7 +178,7 @@ def load(self) -> ir.Table:
reader = getattr(self.connection, f"read_{self._file_format}")
return reader(self._filepath, self._table_name, **self._load_args)
else:
return self.connection.table(self._table_name)
return self.connection.table(self._table_name, **self._load_args)

def save(self, data: ir.Table) -> None:
if self._table_name is None:
Expand All @@ -176,13 +188,18 @@ def save(self, data: ir.Table) -> None:
writer(self._table_name, data, **self._save_args)

def _describe(self) -> dict[str, Any]:
load_args = deepcopy(self._load_args)
save_args = deepcopy(self._save_args)
load_args.pop("database", None)
save_args.pop("database", None)
return {
"filepath": self._filepath,
"file_format": self._file_format,
"table_name": self._table_name,
"database": self._database,
"backend": self._connection_config["backend"],
"load_args": self._load_args,
"save_args": self._save_args,
"load_args": load_args,
"save_args": save_args,
"materialized": self._materialized,
}

Expand Down
1 change: 0 additions & 1 deletion kedro-datasets/kedro_datasets/pandas/sql_dataset.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""``SQLDataset`` to load and save data to a SQL backend."""

from __future__ import annotations

import copy
Expand Down
71 changes: 70 additions & 1 deletion kedro-datasets/tests/ibis/test_table_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ibis
import pytest
from kedro.io import DatasetError
from packaging.version import Version
from pandas.testing import assert_frame_equal

from kedro_datasets.ibis import FileDataset, TableDataset
Expand All @@ -21,6 +22,11 @@ def database(tmp_path):
return (tmp_path / "file.db").as_posix()


@pytest.fixture(params=[None])
def database_name(request):
return request.param


@pytest.fixture(params=[_SENTINEL])
def connection_config(request, database):
return (
Expand All @@ -31,9 +37,10 @@ def connection_config(request, database):


@pytest.fixture
def table_dataset(connection_config, load_args, save_args):
def table_dataset(database_name, connection_config, load_args, save_args):
return TableDataset(
table_name="test",
database=database_name,
connection=connection_config,
load_args=load_args,
save_args=save_args,
Expand Down Expand Up @@ -79,6 +86,25 @@ def test_save_and_load(self, table_dataset, dummy_table, database):
assert not con.sql("SELECT * FROM duckdb_tables").fetchnumpy()["table_name"]
assert "test" in con.sql("SELECT * FROM duckdb_views").fetchnumpy()["view_name"]

@pytest.mark.parametrize(
"connection_config", [{"backend": "polars"}], indirect=True
)
@pytest.mark.parametrize("save_args", [{"materialized": "table"}], indirect=True)
def test_save_and_load_polars(
self, table_dataset, connection_config, save_args, dummy_table
):
"""Test saving and reloading the dataset configured with Polars.
If and when the Polars backend handles the `database` parameter,
this test can be removed. Additionally, the `create_view` method
is supported since Ibis 9.1.0, so `save_args` doesn't need to be
overridden.
"""
table_dataset.save(dummy_table)
reloaded = table_dataset.load()
assert_frame_equal(dummy_table.execute(), reloaded.execute())

def test_exists(self, table_dataset, dummy_table):
"""Test `exists` method invocation for both existing and
nonexistent dataset."""
Expand All @@ -103,6 +129,49 @@ def test_save_extra_params(self, table_dataset, save_args, dummy_table, database
)
assert not con.sql("SELECT * FROM duckdb_views").fetchnumpy()["view_name"]

@pytest.mark.parametrize("database_name", ["test"], indirect=True)
def test_external_database(
self, tmp_path, table_dataset, database_name, dummy_table, database
):
"""Test passing the database name to read from and create in."""
# Attach another DuckDB database to the existing DuckDB session.
table_dataset.connection.attach(tmp_path / f"{database_name}.db")

table_dataset.save(dummy_table)
reloaded = table_dataset.load()
assert_frame_equal(dummy_table.execute(), reloaded.execute())

# Verify that the attached database file was the one written to.
con = duckdb.connect(database)
assert (
"test"
in con.sql("SELECT * FROM duckdb_views").fetchnumpy()["database_name"]
)

@pytest.mark.skipif(
Version(ibis.__version__) < Version("9.0.0"),
reason='Ibis 9.0 standardised use of "database" to mean a collection of tables',
)
@pytest.mark.parametrize("database_name", ["test"], indirect=True)
def test_database(
self, tmp_path, table_dataset, database_name, dummy_table, database
):
"""Test passing the database name to read from and create in."""
# Create a database (meaning a collection of tables, or schema).
# To learn more about why Ibis uses "database" in this way, read
# https://ibis-project.org/posts/ibis-version-9.0.0-release/#what-does-schema-mean
table_dataset.connection.create_database(database_name)

table_dataset.save(dummy_table)
reloaded = table_dataset.load()
assert_frame_equal(dummy_table.execute(), reloaded.execute())

# Verify that the attached database file was the one written to.
con = duckdb.connect(database)
assert (
"test" in con.sql("SELECT * FROM duckdb_views").fetchnumpy()["schema_name"]
)

def test_no_filepath_or_table_name(connection_config):
pattern = r"Must provide at least one of `filepath` or `table_name`\."
with pytest.raises(DatasetError, match=pattern):
Expand Down

0 comments on commit cd694ae

Please sign in to comment.