diff --git a/kedro-datasets/RELEASE.md b/kedro-datasets/RELEASE.md index f7f69ce13..6b9579f9d 100644 --- a/kedro-datasets/RELEASE.md +++ b/kedro-datasets/RELEASE.md @@ -5,6 +5,7 @@ * Added SQLAlchemy 2.0 support (and dropped support for versions below 1.4). * Added a save method to the APIDataSet +* Reduced constructor arguments for `APIDataSet` by replacing most arguments with a single constructor argument `load_args`. This makes it more consistent with other Kedro DataSets and the underlying `requests` API, and automatically enables the full configuration domain: stream, certificates, proxies, and more. ## Bug fixes and other changes * Relaxed `delta-spark` upper bound to allow compatibility with Spark 3.1.x and 3.2.x. diff --git a/kedro-datasets/kedro_datasets/api/api_dataset.py b/kedro-datasets/kedro_datasets/api/api_dataset.py index 04a490d73..f633e8e55 100644 --- a/kedro-datasets/kedro_datasets/api/api_dataset.py +++ b/kedro-datasets/kedro_datasets/api/api_dataset.py @@ -3,12 +3,17 @@ """ import json as json_ # make pylint happy from copy import deepcopy -from typing import Any, Dict, Iterable, List, Union +from typing import Any, Dict, List, Tuple, Union import requests from kedro.io.core import AbstractDataSet, DataSetError +from requests import Session, sessions from requests.auth import AuthBase +# NOTE: kedro.extras.datasets will be removed in Kedro 0.19.0. +# Any contribution to datasets should be made in kedro-datasets +# in kedro-plugins (https://github.com/kedro-org/kedro-plugins) + class APIDataSet(AbstractDataSet[None, requests.Response]): """``APIDataSet`` loads the data from HTTP(S) APIs. @@ -27,19 +32,22 @@ class APIDataSet(AbstractDataSet[None, requests.Response]): Example usage for the `Python API `_: :: - >>> from kedro_datasets.api import APIDataSet + >>> from kedro.extras.datasets.api import APIDataSet >>> >>> >>> data_set = APIDataSet( >>> url="https://quickstats.nass.usda.gov", - >>> params={ - >>> "key": "SOME_TOKEN", - >>> "format": "JSON", - >>> "commodity_desc": "CORN", - >>> "statisticcat_des": "YIELD", - >>> "agg_level_desc": "STATE", - >>> "year": 2000 - >>> } + >>> load_args={ + >>> "params": { + >>> "key": "SOME_TOKEN", + >>> "format": "JSON", + >>> "commodity_desc": "CORN", + >>> "statisticcat_des": "YIELD", + >>> "agg_level_desc": "STATE", + >>> "year": 2000 + >>> } + >>> }, + >>> credentials=("username", "password") >>> ) >>> data = data_set.load() @@ -89,76 +97,65 @@ def __init__( self, url: str, method: str = "GET", - data: Any = None, - params: Dict[str, Any] = None, - headers: Dict[str, Any] = None, - auth: Union[Iterable[str], AuthBase] = None, - json: Union[List, Dict[str, Any]] = None, - timeout: int = 60, - credentials: Union[Iterable[str], AuthBase] = None, - save_args: Dict[str, Any] = None, + load_args: Dict[str, Any] = None, + credentials: Union[Tuple[str, str], List[str], AuthBase] = None, ) -> None: """Creates a new instance of ``APIDataSet`` to fetch data from an API endpoint. Args: url: The API URL endpoint. - method: The Method of the request, GET, POST, PUT, - DELETE, HEAD, etc... data: The request payload, used for POST, PUT, etc - requests - https://requests.readthedocs.io/en/latest/user/quickstart/#more-complicated-post-requests - params: The url parameters of the API. - https://requests.readthedocs.io/en/latest/user/quickstart/#passing-parameters-in-urls - headers: The HTTP headers. - https://requests.readthedocs.io/en/latest/user/quickstart/#custom-headers - auth: Anything ``requests`` accepts. Normally it's either ``('login', - 'password')``, - or ``AuthBase``, ``HTTPBasicAuth`` instance for more complex cases. Any - iterable will be cast to a tuple. - json: The request payload, used for POST, PUT, etc requests, passed in - to the json kwarg in the requests object. - https://requests.readthedocs.io/en/latest/user/quickstart/#more-complicated-post-requests - timeout: The wait time in seconds for a response, defaults to 1 minute. - https://requests.readthedocs.io/en/latest/user/quickstart/#timeouts - credentials: same as ``auth``. Allows specifying ``auth`` secrets in - credentials.yml. - save_args: Options for saving data on server. Includes all parameters used - during load method. Adds an optional parameter, ``chunk_size`` which determines the - size of the package sent at each request. + method: The Method of the request, GET, POST, PUT, DELETE, HEAD, etc... + load_args: Additional parameters to be fed to requests.request. + https://requests.readthedocs.io/en/latest/api/#requests.request + credentials: Allows specifying secrets in credentials.yml. + Expected format is ``('login', 'password')`` if given as a tuple or list. + An ``AuthBase`` instance can be provided for more complex cases. Raises: - ValueError: if both ``credentials`` and ``auth`` are specified. + ValueError: if both ``auth`` in ``load_args`` and ``credentials`` are specified. """ super().__init__() - self._save_args = deepcopy(self.DEFAULT_SAVE_ARGS) - - if save_args is not None: - self._save_args.update(save_args) + self._load_args = load_args or {} + self._load_args_auth = self._load_args.pop("auth", None) - if credentials is not None and auth is not None: + if credentials is not None and self._load_args_auth is not None: raise ValueError("Cannot specify both auth and credentials.") - auth = credentials or auth + self._auth = credentials or self._load_args_auth + + if "cert" in self._load_args: + self._load_args["cert"] = self._convert_type(self._load_args["cert"]) - if isinstance(auth, Iterable): - auth = tuple(auth) + if "timeout" in self._load_args: + self._load_args["timeout"] = self._convert_type(self._load_args["timeout"]) self._request_args: Dict[str, Any] = { "url": url, "method": method, - "data": data, - "params": params, - "headers": headers, - "auth": auth, - "json": json, - "timeout": timeout, + "auth": self._convert_type(self._auth), + **self._load_args, } + @staticmethod + def _convert_type(value: Any): + """ + From the Data Catalog, iterables are provided as Lists. + However, for some parameters in the Python requests library, + only Tuples are allowed. + """ + if isinstance(value, List): + return tuple(value) + return value + def _describe(self) -> Dict[str, Any]: - return {**self._request_args} + # prevent auth from logging + request_args_cp = self._request_args.copy() + request_args_cp.pop("auth", None) + return request_args_cp - def _execute_request(self) -> requests.Response: + def _execute_request(self, session: Session) -> requests.Response: try: - response = requests.request(**self._request_args) + response = session.request(**self._request_args) response.raise_for_status() except requests.exceptions.HTTPError as exc: raise DataSetError("Failed to fetch data", exc) from exc @@ -170,50 +167,8 @@ def _execute_request(self) -> requests.Response: def _load(self) -> requests.Response: return self._execute_request() - def _execute_save_with_chunks( - self, - json_data: List[Dict[str, Any]], - ) -> requests.Response: - chunk_size = self._save_args["chunk_size"] - n_chunks = len(json_data) // chunk_size + 1 - - for i in range(n_chunks): - send_data = json_data[i * chunk_size : (i + 1) * chunk_size] - - self._save_args["json"] = send_data - try: - response = requests.request(**self._request_args) - response.raise_for_status() - - except requests.exceptions.HTTPError as exc: - raise DataSetError("Failed to send data", exc) from exc - - except OSError as exc: - raise DataSetError("Failed to connect to the remote server") from exc - return response - - def _execute_save_request(self, json_data: Any) -> requests.Response: - self._save_args["json"] = json_data - try: - response = requests.request(**self._request_args) - response.raise_for_status() - except requests.exceptions.HTTPError as exc: - raise DataSetError("Failed to send data", exc) from exc - - except OSError as exc: - raise DataSetError("Failed to connect to the remote server") from exc - return response - - def _save(self, data: Any) -> requests.Response: - # case where we have a list of json data - if isinstance(data, list): - return self._execute_save_with_chunks(json_data=data) - try: - json_.loads(data) - except TypeError: - data = json_.dumps(data) - - return self._execute_save_request(json_data=data) + def _save(self, data: None) -> NoReturn: + raise DataSetError(f"{self.__class__.__name__} is a read only data set type") def _exists(self) -> bool: response = self._execute_request() diff --git a/kedro-datasets/tests/api/test_api_dataset.py b/kedro-datasets/tests/api/test_api_dataset.py index 05e182e87..740a2d817 100644 --- a/kedro-datasets/tests/api/test_api_dataset.py +++ b/kedro-datasets/tests/api/test_api_dataset.py @@ -1,11 +1,11 @@ # pylint: disable=no-member -import json +import base64 import socket import pytest import requests -import requests_mock from kedro.io.core import DataSetError +from requests.auth import HTTPBasicAuth from kedro_datasets.api import APIDataSet @@ -13,79 +13,176 @@ TEST_URL = "http://example.com/api/test" TEST_TEXT_RESPONSE_DATA = "This is a response." -TEST_JSON_RESPONSE_DATA = [{"key": "value"}] +TEST_JSON_REQUEST_DATA = [{"key": "value"}] TEST_PARAMS = {"param": "value"} TEST_URL_WITH_PARAMS = TEST_URL + "?param=value" - +TEST_METHOD = "GET" TEST_HEADERS = {"key": "value"} TEST_SAVE_DATA = [json.dumps({"key1": "info1", "key2": "info2"})] -@pytest.mark.parametrize("method", POSSIBLE_METHODS) class TestAPIDataSet: - @pytest.fixture - def requests_mocker(self): - with requests_mock.Mocker() as mock: - yield mock + @pytest.mark.parametrize("method", POSSIBLE_METHODS) + def test_request_method(self, requests_mock, method): + api_data_set = APIDataSet(url=TEST_URL, method=method) + requests_mock.register_uri(method, TEST_URL, text=TEST_TEXT_RESPONSE_DATA) + + response = api_data_set.load() + assert response.text == TEST_TEXT_RESPONSE_DATA - def test_successfully_load_with_response(self, requests_mocker, method): + @pytest.mark.parametrize( + "parameters_in, url_postfix", + [ + ({"param": "value"}, "?param=value"), + (bytes("a=1", "latin-1"), "?a=1"), + ], + ) + def test_params_in_request(self, requests_mock, parameters_in, url_postfix): api_data_set = APIDataSet( - url=TEST_URL, method=method, params=TEST_PARAMS, headers=TEST_HEADERS + url=TEST_URL, method=TEST_METHOD, load_args={"params": parameters_in} ) - requests_mocker.register_uri( - method, - TEST_URL_WITH_PARAMS, - headers=TEST_HEADERS, - text=TEST_TEXT_RESPONSE_DATA, + requests_mock.register_uri( + TEST_METHOD, TEST_URL + url_postfix, text=TEST_TEXT_RESPONSE_DATA ) response = api_data_set.load() assert isinstance(response, requests.Response) assert response.text == TEST_TEXT_RESPONSE_DATA - def test_successful_json_load_with_response(self, requests_mocker, method): + def test_json_in_request(self, requests_mock): api_data_set = APIDataSet( url=TEST_URL, - method=method, - json=TEST_JSON_RESPONSE_DATA, - headers=TEST_HEADERS, + method=TEST_METHOD, + load_args={"json": TEST_JSON_REQUEST_DATA}, ) - requests_mocker.register_uri( - method, + requests_mock.register_uri(TEST_METHOD, TEST_URL) + + response = api_data_set.load() + assert response.request.json() == TEST_JSON_REQUEST_DATA + + def test_headers_in_request(self, requests_mock): + api_data_set = APIDataSet( + url=TEST_URL, method=TEST_METHOD, load_args={"headers": TEST_HEADERS} + ) + requests_mock.register_uri(TEST_METHOD, TEST_URL, headers={"pan": "cake"}) + + response = api_data_set.load() + + assert response.request.headers["key"] == "value" + assert response.headers["pan"] == "cake" + + def test_api_cookies(self, requests_mock): + api_data_set = APIDataSet( + url=TEST_URL, method=TEST_METHOD, load_args={"cookies": {"pan": "cake"}} + ) + requests_mock.register_uri(TEST_METHOD, TEST_URL, text="text") + + response = api_data_set.load() + assert response.request.headers["Cookie"] == "pan=cake" + + def test_credentials_auth_error(self): + """ + If ``auth`` in ``load_args`` and ``credentials`` are both provided, + the constructor should raise a ValueError. + """ + with pytest.raises(ValueError, match="both auth and credentials"): + APIDataSet( + url=TEST_URL, method=TEST_METHOD, load_args={"auth": []}, credentials={} + ) + + @staticmethod + def _basic_auth(username, password): + encoded = base64.b64encode(f"{username}:{password}".encode("latin-1")) + return f"Basic {encoded.decode('latin-1')}" + + @pytest.mark.parametrize( + "auth_kwarg", + [ + {"load_args": {"auth": ("john", "doe")}}, + {"load_args": {"auth": ["john", "doe"]}}, + {"load_args": {"auth": HTTPBasicAuth("john", "doe")}}, + {"credentials": ("john", "doe")}, + {"credentials": ["john", "doe"]}, + {"credentials": HTTPBasicAuth("john", "doe")}, + ], + ) + def test_auth_sequence(self, requests_mock, auth_kwarg): + api_data_set = APIDataSet(url=TEST_URL, method=TEST_METHOD, **auth_kwarg) + requests_mock.register_uri( + TEST_METHOD, TEST_URL, - headers=TEST_HEADERS, - text=json.dumps(TEST_JSON_RESPONSE_DATA), + text=TEST_TEXT_RESPONSE_DATA, ) response = api_data_set.load() assert isinstance(response, requests.Response) - assert response.json() == TEST_JSON_RESPONSE_DATA + assert response.request.headers["Authorization"] == TestAPIDataSet._basic_auth( + "john", "doe" + ) + assert response.text == TEST_TEXT_RESPONSE_DATA - def test_http_error(self, requests_mocker, method): + @pytest.mark.parametrize( + "timeout_in, timeout_out", + [ + (1, 1), + ((1, 2), (1, 2)), + ([1, 2], (1, 2)), + ], + ) + def test_api_timeout(self, requests_mock, timeout_in, timeout_out): api_data_set = APIDataSet( - url=TEST_URL, method=method, params=TEST_PARAMS, headers=TEST_HEADERS + url=TEST_URL, method=TEST_METHOD, load_args={"timeout": timeout_in} ) - requests_mocker.register_uri( - method, - TEST_URL_WITH_PARAMS, - headers=TEST_HEADERS, - text="Nope, not found", - status_code=requests.codes.FORBIDDEN, + requests_mock.register_uri(TEST_METHOD, TEST_URL) + response = api_data_set.load() + assert response.request.timeout == timeout_out + + def test_stream(self, requests_mock): + text = "I am being streamed." + + api_data_set = APIDataSet( + url=TEST_URL, method=TEST_METHOD, load_args={"stream": True} ) - with pytest.raises(DataSetError, match="Failed to fetch data"): - api_data_set.load() + requests_mock.register_uri(TEST_METHOD, TEST_URL, text=text) + + response = api_data_set.load() + assert isinstance(response, requests.Response) + assert response.request.stream - def test_socket_error(self, requests_mocker, method): + chunks = list(response.iter_content(chunk_size=2, decode_unicode=True)) + assert chunks == ["I ", "am", " b", "ei", "ng", " s", "tr", "ea", "me", "d."] + + def test_proxy(self, requests_mock): api_data_set = APIDataSet( - url=TEST_URL, method=method, params=TEST_PARAMS, headers=TEST_HEADERS + url="ftp://example.com/api/test", + method=TEST_METHOD, + load_args={"proxies": {"ftp": "ftp://127.0.0.1:3000"}}, + ) + requests_mock.register_uri( + TEST_METHOD, + "ftp://example.com/api/test", ) - requests_mocker.register_uri(method, TEST_URL_WITH_PARAMS, exc=socket.error) - with pytest.raises(DataSetError, match="Failed to connect"): - api_data_set.load() + response = api_data_set.load() + assert response.request.proxies.get("ftp") == "ftp://127.0.0.1:3000" + + @pytest.mark.parametrize( + "cert_in, cert_out", + [ + (("cert.pem", "privkey.pem"), ("cert.pem", "privkey.pem")), + (["cert.pem", "privkey.pem"], ("cert.pem", "privkey.pem")), + ("some/path/to/file.pem", "some/path/to/file.pem"), + (None, None), + ], + ) + def test_certs(self, requests_mock, cert_in, cert_out): + api_data_set = APIDataSet( + url=TEST_URL, method=TEST_METHOD, load_args={"cert": cert_in} + ) + requests_mock.register_uri(TEST_METHOD, TEST_URL) def test_successful_save(self, requests_mocker, method): """ @@ -175,16 +272,18 @@ def test_save_socket_error(self, requests_mocker, method): ): api_data_set.save(TEST_SAVE_DATA[0]) - def test_exists_http_error(self, requests_mocker, method): + def test_exists_http_error(self, requests_mock): """ In case of an unexpected HTTP error, ``exists()`` should not silently catch it. """ api_data_set = APIDataSet( - url=TEST_URL, method=method, params=TEST_PARAMS, headers=TEST_HEADERS + url=TEST_URL, + method=TEST_METHOD, + load_args={"params": TEST_PARAMS, "headers": TEST_HEADERS}, ) - requests_mocker.register_uri( - method, + requests_mock.register_uri( + TEST_METHOD, TEST_URL_WITH_PARAMS, headers=TEST_HEADERS, text="Nope, not found", @@ -193,16 +292,18 @@ def test_exists_http_error(self, requests_mocker, method): with pytest.raises(DataSetError, match="Failed to fetch data"): api_data_set.exists() - def test_exists_ok(self, requests_mocker, method): + def test_exists_ok(self, requests_mock): """ If the file actually exists and server responds 200, ``exists()`` should return True """ api_data_set = APIDataSet( - url=TEST_URL, method=method, params=TEST_PARAMS, headers=TEST_HEADERS + url=TEST_URL, + method=TEST_METHOD, + load_args={"params": TEST_PARAMS, "headers": TEST_HEADERS}, ) - requests_mocker.register_uri( - method, + requests_mock.register_uri( + TEST_METHOD, TEST_URL_WITH_PARAMS, headers=TEST_HEADERS, text=TEST_TEXT_RESPONSE_DATA, @@ -210,43 +311,38 @@ def test_exists_ok(self, requests_mocker, method): assert api_data_set.exists() - def test_credentials_auth_error(self, method): - """ - If ``auth`` and ``credentials`` are both provided, - the constructor should raise a ValueError. - """ - with pytest.raises(ValueError, match="both auth and credentials"): - APIDataSet(url=TEST_URL, method=method, auth=[], credentials=[]) - - @pytest.mark.parametrize("auth_kwarg", ["auth", "credentials"]) - @pytest.mark.parametrize( - "auth_seq", - [ - ("username", "password"), - ["username", "password"], - (e for e in ["username", "password"]), # Generator. - ], - ) - def test_auth_sequence(self, requests_mocker, method, auth_seq, auth_kwarg): - """ - ``auth`` and ``credentials`` should be able to be any Iterable. - """ - kwargs = { - "url": TEST_URL, - "method": method, - "params": TEST_PARAMS, - "headers": TEST_HEADERS, - auth_kwarg: auth_seq, - } - - api_data_set = APIDataSet(**kwargs) - requests_mocker.register_uri( - method, + def test_http_error(self, requests_mock): + api_data_set = APIDataSet( + url=TEST_URL, + method=TEST_METHOD, + load_args={"params": TEST_PARAMS, "headers": TEST_HEADERS}, + ) + requests_mock.register_uri( + TEST_METHOD, TEST_URL_WITH_PARAMS, headers=TEST_HEADERS, - text=TEST_TEXT_RESPONSE_DATA, + text="Nope, not found", + status_code=requests.codes.FORBIDDEN, ) - response = api_data_set.load() - assert isinstance(response, requests.Response) - assert response.text == TEST_TEXT_RESPONSE_DATA + with pytest.raises(DataSetError, match="Failed to fetch data"): + api_data_set.load() + + def test_socket_error(self, requests_mock): + api_data_set = APIDataSet( + url=TEST_URL, + method=TEST_METHOD, + load_args={"params": TEST_PARAMS, "headers": TEST_HEADERS}, + ) + requests_mock.register_uri(TEST_METHOD, TEST_URL_WITH_PARAMS, exc=socket.error) + + with pytest.raises(DataSetError, match="Failed to connect"): + api_data_set.load() + + def test_read_only_mode(self): + """ + Saving is disabled on the data set. + """ + api_data_set = APIDataSet(url=TEST_URL, method=TEST_METHOD) + with pytest.raises(DataSetError, match="is a read only data set type"): + api_data_set.save({})