From 131f6f86a32d953c3ae1b8b4584933113982fcff Mon Sep 17 00:00:00 2001 From: Casey Jao Date: Wed, 15 May 2024 21:18:37 -0400 Subject: [PATCH] API: refactor and fix `get_result(wait=True)` The previous `GET /dispatches/{dispatch_id}` endpoint was trying to do too much. Its responsibilities are now separated into two endpoints: * `GET /dispatches`: bulk query dispatch summaries (including status) with options to filter by `dispatch_id`, sort chronologically, and also limit the output to status only. * `GET /dispatches/{dispatch_id}`: download manifest To achieve the desired behavior of `get_result(id, wait=True)`, the client 1. Polls the dispatch status by querying the first endpoint. 2. Downloads the manifest after the dispatch has reached a final status. The server no longer returns 503 errors when the dispatch is not yet "ready". A 503 status code is not entirely accurate here because it is intended to convey temporary service unavailablity resulting from server overload or rate limiting. However, the fact that the workflow is still running does not indicate any fault of the server. These changes will allow `get_result(dispatch_id, wait=True)` to wait as long as required instead of erroring out after some time. Supporting improvements: DAL: Add sorting and pagination to Controller DAL: improve bulk get when retrieving only some columns Directly select the specified columns instead of retrieving the whole ORM entities and deferring column loading using load_only --- CHANGELOG.md | 4 + covalent/_dispatcher_plugins/local.py | 66 +------- covalent/_results_manager/results_manager.py | 98 ++++++------ covalent/_shared_files/defaults.py | 3 - covalent/triggers/base.py | 17 +-- covalent_dispatcher/_dal/controller.py | 46 ++++-- covalent_dispatcher/_dal/result.py | 40 ++++- covalent_dispatcher/_service/app.py | 143 ++++++++---------- covalent_dispatcher/_service/models.py | 34 ++++- .../_dal/result_test.py | 81 ++++++++++ .../_service/app_test.py | 52 +------ .../dispatcher_plugins/local_test.py | 84 ---------- .../results_manager_test.py | 41 ++--- tests/covalent_tests/triggers/base_test.py | 24 +-- 14 files changed, 324 insertions(+), 409 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a23904f2..a660a2891 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] +### Fixed + +- `get_result(wait=True)` no longer times out + ### Changed - Added support for Python 3.11 and 3.12 diff --git a/covalent/_dispatcher_plugins/local.py b/covalent/_dispatcher_plugins/local.py index 9857342cf..bc29af30c 100644 --- a/covalent/_dispatcher_plugins/local.py +++ b/covalent/_dispatcher_plugins/local.py @@ -129,8 +129,6 @@ def dispatch( Wrapper function which takes the inputs of the workflow as arguments """ - multistage = get_config("sdk.multistage_dispatch") == "true" - # Extract triggers here if "triggers" in orig_lattice.metadata: triggers_data = orig_lattice.metadata.pop("triggers") @@ -155,14 +153,7 @@ def wrapper(*args, **kwargs) -> str: The dispatch id of the workflow. """ - if multistage: - dispatch_id = LocalDispatcher.register(orig_lattice, dispatcher_addr)( - *args, **kwargs - ) - else: - dispatch_id = LocalDispatcher.submit(orig_lattice, dispatcher_addr)( - *args, **kwargs - ) + dispatch_id = LocalDispatcher.register(orig_lattice, dispatcher_addr)(*args, **kwargs) if triggers_data: LocalDispatcher.register_triggers(triggers_data, dispatch_id) @@ -237,61 +228,6 @@ def wrapper(*args, **kwargs) -> str: return wrapper - @staticmethod - def submit( - orig_lattice: Lattice, - dispatcher_addr: str = None, - ) -> Callable: - """ - Wrapping the dispatching functionality to allow input passing - and server address specification. - - Afterwards, send the lattice to the dispatcher server and return - the assigned dispatch id. - - Args: - orig_lattice: The lattice/workflow to send to the dispatcher server. - dispatcher_addr: The address of the dispatcher server. If None then then defaults to the address set in Covalent's config. - - Returns: - Wrapper function which takes the inputs of the workflow as arguments - """ - - if dispatcher_addr is None: - dispatcher_addr = format_server_url() - - @wraps(orig_lattice) - def wrapper(*args, **kwargs) -> str: - """ - Send the lattice to the dispatcher server and return - the assigned dispatch id. - - Args: - *args: The inputs of the workflow. - **kwargs: The keyword arguments of the workflow. - - Returns: - The dispatch id of the workflow. - """ - - if not isinstance(orig_lattice, Lattice): - message = f"Dispatcher expected a Lattice, received {type(orig_lattice)} instead." - app_log.error(message) - raise TypeError(message) - - lattice = deepcopy(orig_lattice) - - lattice.build_graph(*args, **kwargs) - - # Serialize the transport graph to JSON - json_lattice = lattice.serialize_to_json() - endpoint = "/api/v2/dispatches/submit" - r = APIClient(dispatcher_addr).post(endpoint, data=json_lattice) - r.raise_for_status() - return r.content.decode("utf-8").strip().replace('"', "") - - return wrapper - @staticmethod def start( dispatch_id: str, diff --git a/covalent/_results_manager/results_manager.py b/covalent/_results_manager/results_manager.py index 4c751206a..941999546 100644 --- a/covalent/_results_manager/results_manager.py +++ b/covalent/_results_manager/results_manager.py @@ -19,12 +19,11 @@ import contextlib import os +import time from pathlib import Path -from typing import Dict, List, Optional +from typing import List, Optional from furl import furl -from requests.adapters import HTTPAdapter -from urllib3.util import Retry from .._api.apiclient import CovalentAPIClient from .._serialize.common import load_asset @@ -40,9 +39,9 @@ from .._shared_files.exceptions import MissingLatticeRecordError from .._shared_files.schemas.asset import AssetSchema from .._shared_files.schemas.result import ResultSchema +from .._shared_files.util_classes import RESULT_STATUS, Status from .._shared_files.utils import copy_file_locally, format_server_url from .result import Result -from .wait import EXTREME app_log = logger.app_log log_stack_info = logger.log_stack_info @@ -139,12 +138,20 @@ def cancel(dispatch_id: str, task_ids: List[int] = None, dispatcher_addr: str = # Multi-part +def _query_dispatch_status(dispatch_id: str, api_client: CovalentAPIClient): + endpoint = "/api/v2/dispatches" + resp = api_client.get(endpoint, params={"dispatch_id": dispatch_id, "status_only": True}) + resp.raise_for_status() + dispatches = resp.json()["dispatches"] + if len(dispatches) == 0: + raise MissingLatticeRecordError + + return dispatches[0]["status"] + + def _get_result_export_from_dispatcher( - dispatch_id: str, - wait: bool = False, - status_only: bool = False, - dispatcher_addr: str = None, -) -> Dict: + dispatch_id: str, api_client: CovalentAPIClient +) -> ResultSchema: """ Internal function to get the results of a dispatch from the server without checking if it is ready to read. @@ -161,24 +168,13 @@ def _get_result_export_from_dispatcher( MissingLatticeRecordError: If the result is not found. """ - if dispatcher_addr is None: - dispatcher_addr = format_server_url() - - retries = int(EXTREME) if wait else 5 - - adapter = HTTPAdapter(max_retries=Retry(total=retries, backoff_factor=1)) - api_client = CovalentAPIClient(dispatcher_addr, adapter=adapter, auto_raise=False) - endpoint = f"/api/v2/dispatches/{dispatch_id}" - response = api_client.get( - endpoint, - params={"wait": wait, "status_only": status_only}, - ) + response = api_client.get(endpoint) if response.status_code == 404: raise MissingLatticeRecordError response.raise_for_status() export = response.json() - return export + return ResultSchema.model_validate(export) # Function to download default assets @@ -346,11 +342,17 @@ def from_dispatch_id( wait: bool = False, dispatcher_addr: str = None, ) -> "ResultManager": - export = _get_result_export_from_dispatcher( - dispatch_id, wait, status_only=False, dispatcher_addr=dispatcher_addr - ) + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + api_client = CovalentAPIClient(dispatcher_addr, auto_raise=False) + if wait: + status = Status(_query_dispatch_status(dispatch_id, api_client)) + while not RESULT_STATUS.is_terminal(status): + time.sleep(1) + status = Status(_query_dispatch_status(dispatch_id, api_client)) - manifest = ResultSchema.model_validate(export["result_export"]) + manifest = _get_result_export_from_dispatcher(dispatch_id, api_client) # sort the nodes manifest.lattice.transport_graph.nodes.sort(key=lambda x: x.id) @@ -408,14 +410,15 @@ def _get_result_multistage( """ + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + api_client = CovalentAPIClient(dispatcher_addr) try: if status_only: - return _get_result_export_from_dispatcher( - dispatch_id=dispatch_id, - wait=wait, - status_only=status_only, - dispatcher_addr=dispatcher_addr, - ) + status = _query_dispatch_status(dispatch_id, api_client) + return {"id": dispatch_id, "status": status} + rm = get_result_manager(dispatch_id, results_dir, wait, dispatcher_addr) _get_default_assets(rm) @@ -496,23 +499,14 @@ def get_result( The Result object from the Covalent server """ - max_attempts = int(os.getenv("COVALENT_GET_RESULT_RETRIES", 10)) - num_attempts = 0 - while num_attempts < max_attempts: - try: - return _get_result_multistage( - dispatch_id=dispatch_id, - wait=wait, - dispatcher_addr=dispatcher_addr, - status_only=status_only, - results_dir=results_dir, - workflow_output=workflow_output, - intermediate_outputs=intermediate_outputs, - sublattice_results=sublattice_results, - qelectron_db=qelectron_db, - ) - - except RecursionError as re: - app_log.error(re) - num_attempts += 1 - raise RuntimeError("Timed out waiting for result. Please retry or check dispatch.") + return _get_result_multistage( + dispatch_id=dispatch_id, + wait=wait, + dispatcher_addr=dispatcher_addr, + status_only=status_only, + results_dir=results_dir, + workflow_output=workflow_output, + intermediate_outputs=intermediate_outputs, + sublattice_results=sublattice_results, + qelectron_db=qelectron_db, + ) diff --git a/covalent/_shared_files/defaults.py b/covalent/_shared_files/defaults.py index aef61086d..2780d7c53 100644 --- a/covalent/_shared_files/defaults.py +++ b/covalent/_shared_files/defaults.py @@ -67,9 +67,6 @@ def get_default_sdk_config(): + "/covalent/dispatches" ), "task_packing": "true" if os.environ.get("COVALENT_ENABLE_TASK_PACKING") else "false", - "multistage_dispatch": ( - "false" if os.environ.get("COVALENT_DISABLE_MULTISTAGE_DISPATCH") == "1" else "true" - ), "results_dir": os.environ.get( "COVALENT_RESULTS_DIR" ) # COVALENT_RESULTS_DIR is where the client downloads workflow artifacts during get_result() which is different from COVALENT_DATA_DIR diff --git a/covalent/triggers/base.py b/covalent/triggers/base.py index 2eb49a434..341cd7c74 100644 --- a/covalent/triggers/base.py +++ b/covalent/triggers/base.py @@ -15,8 +15,6 @@ # limitations under the License. -import asyncio -import json from abc import abstractmethod import requests @@ -108,17 +106,12 @@ def _get_status(self) -> Status: """ if self.use_internal_funcs: - from covalent_dispatcher._service.app import export_result + from covalent_dispatcher._service.app import get_dispatches_bulk - response = asyncio.run_coroutine_threadsafe( - export_result(self.lattice_dispatch_id, status_only=True), - self.event_loop, - ).result() - - if isinstance(response, dict): - return response["status"] - - return json.loads(response.body.decode()).get("status") + response = get_dispatches_bulk( + dispatch_id=[self.lattice_dispatch_id], status_only=True + ) + return response.dispatches[0].status from .. import get_result diff --git a/covalent_dispatcher/_dal/controller.py b/covalent_dispatcher/_dal/controller.py index 3e682b979..53928a2fe 100644 --- a/covalent_dispatcher/_dal/controller.py +++ b/covalent_dispatcher/_dal/controller.py @@ -17,10 +17,12 @@ from __future__ import annotations -from typing import Generic, Type, TypeVar +from typing import Generic, List, Optional, Sequence, Type, TypeVar, Union from sqlalchemy import select, update -from sqlalchemy.orm import Session, load_only +from sqlalchemy.engine import Row +from sqlalchemy.orm import Session +from sqlalchemy.sql.expression import Select, desc from .._db import models @@ -50,11 +52,16 @@ def get( cls, session: Session, *, + stmt: Optional[Select] = None, fields: list, equality_filters: dict, membership_filters: dict, for_update: bool = False, - ): + sort_fields: List[str] = [], + reverse: bool = True, + offset: int = 0, + max_items: Optional[int] = None, + ) -> Union[Sequence[Row], Sequence[T]]: """Bulk ORM-enabled SELECT. Args: @@ -64,19 +71,40 @@ def get( membership_filters: Dict{field_name: value_list} for_update: Whether to lock the selected rows + Returns: + A list of SQLAlchemy Rows or whole ORM entities depending + on whether only a subset of fields is specified. + """ - stmt = select(cls.model) + if stmt is None: + if len(fields) > 0: + entities = [getattr(cls.model, attr) for attr in fields] + stmt = select(*entities) + else: + stmt = select(cls.model) + for attr, val in equality_filters.items(): stmt = stmt.where(getattr(cls.model, attr) == val) for attr, vals in membership_filters.items(): stmt = stmt.where(getattr(cls.model, attr).in_(vals)) - if len(fields) > 0: - attrs = [getattr(cls.model, f) for f in fields] - stmt = stmt.options(load_only(*attrs)) if for_update: stmt = stmt.with_for_update() - - return session.scalars(stmt).all() + for attr in sort_fields: + if reverse: + stmt = stmt.order_by(desc(getattr(cls.model, attr))) + else: + stmt = stmt.order_by(getattr(cls.model, attr)) + + stmt = stmt.offset(offset) + if max_items: + stmt = stmt.limit(max_items) + + if len(fields) == 0: + # Return whole ORM entities + return session.scalars(stmt).all() + else: + # Return a named tuple containing the selected cols + return session.execute(stmt).all() @classmethod def get_by_primary_key( diff --git a/covalent_dispatcher/_dal/result.py b/covalent_dispatcher/_dal/result.py index a9378558c..489efe11d 100644 --- a/covalent_dispatcher/_dal/result.py +++ b/covalent_dispatcher/_dal/result.py @@ -21,6 +21,7 @@ from datetime import datetime from typing import Any, Dict, List +from sqlalchemy import select from sqlalchemy.orm import Session from covalent._shared_files import logger @@ -45,6 +46,41 @@ class ResultMeta(Record[models.Lattice]): model = models.Lattice + @classmethod + def get_toplevel_dispatches( + cls, + session: Session, + *, + fields: list, + equality_filters: dict, + membership_filters: dict, + for_update: bool = False, + sort_fields: List[str] = [], + reverse: bool = True, + offset: int = 0, + max_items: int = 10, + ): + if len(fields) > 0: + entities = [getattr(cls.model, attr) for attr in fields] + stmt = select(*entities) + else: + stmt = select(cls.model) + + stmt = stmt.where(models.Lattice.root_dispatch_id == models.Lattice.dispatch_id) + + return cls.get( + session=session, + stmt=stmt, + fields=fields, + equality_filters=equality_filters, + membership_filters=membership_filters, + for_update=for_update, + sort_fields=sort_fields, + reverse=reverse, + offset=offset, + max_items=max_items, + ) + class ResultAsset(Record[models.LatticeAsset]): model = models.LatticeAsset @@ -175,7 +211,7 @@ def _update_dispatch( with self.session() as session: electron_rec = Electron.get_db_records( session, - keys={"id", "parent_lattice_id"}, + keys=ELECTRON_KEYS, equality_filters={"id": self._electron_id}, membership_filters={}, )[0] @@ -343,7 +379,7 @@ def _get_incomplete_nodes(self): A dictionary {"failed": [node_ids], "cancelled": [node_ids]} """ with self.session() as session: - query_keys = {"parent_lattice_id", "node_id", "name", "status"} + query_keys = {"id", "parent_lattice_id", "node_id", "name", "status"} records = Electron.get_db_records( session, keys=query_keys, diff --git a/covalent_dispatcher/_service/app.py b/covalent_dispatcher/_service/app.py index 03e71186d..3d582a7bd 100644 --- a/covalent_dispatcher/_service/app.py +++ b/covalent_dispatcher/_service/app.py @@ -20,11 +20,11 @@ import asyncio import json from contextlib import asynccontextmanager -from typing import List, Optional, Union -from uuid import UUID +from typing import List, Union -from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi import APIRouter, FastAPI, HTTPException, Query from fastapi.responses import JSONResponse +from typing_extensions import Annotated import covalent_dispatcher.entry_point as dispatcher from covalent._shared_files import logger @@ -38,7 +38,13 @@ from .._db.datastore import workflow_db from .._db.dispatchdb import DispatchDB from .heartbeat import Heartbeat -from .models import DispatchStatusSetSchema, ExportResponseSchema, TargetDispatchStatus +from .models import ( + BulkDispatchGetSchema, + BulkGetMetadata, + DispatchStatusSetSchema, + DispatchSummary, + TargetDispatchStatus, +) app_log = logger.app_log log_stack_info = logger.log_stack_info @@ -98,31 +104,6 @@ async def cancel_all_with_status(status: RESULT_STATUS): await dispatcher.cancel_running_dispatch(dispatch_id) -@router.post("/dispatches/submit") -async def submit(request: Request) -> UUID: - """ - Function to accept the submit request of - new dispatch and return the dispatch id - back to the client. - - Args: - None - - Returns: - dispatch_id: The dispatch id in a json format - returned as a Fast API Response object. - """ - try: - data = await request.json() - data = json.dumps(data).encode("utf-8") - return await dispatcher.make_dispatch(data) - except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Failed to submit workflow: {e}", - ) from e - - async def start(dispatch_id: str): """Start a previously registered (re-)dispatch. @@ -266,74 +247,74 @@ async def set_dispatch_status(dispatch_id: str, desired_status: DispatchStatusSe return await cancel(dispatch_id, desired_status.task_ids) -@router.get("/dispatches/{dispatch_id}") -async def export_result( - dispatch_id: str, wait: Optional[bool] = False, status_only: Optional[bool] = False -) -> ExportResponseSchema: - """Export all metadata about a registered dispatch +@router.get("/dispatches", response_model_exclude_unset=True) +def get_dispatches_bulk( + dispatch_id: Annotated[Union[List[str], None], Query()] = None, + status: Annotated[Union[List[str], None], Query()] = None, + latest: bool = True, + offset: int = 0, + count: int = 10, + status_only: bool = False, +) -> BulkDispatchGetSchema: + dispatch_controller = Result.meta_type - Args: - `dispatch_id`: The dispatch's unique id. - - Returns: - { - id: `dispatch_id`, - status: status, - result_export: manifest for the result - } + if status_only: + fields = ["dispatch_id", "status"] + else: + fields = [ + "dispatch_id", + "root_dispatch_id", + "status", + "name", + "electron_num", + "completed_electron_num", + "created_at", + "updated_at", + "completed_at", + ] + + summaries = [] + with workflow_db.session() as session: + in_filters = {} + if dispatch_id is not None: + in_filters["dispatch_id"] = dispatch_id + if status is not None: + in_filters["status"] = status - The manifest `result_export` has the same schema as that which is - submitted to `/register`. + results = dispatch_controller.get( + session, + fields=fields, + equality_filters={"is_active": True}, + membership_filters=in_filters, + sort_fields=["created_at"], + reverse=latest, + offset=offset, + max_items=count, + ) + for res in results: + dispatch_id = res.dispatch_id + summary = DispatchSummary.model_validate(res) + summaries.append(summary) - """ - loop = asyncio.get_running_loop() - return await loop.run_in_executor( - None, - _export_result_sync, - dispatch_id, - wait, - status_only, - ) + bulk_meta = BulkGetMetadata(count=len(results)) + return BulkDispatchGetSchema(dispatches=summaries, metadata=bulk_meta) -def _export_result_sync( - dispatch_id: str, wait: Optional[bool] = False, status_only: Optional[bool] = False -) -> ExportResponseSchema: +@router.get("/dispatches/{dispatch_id}") +def export_manifest(dispatch_id: str) -> ResultSchema: result_object = _try_get_result_object(dispatch_id) if not result_object: return JSONResponse( status_code=404, content={"message": f"The requested dispatch ID {dispatch_id} was not found."}, ) - status = str(result_object.get_value("status", refresh=False)) - if not wait or status in [ - str(RESULT_STATUS.COMPLETED), - str(RESULT_STATUS.FAILED), - str(RESULT_STATUS.CANCELLED), - ]: - output = { - "id": dispatch_id, - "status": status, - } - if not status_only: - output["result_export"] = export_result_manifest(dispatch_id) - - return output - - response = JSONResponse( - status_code=503, - content={"message": "Result not ready to read yet. Please wait for a couple of seconds."}, - headers={"Retry-After": "2"}, - ) - return response + return export_result_manifest(dispatch_id) def _try_get_result_object(dispatch_id: str) -> Union[Result, None]: try: - res = get_result_object( - dispatch_id, bare=True, keys=["id", "dispatch_id", "status"], lattice_keys=["id"] - ) + res = get_result_object(dispatch_id, bare=True) except KeyError: res = None return res diff --git a/covalent_dispatcher/_service/models.py b/covalent_dispatcher/_service/models.py index 18a33a071..f751c1bc2 100644 --- a/covalent_dispatcher/_service/models.py +++ b/covalent_dispatcher/_service/models.py @@ -17,12 +17,13 @@ """FastAPI models for /api/v1/resultv2 endpoints""" import re +from datetime import datetime from enum import Enum from typing import List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict -from covalent._shared_files.schemas.result import ResultSchema +from covalent._shared_files.schemas.result import StatusEnum range_regex = "bytes=([0-9]+)-([0-9]*)" range_pattern = re.compile(range_regex) @@ -56,12 +57,6 @@ class ElectronAssetKey(str, Enum): qelectron_db = "qelectron_db" -class ExportResponseSchema(BaseModel): - id: str - status: str - result_export: Optional[ResultSchema] = None - - class AssetRepresentation(str, Enum): string = "string" b64pickle = "object" @@ -78,3 +73,26 @@ class DispatchStatusSetSchema(BaseModel): # For cancellation, an optional list of task ids to cancel task_ids: Optional[List] = [] + + +class BulkGetMetadata(BaseModel): + count: int + + +class DispatchSummary(BaseModel): + model_config = ConfigDict(from_attributes=True) + + dispatch_id: str + root_dispatch_id: Optional[str] = None + status: StatusEnum + name: Optional[str] = None + electron_num: Optional[int] = None + completed_electron_num: Optional[int] = None + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + completed_at: Optional[datetime] = None + + +class BulkDispatchGetSchema(BaseModel): + dispatches: List[DispatchSummary] + metadata: BulkGetMetadata diff --git a/tests/covalent_dispatcher_tests/_dal/result_test.py b/tests/covalent_dispatcher_tests/_dal/result_test.py index 5b2ec19fa..cb858ebb5 100644 --- a/tests/covalent_dispatcher_tests/_dal/result_test.py +++ b/tests/covalent_dispatcher_tests/_dal/result_test.py @@ -551,3 +551,84 @@ def test_result_filters_parent_electron_updates(test_db, mocker): assert third_update assert subl_node.get_value("output").get_deserialized() == 42 + + +def test_result_controller_bulk_get(test_db, mocker): + record_1 = models.Lattice( + dispatch_id="dispatch_1", + root_dispatch_id="dispatch_1", + name="dispatch_1", + status="NEW_OBJECT", + electron_num=5, + completed_electron_num=0, + ) + + record_2 = models.Lattice( + dispatch_id="dispatch_2", + root_dispatch_id="dispatch_2", + name="dispatch_2", + status="NEW_OBJECT", + electron_num=25, + completed_electron_num=0, + ) + + record_3 = models.Lattice( + dispatch_id="dispatch_3", + root_dispatch_id="dispatch_2", + name="dispatch_3", + status="COMPLETED", + electron_num=25, + completed_electron_num=25, + ) + + with test_db.session() as session: + session.add(record_1) + session.add(record_2) + session.add(record_3) + session.commit() + + dispatch_controller = Result.meta_type + + with test_db.session() as session: + results = dispatch_controller.get( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + ) + assert len(results) == 3 + + with test_db.session() as session: + results = dispatch_controller.get_toplevel_dispatches( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + ) + assert len(results) == 2 + + with test_db.session() as session: + results = dispatch_controller.get( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + sort_fields=["name"], + reverse=False, + max_items=1, + ) + assert len(results) == 1 + assert results[0].dispatch_id == "dispatch_1" + + with test_db.session() as session: + results = dispatch_controller.get( + session, + fields=["dispatch_id"], + equality_filters={}, + membership_filters={}, + sort_fields=["name"], + max_items=2, + offset=1, + ) + assert len(results) == 2 + assert results[0].dispatch_id == "dispatch_2" diff --git a/tests/covalent_dispatcher_tests/_service/app_test.py b/tests/covalent_dispatcher_tests/_service/app_test.py index 4615e35c5..d02382b98 100644 --- a/tests/covalent_dispatcher_tests/_service/app_test.py +++ b/tests/covalent_dispatcher_tests/_service/app_test.py @@ -103,30 +103,6 @@ def test_db_file(): return MockDataStore(db_URL="sqlite+pysqlite:////tmp/testdb.sqlite") -@pytest.mark.asyncio -async def test_submit(mocker, client): - """Test the submit endpoint.""" - mock_data = json.dumps({}).encode("utf-8") - run_dispatcher_mock = mocker.patch( - "covalent_dispatcher.entry_point.make_dispatch", return_value=DISPATCH_ID - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - response = client.post("/api/v2/dispatches/submit", data=mock_data) - assert response.json() == DISPATCH_ID - run_dispatcher_mock.assert_called_once_with(mock_data) - - -@pytest.mark.asyncio -async def test_submit_exception(mocker, client): - """Test the submit endpoint.""" - mock_data = json.dumps({}).encode("utf-8") - mocker.patch("covalent_dispatcher.entry_point.make_dispatch", side_effect=Exception("mock")) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - response = client.post("/api/v2/dispatches/submit", data=mock_data) - assert response.status_code == 400 - assert response.json()["detail"] == "Failed to submit workflow: mock" - - def test_cancel_dispatch(mocker, app, client): """ Test cancelling dispatch @@ -262,8 +238,8 @@ def test_start(mocker, app, client): assert resp.json() == dispatch_id -def test_export_result_nowait(mocker, app, client, mock_manifest): - dispatch_id = "test_export_result" +def test_export_manifest(mocker, app, client, mock_manifest): + dispatch_id = "test_export_manifest" mock_result_object = MagicMock() mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.NEW_OBJECT)) mocker.patch( @@ -274,29 +250,11 @@ def test_export_result_nowait(mocker, app, client, mock_manifest): ) mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") resp = client.get(f"/api/v2/dispatches/{dispatch_id}") - assert resp.status_code == 200 - assert resp.json()["id"] == dispatch_id - assert resp.json()["status"] == str(RESULT_STATUS.NEW_OBJECT) - assert resp.json()["result_export"] == json.loads(mock_manifest.json()) - - -def test_export_result_wait_not_ready(mocker, app, client, mock_manifest): - dispatch_id = "test_export_result" - mock_result_object = MagicMock() - mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.RUNNING)) - mocker.patch( - "covalent_dispatcher._service.app._try_get_result_object", return_value=mock_result_object - ) - mock_export = mocker.patch( - "covalent_dispatcher._service.app.export_result_manifest", return_value=mock_manifest - ) - mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") - resp = client.get(f"/api/v2/dispatches/{dispatch_id}", params={"wait": True}) - assert resp.status_code == 503 + assert resp.json() == json.loads(mock_manifest.json()) -def test_export_result_bad_dispatch_id(mocker, app, client, mock_manifest): - dispatch_id = "test_export_result" +def test_export_manifest_bad_dispatch_id(mocker, app, client, mock_manifest): + dispatch_id = "test_export_manifest" mock_result_object = MagicMock() mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.NEW_OBJECT)) mocker.patch("covalent_dispatcher._service.app._try_get_result_object", return_value=None) diff --git a/tests/covalent_tests/dispatcher_plugins/local_test.py b/tests/covalent_tests/dispatcher_plugins/local_test.py index d3c09c316..e10c83d31 100644 --- a/tests/covalent_tests/dispatcher_plugins/local_test.py +++ b/tests/covalent_tests/dispatcher_plugins/local_test.py @@ -74,42 +74,6 @@ def workflow(a, b): mock_print.assert_called_once_with(message) -def test_dispatcher_dispatch_single(mocker): - """test dispatching a lattice with submit api""" - - @ct.electron - def task(a, b, c): - return a + b + c - - @ct.lattice - def workflow(a, b): - return task(a, b, c=4) - - # test when api raises an implicit error - - dispatch_id = "test_dispatcher_dispatch_single" - # multistage = False - mocker.patch("covalent._dispatcher_plugins.local.get_config", return_value=False) - - mock_submit_callable = MagicMock(return_value=dispatch_id) - mocker.patch( - "covalent._dispatcher_plugins.local.LocalDispatcher.submit", - return_value=mock_submit_callable, - ) - - mock_reg_tr = mocker.patch( - "covalent._dispatcher_plugins.local.LocalDispatcher.register_triggers" - ) - mock_start = mocker.patch( - "covalent._dispatcher_plugins.local.LocalDispatcher.start", return_value=dispatch_id - ) - - assert dispatch_id == LocalDispatcher.dispatch(workflow)(1, 2) - - mock_submit_callable.assert_called() - mock_start.assert_called() - - def test_dispatcher_dispatch_multi(mocker): """test dispatching a lattice with multistage api""" @@ -131,12 +95,6 @@ def workflow(a, b): return_value=mock_register_callable, ) - mock_submit_callable = MagicMock(return_value=dispatch_id) - mocker.patch( - "covalent._dispatcher_plugins.local.LocalDispatcher.submit", - return_value=mock_submit_callable, - ) - mock_reg_tr = mocker.patch( "covalent._dispatcher_plugins.local.LocalDispatcher.register_triggers" ) @@ -146,7 +104,6 @@ def workflow(a, b): assert dispatch_id == LocalDispatcher.dispatch(workflow)(1, 2) - mock_submit_callable.assert_not_called() mock_register_callable.assert_called() mock_start.assert_called() @@ -172,12 +129,6 @@ def workflow(a, b): return_value=mock_register_callable, ) - mock_submit_callable = MagicMock(return_value=dispatch_id) - mocker.patch( - "covalent._dispatcher_plugins.local.LocalDispatcher.submit", - return_value=mock_submit_callable, - ) - mock_reg_tr = mocker.patch( "covalent._dispatcher_plugins.local.LocalDispatcher.register_triggers" ) @@ -190,41 +141,6 @@ def workflow(a, b): mock_start.assert_not_called() -def test_dispatcher_submit_api(mocker): - """test dispatching a lattice with submit api""" - - @ct.electron - def task(a, b, c): - return a + b + c - - @ct.lattice - def workflow(a, b): - return task(a, b, c=4) - - # test when api raises an implicit error - r = Response() - r.status_code = 404 - r.url = "http://dummy" - r.reason = "dummy reason" - - mocker.patch("covalent._api.apiclient.requests.Session.post", return_value=r) - - with pytest.raises(HTTPError, match="404 Client Error: dummy reason for url: http://dummy"): - dispatch_id = LocalDispatcher.submit(workflow)(1, 2) - assert dispatch_id is None - - # test when api doesn't raise an implicit error - r = Response() - r.status_code = 201 - r.url = "http://dummy" - r._content = b"abcde" - - mocker.patch("covalent._api.apiclient.requests.Session.post", return_value=r) - - dispatch_id = LocalDispatcher.submit(workflow)(1, 2) - assert dispatch_id == "abcde" - - def test_dispatcher_start(mocker): """Test starting a dispatch""" diff --git a/tests/covalent_tests/results_manager_tests/results_manager_test.py b/tests/covalent_tests/results_manager_tests/results_manager_test.py index 06ba8ece0..72c0260d9 100644 --- a/tests/covalent_tests/results_manager_tests/results_manager_test.py +++ b/tests/covalent_tests/results_manager_tests/results_manager_test.py @@ -22,9 +22,9 @@ from unittest.mock import MagicMock import pytest -from requests import Response import covalent as ct +from covalent._api.apiclient import CovalentAPIClient from covalent._results_manager.results_manager import ( MissingLatticeRecordError, Result, @@ -105,24 +105,23 @@ def test_cancel_with_multiple_task_ids(mocker): def test_result_export(mocker): + import json + with tempfile.TemporaryDirectory() as staging_dir: test_manifest = get_test_manifest(staging_dir) dispatch_id = "test_result_export" - mock_body = {"id": "test_result_export", "status": "COMPLETED"} - + mock_response_body = json.loads(test_manifest.model_dump_json()) mock_client = MagicMock() - mock_response = Response() + mock_response = MagicMock() mock_response.status_code = 200 - mock_response.json = MagicMock(return_value=mock_body) + mock_response.json = MagicMock(return_value=mock_response_body) mocker.patch("covalent._api.apiclient.requests.Session.get", return_value=mock_response) - + apiclient = CovalentAPIClient("http://localhost:48008") endpoint = f"/api/v2/dispatches/{dispatch_id}" - assert mock_body == _get_result_export_from_dispatcher( - dispatch_id, wait=False, status_only=True - ) + assert test_manifest == _get_result_export_from_dispatcher(dispatch_id, apiclient) def test_result_manager_assets_local_copies(): @@ -176,11 +175,7 @@ def test_get_result(mocker): # local file copies from server_dir to results_dir. manifest = get_test_manifest(server_dir) - mock_result_export = { - "id": dispatch_id, - "status": "COMPLETED", - "result_export": manifest.dict(), - } + mock_result_export = manifest mocker.patch( "covalent._results_manager.results_manager._get_result_export_from_dispatcher", return_value=mock_result_export, @@ -208,17 +203,9 @@ def test_get_result_sublattice(mocker): # Sublattice manifest sub_manifest = get_test_manifest(server_dir_sub) - mock_result_export = { - "id": dispatch_id, - "status": "COMPLETED", - "result_export": manifest.dict(), - } + mock_result_export = manifest - mock_subresult_export = { - "id": sub_dispatch_id, - "status": "COMPLETED", - "result_export": sub_manifest.dict(), - } + mock_subresult_export = sub_manifest exports = {dispatch_id: mock_result_export, sub_dispatch_id: mock_subresult_export} @@ -277,10 +264,10 @@ def test_get_result_RecursionError(mocker): def test_get_status_only(mocker): """Check get_result when status_only=True""" - dispatch_id = "test_get_result_st" + dispatch_id = "test_get_status_only" mock_get_result_export = mocker.patch( - "covalent._results_manager.results_manager._get_result_export_from_dispatcher", - return_value={"id": dispatch_id, "status": "RUNNING"}, + "covalent._results_manager.results_manager._query_dispatch_status", + return_value="RUNNING", ) status_report = get_result(dispatch_id, status_only=True) diff --git a/tests/covalent_tests/triggers/base_test.py b/tests/covalent_tests/triggers/base_test.py index 0ca0c8d7c..b70aee295 100644 --- a/tests/covalent_tests/triggers/base_test.py +++ b/tests/covalent_tests/triggers/base_test.py @@ -17,7 +17,6 @@ from unittest import mock import pytest -from fastapi.responses import JSONResponse from covalent.triggers import BaseTrigger @@ -46,7 +45,6 @@ def test_register(mocker): @pytest.mark.parametrize( "use_internal_func, mock_status", [ - (True, JSONResponse("mock")), (True, {"status": "mocked-status"}), (False, {"status": "mocked-status"}), ], @@ -61,27 +59,15 @@ def test_get_status(mocker, use_internal_func, mock_status): base_trigger.use_internal_funcs = use_internal_func if use_internal_func: - mocker.patch("covalent_dispatcher._service.app.export_result") - - mock_fut_res = mock.Mock() - mock_fut_res.result.return_value = mock_status - mock_run_coro = mocker.patch( - "covalent.triggers.base.asyncio.run_coroutine_threadsafe", return_value=mock_fut_res + mock_bulk_get_res = mock.Mock() + mock_bulk_get_res.dispatches = [mock.Mock()] + mock_bulk_get_res.dispatches[0].status = mock_status["status"] + mocker.patch( + "covalent_dispatcher._service.app.get_dispatches_bulk", return_value=mock_bulk_get_res ) - if not isinstance(mock_status, dict): - mock_json_loads = mocker.patch( - "covalent.triggers.base.json.loads", return_value={"status": "mocked-status"} - ) - status = base_trigger._get_status() - mock_run_coro.assert_called_once() - mock_fut_res.result.assert_called_once() - - if not isinstance(mock_status, dict): - mock_json_loads.assert_called_once() - else: mock_get_status = mocker.patch("covalent.get_result", return_value=mock_status) status = base_trigger._get_status()