-
Notifications
You must be signed in to change notification settings - Fork 93
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(datasets): Added the Experimental SafetensorsDataset (#898)
* added the skeleton for the Safetensors experimental dataset Signed-off-by: Minura Punchihewa <[email protected]> * implemented the save() and load() funcs Signed-off-by: Minura Punchihewa <[email protected]> * updated the default backend Signed-off-by: Minura Punchihewa <[email protected]> * implemented the describe() and exists() funcs Signed-off-by: Minura Punchihewa <[email protected]> * imported the dataset to main pkg Signed-off-by: Minura Punchihewa <[email protected]> * fixed how data is passed to load() Signed-off-by: Minura Punchihewa <[email protected]> * fixed save() to access the file path Signed-off-by: Minura Punchihewa <[email protected]> * added a release() func Signed-off-by: Minura Punchihewa <[email protected]> * added the docstrings for the dataset Signed-off-by: Minura Punchihewa <[email protected]> * fixed lint issues Signed-off-by: Minura Punchihewa <[email protected]> * added unit tests Signed-off-by: Minura Punchihewa <[email protected]> * added a few more unit tests Signed-off-by: Minura Punchihewa <[email protected]> * fixed broken unit test Signed-off-by: Minura Punchihewa <[email protected]> * fixed lint issues Signed-off-by: Minura Punchihewa <[email protected]> * fixed use of insecure temp files Signed-off-by: Minura Punchihewa <[email protected]> * added the dataset to the documentation Signed-off-by: Minura Punchihewa <[email protected]> * listed the dependencies for the dataset Signed-off-by: Minura Punchihewa <[email protected]> * fixed typo in dataset reference Signed-off-by: Minura Punchihewa <[email protected]> * updated the release notes Signed-off-by: Minura Punchihewa <[email protected]> * updated the docstring for the class Co-authored-by: Nok Lam Chan <[email protected]> Signed-off-by: Minura Punchihewa <[email protected]> * changed the default backend to numpy Signed-off-by: Minura Punchihewa <[email protected]> * added numpy to the list of dependencies Signed-off-by: Minura Punchihewa <[email protected]> * fixed the assert statement in the docstring Signed-off-by: Minura Punchihewa <[email protected]> * fixed syntax error in example in docstring Signed-off-by: Minura Punchihewa <[email protected]> * updated the examples in the docs to use a numpy backend Signed-off-by: Minura Punchihewa <[email protected]> --------- Signed-off-by: Minura Punchihewa <[email protected]> Signed-off-by: Minura Punchihewa <[email protected]> Co-authored-by: Dmitry Sorokin <[email protected]> Co-authored-by: Nok Lam Chan <[email protected]>
- Loading branch information
1 parent
9c92e15
commit 0dec688
Showing
7 changed files
with
439 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
11 changes: 11 additions & 0 deletions
11
kedro-datasets/kedro_datasets_experimental/safetensors/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
"""``AbstractDataset`` implementation to load/save tensors using the SafeTensors library.""" | ||
|
||
from typing import Any | ||
|
||
import lazy_loader as lazy | ||
|
||
SafetensorsDataset: Any | ||
|
||
__getattr__, __dir__, __all__ = lazy.attach( | ||
__name__, submod_attrs={"safetensors_dataset": ["SafetensorsDataset"]} | ||
) |
190 changes: 190 additions & 0 deletions
190
kedro-datasets/kedro_datasets_experimental/safetensors/safetensors_dataset.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,190 @@ | ||
from __future__ import annotations | ||
|
||
import importlib | ||
from copy import deepcopy | ||
from pathlib import PurePosixPath | ||
from typing import Any | ||
|
||
import fsspec | ||
from kedro.io.core import ( | ||
AbstractVersionedDataset, | ||
DatasetError, | ||
Version, | ||
get_filepath_str, | ||
get_protocol_and_path, | ||
) | ||
|
||
|
||
class SafetensorsDataset(AbstractVersionedDataset[Any, Any]): | ||
"""``SafetensorsDataset`` loads/saves data from/to a Safetensors file using an underlying | ||
filesystem (e.g.: local, S3, GCS). The underlying functionality is supported by | ||
the specified backend library passed in (defaults to the ``numpy`` library), so it | ||
supports all allowed options for loading and Safetensors files. | ||
Example usage for the | ||
`YAML API <https://docs.kedro.org/en/stable/data/\ | ||
data_catalog_yaml_examples.html>`_: | ||
.. code-block:: yaml | ||
test_model: | ||
type: safetensors.SafetensorsDataset | ||
filepath: data/07_model_output/test_model.safetensors | ||
Example usage for the | ||
`Python API <https://docs.kedro.org/en/stable/data/\ | ||
advanced_data_catalog_usage.html>`_: | ||
.. code-block:: pycon | ||
>>> from kedro_datasets_experimental.safetensors import SafetensorsDataset | ||
>>> import numpy as np | ||
>>> | ||
>>> data = { | ||
... "embedding": np.zeros((512, 1024)), | ||
... "attention": np.zeros((256, 256)) | ||
... } | ||
>>> dataset = SafetensorsDataset( | ||
... filepath="test.safetensors", | ||
... ) | ||
>>> dataset.save(data) | ||
>>> reloaded = dataset.load() | ||
>>> assert all(np.array_equal(data[key], reloaded[key]) for key in data) | ||
""" | ||
|
||
DEFAULT_LOAD_ARGS: dict[str, Any] = {} | ||
DEFAULT_SAVE_ARGS: dict[str, Any] = {} | ||
DEFAULT_FS_ARGS: dict[str, Any] = {"open_args_save": {"mode": "wb"}} | ||
|
||
def __init__( # noqa: PLR0913 | ||
self, | ||
*, | ||
filepath: str, | ||
backend: str = "numpy", | ||
version: Version | None = None, | ||
credentials: dict[str, Any] | None = None, | ||
fs_args: dict[str, Any] | None = None, | ||
metadata: dict[str, Any] | None = None, | ||
) -> None: | ||
"""Creates a new instance of ``SafetensorsDataset`` pointing to a concrete Safetensors | ||
file on a specific filesystem. ``SafetensorsDataset`` supports custom backends to | ||
serialise/deserialise objects. | ||
The following backends are supported: | ||
* `numpy` | ||
* `torch` | ||
* `tensorflow` | ||
* `paddle` | ||
* `flax` | ||
Args: | ||
filepath: Filepath in POSIX format to a Safetensors file prefixed with a protocol like | ||
`s3://`. If prefix is not provided, `file` protocol (local filesystem) will be used. | ||
The prefix should be any protocol supported by ``fsspec``. | ||
Note: `http(s)` doesn't support versioning. | ||
backend: The backend library to use for serialising/deserialising objects. | ||
The default backend is 'numpy'. | ||
version: If specified, should be an instance of | ||
``kedro.io.core.Version``. If its ``load`` attribute is | ||
None, the latest version will be loaded. If its ``save`` | ||
attribute is None, save version will be autogenerated. | ||
credentials: Credentials required to get access to the underlying filesystem. | ||
E.g. for ``GCSFileSystem`` it should look like `{"token": None}`. | ||
fs_args: Extra arguments to pass into underlying filesystem class constructor | ||
(e.g. `{"project": "my-project"}` for ``GCSFileSystem``), as well as | ||
to pass to the filesystem's `open` method through nested keys | ||
`open_args_load` and `open_args_save`. | ||
Here you can find all available arguments for `open`: | ||
https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem.open | ||
All defaults are preserved, except `mode`, which is set to `wb` when saving. | ||
metadata: Any arbitrary metadata. | ||
This is ignored by Kedro, but may be consumed by users or external plugins. | ||
Raises: | ||
ImportError: If the ``backend`` module could not be imported. | ||
""" | ||
try: | ||
importlib.import_module(f"safetensors.{backend}") | ||
except ImportError as exc: | ||
raise ImportError( | ||
f"Selected backend '{backend}' could not be imported. " | ||
"Make sure it is installed and importable." | ||
) from exc | ||
|
||
_fs_args = deepcopy(fs_args) or {} | ||
_fs_open_args_load = _fs_args.pop("open_args_load", {}) | ||
_fs_open_args_save = _fs_args.pop("open_args_save", {}) | ||
_credentials = deepcopy(credentials) or {} | ||
|
||
protocol, path = get_protocol_and_path(filepath, version) | ||
if protocol == "file": | ||
_fs_args.setdefault("auto_mkdir", True) | ||
|
||
self._protocol = protocol | ||
self._fs = fsspec.filesystem(self._protocol, **_credentials, **_fs_args) | ||
|
||
self.metadata = metadata | ||
|
||
super().__init__( | ||
filepath=PurePosixPath(path), | ||
version=version, | ||
exists_function=self._fs.exists, | ||
glob_function=self._fs.glob, | ||
) | ||
|
||
self._backend = backend | ||
|
||
self._fs_open_args_load = { | ||
**self.DEFAULT_FS_ARGS.get("open_args_load", {}), | ||
**(_fs_open_args_load or {}), | ||
} | ||
self._fs_open_args_save = { | ||
**self.DEFAULT_FS_ARGS.get("open_args_save", {}), | ||
**(_fs_open_args_save or {}), | ||
} | ||
|
||
def load(self) -> Any: | ||
load_path = get_filepath_str(self._get_load_path(), self._protocol) | ||
|
||
with self._fs.open(load_path, **self._fs_open_args_load) as fs_file: | ||
imported_backend = importlib.import_module(f"safetensors.{self._backend}") | ||
return imported_backend.load(fs_file.read()) | ||
|
||
def save(self, data: Any) -> None: | ||
save_path = get_filepath_str(self._get_save_path(), self._protocol) | ||
|
||
with self._fs.open(save_path, **self._fs_open_args_save) as fs_file: | ||
try: | ||
imported_backend = importlib.import_module(f"safetensors.{self._backend}") | ||
imported_backend.save_file(data, fs_file.name) | ||
except Exception as exc: | ||
raise DatasetError( | ||
f"{data.__class__} was not serialised due to: {exc}" | ||
) from exc | ||
|
||
self._invalidate_cache() | ||
|
||
def _describe(self) -> dict[str, Any]: | ||
return { | ||
"filepath": self._filepath, | ||
"backend": self._backend, | ||
"protocol": self._protocol, | ||
"version": self._version, | ||
} | ||
|
||
def _exists(self) -> bool: | ||
try: | ||
load_path = get_filepath_str(self._get_load_path(), self._protocol) | ||
except DatasetError: | ||
return False | ||
|
||
return self._fs.exists(load_path) | ||
|
||
def _release(self) -> None: | ||
super()._release() | ||
self._invalidate_cache() | ||
|
||
def _invalidate_cache(self) -> None: | ||
"""Invalidate underlying filesystem caches.""" | ||
filepath = get_filepath_str(self._filepath, self._protocol) | ||
self._fs.invalidate_cache(filepath) |
Empty file.
Oops, something went wrong.