From 105fcf217603749516961b00229795f6804d5f7e Mon Sep 17 00:00:00 2001 From: michal-mmm <84588078+michal-mmm@users.noreply.github.com> Date: Wed, 12 Jun 2024 17:46:41 +0200 Subject: [PATCH] fix(datasets): add metadata parameter to datasets (#708) Signed-off-by: michal-mmm Co-authored-by: Ankita Katiyar <110245118+ankatiyar@users.noreply.github.com> --- .../kedro_datasets/databricks/managed_table_dataset.py | 4 ++++ .../kedro_datasets/huggingface/hugging_face_dataset.py | 2 ++ .../huggingface/transformer_pipeline_dataset.py | 2 ++ kedro-datasets/kedro_datasets/ibis/table_dataset.py | 4 ++++ .../kedro_datasets/polars/eager_polars_dataset.py | 2 ++ .../kedro_datasets/spark/spark_streaming_dataset.py | 6 +++++- 6 files changed, 19 insertions(+), 1 deletion(-) diff --git a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py index ae193e68a..b139449aa 100644 --- a/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py +++ b/kedro-datasets/kedro_datasets/databricks/managed_table_dataset.py @@ -230,6 +230,7 @@ def __init__( # noqa: PLR0913 schema: dict[str, Any] | None = None, partition_columns: list[str] | None = None, owner_group: str | None = None, + metadata: dict[str, Any] | None = None, ) -> None: """Creates a new instance of ``ManagedTableDataset``. @@ -259,6 +260,8 @@ def __init__( # noqa: PLR0913 owner_group: if table access control is enabled in your workspace, specifying owner_group will transfer ownership of the table and database to this owner. All databases should have the same owner_group. Defaults to None. + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. Raises: DatasetError: Invalid configuration supplied (through ManagedTable validation) """ @@ -276,6 +279,7 @@ def __init__( # noqa: PLR0913 ) self._version = version + self.metadata = metadata super().__init__( filepath=None, # type: ignore[arg-type] diff --git a/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py b/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py index e368cb9a5..b6f605616 100644 --- a/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/hugging_face_dataset.py @@ -41,9 +41,11 @@ def __init__( *, dataset_name: str, dataset_kwargs: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, ): self.dataset_name = dataset_name self._dataset_kwargs = dataset_kwargs or {} + self.metadata = metadata def _load(self): return load_dataset(self.dataset_name, **self._dataset_kwargs) diff --git a/kedro-datasets/kedro_datasets/huggingface/transformer_pipeline_dataset.py b/kedro-datasets/kedro_datasets/huggingface/transformer_pipeline_dataset.py index 942e96627..663edaab6 100644 --- a/kedro-datasets/kedro_datasets/huggingface/transformer_pipeline_dataset.py +++ b/kedro-datasets/kedro_datasets/huggingface/transformer_pipeline_dataset.py @@ -43,12 +43,14 @@ def __init__( task: str | None = None, model_name: str | None = None, pipeline_kwargs: dict[str, t.Any] | None = None, + metadata: dict[str, t.Any] | None = None, ): if task is None and model_name is None: raise ValueError("At least 'task' or 'model_name' are needed") self._task = task if task else None self._model_name = model_name self._pipeline_kwargs = pipeline_kwargs or {} + self.metadata = metadata if self._pipeline_kwargs and ( "task" in self._pipeline_kwargs or "model" in self._pipeline_kwargs diff --git a/kedro-datasets/kedro_datasets/ibis/table_dataset.py b/kedro-datasets/kedro_datasets/ibis/table_dataset.py index 3f06a63c6..67932e9f7 100644 --- a/kedro-datasets/kedro_datasets/ibis/table_dataset.py +++ b/kedro-datasets/kedro_datasets/ibis/table_dataset.py @@ -80,6 +80,7 @@ def __init__( # noqa: PLR0913 connection: dict[str, Any] | None = None, load_args: dict[str, Any] | None = None, save_args: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, ) -> None: """Creates a new ``TableDataset`` pointing to a table (or file). @@ -117,6 +118,8 @@ def __init__( # noqa: PLR0913 objects are materialized as views. To save a table using a different materialization strategy, supply a value for `materialized` in `save_args`. + metadata: Any arbitrary metadata. This is ignored by Kedro, + but may be consumed by users or external plugins. """ if filepath is None and table_name is None: raise DatasetError( @@ -127,6 +130,7 @@ def __init__( # noqa: PLR0913 self._file_format = file_format self._table_name = table_name self._connection_config = connection + self.metadata = metadata # Set load and save arguments, overwriting defaults if provided. self._load_args = deepcopy(self.DEFAULT_LOAD_ARGS) diff --git a/kedro-datasets/kedro_datasets/polars/eager_polars_dataset.py b/kedro-datasets/kedro_datasets/polars/eager_polars_dataset.py index b8e3f9b17..b9c2d64ea 100644 --- a/kedro-datasets/kedro_datasets/polars/eager_polars_dataset.py +++ b/kedro-datasets/kedro_datasets/polars/eager_polars_dataset.py @@ -68,6 +68,7 @@ def __init__( # noqa: PLR0913 version: Version | None = None, credentials: dict[str, Any] | None = None, fs_args: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, ): """Creates a new instance of ``EagerPolarsDataset`` pointing to a concrete data file on a specific filesystem. The appropriate polars load/save methods are dynamically @@ -124,6 +125,7 @@ def __init__( # noqa: PLR0913 self._protocol = protocol self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) + self.metadata = metadata super().__init__( filepath=PurePosixPath(path), diff --git a/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py b/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py index 715c90720..71bd5f8c8 100644 --- a/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py +++ b/kedro-datasets/kedro_datasets/spark/spark_streaming_dataset.py @@ -42,13 +42,14 @@ class SparkStreamingDataset(AbstractDataset): DEFAULT_LOAD_ARGS = {} # type: dict[str, Any] DEFAULT_SAVE_ARGS = {} # type: dict[str, Any] - def __init__( + def __init__( # noqa: PLR0913 self, *, filepath: str = "", file_format: str = "", save_args: dict[str, Any] | None = None, load_args: dict[str, Any] | None = None, + metadata: dict[str, Any] | None = None, ) -> None: """Creates a new instance of SparkStreamingDataset. @@ -73,10 +74,13 @@ def __init__( respectively. You can find a list of options for each selected format in Spark DataFrame write documentation, see https://spark.apache.org/docs/latest/structured-streaming-programming-guide.html + metadata: Any arbitrary metadata. + This is ignored by Kedro, but may be consumed by users or external plugins. """ self._file_format = file_format self._save_args = save_args self._load_args = load_args + self.metadata = metadata fs_prefix, filepath = _split_filepath(filepath)