From 191cd79dd1f07e7c0eca47bec190c22c11e6d412 Mon Sep 17 00:00:00 2001 From: avibhstarburst Date: Tue, 30 Jan 2024 14:46:59 +0200 Subject: [PATCH] Add support for 'None' on response data --- tests/unit/conftest.py | 23 +++++++++++++++++++++++ tests/unit/test_client.py | 26 ++++++++++++++++++++++++++ trino/client.py | 4 +++- 3 files changed, 52 insertions(+), 1 deletion(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 1f84d13e..96b0b486 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -185,6 +185,29 @@ def sample_get_response_data(): } +@pytest.fixture(scope="session") +def sample_get_response_data_none(): + """ + This is the response to the second HTTP request (a GET) from an actual + Trino session. It is deliberately not truncated to document such response + and allow to use it for other tests. After doing the steps above, do: + + :: + >>> cur.fetchall() + + """ + yield { + "id": "20210817_140827_00000_arvdv", + "nextUri": "coordinator:8080/v1/statement/20210817_140827_00000_arvdv/2", + "data": None, + "columns": [], + "taskDownloadUris": [], + "partialCancelUri": "http://localhost:8080/v1/stage/20210817_140827_00000_arvdv.0", # NOQA: E501 + "stats": {}, + "infoUri": "http://coordinator:8080/query.html?20210817_140827_00000_arvdv", # NOQA: E501 + } + + @pytest.fixture(scope="session") def sample_get_error_response_data(): yield { diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index d1b23ae7..2c522f75 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -704,6 +704,32 @@ def test_trino_fetch_request(mock_requests, sample_get_response_data): assert status.rows == sample_get_response_data["data"] +@mock.patch("trino.client.TrinoRequest.http") +def test_trino_fetch_request_data_none(mock_requests, sample_get_response_data_none): + mock_requests.Response.return_value.json.return_value = sample_get_response_data_none + + req = TrinoRequest( + host="coordinator", + port=8080, + client_session=ClientSession( + user="test", + source="test", + catalog="test", + schema="test", + properties={}, + ), + http_scheme="http", + ) + + http_resp = TrinoRequest.http.Response() + http_resp.status_code = 200 + status = req.process(http_resp) + + assert status.next_uri == sample_get_response_data_none["nextUri"] + assert status.id == sample_get_response_data_none["id"] + assert status.rows == [] + + @mock.patch("trino.client.TrinoRequest.http") def test_trino_fetch_error(mock_requests, sample_get_error_response_data): mock_requests.Response.return_value.json.return_value = sample_get_error_response_data diff --git a/trino/client.py b/trino/client.py index c7f26a80..010d6308 100644 --- a/trino/client.py +++ b/trino/client.py @@ -633,6 +633,8 @@ def process(self, http_response) -> TrinoStatus: self._next_uri = response.get("nextUri") + data = response.get("data") if response.get("data") else [] + return TrinoStatus( id=response["id"], stats=response["stats"], @@ -641,7 +643,7 @@ def process(self, http_response) -> TrinoStatus: next_uri=self._next_uri, update_type=response.get("updateType"), update_count=response.get("updateCount"), - rows=response.get("data", []), + rows=data, columns=response.get("columns"), )