diff --git a/aiida_restapi/config.py b/aiida_restapi/config.py index fead108..9f877de 100644 --- a/aiida_restapi/config.py +++ b/aiida_restapi/config.py @@ -17,3 +17,6 @@ 'disabled': False, } } + +# The chunks size for streaming data for download +DOWNLOAD_CHUNK_SIZE = 1024 diff --git a/aiida_restapi/routers/nodes.py b/aiida_restapi/routers/nodes.py index bedbb42..f3ac146 100644 --- a/aiida_restapi/routers/nodes.py +++ b/aiida_restapi/routers/nodes.py @@ -4,16 +4,18 @@ import os import tempfile from pathlib import Path -from typing import Any, List, Optional +from typing import Any, Generator, List, Optional from aiida import orm from aiida.cmdline.utils.decorators import with_dbenv -from aiida.common.exceptions import EntryPointError +from aiida.common.exceptions import EntryPointError, LicensingException, NotExistent from aiida.plugins.entry_point import load_entry_point from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile +from fastapi.responses import StreamingResponse from pydantic import ValidationError from aiida_restapi import models, resources +from aiida_restapi.config import DOWNLOAD_CHUNK_SIZE from .auth import get_current_active_user @@ -41,6 +43,50 @@ async def get_nodes_download_formats() -> dict[str, Any]: return resources.get_all_download_formats() +@router.get('/nodes/{nodes_id}/download') +@with_dbenv() +async def download_node(nodes_id: int, download_format: Optional[str] = None) -> StreamingResponse: + """Get nodes by id.""" + from aiida.orm import load_node + + try: + node = load_node(nodes_id) + except NotExistent: + raise HTTPException(status_code=404, detail=f'Could no find any node with id {nodes_id}') + + if download_format is None: + raise HTTPException( + status_code=422, + detail='Please specify the download format. ' + 'The available download formats can be ' + 'queried using the /nodes/download_formats/ endpoint.', + ) + + elif download_format in node.get_export_formats(): + # byteobj, dict with {filename: filecontent} + import io + + try: + exported_bytes, _ = node._exportcontent(download_format) + except LicensingException as exc: + raise HTTPException(status_code=500, detail=str(exc)) + + def stream() -> Generator[bytes, None, None]: + with io.BytesIO(exported_bytes) as handler: + while chunk := handler.read(DOWNLOAD_CHUNK_SIZE): + yield chunk + + return StreamingResponse(stream(), media_type=f'application/{download_format}') + + else: + raise HTTPException( + status_code=422, + detail='The format {} is not supported. ' + 'The available download formats can be ' + 'queried using the /nodes/download_formats/ endpoint.'.format(download_format), + ) + + @router.get('/nodes/{nodes_id}', response_model=models.Node) @with_dbenv() async def read_node(nodes_id: int) -> Optional[models.Node]: diff --git a/pyproject.toml b/pyproject.toml index 5d1ac28..c23cef3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,8 @@ testing = [ 'pytest-regressions', 'pytest-cov', 'requests', - 'httpx' + 'httpx', + 'numpy~=1.21' ] [project.urls] diff --git a/tests/conftest.py b/tests/conftest.py index c39c403..b680eba 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ from datetime import datetime from typing import Any, Callable, Mapping, MutableMapping, Optional, Union +import numpy as np import pytest import pytz from aiida import orm @@ -164,6 +165,17 @@ def default_nodes(): return [node_1.pk, node_2.pk, node_3.pk, node_4.pk] +@pytest.fixture(scope='function') +def array_data_node(): + """Populate database with downloadable node (implmenting a _prepare_* function). + For testing the chunking of the streaming we create an array that needs to be splitted int two chunks.""" + + from aiida_restapi.config import DOWNLOAD_CHUNK_SIZE + + nb_elements = DOWNLOAD_CHUNK_SIZE // 64 + 1 + return orm.ArrayData(np.arange(nb_elements, dtype=np.int64)).store() + + @pytest.fixture(scope='function') def authenticate(): """Authenticate user. diff --git a/tests/test_nodes.py b/tests/test_nodes.py index 5429b6e..72c8444 100644 --- a/tests/test_nodes.py +++ b/tests/test_nodes.py @@ -330,3 +330,24 @@ def test_create_bool_with_extra(client, authenticate): # pylint: disable=unused assert check_response.status_code == 200, response.content assert check_response.json()['extras']['extra_one'] == 'value_1' assert check_response.json()['extras']['extra_two'] == 'value_2' + + +@pytest.mark.anyio +async def test_get_download_node(array_data_node, async_client): + """Test download node /nodes/{nodes_id}/download. + The async client is needed to avoid an error caused by an I/O operation on closed file""" + + # Test that array is correctly downloaded as json + response = await async_client.get(f'/nodes/{array_data_node.pk}/download?download_format=json') + assert response.status_code == 200, response.json() + assert response.json().get('default', None) == array_data_node.get_array().tolist() + + # Test exception when wrong download format given + response = await async_client.get(f'/nodes/{array_data_node.pk}/download?download_format=cif') + assert response.status_code == 422, response.json() + assert 'format cif is not supported' in response.json()['detail'] + + # Test exception when no download format given + response = await async_client.get(f'/nodes/{array_data_node.pk}/download') + assert response.status_code == 422, response.json() + assert 'Please specify the download format' in response.json()['detail']