From 1995292b2d75c4335cf2afc7fb8a421ad8157ee0 Mon Sep 17 00:00:00 2001 From: Casey Jao Date: Thu, 19 Oct 2023 12:49:15 -0400 Subject: [PATCH 1/2] Memory improvements (2/3): Migrate core Covalent to new data access layer (#1729) * Mem (1/3): introduce new DAL and core Pydantic models * Mem (1/3): Fix schemas * Mem (1/3): DAL PR: temporarily redirect core dispatcher tests * Mem (1/3): DAL PR: fix tests Introduce temporary implementations of `update._node` and `update.lattice_data`. These will be removed once core covalent is transitioned to the new DAL. * Mem (1/3): Fix requirements workflow Change abs imports to rel imports. Needed to please pip-missing-reqs. * Mem (1/3): Uncomment boilerplate in disabled unit tests * Mem (1/3): Add unit test for format_server_url * Mem (1/3): defer copy_file_locally to next PR * Mem (1/3): update changelog * Mem (1/3): Core DAL improvements - Improve type hints * Mem (2/3): Revert "DAL PR: fix tests" * Mem (2/3): Revert "Mem (1/3): defer copy_file_locally to next PR" This reverts commit a3ab70b91c3a527461fc915f4888a4ccbe158180. * Mem (2/3): Revert "Mem (1/3): Uncomment boilerplate in disabled unit tests" * Mem (2/3): Revert "Mem (1/3): Fix requirements workflow" * Mem (2/3): Revert "DAL PR: temporarily redirect core dispatcher tests" This reverts commit 388df38236ebec7555ec7e83ffc1834427b46650. * Mem (2/3): Core migration -- Re-enable electron_tests * Mem (2/3): migrate core to new DAL * Mem (2/3): redirect dispatcher to in-memory runner Make API endpoints restful Cancel all dispatches upon shutdown * Mem (2/3): Update changelog * updated license to Apache * some fixes to make dispatching work * reverted some changes to make the funcs sync * changes after reviewing * fixing tests, removed cancel_requested from Electron model * fixing tests, sublattice issue fixed * fixing tests, cancellation issue fixed * fixing tests, sdk issue fixed * cancel_tasks import location changed * tg utils import location changed * fixing tests * fixing tests * fixing tests * fixing tests * fixing tests * fixing tests * fixing tests * fixing tests * fixing tests * fixing tests * fixing tests * fixing tests * fixing tests, skipping some ui backend tests * Fix UI backend tests * Undo some changes to UI assert data * renamed core dispatcher tests * removing the _db/load.py file --------- Co-authored-by: sankalp --- .github/workflows/requirements.yml | 7 +- .github/workflows/tests.yml | 6 +- CHANGELOG.md | 2 + covalent/__init__.py | 6 +- covalent/_api/__init__.py | 15 + covalent/_api/apiclient.py | 128 +++ covalent/_dispatcher_plugins/local.py | 525 ++++++++--- covalent/_results_manager/result.py | 11 +- covalent/_results_manager/results_manager.py | 505 ++++++++--- covalent/_serialize/__init__.py | 2 +- covalent/_serialize/result.py | 1 - covalent/_shared_files/schemas/__init__.py | 2 +- covalent/_shared_files/schemas/common.py | 4 +- covalent/_shared_files/utils.py | 19 + covalent/_workflow/electron.py | 3 +- covalent/_workflow/transport_graph_ops.py | 200 ----- covalent/_workflow/transportable_object.py | 78 +- covalent/executor/executor_plugins/dask.py | 4 +- covalent/triggers/base.py | 31 +- covalent_dispatcher/__init__.py | 2 +- covalent_dispatcher/_cli/service.py | 6 + covalent_dispatcher/_core/__init__.py | 3 +- covalent_dispatcher/_core/data_manager.py | 452 ++++------ .../_core/data_modules/asset_manager.py | 89 ++ .../_core/data_modules/dispatch.py | 66 ++ .../_core/data_modules/electron.py | 110 +++ .../_core/data_modules/graph.py | 102 +++ .../_core/data_modules/importer.py | 148 +++ .../_core/data_modules/job_manager.py | 8 +- .../_core/data_modules/lattice.py | 38 + .../_core/data_modules/utils.py | 31 + covalent_dispatcher/_core/dispatcher.py | 627 ++++++++----- .../_core/dispatcher_modules/__init__.py | 15 + .../_core/dispatcher_modules/caches.py | 101 +++ .../_core/dispatcher_modules/store.py | 66 ++ covalent_dispatcher/_core/execution.py | 25 +- covalent_dispatcher/_core/runner.py | 272 ++---- .../_core/runner_modules/cancel.py | 146 +++ .../_core/runner_modules/executor_proxy.py | 59 +- .../_core/runner_modules/jobs.py | 126 +++ .../_core/runner_modules/utils.py | 43 + covalent_dispatcher/_dal/__init__.py | 2 +- .../_dal/db_interfaces/__init__.py | 2 +- .../_dal/exporters/__init__.py | 2 +- .../_dal/exporters/electron.py | 5 +- covalent_dispatcher/_dal/exporters/lattice.py | 9 +- covalent_dispatcher/_dal/exporters/result.py | 17 +- covalent_dispatcher/_dal/exporters/tg.py | 9 +- .../_dal/importers/__init__.py | 2 +- .../_dal/importers/electron.py | 15 +- covalent_dispatcher/_dal/importers/lattice.py | 17 +- covalent_dispatcher/_dal/importers/result.py | 18 +- covalent_dispatcher/_dal/importers/tg.py | 27 +- covalent_dispatcher/_dal/tg_ops.py | 6 +- covalent_dispatcher/_dal/utils/__init__.py | 2 +- covalent_dispatcher/_dal/utils/uri_filters.py | 13 +- covalent_dispatcher/_db/__init__.py | 2 +- covalent_dispatcher/_db/load.py | 223 ----- covalent_dispatcher/_db/models.py | 3 - covalent_dispatcher/_db/update.py | 96 +- covalent_dispatcher/_db/upsert.py | 30 +- covalent_dispatcher/_db/write_result_to_db.py | 2 - covalent_dispatcher/_service/app.py | 491 ++++++---- covalent_dispatcher/_service/assets.py | 526 +++++++++++ .../_service}/heartbeat.py | 50 -- covalent_dispatcher/_service/models.py | 126 +++ covalent_dispatcher/entry_point.py | 102 ++- covalent_ui/api/main.py | 2 +- covalent_ui/api/v1/data_layer/electron_dal.py | 2 +- covalent_ui/api/v1/data_layer/lattice_dal.py | 2 + .../api/v1/database/schema/electron.py | 6 +- .../api/v1/database/schema/lattices.py | 6 + covalent_ui/api/v1/models/lattices_model.py | 1 - .../v1/routes/end_points/electron_routes.py | 161 ++-- .../api/v1/routes/end_points/lattice_route.py | 79 +- covalent_ui/api/v1/routes/routes.py | 8 +- covalent_ui/api/v1/utils/file_handle.py | 9 + requirements.txt | 1 + tests/__init__.py | 15 + tests/covalent_dispatcher_tests/__init__.py | 15 + .../_cli/__init__.py | 15 + .../_core/__init__.py | 15 + .../_core/data_manager_test.py | 472 ++++++++++ .../asset_manager_db_integration_test.py | 165 ++++ .../_core/data_modules/dispatch_test.py | 69 ++ .../_core/data_modules/graph_test.py | 94 ++ .../_core/data_modules/importer_test.py | 120 +++ .../_core/data_modules/job_manager_test.py | 10 +- .../_core/data_modules/lattice_query_test.py | 41 + .../_core/dispatcher_db_integration_test.py | 322 +++++++ .../_core/dispatcher_test.py | 849 ++++++++++++++++++ .../_core/execution_test.py | 279 ++++++ .../_core/runner_db_integration_test.py | 123 +++ .../_core/runner_modules/cancel_test.py | 102 +++ .../_core/runner_modules/jobs_test.py | 99 ++ .../_core/runner_test.py | 281 ++---- .../_core/tmp_data_manager_test.py | 543 ----------- .../_core/tmp_dispatcher_test.py | 592 ------------ .../_core/tmp_execution_test.py | 391 -------- .../_dal/exporters/result_export_test.py | 2 +- .../_dal/importers/result_import_test.py | 2 +- .../_dal/tg_ops_test.py | 2 +- .../_db/load_test.py | 159 ---- .../_db/update_test.py | 93 +- .../_db/write_result_to_db_test.py | 10 - .../_object_store/__init__.py | 2 +- .../_object_store/local_test.py | 2 +- .../_service/app_test.py | 312 ++++--- .../_service/assets_test.py | 739 +++++++++++++++ .../entry_point_test.py | 115 ++- tests/covalent_tests/__init__.py | 15 + .../dispatcher_plugins/__init__.py | 15 + .../dispatcher_plugins/local_test.py | 570 ++++++++++-- .../covalent_tests/file_transfer/__init__.py | 2 +- .../results_manager_tests/__init__.py | 15 + .../results_manager_test.py | 324 +++++-- tests/covalent_tests/triggers/base_test.py | 34 +- .../workflow/electron_metadata_test.py | 1 - .../covalent_tests/workflow/electron_test.py | 38 +- .../workflow/transport_graph_ops_test.py | 248 ----- .../utils/assert_data/electrons.py | 2 +- .../utils/assert_data/lattices.py | 4 +- .../utils/data/electrons.json | 152 ++++ .../utils/data/lattices.json | 20 +- .../utils/seed_script.py | 4 +- tests/functional_tests/__init__.py | 15 + tests/functional_tests/file_transfer_test.py | 4 +- tests/functional_tests/local_executor_test.py | 32 + .../functional_tests/results_manager_test.py | 77 ++ tests/functional_tests/triggers_test.py | 2 +- .../workflow_cancellation_test.py | 10 +- tests/functional_tests/workflow_stack_test.py | 40 +- tests/load_tests/locustfiles/basic.py | 8 +- tests/load_tests/workflows/horizontal.py | 16 + tests/stress_tests/benchmarks/__init__.py | 15 + 135 files changed, 9219 insertions(+), 4652 deletions(-) create mode 100644 covalent/_api/__init__.py create mode 100644 covalent/_api/apiclient.py delete mode 100644 covalent/_workflow/transport_graph_ops.py create mode 100644 covalent_dispatcher/_core/data_modules/asset_manager.py create mode 100644 covalent_dispatcher/_core/data_modules/dispatch.py create mode 100644 covalent_dispatcher/_core/data_modules/electron.py create mode 100644 covalent_dispatcher/_core/data_modules/graph.py create mode 100644 covalent_dispatcher/_core/data_modules/importer.py create mode 100644 covalent_dispatcher/_core/data_modules/lattice.py create mode 100644 covalent_dispatcher/_core/data_modules/utils.py create mode 100644 covalent_dispatcher/_core/dispatcher_modules/__init__.py create mode 100644 covalent_dispatcher/_core/dispatcher_modules/caches.py create mode 100644 covalent_dispatcher/_core/dispatcher_modules/store.py create mode 100644 covalent_dispatcher/_core/runner_modules/cancel.py create mode 100644 covalent_dispatcher/_core/runner_modules/jobs.py create mode 100644 covalent_dispatcher/_core/runner_modules/utils.py delete mode 100644 covalent_dispatcher/_db/load.py create mode 100644 covalent_dispatcher/_service/assets.py rename {covalent_ui => covalent_dispatcher/_service}/heartbeat.py (58%) create mode 100644 covalent_dispatcher/_service/models.py create mode 100644 tests/covalent_dispatcher_tests/_core/data_manager_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/data_modules/asset_manager_db_integration_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/data_modules/dispatch_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/data_modules/graph_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/data_modules/importer_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/data_modules/lattice_query_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/dispatcher_db_integration_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/dispatcher_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/execution_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/runner_db_integration_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/runner_modules/cancel_test.py create mode 100644 tests/covalent_dispatcher_tests/_core/runner_modules/jobs_test.py delete mode 100644 tests/covalent_dispatcher_tests/_core/tmp_data_manager_test.py delete mode 100644 tests/covalent_dispatcher_tests/_core/tmp_dispatcher_test.py delete mode 100644 tests/covalent_dispatcher_tests/_core/tmp_execution_test.py delete mode 100644 tests/covalent_dispatcher_tests/_db/load_test.py create mode 100644 tests/covalent_dispatcher_tests/_service/assets_test.py delete mode 100644 tests/covalent_tests/workflow/transport_graph_ops_test.py create mode 100644 tests/functional_tests/results_manager_test.py diff --git a/.github/workflows/requirements.yml b/.github/workflows/requirements.yml index 3ab5ffabe..8d692b0fe 100644 --- a/.github/workflows/requirements.yml +++ b/.github/workflows/requirements.yml @@ -39,14 +39,17 @@ jobs: run: python -m pip install pip-check-reqs - name: Check extra core requirements - run: pip-extra-reqs -r werkzeug covalent covalent_dispatcher covalent_ui --ignore-requirement=qiskit --ignore-requirement=qiskit-ibm-provider --ignore-requirement=amazon-braket-pennylane-plugin + run: pip-extra-reqs -r werkzeug -r python-multipart covalent covalent_dispatcher covalent_ui --ignore-requirement=qiskit --ignore-requirement=qiskit-ibm-provider --ignore-requirement=amazon-braket-pennylane-plugin - name: Check missing SDK requirements run: > pip-missing-reqs --ignore-module=covalent_ui.* + --ignore-module=covalent.* --ignore-module=urllib3.* --ignore-module=pkg_resources + --ignore-module=covalent/_dispatcher_plugins + --ignore-module=covalent/_shared_files --ignore-file=covalent/executor/** --ignore-file=covalent/triggers/** --ignore-file=covalent/cloud_resource_manager/** @@ -58,7 +61,7 @@ jobs: pip-missing-reqs --ignore-module=covalent_ui.* --ignore-module=covalent.* - --ignore-module=covalent_dispatcher + --ignore-module=covalent_dispatcher.* --ignore-module=distributed.* covalent_dispatcher diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a9befb923..99c5d6075 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -236,7 +236,7 @@ jobs: if: > steps.modified-files.outputs.sdk == 'true' || env.BUILD_AND_RUN_ALL - run: PYTHONPATH=$PWD/ pytest -vvs --reruns=5 tests/covalent_tests --cov=covalent --cov-config=.coveragerc + run: PYTHONPATH=$PWD/ pytest -vvs --reruns=5 tests/covalent_tests --cov=covalent --cov-config=.coveragerc - name: Generate SDK coverage report id: sdk-coverage @@ -248,7 +248,7 @@ jobs: if: > steps.modified-files.outputs.dispatcher == 'true' || env.BUILD_AND_RUN_ALL - run: PYTHONPATH=$PWD/ pytest -vvs --reruns=5 tests/covalent_dispatcher_tests --cov=covalent_dispatcher --cov-config=.coveragerc + run: PYTHONPATH=$PWD/ pytest -vvs --reruns=5 tests/covalent_dispatcher_tests --cov=covalent_dispatcher --cov-config=.coveragerc - name: Generate dispatcher coverage report id: dispatcher-coverage @@ -260,7 +260,7 @@ jobs: if: > steps.modified-files.outputs.ui_backend == 'true' || env.BUILD_AND_RUN_ALL - run: PYTHONPATH=$PWD/ pytest -vvs --reruns=5 tests/covalent_ui_backend_tests --cov=covalent_ui --cov-config=.coveragerc + run: PYTHONPATH=$PWD/ pytest -vvs --reruns=5 tests/covalent_ui_backend_tests --cov=covalent_ui --cov-config=.coveragerc - name: Generate UI backend coverage report id: ui-backend-coverage diff --git a/CHANGELOG.md b/CHANGELOG.md index 323c3a90f..9dc84c08b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Removed strict version pins on `lmdbm`, `mpire`, `orjson`, and `pennylane` - Changed license to Apache - Improved error handling in generate_docs.py +- [Significant Changes] Migrated core server-side code to new data access layer. ### Added @@ -72,6 +73,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Rsync command fixed to recursively copy files when using SSH - Removed accidentally added migrations build files - Updated migration script to add a default value for `qelectron_data_exists` in the `electrons` table since it cannot be nullable +- Reduced server memory consumption during workflow processing ### Changed diff --git a/covalent/__init__.py b/covalent/__init__.py index 66c4736ed..aed63a724 100644 --- a/covalent/__init__.py +++ b/covalent/__init__.py @@ -25,7 +25,11 @@ from ._dispatcher_plugins import local_redispatch as redispatch # nopycln: import from ._dispatcher_plugins import stop_triggers # nopycln: import from ._file_transfer import strategies as fs_strategies # nopycln: import -from ._results_manager.results_manager import cancel, get_result, sync # nopycln: import +from ._results_manager.results_manager import ( # nopycln: import + cancel, + get_result, + get_result_manager, +) from ._shared_files.config import get_config, reload_config, set_config # nopycln: import from ._shared_files.util_classes import RESULT_STATUS as status # nopycln: import from ._workflow import ( # nopycln: import diff --git a/covalent/_api/__init__.py b/covalent/_api/__init__.py new file mode 100644 index 000000000..21d7eaa5c --- /dev/null +++ b/covalent/_api/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/covalent/_api/apiclient.py b/covalent/_api/apiclient.py new file mode 100644 index 000000000..c4c2a5492 --- /dev/null +++ b/covalent/_api/apiclient.py @@ -0,0 +1,128 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""API client""" + +import json +import os +from typing import Dict + +import requests +from requests.adapters import HTTPAdapter + + +class CovalentAPIClient: + """Thin wrapper around Requests to centralize error handling.""" + + def __init__(self, dispatcher_addr: str, adapter: HTTPAdapter = None, auto_raise: bool = True): + self.dispatcher_addr = dispatcher_addr + self.adapter = adapter + self.auto_raise = auto_raise + + def prepare_headers(self, **kwargs): + extra_headers = CovalentAPIClient.get_extra_headers() + headers = kwargs.get("headers", {}) + if headers: + kwargs.pop("headers") + headers.update(extra_headers) + return headers + + def get(self, endpoint: str, **kwargs): + headers = self.prepare_headers(**kwargs) + url = self.dispatcher_addr + endpoint + try: + with requests.Session() as session: + if self.adapter: + session.mount("http://", self.adapter) + + r = session.get(url, headers=headers, **kwargs) + + if self.auto_raise: + r.raise_for_status() + + except requests.exceptions.ConnectionError: + message = f"The Covalent server cannot be reached at {url}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." + print(message) + raise + + return r + + def put(self, endpoint: str, **kwargs): + headers = self.prepare_headers() + url = self.dispatcher_addr + endpoint + try: + with requests.Session() as session: + if self.adapter: + session.mount("http://", self.adapter) + + r = session.put(url, headers=headers, **kwargs) + + if self.auto_raise: + r.raise_for_status() + except requests.exceptions.ConnectionError: + message = f"The Covalent server cannot be reached at {url}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." + print(message) + raise + + return r + + def post(self, endpoint: str, **kwargs): + headers = self.prepare_headers() + url = self.dispatcher_addr + endpoint + try: + with requests.Session() as session: + if self.adapter: + session.mount("http://", self.adapter) + + r = session.post(url, headers=headers, **kwargs) + + if self.auto_raise: + r.raise_for_status() + except requests.exceptions.ConnectionError: + message = f"The Covalent server cannot be reached at {url}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." + print(message) + raise + + return r + + def delete(self, endpoint: str, **kwargs): + headers = self.prepare_headers() + url = self.dispatcher_addr + endpoint + try: + with requests.Session() as session: + if self.adapter: + session.mount("http://", self.adapter) + + r = session.delete(url, headers=headers, **kwargs) + + if self.auto_raise: + r.raise_for_status() + except requests.exceptions.ConnectionError: + message = f"The Covalent server cannot be reached at {url}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." + print(message) + raise + + return r + + @classmethod + def get_extra_headers(headers: Dict) -> Dict: + # This is expected to be a JSONified dictionary + data = os.environ.get("COVALENT_EXTRA_HEADERS") + if data: + return json.loads(data) + else: + return {} diff --git a/covalent/_dispatcher_plugins/local.py b/covalent/_dispatcher_plugins/local.py index b3a82e296..11d0c38cb 100644 --- a/covalent/_dispatcher_plugins/local.py +++ b/covalent/_dispatcher_plugins/local.py @@ -14,18 +14,28 @@ # See the License for the specific language governing permissions and # limitations under the License. -import json +import tempfile from copy import deepcopy from functools import wraps +from pathlib import Path from typing import Callable, Dict, List, Optional, Union -import requests +from furl import furl -from .._results_manager import wait +from .._api.apiclient import CovalentAPIClient as APIClient from .._results_manager.result import Result -from .._results_manager.results_manager import get_result +from .._results_manager.results_manager import get_result, get_result_manager +from .._serialize.result import ( + extract_assets, + merge_response_manifest, + serialize_result, + strip_local_uris, +) from .._shared_files import logger from .._shared_files.config import get_config +from .._shared_files.schemas.asset import AssetSchema +from .._shared_files.schemas.result import ResultSchema +from .._shared_files.utils import copy_file_locally, format_server_url from .._workflow.lattice import Lattice from ..triggers import BaseTrigger from .base import BaseDispatcher @@ -33,36 +43,55 @@ app_log = logger.app_log log_stack_info = logger.log_stack_info +dispatch_cache_dir = Path(get_config("sdk.dispatch_cache_dir")) +dispatch_cache_dir.mkdir(parents=True, exist_ok=True) -def get_redispatch_request_body( + +def get_redispatch_request_body_v2( dispatch_id: str, - new_args: Optional[List] = None, - new_kwargs: Optional[Dict] = None, - replace_electrons: Optional[Dict[str, Callable]] = None, - reuse_previous_results: bool = False, -) -> Dict: - """Get request body for re-dispatching a workflow.""" - if new_args is None: - new_args = [] - if new_kwargs is None: - new_kwargs = {} + staging_dir: str, + new_args: List, + new_kwargs: Dict, + replace_electrons: Optional[Dict[str, Callable]], + dispatcher_addr: str = None, +) -> ResultSchema: + rm = get_result_manager(dispatch_id, dispatcher_addr=dispatcher_addr, wait=True) + manifest = ResultSchema.parse_obj(rm._manifest) + + # If no changes to inputs or electron, just retry the dispatch + if not new_args and not new_kwargs and not replace_electrons: + manifest.reset_metadata() + app_log.debug("Resubmitting manifest only") + return manifest + + # In all other cases we need to rebuild the graph + rm.download_lattice_asset("workflow_function") + rm.download_lattice_asset("workflow_function_string") + rm.load_lattice_asset("workflow_function") + rm.load_lattice_asset("workflow_function_string") + if replace_electrons is None: replace_electrons = {} - if new_args or new_kwargs: - res = get_result(dispatch_id) - lat = res.lattice - lat.build_graph(*new_args, **new_kwargs) - json_lattice = lat.serialize_to_json() - else: - json_lattice = None - updates = {k: v.electron_object.as_transportable_dict for k, v in replace_electrons.items()} - return { - "json_lattice": json_lattice, - "dispatch_id": dispatch_id, - "electron_updates": updates, - "reuse_previous_results": reuse_previous_results, - } + lat = rm.result_object.lattice + + if replace_electrons: + lat._replace_electrons = replace_electrons + + # If lattice inputs are not supplied, retrieve them from the previous dispatch + if not new_args and not new_kwargs: + rm.download_lattice_asset("inputs") + rm.load_lattice_asset("inputs") + res_obj = rm.result_object + inputs = res_obj.inputs.get_deserialized() + new_args = inputs["args"] + new_kargs = inputs["kwargs"] + + lat.build_graph(*new_args, **new_kwargs) + if replace_electrons: + del lat.__dict__["_replace_electrons"] + + return serialize_result(Result(lat), staging_dir) class LocalDispatcher(BaseDispatcher): @@ -75,6 +104,7 @@ class LocalDispatcher(BaseDispatcher): def dispatch( orig_lattice: Lattice, dispatcher_addr: str = None, + *, disable_run: bool = False, ) -> Callable: """ @@ -86,20 +116,26 @@ def dispatch( Args: orig_lattice: The lattice/workflow to send to the dispatcher server. - dispatcher_addr: The address of the dispatcher server. If None then defaults to the address set in Covalent's config. - disable_run: Whether to disable running the workflow and rather just save it on Covalent's server for later execution + dispatcher_addr: The address of the dispatcher server. If None then defaults to the address set in Covalent's config. + + Kwargs: + disable_run: Whether to disable running the workflow and rather just save it on Covalent's server for later execution. Returns: Wrapper function which takes the inputs of the workflow as arguments """ - if dispatcher_addr is None: - dispatcher_addr = ( - "http://" - + get_config("dispatcher.address") - + ":" - + str(get_config("dispatcher.port")) - ) + multistage = get_config("sdk.multistage_dispatch") == "true" + + # Extract triggers here + if "triggers" in orig_lattice.metadata: + triggers_data = orig_lattice.metadata.pop("triggers") + else: + triggers_data = None + + if not disable_run: + # Determine whether to disable first run based on trigger_data + disable_run = triggers_data is not None @wraps(orig_lattice) def wrapper(*args, **kwargs) -> str: @@ -115,8 +151,61 @@ def wrapper(*args, **kwargs) -> str: The dispatch id of the workflow. """ - # To access the disable_run passed to the dispatch function - nonlocal disable_run + if multistage: + dispatch_id = LocalDispatcher.register(orig_lattice, dispatcher_addr)( + *args, **kwargs + ) + else: + dispatch_id = LocalDispatcher.submit(orig_lattice, dispatcher_addr)( + *args, **kwargs + ) + + if triggers_data: + LocalDispatcher.register_triggers(triggers_data, dispatch_id) + + if not disable_run: + return LocalDispatcher.start(dispatch_id, dispatcher_addr) + else: + return dispatch_id + + return wrapper + + @staticmethod + def submit( + orig_lattice: Lattice, + dispatcher_addr: str = None, + ) -> Callable: + """ + Wrapping the dispatching functionality to allow input passing + and server address specification. + + Afterwards, send the lattice to the dispatcher server and return + the assigned dispatch id. + + Args: + orig_lattice: The lattice/workflow to send to the dispatcher server. + dispatcher_addr: The address of the dispatcher server. If None then then defaults to the address set in Covalent's config. + + Returns: + Wrapper function which takes the inputs of the workflow as arguments + """ + + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + @wraps(orig_lattice) + def wrapper(*args, **kwargs) -> str: + """ + Send the lattice to the dispatcher server and return + the assigned dispatch id. + + Args: + *args: The inputs of the workflow. + **kwargs: The keyword arguments of the workflow. + + Returns: + The dispatch id of the workflow. + """ if not isinstance(orig_lattice, Lattice): message = f"Dispatcher expected a Lattice, received {type(orig_lattice)} instead." @@ -129,42 +218,41 @@ def wrapper(*args, **kwargs) -> str: # Serialize the transport graph to JSON json_lattice = lattice.serialize_to_json() + endpoint = "/api/v2/dispatches/submit" + r = APIClient(dispatcher_addr).post(endpoint, data=json_lattice) + r.raise_for_status() + return r.content.decode("utf-8").strip().replace('"', "") - # Extract triggers here - json_lattice = json.loads(json_lattice) - triggers_data = json_lattice["metadata"].pop("triggers") - - if not disable_run: - # Determine whether to disable first run based on trigger_data - disable_run = triggers_data is not None - - json_lattice = json.dumps(json_lattice) + return wrapper - submit_dispatch_url = f"{dispatcher_addr}/api/submit" + @staticmethod + def start( + dispatch_id: str, + dispatcher_addr: str = None, + ) -> Callable: + """ + Wrapping the dispatching functionality to allow input passing + and server address specification. - lattice_dispatch_id = None - try: - r = requests.post( - submit_dispatch_url, - data=json_lattice, - params={"disable_run": disable_run}, - timeout=5, - ) - r.raise_for_status() - lattice_dispatch_id = r.content.decode("utf-8").strip().replace('"', "") - except requests.exceptions.ConnectionError: - message = f"The Covalent server cannot be reached at {dispatcher_addr}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." - print(message) - return + Afterwards, send the lattice to the dispatcher server and return + the assigned dispatch id. - if not disable_run or triggers_data is None: - return lattice_dispatch_id + Args: + orig_lattice: The lattice/workflow to send to the dispatcher server. + dispatcher_addr: The address of the dispatcher server. If None then then defaults to the address set in Covalent's config. - LocalDispatcher.register_triggers(triggers_data, lattice_dispatch_id) + Returns: + Wrapper function which takes the inputs of the workflow as arguments + """ - return lattice_dispatch_id + if dispatcher_addr is None: + dispatcher_addr = format_server_url() - return wrapper + endpoint = f"/api/v2/dispatches/{dispatch_id}/status" + body = {"status": "RUNNING"} + r = APIClient(dispatcher_addr).put(endpoint, json=body) + r.raise_for_status() + return r.content.decode("utf-8").strip().replace('"', "") @staticmethod def dispatch_sync( @@ -187,12 +275,7 @@ def dispatch_sync( """ if dispatcher_addr is None: - dispatcher_addr = ( - "http://" - + get_config("dispatcher.address") - + ":" - + str(get_config("dispatcher.port")) - ) + dispatcher_addr = format_server_url() @wraps(lattice) def wrapper(*args, **kwargs) -> Result: @@ -210,7 +293,7 @@ def wrapper(*args, **kwargs) -> Result: return get_result( LocalDispatcher.dispatch(lattice, dispatcher_addr)(*args, **kwargs), - wait=wait.EXTREME, + wait=True, ) return wrapper @@ -221,14 +304,13 @@ def redispatch( dispatcher_addr: str = None, replace_electrons: Dict[str, Callable] = None, reuse_previous_results: bool = False, - is_pending: bool = False, ) -> Callable: """ Wrapping the dispatching functionality to allow input passing and server address specification. Args: dispatch_id: The dispatch id of the workflow to re-dispatch. - dispatcher_addr: The address of the dispatcher server. If None then then defaults to the address set in Covalent's config. + dispatcher_addr: The address of the dispatcher server. If None then defaults to the address set in Covalent's config. replace_electrons: A dictionary of electron names and the new electron to replace them with. reuse_previous_results: Boolean value whether to reuse the results from the previous dispatch. @@ -237,45 +319,17 @@ def redispatch( """ if dispatcher_addr is None: - dispatcher_addr = ( - "http://" - + get_config("dispatcher.address") - + ":" - + str(get_config("dispatcher.port")) - ) + dispatcher_addr = format_server_url() if replace_electrons is None: replace_electrons = {} - def func(*new_args, **new_kwargs): - """ - Prepare the redispatch request body and redispatch the workflow. - - Args: - *args: The inputs of the workflow. - **kwargs: The keyword arguments of the workflow. - - Returns: - The result of the executed workflow. - - """ - body = get_redispatch_request_body( - dispatch_id, new_args, new_kwargs, replace_electrons, reuse_previous_results - ) - redispatch_url = f"{dispatcher_addr}/api/redispatch" - try: - r = requests.post( - redispatch_url, json=body, params={"is_pending": is_pending}, timeout=5 - ) - r.raise_for_status() - except requests.exceptions.ConnectionError: - message = f"The Covalent server cannot be reached at {dispatcher_addr}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." - print(message) - return - - return r.content.decode("utf-8").strip().replace('"', "") - - return func + return LocalDispatcher.register_redispatch( + dispatch_id=dispatch_id, + dispatcher_addr=dispatcher_addr, + replace_electrons=replace_electrons, + reuse_previous_results=reuse_previous_results, + ) @staticmethod def register_triggers(triggers_data: List[Dict], dispatch_id: str) -> None: @@ -324,9 +378,252 @@ def stop_triggers( if isinstance(dispatch_ids, str): dispatch_ids = [dispatch_ids] - r = requests.post(stop_triggers_url, json=dispatch_ids) + endpoint = "/api/triggers/stop_observe" + r = APIClient(triggers_server_addr).post(endpoint, json=dispatch_ids) r.raise_for_status() app_log.debug("Triggers for following dispatch_ids have stopped observing:") for d_id in dispatch_ids: app_log.debug(d_id) + + @staticmethod + def register( + orig_lattice: Lattice, + dispatcher_addr: str = None, + ) -> Callable: + """ + Wrapping the dispatching functionality to allow input passing + and server address specification. + + Afterwards, send the lattice to the dispatcher server and return + the assigned dispatch id. + + Args: + orig_lattice: The lattice/workflow to send to the dispatcher server. + dispatcher_addr: The address of the dispatcher server. If None then then defaults to the address set in Covalent's config. + + Returns: + Wrapper function which takes the inputs of the workflow as arguments + """ + + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + @wraps(orig_lattice) + def wrapper(*args, **kwargs) -> str: + """ + Send the lattice to the dispatcher server and return + the assigned dispatch id. + + Args: + *args: The inputs of the workflow. + **kwargs: The keyword arguments of the workflow. + + Returns: + The dispatch id of the workflow. + """ + + if not isinstance(orig_lattice, Lattice): + message = f"Dispatcher expected a Lattice, received {type(orig_lattice)} instead." + app_log.error(message) + raise TypeError(message) + + lattice = deepcopy(orig_lattice) + + lattice.build_graph(*args, **kwargs) + + with tempfile.TemporaryDirectory() as tmp_dir: + manifest = LocalDispatcher.prepare_manifest(lattice, tmp_dir) + LocalDispatcher.register_manifest(manifest, dispatcher_addr) + + dispatch_id = manifest.metadata.dispatch_id + + path = dispatch_cache_dir / f"{dispatch_id}" + + with open(path, "w") as f: + f.write(manifest.json()) + + LocalDispatcher.upload_assets(manifest) + + return dispatch_id + + return wrapper + + @staticmethod + def register_redispatch( + dispatch_id: str, + dispatcher_addr: str = None, + replace_electrons: Dict[str, Callable] = None, + reuse_previous_results: bool = False, + ) -> Callable: + """ + Wrapping the dispatching functionality to allow input passing and server address specification. + + Args: + dispatch_id: The dispatch id of the workflow to re-dispatch. + dispatcher_addr: The address of the dispatcher server. If None then defaults to the address set in Covalent's config. + replace_electrons: A dictionary of electron names and the new electron to replace them with. + reuse_previous_results: Boolean value whether to reuse the results from the previous dispatch. + + Returns: + Wrapper function which takes the inputs of the workflow as arguments. + """ + + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + def func(*new_args, **new_kwargs): + """ + Prepare the redispatch request body and redispatch the workflow. + + Args: + *args: The inputs of the workflow. + **kwargs: The keyword arguments of the workflow. + + Returns: + The result of the executed workflow. + """ + + with tempfile.TemporaryDirectory() as staging_dir: + manifest = get_redispatch_request_body_v2( + dispatch_id=dispatch_id, + staging_dir=staging_dir, + new_args=new_args, + new_kwargs=new_kwargs, + replace_electrons=replace_electrons, + dispatcher_addr=dispatcher_addr, + ) + + LocalDispatcher.register_derived_manifest( + manifest, + dispatch_id, + reuse_previous_results=reuse_previous_results, + dispatcher_addr=dispatcher_addr, + ) + + redispatch_id = manifest.metadata.dispatch_id + + path = dispatch_cache_dir / f"{redispatch_id}" + + with open(path, "w") as f: + f.write(manifest.json()) + + LocalDispatcher.upload_assets(manifest) + + return LocalDispatcher.start(redispatch_id, dispatcher_addr) + + return func + + @staticmethod + def prepare_manifest(lattice, storage_path) -> ResultSchema: + """Prepare a built-out lattice for submission""" + + result_object = Result(lattice) + return serialize_result(result_object, storage_path) + + @staticmethod + def register_manifest( + manifest: ResultSchema, + dispatcher_addr: Optional[str] = None, + parent_dispatch_id: Optional[str] = None, + push_assets: bool = True, + ) -> ResultSchema: + """Submits a manifest for registration. + + Returns: + Dictionary representation of manifest with asset remote_uris filled in + + Side effect: + If push_assets is False, the server will + automatically pull the task assets from the submitted asset URIs. + """ + + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + if push_assets: + stripped = strip_local_uris(manifest) + else: + stripped = manifest + + endpoint = "/api/v2/dispatches" + + if parent_dispatch_id: + endpoint = f"{endpoint}/{parent_dispatch_id}/subdispatches" + + r = APIClient(dispatcher_addr).post(endpoint, data=stripped.json()) + r.raise_for_status() + + parsed_resp = ResultSchema.parse_obj(r.json()) + + return merge_response_manifest(manifest, parsed_resp) + + @staticmethod + def register_derived_manifest( + manifest: ResultSchema, + dispatch_id: str, + reuse_previous_results: bool = False, + dispatcher_addr: Optional[str] = None, + ) -> ResultSchema: + """Submits a derived manifest for registration. + + Returns: + Dictionary representation of manifest with asset remote_uris filled in + + """ + + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + # We don't yet support pulling assets for redispatch + stripped = strip_local_uris(manifest) + + endpoint = f"/api/v2/dispatches/{dispatch_id}/redispatches" + + params = {"reuse_previous_results": reuse_previous_results} + r = APIClient(dispatcher_addr).post(endpoint, data=stripped.json(), params=params) + r.raise_for_status() + + parsed_resp = ResultSchema.parse_obj(r.json()) + + return merge_response_manifest(manifest, parsed_resp) + + @staticmethod + def upload_assets(manifest: ResultSchema): + assets = extract_assets(manifest) + LocalDispatcher._upload(assets) + + @staticmethod + def _upload(assets: List[AssetSchema]): + local_scheme_prefix = "file://" + total = len(assets) + for i, asset in enumerate(assets): + if not asset.remote_uri: + app_log.debug(f"Skipping asset {i+1} out of {total}") + continue + if asset.remote_uri.startswith(local_scheme_prefix): + copy_file_locally(asset.uri, asset.remote_uri) + else: + _upload_asset(asset.uri, asset.remote_uri) + app_log.debug(f"uploaded {i+1} out of {total} assets.") + + +def _upload_asset(local_uri, remote_uri): + scheme_prefix = "file://" + if local_uri.startswith(scheme_prefix): + local_path = local_uri[len(scheme_prefix) :] + else: + local_path = local_uri + + with open(local_path, "rb") as reader: + app_log.debug(f"uploading to {remote_uri}") + f = furl(remote_uri) + scheme = f.scheme + host = f.host + port = f.port + dispatcher_addr = f"{scheme}://{host}:{port}" + endpoint = str(f.path) + api_client = APIClient(dispatcher_addr) + + r = api_client.put(endpoint, data=reader) + r.raise_for_status() diff --git a/covalent/_results_manager/result.py b/covalent/_results_manager/result.py index 0c2b4a0d3..6941de762 100644 --- a/covalent/_results_manager/result.py +++ b/covalent/_results_manager/result.py @@ -295,11 +295,12 @@ def get_all_node_outputs(self) -> dict: node_outputs: A dictionary containing the output of every node execution. """ - all_node_outputs = {} - for node_id in self._lattice.transport_graph._graph.nodes: - all_node_outputs[ - f"{self._get_node_name(node_id=node_id)}({node_id})" - ] = self._get_node_output(node_id=node_id) + all_node_outputs = { + f"{self._get_node_name(node_id=node_id)}({node_id})": self._get_node_output( + node_id=node_id + ) + for node_id in self._lattice.transport_graph._graph.nodes + } return all_node_outputs def get_all_node_results(self) -> List[Dict]: diff --git a/covalent/_results_manager/results_manager.py b/covalent/_results_manager/results_manager.py index bc7175480..65534687e 100644 --- a/covalent/_results_manager/results_manager.py +++ b/covalent/_results_manager/results_manager.py @@ -15,19 +15,32 @@ # limitations under the License. -import codecs +from __future__ import annotations + import contextlib import os -from typing import Dict, List, Optional, Union +from pathlib import Path +from typing import Dict, List, Optional -import cloudpickle as pickle -import requests +from furl import furl from requests.adapters import HTTPAdapter from urllib3.util import Retry +from .._api.apiclient import CovalentAPIClient +from .._serialize.common import load_asset +from .._serialize.electron import ASSET_FILENAME_MAP as ELECTRON_ASSET_FILENAMES +from .._serialize.electron import ASSET_TYPES as ELECTRON_ASSET_TYPES +from .._serialize.lattice import ASSET_FILENAME_MAP as LATTICE_ASSET_FILENAMES +from .._serialize.lattice import ASSET_TYPES as LATTICE_ASSET_TYPES +from .._serialize.result import ASSET_FILENAME_MAP as RESULT_ASSET_FILENAMES +from .._serialize.result import ASSET_TYPES as RESULT_ASSET_TYPES +from .._serialize.result import deserialize_result from .._shared_files import logger from .._shared_files.config import get_config from .._shared_files.exceptions import MissingLatticeRecordError +from .._shared_files.schemas.asset import AssetSchema +from .._shared_files.schemas.result import ResultSchema +from .._shared_files.utils import copy_file_locally, format_server_url from .result import Result from .wait import EXTREME @@ -35,52 +48,106 @@ log_stack_info = logger.log_stack_info -def get_result( - dispatch_id: str, wait: bool = False, dispatcher_addr: str = None, status_only: bool = False -) -> Result: +SDK_NODE_META_KEYS = { + "executor", + "executor_data", + "deps", + "call_before", + "call_after", +} + +SDK_LAT_META_KEYS = { + "executor", + "executor_data", + "workflow_executor", + "workflow_executor_data", + "deps", + "call_before", + "call_after", +} + +DEFERRED_KEYS = { + "output", + "value", + "result", +} + + +def _delete_result( + dispatch_id: str, + results_dir: str = None, + remove_parent_directory: bool = False, +) -> None: """ - Get the results of a dispatch from the Covalent server. + Internal function to delete the result. Args: dispatch_id: The dispatch id of the result. - wait: Controls how long the method waits for the server to return a result. If False, the method will not wait and will return the current status of the workflow. If True, the method will wait for the result to finish and keep retrying for sys.maxsize. - dispatcher_addr: Dispatcher server address, if None then defaults to the address set in Covalent's config. - status_only: If true, only returns result status, not the full result object, default is False. + results_dir: The directory where the results are stored in dispatch id named folders. + remove_parent_directory: Status of whether to delete the parent directory when removing the result. Returns: - The Result object from the Covalent server + None + Raises: + FileNotFoundError: If the result file is not found. """ - try: - result = _get_result_from_dispatcher( - dispatch_id, - wait, - dispatcher_addr, - status_only, - ) + if results_dir is None: + results_dir = os.environ.get("COVALENT_DATA_DIR") or get_config("dispatcher.results_dir") - if not status_only: - result = pickle.loads(codecs.decode(result["result"].encode(), "base64")) + import shutil - except MissingLatticeRecordError as ex: - app_log.warning( - f"Dispatch ID {dispatch_id} was not found in the database. Incorrect dispatch id." - ) + result_folder_path = os.path.join(results_dir, f"{dispatch_id}") - raise ex + if os.path.exists(result_folder_path): + shutil.rmtree(result_folder_path, ignore_errors=True) - except requests.exceptions.ConnectionError: - return None + with contextlib.suppress(OSError): + os.rmdir(results_dir) - return result + if remove_parent_directory: + shutil.rmtree(results_dir, ignore_errors=True) -def _get_result_from_dispatcher( +def cancel(dispatch_id: str, task_ids: List[int] = None, dispatcher_addr: str = None) -> str: + """ + Cancel a running dispatch. + + Args: + dispatch_id: The dispatch id of the dispatch to be cancelled. + task_ids: Optional, list of task ids to cancel within the workflow + dispatcher_addr: Dispatcher server address, if None then defaults to the address set in Covalent's config. + + Returns: + Cancellation response + """ + + if dispatcher_addr is None: + dispatcher_addr = format_server_url() + + if task_ids is None: + task_ids = [] + + api_client = CovalentAPIClient(dispatcher_addr) + endpoint = f"/api/v2/dispatches/{dispatch_id}/status" + + if isinstance(task_ids, int): + task_ids = [task_ids] + + body = {"status": "CANCELLED", "task_ids": task_ids} + r = api_client.put(endpoint, json=body) + return r.content.decode("utf-8").strip().replace('"', "") + + +# Multi-part + + +def _get_result_export_from_dispatcher( dispatch_id: str, wait: bool = False, - dispatcher_addr: str = None, status_only: bool = False, + dispatcher_addr: str = None, ) -> Dict: """ Internal function to get the results of a dispatch from the server without checking if it is ready to read. @@ -88,8 +155,8 @@ def _get_result_from_dispatcher( Args: dispatch_id: The dispatch id of the result. wait: Controls how long the method waits for the server to return a result. If False, the method will not wait and will return the current status of the workflow. If True, the method will wait for the result to finish and keep retrying for sys.maxsize. - dispatcher_addr: Dispatcher server address, if None then defaults to the address set in Covalent's config. status_only: If true, only returns result status, not the full result object, default is False. + dispatcher_addr: Dispatcher server address, defaults to the address set in covalent.config. Returns: The result object from the server. @@ -99,141 +166,329 @@ def _get_result_from_dispatcher( """ if dispatcher_addr is None: - dispatcher_addr = ( - "http://" + get_config("dispatcher.address") + ":" + str(get_config("dispatcher.port")) - ) + dispatcher_addr = format_server_url() retries = int(EXTREME) if wait else 5 adapter = HTTPAdapter(max_retries=Retry(total=retries, backoff_factor=1)) - http = requests.Session() - http.mount("http://", adapter) - - result_url = f"{dispatcher_addr}/api/result/{dispatch_id}" - - try: - response = http.get( - result_url, - params={"wait": bool(int(wait)), "status_only": status_only}, - timeout=5, - ) - except requests.exceptions.ConnectionError: - message = f"The Covalent server cannot be reached at {dispatcher_addr}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." - print(message) - raise + api_client = CovalentAPIClient(dispatcher_addr, adapter=adapter, auto_raise=False) + endpoint = f"/api/v2/dispatches/{dispatch_id}" + response = api_client.get( + endpoint, + params={"wait": wait, "status_only": status_only}, + ) if response.status_code == 404: raise MissingLatticeRecordError response.raise_for_status() + export = response.json() + return export - return response.json() +# Function to download default assets +def _get_default_assets(rm: ResultManager): + for key in RESULT_ASSET_TYPES.keys(): + if key not in DEFERRED_KEYS: + rm.download_result_asset(key) + rm.load_result_asset(key) -def _delete_result( - dispatch_id: str, - results_dir: str = None, - remove_parent_directory: bool = False, -) -> None: - """ - Internal function to delete the result. + for key in LATTICE_ASSET_TYPES.keys(): + if key not in DEFERRED_KEYS: + rm.download_lattice_asset(key) + rm.load_lattice_asset(key) - Args: - dispatch_id: The dispatch id of the result. - results_dir: The directory where the results are stored in dispatch id named folders. - remove_parent_directory: Status of whether to delete the parent directory when removing the result. + tg = rm.result_object.lattice.transport_graph - Returns: - None + tg.lattice_metadata = rm.result_object.lattice.metadata + rm.result_object.lattice.__doc__ = rm.result_object.lattice.__dict__.pop("doc") - Raises: - FileNotFoundError: If the result file is not found. - """ + for key in ELECTRON_ASSET_TYPES.keys(): + if key not in DEFERRED_KEYS: + for node_id in tg._graph.nodes: + rm.download_node_asset(node_id, key) + rm.load_node_asset(node_id, key) - if results_dir is None: - results_dir = os.environ.get("COVALENT_DATA_DIR") or get_config("dispatcher.results_dir") - import shutil +# Functions for computing local URIs +def get_node_asset_path(results_dir: str, node_id: int, key: str): + filename = ELECTRON_ASSET_FILENAMES[key] + return f"{results_dir}/node_{node_id}/{filename}" - result_folder_path = os.path.join(results_dir, f"{dispatch_id}") - if os.path.exists(result_folder_path): - shutil.rmtree(result_folder_path, ignore_errors=True) +def get_lattice_asset_path(results_dir: str, key: str): + filename = LATTICE_ASSET_FILENAMES[key] + return f"{results_dir}/{filename}" - with contextlib.suppress(OSError): - os.rmdir(results_dir) - if remove_parent_directory: - shutil.rmtree(results_dir, ignore_errors=True) +def get_result_asset_path(results_dir: str, key: str): + filename = RESULT_ASSET_FILENAMES[key] + return f"{results_dir}/{filename}" -def redispatch_result(result_object: Result, dispatcher: str = None) -> str: - """ - Function to redispatch the result as a new dispatch. +# Asset transfers - Args: - result_object: The result object to be redispatched. - dispatcher: The address to the dispatcher in the form of hostname:port, e.g. "localhost:8080". - Returns: - dispatch_id: The dispatch id of the new dispatch. - """ - result_object._lattice.metadata["dispatcher"] = ( - dispatcher or result_object.lattice.metadata["dispatcher"] - ) +def download_asset(remote_uri: str, local_path: str, chunk_size: int = 1024 * 1024): + local_scheme = "file" + if remote_uri.startswith(local_scheme): + copy_file_locally(remote_uri, f"file://{local_path}") + else: + f = furl(remote_uri) + scheme = f.scheme + host = f.host + port = f.port + dispatcher_addr = f"{scheme}://{host}:{port}" + endpoint = str(f.path) + api_client = CovalentAPIClient(dispatcher_addr) + r = api_client.get(endpoint, stream=True) + with open(local_path, "wb") as f: + for chunk in r.iter_content(chunk_size=chunk_size): + f.write(chunk) + + +def _download_result_asset(manifest: dict, results_dir: str, key: str): + remote_uri = manifest["assets"][key]["remote_uri"] + local_path = get_result_asset_path(results_dir, key) + download_asset(remote_uri, local_path) + manifest["assets"][key]["uri"] = f"file://{local_path}" + + +def _download_lattice_asset(manifest: dict, results_dir: str, key: str): + lattice_assets = manifest["lattice"]["assets"] + remote_uri = lattice_assets[key]["remote_uri"] + local_path = get_lattice_asset_path(results_dir, key) + download_asset(remote_uri, local_path) + lattice_assets[key]["uri"] = f"file://{local_path}" + + +def _download_node_asset(manifest: dict, results_dir: str, node_id: int, key: str): + node = manifest["lattice"]["transport_graph"]["nodes"][node_id] + node_assets = node["assets"] + remote_uri = node_assets[key]["remote_uri"] + local_path = get_node_asset_path(results_dir, node_id, key) + download_asset(remote_uri, local_path) + node_assets[key]["uri"] = f"file://{local_path}" + + +def _load_result_asset(manifest: dict, key: str): + asset_meta = AssetSchema(**manifest["assets"][key]) + return load_asset(asset_meta, RESULT_ASSET_TYPES[key]) + + +def _load_lattice_asset(manifest: dict, key: str): + asset_meta = AssetSchema(**manifest["lattice"]["assets"][key]) + return load_asset(asset_meta, LATTICE_ASSET_TYPES[key]) + + +def _load_node_asset(manifest: dict, node_id: int, key: str): + node = manifest["lattice"]["transport_graph"]["nodes"][node_id] + asset_meta = AssetSchema(**node["assets"][key]) + return load_asset(asset_meta, ELECTRON_ASSET_TYPES[key]) + + +class ResultManager: + def __init__(self, manifest: ResultSchema, results_dir: str): + self.result_object = deserialize_result(manifest) + self._manifest = manifest.dict() + self._results_dir = results_dir + + def save(self, path: Optional[str] = None): + if not path: + path = os.path.join(self._results_dir, "manifest.json") + with open(path, "w") as f: + f.write(ResultSchema.parse_obj(self._manifest).json()) + + @staticmethod + def load(path: str, results_dir: str) -> "ResultManager": + with open(path, "r") as f: + manifest_json = f.read() + + return ResultManager(ResultSchema.parse_raw(manifest_json), results_dir) + + def download_result_asset(self, key: str): + _download_result_asset(self._manifest, self._results_dir, key) + + def download_lattice_asset(self, key: str): + _download_lattice_asset(self._manifest, self._results_dir, key) + + def download_node_asset(self, node_id: int, key: str): + _download_node_asset(self._manifest, self._results_dir, node_id, key) + + def load_result_asset(self, key: str): + data = _load_result_asset(self._manifest, key) + self.result_object.__dict__[f"_{key}"] = data + + def load_lattice_asset(self, key: str): + data = _load_lattice_asset(self._manifest, key) + if key in SDK_LAT_META_KEYS: + self.result_object.lattice.metadata[key] = data + else: + self.result_object.lattice.__dict__[key] = data + + def load_node_asset(self, node_id: int, key: str): + data = _load_node_asset(self._manifest, node_id, key) + tg = self.result_object.lattice.transport_graph + if key in SDK_NODE_META_KEYS: + node_meta = tg.get_node_value(node_id, "metadata") + node_meta[key] = data + else: + tg.set_node_value(node_id, key, data) + + @staticmethod + def from_dispatch_id( + dispatch_id: str, + results_dir: str, + wait: bool = False, + dispatcher_addr: str = None, + ) -> "ResultManager": + export = _get_result_export_from_dispatcher( + dispatch_id, wait, status_only=False, dispatcher_addr=dispatcher_addr + ) - return result_object.lattice._server_dispatch(result_object) + manifest = ResultSchema.parse_obj(export["result_export"]) + # sort the nodes + manifest.lattice.transport_graph.nodes.sort(key=lambda x: x.id) -def sync( - dispatch_id: Optional[Union[List[str], str]] = None, -) -> None: + rm = ResultManager(manifest, results_dir) + result_object = rm.result_object + result_object._results_dir = results_dir + Path(results_dir).mkdir(parents=True, exist_ok=True) + + # Create node subdirectories + for node_id in result_object.lattice.transport_graph._graph.nodes: + node_dir = f"{results_dir}/node_{node_id}" + Path(node_dir).mkdir(exist_ok=True) + + return rm + + +def get_result_manager(dispatch_id, results_dir=None, wait=False, dispatcher_addr=None): + if not results_dir: + results_dir = get_config("sdk.results_dir") + f"/{dispatch_id}" + return ResultManager.from_dispatch_id(dispatch_id, results_dir, wait, dispatcher_addr) + + +def _get_result_multistage( + dispatch_id: str, + wait: bool = False, + dispatcher_addr: str = None, + status_only: bool = False, + results_dir: Optional[str] = None, + *, + workflow_output: bool = True, + intermediate_outputs: bool = True, + sublattice_results: bool = True, +) -> Result: """ - Synchronization call. Returns when one or more dispatches have completed. + Get the results of a dispatch from a file. Args: - dispatch_id: One or more dispatch IDs to wait for before returning. + dispatch_id: The dispatch id of the result. + wait: Controls how long the method waits for the server to return a result. If False, the method will not wait and will return the current status of the workflow. If True, the method will wait for the result to finish and keep retrying for sys.maxsize. + status_only: If true, only returns result status, not the full result object. Default is False. + dispatcher_addr: Dispatcher server address, defaults to the address set in Covalent's config. + results_dir: The directory where the results are stored in dispatch id named folders. + + + Kwargs: + workflow_output: Whether to return the workflow output. Defaults to True. + intermediate_outputs: Whether to return all intermediate outputs in the compute graph. Defaults to True. + sublattice_results: Whether to recursively retrieve sublattice results. Default is True. Returns: - None + The Result object from the Covalent server + """ - if isinstance(dispatch_id, str): - _get_result_from_dispatcher(dispatch_id, wait=True, status_only=True) - elif isinstance(dispatch_id, list): - for d in dispatch_id: - _get_result_from_dispatcher(d, wait=True, status_only=True) - else: - raise RuntimeError( - f"dispatch_id must be a string or a list. You passed a {type(dispatch_id)}." + try: + if status_only: + return _get_result_export_from_dispatcher( + dispatch_id=dispatch_id, + wait=wait, + status_only=status_only, + dispatcher_addr=dispatcher_addr, + ) + rm = get_result_manager(dispatch_id, results_dir, wait, dispatcher_addr) + _get_default_assets(rm) + + if workflow_output: + rm.download_result_asset("result") + rm.load_result_asset("result") + + if intermediate_outputs: + tg = rm.result_object.lattice.transport_graph + for node_id in tg._graph.nodes: + rm.download_node_asset(node_id, "output") + rm.load_node_asset(node_id, "output") + + # Fetch sublattice result objects recursively + tg = rm.result_object.lattice.transport_graph + for node_id in tg._graph.nodes: + sub_dispatch_id = tg.get_node_value(node_id, "sub_dispatch_id") + if sublattice_results and sub_dispatch_id: + sub_result = _get_result_multistage( + sub_dispatch_id, + wait, + dispatcher_addr, + status_only, + results_dir=results_dir, + workflow_output=workflow_output, + intermediate_outputs=intermediate_outputs, + sublattice_results=sublattice_results, + ) + tg.set_node_value(node_id, "sublattice_result", sub_result) + else: + tg.set_node_value(node_id, "sublattice_result", None) + + except MissingLatticeRecordError as ex: + app_log.warning( + f"Dispatch ID {dispatch_id} was not found in the database. Incorrect dispatch id." ) + raise ex -def cancel(dispatch_id: str, task_ids: List[int] = None, dispatcher_addr: str = None) -> str: - """ - Cancel a running dispatch. + return rm.result_object - Args: - dispatch_id: The dispatch id of the dispatch to be cancelled. - task_ids: Optional, list of task ids to cancel within the workflow - dispatcher_addr: Dispatcher server address, if None then defaults to the address set in Covalent's config. - Returns: - Cancellation response +def get_result( + dispatch_id: str, + wait: bool = False, + dispatcher_addr: str = None, + status_only: bool = False, + *, + results_dir: Optional[str] = None, + workflow_output: bool = True, + intermediate_outputs: bool = True, + sublattice_results: bool = True, +) -> Result: """ + Get the results of a dispatch. - if dispatcher_addr is None: - dispatcher_addr = ( - get_config("dispatcher.address") + ":" + str(get_config("dispatcher.port")) - ) + Args: + dispatch_id: The dispatch id of the result. + wait: Controls how long the method waits for the server to return a result. If False, the method will not wait and will return the current status of the workflow. If True, the method will wait for the result to finish and keep retrying for sys.maxsize. + dispatcher_addr: Dispatcher server address. Defaults to the address set in Covalent's config. + status_only: If true, only returns result status, not the full result object. Default is False. - if task_ids is None: - task_ids = [] + Kwargs: + results_dir: The directory where the results are stored in dispatch id named folders. + workflow_output: Whether to return the workflow output. Defaults to True. + intermediate_outputs: Whether to return all intermediate outputs in the compute graph. Defaults to True. + sublattice_results: Whether to recursively retrieve sublattice results. Default is True. - url = f"http://{dispatcher_addr}/api/cancel" + Returns: + The Result object from the Covalent server - if isinstance(task_ids, int): - task_ids = [task_ids] + """ - r = requests.post(url, json={"dispatch_id": dispatch_id, "task_ids": task_ids}) - r.raise_for_status() - return r.content.decode("utf-8").strip().replace('"', "") + return _get_result_multistage( + dispatch_id=dispatch_id, + wait=wait, + dispatcher_addr=dispatcher_addr, + status_only=status_only, + results_dir=results_dir, + workflow_output=workflow_output, + intermediate_outputs=intermediate_outputs, + sublattice_results=sublattice_results, + ) diff --git a/covalent/_serialize/__init__.py b/covalent/_serialize/__init__.py index cfc23bfdf..21d7eaa5c 100644 --- a/covalent/_serialize/__init__.py +++ b/covalent/_serialize/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # diff --git a/covalent/_serialize/result.py b/covalent/_serialize/result.py index 612b10a60..c55fe1305 100644 --- a/covalent/_serialize/result.py +++ b/covalent/_serialize/result.py @@ -145,7 +145,6 @@ def merge_response_manifest(manifest: ResultSchema, response: ResultSchema) -> R response: The manifest returned from `/register`. Returns: A combined manifest with asset `remote_uri`s populated. - """ manifest.metadata.dispatch_id = response.metadata.dispatch_id diff --git a/covalent/_shared_files/schemas/__init__.py b/covalent/_shared_files/schemas/__init__.py index cfc23bfdf..21d7eaa5c 100644 --- a/covalent/_shared_files/schemas/__init__.py +++ b/covalent/_shared_files/schemas/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # diff --git a/covalent/_shared_files/schemas/common.py b/covalent/_shared_files/schemas/common.py index 0c4a226d6..abe55bd6c 100644 --- a/covalent/_shared_files/schemas/common.py +++ b/covalent/_shared_files/schemas/common.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # @@ -18,7 +18,7 @@ from enum import Enum -from ..util_classes import RESULT_STATUS +from covalent._shared_files.util_classes import RESULT_STATUS class StatusEnum(str, Enum): diff --git a/covalent/_shared_files/utils.py b/covalent/_shared_files/utils.py index e0e2c9504..14136f6e3 100644 --- a/covalent/_shared_files/utils.py +++ b/covalent/_shared_files/utils.py @@ -18,6 +18,7 @@ import importlib import inspect +import shutil import socket from datetime import timedelta from typing import Any, Callable, Dict, Set, Tuple @@ -237,6 +238,7 @@ def format_server_url(hostname: str = None, port: int = None) -> str: url = hostname if not url.startswith("http"): url = f"https://{url}" if port == 443 else f"http://{url}" + # Inject port if port not in [80, 443]: parts = url.split("/") @@ -245,6 +247,23 @@ def format_server_url(hostname: str = None, port: int = None) -> str: return url.strip("/") +# For use by LocalDispatcher and ResultsManager when running Covalent +# server locally +def copy_file_locally(src_uri, dest_uri): + scheme_prefix = "file://" + if src_uri.startswith(scheme_prefix): + src_path = src_uri[len(scheme_prefix) :] + else: + raise TypeError(f"{src_uri} is not a valid URI") + # src_path = src_uri + if dest_uri.startswith(scheme_prefix): + dest_path = dest_uri[len(scheme_prefix) :] + else: + raise TypeError(f"{dest_uri} is not a valid URI") + + shutil.copyfile(src_path, dest_path) + + @_qml_mods_pickle def cloudpickle_serialize(obj): return cloudpickle.dumps(obj) diff --git a/covalent/_workflow/electron.py b/covalent/_workflow/electron.py index cfe45f485..05e4ee6d8 100644 --- a/covalent/_workflow/electron.py +++ b/covalent/_workflow/electron.py @@ -25,7 +25,8 @@ from functools import wraps from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Union -from .._dispatcher_plugins.local import LocalDispatcher +from covalent._dispatcher_plugins.local import LocalDispatcher + from .._file_transfer.enums import Order from .._file_transfer.file_transfer import FileTransfer from .._shared_files import logger diff --git a/covalent/_workflow/transport_graph_ops.py b/covalent/_workflow/transport_graph_ops.py deleted file mode 100644 index 8d9d96a92..000000000 --- a/covalent/_workflow/transport_graph_ops.py +++ /dev/null @@ -1,200 +0,0 @@ -# Copyright 2023 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the Apache License 2.0 (the "License"). A copy of the -# License may be obtained with this software package or at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Use of this file is prohibited except in compliance with the License. -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Module for transport graph operations.""" - -from collections import deque -from typing import Callable, List - -import networkx as nx - -from .._shared_files import logger -from .transport import _TransportGraph - -app_log = logger.app_log - - -class TransportGraphOps: - def __init__(self, tg): - self.tg = tg - self._status_map = {1: True, -1: False} - - @staticmethod - def _flag_successors(A: nx.MultiDiGraph, node_statuses: dict, starting_node: int): - """Flag all successors of a node (including the node itself).""" - nodes_to_invalidate = [starting_node] - for node, successors in nx.bfs_successors(A, starting_node): - nodes_to_invalidate.extend(iter(successors)) - for node in nodes_to_invalidate: - node_statuses[node] = -1 - - @staticmethod - def is_same_node(A: nx.MultiDiGraph, B: nx.MultiDiGraph, node: int) -> bool: - """Check if the node attributes are the same in both graphs.""" - return A.nodes[node] == B.nodes[node] - - @staticmethod - def is_same_edge_attributes( - A: nx.MultiDiGraph, B: nx.MultiDiGraph, parent: int, node: int - ) -> bool: - """Check if the edge attributes are the same in both graphs.""" - return A.adj[parent][node] == B.adj[parent][node] - - def copy_nodes_from(self, tg: _TransportGraph, nodes): - """Copy nodes from the transport graph in the argument.""" - for n in nodes: - for k, v in tg._graph.nodes[n].items(): - self.tg.set_node_value(n, k, v) - - @staticmethod - def _cmp_name_and_pval(A: nx.MultiDiGraph, B: nx.MultiDiGraph, node: int) -> bool: - """Default node comparison function for diffing transport graphs.""" - name_A = A.nodes[node]["name"] - name_B = B.nodes[node]["name"] - - if name_A != name_B: - return False - - val_A = A.nodes[node].get("value", None) - val_B = B.nodes[node].get("value", None) - - return val_A == val_B - - def _max_cbms( - self, - A: nx.MultiDiGraph, - B: nx.MultiDiGraph, - node_cmp: Callable = None, - edge_cmp: Callable = None, - ): - """Computes a "maximum backward-maximal common subgraph" (cbms) - Args: - A: nx.MultiDiGraph - B: nx.MultiDiGraph - node_cmp: An optional function for comparing node attributes in A and B. - Defaults to testing for equality of the attribute dictionaries - edge_cmp: An optional function for comparing the edges between two nodes. - Defaults to checking that the two sets of edges have the same attributes - Returns: A_node_status, B_node_status, where each is a dictionary - `{node: True/False}` where True means reusable. - Performs a modified BFS of A and B. - """ - if node_cmp is None: - node_cmp = self.is_same_node - if edge_cmp is None: - edge_cmp = self.is_same_edge_attributes - - A_node_status = {node_id: 0 for node_id in A.nodes} - B_node_status = {node_id: 0 for node_id in B.nodes} - app_log.debug(f"A node status: {A_node_status}") - app_log.debug(f"B node status: {B_node_status}") - - virtual_root = -1 - - if virtual_root in A.nodes or virtual_root in B.nodes: - raise RuntimeError(f"Encountered forbidden node: {virtual_root}") - - assert virtual_root not in B.nodes - - nodes_to_visit = deque() - nodes_to_visit.appendleft(virtual_root) - - # Add a temporary root - A_parentless_nodes = [node for node, deg in A.in_degree() if deg == 0] - B_parentless_nodes = [node for node, deg in B.in_degree() if deg == 0] - for node_id in A_parentless_nodes: - A.add_edge(virtual_root, node_id) - - for node_id in B_parentless_nodes: - B.add_edge(virtual_root, node_id) - - # Assume inductively that predecessors subgraphs are the same; - # this is satisfied for the root - while nodes_to_visit: - current_node = nodes_to_visit.pop() - - app_log.debug(f"Visiting node {current_node}") - for y in A.adj[current_node]: - # Don't process already failed nodes - if A_node_status[y] == -1: - continue - - # Check if y is a valid child of current_node in B - if y not in B.adj[current_node]: - app_log.debug(f"A: {y} not adjacent to node {current_node} in B") - self._flag_successors(A, A_node_status, y) - continue - - if y in B.adj[current_node] and B_node_status[y] == -1: - app_log.debug(f"A: Node {y} is marked as failed in B") - self._flag_successors(A, A_node_status, y) - continue - - # Compare edges - if not edge_cmp(A, B, current_node, y): - app_log.debug(f"Edges between {current_node} and {y} differ") - self._flag_successors(A, A_node_status, y) - self._flag_successors(B, B_node_status, y) - continue - - # Compare nodes - if not node_cmp(A, B, y): - app_log.debug(f"Attributes of node {y} differ:") - app_log.debug(f"A[y] = {A.nodes[y]}") - app_log.debug(f"B[y] = {B.nodes[y]}") - self._flag_successors(A, A_node_status, y) - self._flag_successors(B, B_node_status, y) - continue - - # Predecessors subgraphs of y are the same in A and B, so - # enqueue y if it hasn't already been visited - assert A_node_status[y] != -1 - if A_node_status[y] == 0: - A_node_status[y] = 1 - B_node_status[y] = 1 - app_log.debug(f"Enqueueing node {y}") - nodes_to_visit.appendleft(y) - - # Prune children of current_node in B that aren't valid children in A - for y in B.adj[current_node]: - if B_node_status[y] == -1: - continue - if y not in A.adj[current_node]: - app_log.debug(f"B: {y} not adjacent to node {current_node} in A") - self._flag_successors(B, B_node_status, y) - continue - if y in A.adj[current_node] and B_node_status[y] == -1: - app_log.debug(f"B: Node {y} is marked as failed in A") - self._flag_successors(B, B_node_status, y) - - A.remove_node(-1) - B.remove_node(-1) - - app_log.debug(f"A node status: {A_node_status}") - app_log.debug(f"B node status: {B_node_status}") - - for k, v in A_node_status.items(): - A_node_status[k] = self._status_map[v] - for k, v in B_node_status.items(): - B_node_status[k] = self._status_map[v] - return A_node_status, B_node_status - - def get_reusable_nodes(self, tg_new: _TransportGraph) -> List[int]: - """Find which nodes are common between the current graph and a new graph.""" - A = self.tg.get_internal_graph_copy() - B = tg_new.get_internal_graph_copy() - status_A, _ = self._max_cbms(A, B, node_cmp=self._cmp_name_and_pval) - return [k for k, v in status_A.items() if v] diff --git a/covalent/_workflow/transportable_object.py b/covalent/_workflow/transportable_object.py index cf930ed93..4ba789662 100644 --- a/covalent/_workflow/transportable_object.py +++ b/covalent/_workflow/transportable_object.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""TransportableObject""" +"""Transportable object module defining relevant classes and functions""" import base64 import json @@ -32,16 +32,35 @@ class _TOArchive: - """ - Archived TransportableObject - """ + + """Archived transportable object.""" def __init__(self, header: bytes, object_string: bytes, data: bytes): + """ + Initialize TOArchive. + + Args: + header: Archived transportable object header. + object_string: Archived transportable object string. + data: Archived transportable object data. + + Returns: + None + """ + self.header = header self.object_string = object_string self.data = data def cat(self) -> bytes: + """ + Concatenate TOArchive. + + Returns: + Concatenated TOArchive. + + """ + header_size = len(self.header) string_size = len(self.object_string) data_offset = STRING_OFFSET_BYTES + DATA_OFFSET_BYTES + header_size + string_size @@ -54,6 +73,19 @@ def cat(self) -> bytes: @staticmethod def load(serialized: bytes, header_only: bool, string_only: bool) -> "_TOArchive": + """ + Load TOArchive object from serialized bytes. + + Args: + serialized: Serialized transportable object. + header_only: Load header only. + string_only: Load string only. + + Returns: + Archived transportable object. + + """ + string_offset = TOArchiveUtils.string_offset(serialized) header = TOArchiveUtils.parse_header(serialized, string_offset) object_string = b"" @@ -247,6 +279,17 @@ def deserialize_from_json(json_string: str) -> str: @staticmethod def make_transportable(obj) -> "TransportableObject": + """ + Make an object transportable. + + Args: + obj: The object to make transportable. + + Returns: + Transportable object. + + """ + if isinstance(obj, TransportableObject): return obj else: @@ -296,6 +339,11 @@ def deserialize_dict(collection: dict) -> dict: precisely, `collection` is a dict, each of whose entries is assumed to be either a `TransportableObject`, a list, or dict` + Args: + collection: A dictionary of TransportableObjects. + Returns: + A dictionary of deserialized objects. + """ new_dict = {} @@ -312,6 +360,17 @@ def deserialize_dict(collection: dict) -> dict: def _to_archive(to: TransportableObject) -> _TOArchive: + """ + Convert a TransportableObject to a _TOArchive. + + Args: + to: Transportable object to be converted. + + Returns: + Archived transportable object. + + """ + header = json.dumps(to._header).encode("utf-8") object_string = to._object_string.encode("utf-8") data = to._object.encode("utf-8") @@ -319,6 +378,17 @@ def _to_archive(to: TransportableObject) -> _TOArchive: def _from_archive(ar: _TOArchive) -> TransportableObject: + """ + Convert a _TOArchive to a TransportableObject. + + Args: + ar: Archived transportable object. + + Returns: + Transportable object. + + """ + decoded_object_str = ar.object_string.decode("utf-8") decoded_data = ar.data.decode("utf-8") decoded_header = json.loads(ar.header.decode("utf-8")) diff --git a/covalent/executor/executor_plugins/dask.py b/covalent/executor/executor_plugins/dask.py index 5e628bdfd..2344d4c28 100644 --- a/covalent/executor/executor_plugins/dask.py +++ b/covalent/executor/executor_plugins/dask.py @@ -31,7 +31,6 @@ # Relative imports are not allowed in executor plugins from covalent._shared_files.config import get_config from covalent._shared_files.exceptions import TaskCancelledError -from covalent._shared_files.utils import _address_client_mapper from covalent.executor.base import AsyncBaseExecutor from covalent.executor.utils.wrappers import io_wrapper as dask_wrapper @@ -55,6 +54,9 @@ "create_unique_workdir": False, } +# Temporary +_address_client_mapper = {} + class DaskExecutor(AsyncBaseExecutor): """ diff --git a/covalent/triggers/base.py b/covalent/triggers/base.py index 837c155a0..2eb49a434 100644 --- a/covalent/triggers/base.py +++ b/covalent/triggers/base.py @@ -21,10 +21,10 @@ import requests -from .._results_manager import Result +from .._dispatcher_plugins import local from .._shared_files import logger from .._shared_files.config import get_config -from .._shared_files.util_classes import Status +from .._shared_files.util_classes import RESULT_STATUS, Status app_log = logger.app_log log_stack_info = logger.log_stack_info @@ -108,10 +108,10 @@ def _get_status(self) -> Status: """ if self.use_internal_funcs: - from covalent_dispatcher._service.app import get_result + from covalent_dispatcher._service.app import export_result response = asyncio.run_coroutine_threadsafe( - get_result(self.lattice_dispatch_id, status_only=True), + export_result(self.lattice_dispatch_id, status_only=True), self.event_loop, ).result() @@ -137,21 +137,12 @@ def _do_redispatch(self, is_pending: bool = False) -> str: new_dispatch_id: Dispatch id of the newly dispatched workflow """ - if self.use_internal_funcs: - from covalent_dispatcher import run_redispatch - - return asyncio.run_coroutine_threadsafe( - run_redispatch(self.lattice_dispatch_id, None, None, False, is_pending), - self.event_loop, - ).result() - - from .. import redispatch - - return redispatch( - dispatch_id=self.lattice_dispatch_id, - dispatcher_addr=self.dispatcher_addr, - is_pending=is_pending, - )() + if is_pending: + return local.LocalDispatcher.start(self.lattice_dispatch_id, self.dispatcher_addr) + else: + return local.LocalDispatcher.redispatch( + self.lattice_dispatch_id, self.dispatcher_addr + )() def trigger(self) -> None: """ @@ -169,7 +160,7 @@ def trigger(self) -> None: status = self._get_status() - if status == Result.NEW_OBJ or status is None: + if status == str(RESULT_STATUS.NEW_OBJECT) or status is None: # To continue the pending dispatch same_dispatch_id = self._do_redispatch(True) app_log.debug(f"Initiating run for pending dispatch_id: {same_dispatch_id}") diff --git a/covalent_dispatcher/__init__.py b/covalent_dispatcher/__init__.py index 1cf39cda1..0ad60669e 100644 --- a/covalent_dispatcher/__init__.py +++ b/covalent_dispatcher/__init__.py @@ -14,4 +14,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .entry_point import cancel_running_dispatch, run_dispatcher, run_redispatch +from .entry_point import cancel_running_dispatch, run_dispatcher diff --git a/covalent_dispatcher/_cli/service.py b/covalent_dispatcher/_cli/service.py index 815f42e09..bcd7a5557 100644 --- a/covalent_dispatcher/_cli/service.py +++ b/covalent_dispatcher/_cli/service.py @@ -592,6 +592,7 @@ def status() -> None: """ Display local server status """ + console = Console() print_header(console) @@ -618,6 +619,7 @@ def status() -> None: elif not exists or psutil.Process(pid).status() == psutil.STATUS_STOPPED: _rm_pid_file(UI_PIDFILE) status_table.add_row("Covalent Server", "[red]Stopped[/red]") + if exists and pid != -1: if Path(get_config("dispatcher.heartbeat_file")).is_file(): with open(get_config("dispatcher.heartbeat_file")) as f: @@ -628,11 +630,13 @@ def status() -> None: ) running_workflows = response.json()["total_count"] status_table.add_row("", f"There are {running_workflows} workflows currently running.") + admin_address = _get_cluster_admin_address() loop = asyncio.get_event_loop() cluster_status = ( loop.run_until_complete(_get_cluster_status(admin_address)) if admin_address else None ) + if _is_server_running() and cluster_status: status_table.add_row("Dask Cluster", f"[green]Running[/green] at {admin_address}") client = Client(get_config("dask.scheduler_address")) @@ -640,6 +644,7 @@ def status() -> None: status_table.add_row("", f"There are {running_tasks} tasks currently running.") else: status_table.add_row("Dask Cluster", "[red]Stopped[/red]") + try: response = requests.get(f"http://localhost:{port}/api/triggers/status", timeout=1) trigger_status = response.json()["status"] @@ -650,6 +655,7 @@ def status() -> None: status_table.add_row("Triggers Server", "[green]Running[/green]") else: status_table.add_row("Triggers Server", "[red]Stopped[/red]") + try: db = DataStore.factory() diff --git a/covalent_dispatcher/_core/__init__.py b/covalent_dispatcher/_core/__init__.py index 1803c81cf..58c050f35 100644 --- a/covalent_dispatcher/_core/__init__.py +++ b/covalent_dispatcher/_core/__init__.py @@ -14,5 +14,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .data_manager import make_derived_dispatch, make_dispatch +from .data_manager import make_dispatch +from .data_modules.importer import copy_futures from .dispatcher import cancel_dispatch, run_dispatch diff --git a/covalent_dispatcher/_core/data_manager.py b/covalent_dispatcher/_core/data_manager.py index 1de24eb74..7e1967f0d 100644 --- a/covalent_dispatcher/_core/data_manager.py +++ b/covalent_dispatcher/_core/data_manager.py @@ -19,37 +19,36 @@ """ import asyncio +import tempfile import traceback -import uuid -from datetime import datetime, timezone -from typing import Callable, Dict, Optional +from typing import Dict +from pydantic import ValidationError + +from covalent._dispatcher_plugins.local import LocalDispatcher from covalent._results_manager import Result from covalent._shared_files import logger -from covalent._shared_files.defaults import sublattice_prefix from covalent._shared_files.qelectron_utils import extract_qelectron_db, write_qelectron_db +from covalent._shared_files.schemas.result import ResultSchema from covalent._shared_files.util_classes import RESULT_STATUS from covalent._workflow.lattice import Lattice -from covalent._workflow.transport_graph_ops import TransportGraphOps -from .._db import load, update +from .._dal.result import Result as SRVResult +from .._dal.result import get_result_object as get_result_object_from_db from .._db.write_result_to_db import resolve_electron_id +from . import dispatcher +from .data_modules import dispatch, electron # nopycln: import +from .data_modules import importer as manifest_importer +from .data_modules.utils import run_in_executor app_log = logger.app_log log_stack_info = logger.log_stack_info -# References to result objects of live dispatches -_registered_dispatches = {} - -# Map of dispatch_id -> message_queue for pushing node status updates -# to dispatcher -_dispatch_status_queues = {} - def generate_node_result( dispatch_id: str, node_id: int, - node_name: str, + node_name: str = None, start_time=None, end_time=None, status=None, @@ -57,8 +56,6 @@ def generate_node_result( error=None, stdout=None, stderr=None, - sub_dispatch_id=None, - sublattice_result=None, ): """ Helper routine to prepare the node result @@ -74,8 +71,6 @@ def generate_node_result( error: Error from the node stdout: STDOUT of a node stderr: STDERR generated during node execution - sub_dispatch_id: Dispatch ID of the sublattice - sublattice_result: Result of the sublattice Return(s) Dictionary of the inputs @@ -97,304 +92,247 @@ def generate_node_result( "error": error, "stdout": clean_stdout, "stderr": stderr, - "sub_dispatch_id": sub_dispatch_id, - "sublattice_result": sublattice_result, "qelectron_data_exists": qelectron_data_exists, } -async def _handle_built_sublattice(dispatch_id: str, node_result: Dict) -> None: - """Make dispatch for sublattice node. - - Note: The status COMPLETED which invokes this function refers to the graph being built. Once this step is completed, the sublattice is ready to be dispatched. Hence, the status is changed to DISPATCHING. - - Args: - dispatch_id: Dispatch ID - node_result: Node result dictionary - - """ - try: - node_result["status"] = RESULT_STATUS.DISPATCHING_SUBLATTICE - result_object = get_result_object(dispatch_id) - sub_dispatch_id = await make_sublattice_dispatch(result_object, node_result) - node_result["sub_dispatch_id"] = sub_dispatch_id - node_result["start_time"] = datetime.now(timezone.utc) - node_result["end_time"] = None - except Exception as ex: - tb = "".join(traceback.TracebackException.from_exception(ex).format()) - node_result["status"] = RESULT_STATUS.FAILED - node_result["error"] = tb - app_log.debug(f"Failed to make sublattice dispatch: {tb}") - - # Domain: result -async def update_node_result(result_object, node_result) -> None: - """ - Updates the result object with the current node_result - - Arg(s) - result_object: Result object the current dispatch - node_result: Result of the node to be updated in the result object - - Return(s) - None - - """ - app_log.debug(f"Updating node result for {node_result['node_id']}.") - - if ( - node_result["status"] == RESULT_STATUS.COMPLETED - and node_result["node_name"].startswith(sublattice_prefix) - and not node_result["sub_dispatch_id"] - ): - app_log.debug( - f"Sublattice {node_result['node_name']} build graph completed, invoking make sublattice dispatch..." +async def update_node_result(dispatch_id, node_result): + app_log.debug("Updating node result (run_planned_workflow).") + valid_update = True + try: + node_id = node_result["node_id"] + node_status = node_result["status"] + node_info = await electron.get(dispatch_id, node_id, ["type", "sub_dispatch_id"]) + node_type = node_info["type"] + sub_dispatch_id = node_info["sub_dispatch_id"] + + # Handle returns from _build_sublattice_graph -- change + # COMPLETED -> DISPATCHING + node_result = _filter_sublattice_status( + dispatch_id, node_id, node_status, node_type, sub_dispatch_id, node_result ) - await _handle_built_sublattice(result_object.dispatch_id, node_result) - try: - update._node(result_object, **node_result) - except Exception as ex: + valid_update = await electron.update(dispatch_id, node_result) + if not valid_update: + app_log.warning( + f"Invalid status update {node_status} for node {dispatch_id}:{node_id}" + ) + return + + if node_result["status"] == RESULT_STATUS.DISPATCHING: + app_log.debug("Received sublattice dispatch") + try: + sub_dispatch_id = await _make_sublattice_dispatch(dispatch_id, node_result) + except Exception as ex: + tb = "".join(traceback.TracebackException.from_exception(ex).format()) + node_result["status"] = RESULT_STATUS.FAILED + node_result["error"] = tb + await electron.update(dispatch_id, node_result) + + except KeyError as ex: + valid_update = False app_log.exception(f"Error persisting node update: {ex}") - node_result["status"] = RESULT_STATUS.FAILED - finally: - sub_dispatch_id = node_result["sub_dispatch_id"] - detail = {"sub_dispatch_id": sub_dispatch_id} if sub_dispatch_id is not None else {} - if node_status := node_result["status"]: - dispatch_id = result_object.dispatch_id - status_queue = get_status_queue(dispatch_id) - node_id = node_result["node_id"] - await status_queue.put((node_id, node_status, detail)) - - -# Domain: result -def initialize_result_object( - json_lattice: str, parent_result_object: Result = None, parent_electron_id: int = None -) -> Result: - """Convenience function for constructing a result object from a json-serialized lattice. - Args: - json_lattice: a JSON-serialized lattice - parent_result_object: the parent result object if json_lattice is a sublattice - parent_electron_id: the DB id of the parent electron (for sublattices) - - Returns: - Result: result object - - """ - dispatch_id = get_unique_id() - lattice = Lattice.deserialize_from_json(json_lattice) - result_object = Result(lattice, dispatch_id) - if parent_result_object: - result_object._root_dispatch_id = parent_result_object._root_dispatch_id + except Exception as ex: + app_log.exception(f"Error persisting node update: {ex}") + sub_dispatch_id = None + node_result["status"] = Result.FAILED - result_object._electron_id = parent_electron_id - result_object._initialize_nodes() - app_log.debug("2: Constructed result object and initialized nodes.") + finally: + if not valid_update: + return - update.persist(result_object, electron_id=parent_electron_id) - app_log.debug("Result object persisted.") + node_id = node_result["node_id"] + node_status = node_result["status"] + detail = {"sub_dispatch_id": sub_dispatch_id} if sub_dispatch_id else {} - return result_object + if node_status and valid_update: + dispatch_id = dispatch_id + await dispatcher.notify_node_status(dispatch_id, node_id, node_status, detail) # Domain: result -def get_unique_id() -> str: - """ - Get a unique ID. - - Args: - None - - Returns: - str: Unique ID - - """ - return str(uuid.uuid4()) - - -async def make_dispatch( - json_lattice: str, parent_result_object: Result = None, parent_electron_id: int = None +def _redirect_lattice( + json_lattice: str, + parent_dispatch_id: str, + parent_electron_id: int, + loop: asyncio.AbstractEventLoop, ) -> str: - """Make a dispatch from a json-serialized lattice. + """Redirect a JSON lattice through the new DAL. Args: - json_lattice: a JSON-serialized lattice. - parent_result_object: the parent result object if json_lattice is a sublattice. - parent_electron_id: the DB id of the parent electron (for sublattices). - - Returns: - Dispatch ID of the lattice. - - """ - result_object = initialize_result_object( - json_lattice, parent_result_object, parent_electron_id - ) - _register_result_object(result_object) - return result_object.dispatch_id - + json_lattice: A JSON-serialized lattice. + parent_dispatch_id: The id of a sublattice's parent dispatch. -async def make_sublattice_dispatch(result_object: Result, node_result: dict) -> str: - """Get sublattice json lattice (once the transport graph has been built) and invoke make_dispatch. - - Args: - result_object: Result object for parent dispatch of the node. - node_result: Result of the node. + This will only be triggered from either the monolithic /submit + endpoint or a monolithic sublattice dispatch. Returns: - str: Dispatch ID of the sublattice. + The dispatch manifest """ - node_id = node_result["node_id"] - json_lattice = node_result["output"].object_string - parent_electron_id = load.electron_record(result_object.dispatch_id, node_id)["id"] - app_log.debug( - f"Making sublattice dispatch for node_id {node_id} and electron_id {parent_electron_id}." - ) - return await make_dispatch(json_lattice, result_object, parent_electron_id) - + lattice = Lattice.deserialize_from_json(json_lattice) + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(lattice, staging_dir) + + # Trigger an internal asset pull from /tmp to object store + coro = manifest_importer.import_manifest( + manifest, + parent_dispatch_id, + parent_electron_id, + ) + filtered_manifest = manifest_importer._import_manifest( + manifest, + parent_dispatch_id, + parent_electron_id, + ) -def _get_result_object_from_new_lattice( - json_lattice: str, old_result_object: Result, reuse_previous_results: bool -) -> Result: - """Get new result object for re-dispatching from new lattice json. + manifest_importer._pull_assets(filtered_manifest) - Args: - json_lattice: JSON-serialized lattice. - old_result_object: Result object of the previous dispatch. + return filtered_manifest.metadata.dispatch_id - Returns: - Result object. - """ - lat = Lattice.deserialize_from_json(json_lattice) - result_object = Result(lat, get_unique_id()) - result_object._initialize_nodes() +async def make_dispatch( + json_lattice: str, parent_dispatch_id: str = None, parent_electron_id: int = None +) -> str: + return await run_in_executor( + _redirect_lattice, + json_lattice, + parent_dispatch_id, + parent_electron_id, + asyncio.get_running_loop(), + ) - if reuse_previous_results: - tg = result_object.lattice.transport_graph - tg_old = old_result_object.lattice.transport_graph - reusable_nodes = TransportGraphOps(tg_old).get_reusable_nodes(tg) - TransportGraphOps(tg).copy_nodes_from(tg_old, reusable_nodes) - return result_object +def get_result_object(dispatch_id: str, bare: bool = True) -> SRVResult: + app_log.debug(f"Getting result object from db, bare={bare}") + return get_result_object_from_db(dispatch_id, bare) -def _get_result_object_from_old_result( - old_result_object: Result, reuse_previous_results: bool -) -> Result: - """Get new result object for re-dispatching from old result object. +def finalize_dispatch(dispatch_id: str): + app_log.debug(f"Finalizing dispatch {dispatch_id}") - Args: - old_result_object: Result object of the previous dispatch. - reuse_previous_results: Whether to reuse previous results. - Returns: - Result: Result object for the new dispatch. +async def persist_result(dispatch_id: str): + await _update_parent_electron(dispatch_id) - """ - result_object = Result(old_result_object.lattice, get_unique_id()) - result_object._num_nodes = old_result_object._num_nodes - if not reuse_previous_results: - result_object._initialize_nodes() +async def _update_parent_electron(dispatch_id: str): + dispatch_attrs = await dispatch.get(dispatch_id, ["electron_id", "status", "end_time"]) + parent_eid = dispatch_attrs["electron_id"] - return result_object + if parent_eid: + dispatch_id, node_id = resolve_electron_id(parent_eid) + status = dispatch_attrs["status"] + node_result = generate_node_result( + dispatch_id=dispatch_id, + node_id=node_id, + end_time=dispatch_attrs["end_time"], + status=status, + ) + parent_result_obj = get_result_object(dispatch_id) + app_log.debug(f"Updating sublattice parent node {dispatch_id}:{node_id}") + await update_node_result(parent_result_obj.dispatch_id, node_result) -def make_derived_dispatch( - parent_dispatch_id: str, - json_lattice: Optional[str] = None, - electron_updates: Optional[Dict[str, Callable]] = None, - reuse_previous_results: bool = False, -) -> str: - """Make a re-dispatch from a previous dispatch. +def _filter_sublattice_status( + dispatch_id, node_id, status, node_type, sub_dispatch_id, node_result +): + if status == Result.COMPLETED and node_type == "sublattice" and not sub_dispatch_id: + node_result["status"] = RESULT_STATUS.DISPATCHING + return node_result - Args: - parent_dispatch_id: Dispatch ID of the parent dispatch. - json_lattice: JSON-serialized lattice of the new dispatch. - electron_updates: Dictionary of electron updates. - reuse_previous_results: Whether to reuse previous results. - Returns: - str: Dispatch ID of the new dispatch. +async def _make_sublattice_dispatch(dispatch_id: str, node_result: dict): + try: + manifest, parent_electron_id = await run_in_executor( + _make_sublattice_dispatch_helper, + dispatch_id, + node_result, + ) - """ - if electron_updates is None: - electron_updates = {} + imported_manifest = await manifest_importer.import_manifest( + manifest=manifest, + parent_dispatch_id=dispatch_id, + parent_electron_id=parent_electron_id, + ) - old_result_object = load.get_result_object_from_storage(parent_dispatch_id) + return imported_manifest.metadata.dispatch_id - if json_lattice: - result_object = _get_result_object_from_new_lattice( - json_lattice, old_result_object, reuse_previous_results + except ValidationError as ex: + # Fall back to legacy sublattice handling + # NB: this loads the JSON sublattice in memory + json_lattice, parent_electron_id = await run_in_executor( + _legacy_sublattice_dispatch_helper, + dispatch_id, + node_result, ) - else: - result_object = _get_result_object_from_old_result( - old_result_object, reuse_previous_results + return await make_dispatch( + json_lattice, + dispatch_id, + parent_electron_id, ) - result_object.lattice.transport_graph.apply_electron_updates(electron_updates) - result_object.lattice.transport_graph.dirty_nodes = list( - result_object.lattice.transport_graph._graph.nodes - ) - update.persist(result_object) - _register_result_object(result_object) - app_log.debug(f"Redispatch result object: {result_object}") - return result_object.dispatch_id +def _legacy_sublattice_dispatch_helper(dispatch_id: str, node_result: Dict): + app_log.debug("falling back to legacy sublattice dispatch") + result_object = get_result_object(dispatch_id, bare=True) + node_id = node_result["node_id"] + parent_node = result_object.lattice.transport_graph.get_node(node_id) + bg_output = parent_node.get_value("output") -def get_result_object(dispatch_id: str) -> Result: - return _registered_dispatches.get(dispatch_id) + parent_electron_id = parent_node._electron_id + json_lattice = bg_output.object_string + return json_lattice, parent_electron_id -def _register_result_object(result_object: Result): - dispatch_id = result_object.dispatch_id - _registered_dispatches[dispatch_id] = result_object - _dispatch_status_queues[dispatch_id] = asyncio.Queue() +def _make_sublattice_dispatch_helper(dispatch_id: str, node_result: Dict): + """Helper function for performing DB queries related to sublattices.""" + result_object = get_result_object(dispatch_id, bare=True) + node_id = node_result["node_id"] + parent_node = result_object.lattice.transport_graph.get_node(node_id) + bg_output = parent_node.get_value("output") + manifest = ResultSchema.parse_raw(bg_output.object_string) + parent_electron_id = parent_node._electron_id -def finalize_dispatch(dispatch_id: str): - del _dispatch_status_queues[dispatch_id] - del _registered_dispatches[dispatch_id] + return manifest, parent_electron_id -def get_status_queue(dispatch_id: str): - return _dispatch_status_queues[dispatch_id] +# Common Result object queries -async def persist_result(dispatch_id: str): - result_object = get_result_object(dispatch_id) - upsert_lattice_data(result_object.dispatch_id) - await _update_parent_electron(result_object) +def generate_dispatch_result( + dispatch_id, + start_time=None, + end_time=None, + status=None, + error=None, + result=None, +): + return { + "start_time": start_time, + "end_time": end_time, + "status": status, + "error": error, + "result": result, + } -async def _update_parent_electron(result_object: Result): - if parent_eid := result_object._electron_id: - dispatch_id, node_id = resolve_electron_id(parent_eid) - status = result_object.status - if status == RESULT_STATUS.POSTPROCESSING_FAILED: - status = RESULT_STATUS.FAILED - parent_result_obj = get_result_object(dispatch_id) - node_result = generate_node_result( - dispatch_id=dispatch_id, - node_id=node_id, - node_name=parent_result_obj.lattice.transport_graph.get_node_value(node_id, "name"), - end_time=result_object.end_time, - status=status, - output=result_object._result, - error=result_object._error, - sub_dispatch_id=load.sublattice_dispatch_id(parent_eid), - sublattice_result=result_object, - ) +# Ensure that a dispatch is only run once; in the future, also check +# if all assets have been uploaded - app_log.debug(f"Updating sublattice parent node {dispatch_id}:{node_id}") - await update_node_result(parent_result_obj, node_result) +async def ensure_dispatch(dispatch_id: str) -> bool: + """Check if a dispatch can be run. -def upsert_lattice_data(dispatch_id: str): - result_object = get_result_object(dispatch_id) - # Redirect to new DAL -- this is a temporary fix as - # upsert_lattice_data will be obsoleted next by the next patch. - update.lattice_data(result_object) + The following criteria must be met: + * The dispatch has not been run before. + * (later) all assets have been uploaded + """ + return await run_in_executor( + SRVResult.ensure_run_once, + dispatch_id, + ) diff --git a/covalent_dispatcher/_core/data_modules/asset_manager.py b/covalent_dispatcher/_core/data_modules/asset_manager.py new file mode 100644 index 000000000..984ea6e46 --- /dev/null +++ b/covalent_dispatcher/_core/data_modules/asset_manager.py @@ -0,0 +1,89 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilties to transfer data between Covalent and compute backends +""" + +import asyncio +from concurrent.futures import ThreadPoolExecutor +from typing import Dict + +from covalent._shared_files import logger +from covalent._shared_files.schemas.asset import AssetUpdate + +from ..._dal.result import get_result_object as get_result_object +from .utils import run_in_executor + +app_log = logger.app_log +am_pool = ThreadPoolExecutor() + + +# Consumed by Runner +async def upload_asset_for_nodes(dispatch_id: str, key: str, dest_uris: dict): + """Typical keys: "output", "deps", "call_before", "call_after", "function""" + + result_object = get_result_object(dispatch_id, bare=True) + tg = result_object.lattice.transport_graph + loop = asyncio.get_running_loop() + + futs = [] + for node_id, dest_uri in dest_uris.items(): + if dest_uri: + node = tg.get_node(node_id) + asset = node.get_asset(key, session=None) + futs.append(loop.run_in_executor(am_pool, asset.upload, dest_uri)) + + await asyncio.gather(*futs) + + +async def download_assets_for_node( + dispatch_id: str, node_id: int, asset_updates: Dict[str, AssetUpdate] +): + # Keys for src_uris: "output", "stdout", "stderr" + + result_object = get_result_object(dispatch_id, bare=True) + tg = result_object.lattice.transport_graph + node = tg.get_node(node_id) + loop = asyncio.get_running_loop() + + futs = [] + db_updates = {} + + # Mapping from asset key to (non-empty) remote uri + assets_to_download = {} + + # Prepare asset metadata update; prune empty fields + for key in asset_updates: + update = {} + asset = asset_updates[key].dict() + if asset["remote_uri"]: + assets_to_download[key] = asset["remote_uri"] + # Prune empty fields + for attr, val in asset.items(): + if val is not None: + update[attr] = val + if update: + db_updates[key] = update + + # Update metadata using the designated DB worker thread + await run_in_executor(node.update_assets, db_updates) + + for key, remote_uri in assets_to_download.items(): + asset = node.get_asset(key, session=None) + # Download assets concurrently. + futs.append(loop.run_in_executor(am_pool, asset.download, remote_uri)) + await asyncio.gather(*futs) diff --git a/covalent_dispatcher/_core/data_modules/dispatch.py b/covalent_dispatcher/_core/data_modules/dispatch.py new file mode 100644 index 000000000..9ecc72e3f --- /dev/null +++ b/covalent_dispatcher/_core/data_modules/dispatch.py @@ -0,0 +1,66 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Queries involving dispatches +""" + +from typing import Dict, List + +from ..._dal.result import get_result_object +from .utils import run_in_executor + + +def get_sync(dispatch_id: str, keys: List[str]) -> Dict: + refresh = False + result_object = get_result_object(dispatch_id) + return result_object.get_values(keys, refresh=refresh) + + +async def get(dispatch_id: str, keys: List[str]) -> Dict: + return await run_in_executor( + get_sync, + dispatch_id, + keys, + ) + + +def get_incomplete_tasks_sync(dispatch_id: str) -> Dict: + """Query all cancelled or failed tasks""" + result_object = get_result_object(dispatch_id) + return result_object._get_incomplete_nodes() + + +async def get_incomplete_tasks(dispatch_id: str) -> Dict: + """Query all cancelled or failed tasks in a dispatch. + + Args: + dispatch_id: The id of the dispatch + + Returns: + {"cancelled": [node_ids], "failed": [node_ids]} + """ + + return await run_in_executor(get_incomplete_tasks_sync, dispatch_id) + + +def update_sync(dispatch_id, dispatch_result): + result_object = get_result_object(dispatch_id) + result_object._update_dispatch(**dispatch_result) + + +async def update(dispatch_id, dispatch_result): + await run_in_executor(update_sync, dispatch_id, dispatch_result) diff --git a/covalent_dispatcher/_core/data_modules/electron.py b/covalent_dispatcher/_core/data_modules/electron.py new file mode 100644 index 000000000..9bb7cad02 --- /dev/null +++ b/covalent_dispatcher/_core/data_modules/electron.py @@ -0,0 +1,110 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for querying the transport graph +""" + +from typing import Dict, List + +from ..._dal.result import get_result_object +from .utils import run_in_executor + + +def get_bulk_sync(dispatch_id: str, node_ids: List[int], keys: List[str]) -> List[Dict]: + result_object = get_result_object(dispatch_id) + attrs = result_object.lattice.transport_graph.get_values_for_nodes( + node_ids=node_ids, + keys=keys, + refresh=False, + ) + return attrs + + +async def get_bulk(dispatch_id: str, node_ids: List[int], keys: List[str]) -> List[Dict]: + """Query attributes for multiple electrons. + + Args: + node_ids: The list of nodes to query + keys: The list of attributes to query for each electron + + Returns: + A list of dictionaries {attr_key: attr_val}, one for + each node id, in the same order as `node_ids` + + Example: + ``` + await get_bulk( + "my_dispatch", [2, 4], ["name", "status"], + ) + ``` + will return + ``` + [ + { + "name": "task_2", "status": RESULT_STATUS.COMPLETED, + }, + { + "name": "task_4, "status": RESULT_STATUS.FAILED, + }, + ] + ``` + + """ + return await run_in_executor( + get_bulk_sync, + dispatch_id, + node_ids, + keys, + ) + + +async def get(dispatch_id: str, node_id: int, keys: List[str]) -> Dict: + """Convenience function to query attributes for an electron. + + Args: + node_id: The node to query + keys: The list of attributes to query + + Returns: + A dictionary {attr_key: attr_val} + + Example: + ``` + await get( + "my_dispatch", 2, ["name", "status"], + ) + ``` + will return + ``` + { + "name": "task_2", "status": RESULT_STATUS.COMPLETED, + } + ``` + + """ + attrs = await get_bulk(dispatch_id, [node_id], keys) + return attrs[0] + + +def update_sync(dispatch_id: str, node_result: Dict): + result_object = get_result_object(dispatch_id, bare=True) + return result_object._update_node(**node_result) + + +async def update(dispatch_id: str, node_result: Dict): + """Update a node's attributes""" + return await run_in_executor(update_sync, dispatch_id, node_result) diff --git a/covalent_dispatcher/_core/data_modules/graph.py b/covalent_dispatcher/_core/data_modules/graph.py new file mode 100644 index 000000000..ecea817d7 --- /dev/null +++ b/covalent_dispatcher/_core/data_modules/graph.py @@ -0,0 +1,102 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utilities for querying the transport graph +""" + + +# Note: these query static information which should be amenable to caching + +from typing import Dict, List + +import networkx as nx + +from ..._dal.result import get_result_object +from .utils import run_in_executor + + +def get_incoming_edges_sync(dispatch_id: str, node_id: int): + """Query in-edges of a node. + + Returns: + List[Edge], where + + Edge is a dictionary with structure + source: int, + target: int, + attrs: dict + """ + + result_object = get_result_object(dispatch_id) + return result_object.lattice.transport_graph.get_incoming_edges(node_id) + + +def get_node_successors_sync( + dispatch_id: str, + node_id: int, + attrs: List[str], +) -> List[Dict]: + """Get child nodes with multiplicity. + + Parameters: + node_id: id of node + attr_keys: list of node attributes to return, such as task_group_id + + Returns: + List[Dict], where each dictionary is of the form + {"node_id": node_id, attr_key_1: node_attr[attr_key_1], ...} + + """ + + result_object = get_result_object(dispatch_id) + return result_object.lattice.transport_graph.get_successors(node_id, attrs) + + +def get_nodes_links_sync(dispatch_id: str) -> dict: + """Return the internal transport graph in NX node-link form""" + + # Need the whole NX graph here + result_object = get_result_object(dispatch_id, False) + g = result_object.lattice.transport_graph.get_internal_graph_copy() + return nx.readwrite.node_link_data(g) + + +def get_nodes_sync(dispatch_id: str) -> List[int]: + """Return a list of all node ids in the graph.""" + result_object = get_result_object(dispatch_id, False) + g = result_object.lattice.transport_graph.get_internal_graph_copy() + return list(g.nodes) + + +async def get_incoming_edges(dispatch_id: str, node_id: int): + return await run_in_executor(get_incoming_edges_sync, dispatch_id, node_id) + + +async def get_node_successors( + dispatch_id: str, + node_id: int, + attrs: List[str] = ["task_group_id"], +) -> List[Dict]: + return await run_in_executor(get_node_successors_sync, dispatch_id, node_id, attrs) + + +async def get_nodes_links(dispatch_id: str) -> Dict: + return await run_in_executor(get_nodes_links_sync, dispatch_id) + + +async def get_nodes(dispatch_id: str) -> List[int]: + return await run_in_executor(get_nodes_sync, dispatch_id) diff --git a/covalent_dispatcher/_core/data_modules/importer.py b/covalent_dispatcher/_core/data_modules/importer.py new file mode 100644 index 000000000..4630b9544 --- /dev/null +++ b/covalent_dispatcher/_core/data_modules/importer.py @@ -0,0 +1,148 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Functionality for importing dispatch submissions +""" + +import uuid +from typing import Optional + +from covalent._shared_files import logger +from covalent._shared_files.config import get_config +from covalent._shared_files.schemas.result import ResultSchema +from covalent_dispatcher._dal.asset import copy_asset +from covalent_dispatcher._dal.importers.result import handle_redispatch, import_result +from covalent_dispatcher._dal.result import Result as SRVResult + +from .utils import dm_pool, run_in_executor + +BASE_PATH = get_config("dispatcher.results_dir") + +app_log = logger.app_log + +# Concurrent futures for copying assets during redispatch +copy_futures = {} + + +# Domain: result +def get_unique_id() -> str: + """ + Get a unique ID. + + Args: + None + + Returns: + str: Unique ID + """ + + return str(uuid.uuid4()) + + +def _import_manifest( + res: ResultSchema, + parent_dispatch_id: Optional[str], + parent_electron_id: Optional[int], +) -> ResultSchema: + if not res.metadata.dispatch_id: + res.metadata.dispatch_id = get_unique_id() + + # Compute root_dispatch_id for sublattice dispatches + if parent_dispatch_id: + parent_result_object = SRVResult.from_dispatch_id( + dispatch_id=parent_dispatch_id, + bare=True, + ) + res.metadata.root_dispatch_id = parent_result_object.root_dispatch_id + else: + res.metadata.root_dispatch_id = res.metadata.dispatch_id + + return import_result(res, BASE_PATH, parent_electron_id) + + +def _get_all_assets(dispatch_id: str): + result_object = SRVResult.from_dispatch_id(dispatch_id, bare=True) + return result_object.get_all_assets() + + +def _pull_assets(manifest: ResultSchema) -> None: + dispatch_id = manifest.metadata.dispatch_id + assets = _get_all_assets(dispatch_id) + futs = [] + for asset in assets["lattice"]: + if asset.remote_uri: + asset.download(asset.remote_uri) + + for asset in assets["nodes"]: + if asset.remote_uri: + asset.download(asset.remote_uri) + + app_log.debug(f"imported {len(futs)} assets for dispatch {dispatch_id}") + + +async def import_manifest( + manifest: ResultSchema, + parent_dispatch_id: Optional[str], + parent_electron_id: Optional[int], +) -> ResultSchema: + filtered_manifest = await run_in_executor( + _import_manifest, manifest, parent_dispatch_id, parent_electron_id + ) + await run_in_executor(_pull_assets, filtered_manifest) + + return filtered_manifest + + +def _copy_assets(assets_to_copy): + for item in assets_to_copy: + src, dest = item + copy_asset(src, dest) + + +def _import_derived_manifest( + manifest: ResultSchema, + parent_dispatch_id: str, + reuse_previous_results: bool, +) -> ResultSchema: + filtered_manifest = _import_manifest(manifest, None, None) + filtered_manifest, assets_to_copy = handle_redispatch( + filtered_manifest, parent_dispatch_id, reuse_previous_results + ) + + dispatch_id = filtered_manifest.metadata.dispatch_id + fut = dm_pool.submit(_copy_assets, assets_to_copy) + copy_futures[dispatch_id] = fut + fut.add_done_callback(lambda x: copy_futures.pop(dispatch_id)) + + return filtered_manifest + + +async def import_derived_manifest( + manifest: ResultSchema, + parent_dispatch_id: str, + reuse_previous_results: bool, +) -> ResultSchema: + filtered_manifest = await run_in_executor( + _import_derived_manifest, + manifest, + parent_dispatch_id, + reuse_previous_results, + ) + + await run_in_executor(_pull_assets, filtered_manifest) + + return filtered_manifest diff --git a/covalent_dispatcher/_core/data_modules/job_manager.py b/covalent_dispatcher/_core/data_modules/job_manager.py index d8bdb837e..58c437d92 100644 --- a/covalent_dispatcher/_core/data_modules/job_manager.py +++ b/covalent_dispatcher/_core/data_modules/job_manager.py @@ -98,16 +98,16 @@ async def set_job_handle(dispatch_id: str, task_id: int, job_handle: str) -> Non await _set_job_metadata(dispatch_id, task_id, job_handle=job_handle) -async def set_cancel_result(dispatch_id: str, task_id: int, cancel_status: bool) -> None: +async def set_job_status(dispatch_id: str, task_id: int, status: str) -> None: """ - Update the cancel status of the job in the database if task cancellation is requested + Update the status of the job in the database Arg(s) dispatch_id: Dispatch ID of the lattice task_id: ID of the task in the lattice - cancel_status: True/False indicating whether the task is to be cancelled + status: status Return(s) None """ - await _set_job_metadata(dispatch_id, task_id, cancel_successful=cancel_status) + await _set_job_metadata(dispatch_id, task_id, job_status=status) diff --git a/covalent_dispatcher/_core/data_modules/lattice.py b/covalent_dispatcher/_core/data_modules/lattice.py new file mode 100644 index 000000000..cacfcfdf5 --- /dev/null +++ b/covalent_dispatcher/_core/data_modules/lattice.py @@ -0,0 +1,38 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Queries involving lattice +""" + +from typing import Dict, List + +from ..._dal.result import get_result_object +from .utils import run_in_executor + + +def get_sync(dispatch_id: str, keys: List[str]) -> Dict: + refresh = False + result_object = get_result_object(dispatch_id) + return result_object.lattice.get_values(keys, refresh=refresh) + + +async def get(dispatch_id: str, keys: List[str]) -> Dict: + return await run_in_executor( + get_sync, + dispatch_id, + keys, + ) diff --git a/covalent_dispatcher/_core/data_modules/utils.py b/covalent_dispatcher/_core/data_modules/utils.py new file mode 100644 index 000000000..e8fba2cda --- /dev/null +++ b/covalent_dispatcher/_core/data_modules/utils.py @@ -0,0 +1,31 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Utils for the data service +""" + +import asyncio +from concurrent.futures import ThreadPoolExecutor + +# Worker thread for Datastore I/O Clamp this threadpool to one +# thread because Sqlite only supports a single writer. +dm_pool = ThreadPoolExecutor(max_workers=1) + + +def run_in_executor(func, *args) -> asyncio.Future: + loop = asyncio.get_running_loop() + return loop.run_in_executor(dm_pool, func, *args) diff --git a/covalent_dispatcher/_core/dispatcher.py b/covalent_dispatcher/_core/dispatcher.py index 0054adcf8..7041dde88 100644 --- a/covalent_dispatcher/_core/dispatcher.py +++ b/covalent_dispatcher/_core/dispatcher.py @@ -23,39 +23,39 @@ from datetime import datetime, timezone from typing import Dict, List, Tuple -from covalent._results_manager import Result +import networkx as nx + from covalent._shared_files import logger +from covalent._shared_files.config import get_config from covalent._shared_files.defaults import WAIT_EDGE_NAME, parameter_prefix from covalent._shared_files.util_classes import RESULT_STATUS -from covalent_ui import result_webhook from . import data_manager as datasvc from . import runner -from .data_modules.job_manager import set_cancel_requested +from .data_modules import graph as tg_utils +from .data_modules import job_manager as jbmgr +from .dispatcher_modules.caches import _pending_parents, _sorted_task_groups, _unresolved_tasks +from .runner_modules.cancel import cancel_tasks app_log = logger.app_log log_stack_info = logger.log_stack_info +_global_status_queue = None +_status_queues = {} +_futures = {} +_global_event_listener = None -""" -Dispatcher module is responsible for planning and dispatching workflows. The dispatcher - -1. Submits tasks to the Runner module. -2. Retrieves information using the Data Manager module. -3. Handles the tasks in terminal (COMPLETED, FAILED, CANCELLED) states. -4. Handles sublattice dispatches once the corresponding graph has been built in the Runner module. -""" +SYNC_DISPATCHES = get_config("dispatcher.use_async_dispatcher") == "false" # Domain: dispatcher -def _get_abstract_task_inputs(node_id: int, node_name: str, result_object: Result) -> dict: +async def _get_abstract_task_inputs(dispatch_id: str, node_id: int, node_name: str) -> dict: """Return placeholders for the required inputs for a task execution. Args: + dispatch_id: id of the current dispatch node_id: Node id of this task in the transport graph. node_name: Name of the node. - result_object: Result object to be used to update and store execution related - info including the results. Returns: inputs: Input dictionary to be passed to the task with `node_id` placeholders for args, kwargs. These are to be @@ -64,16 +64,17 @@ def _get_abstract_task_inputs(node_id: int, node_name: str, result_object: Resul abstract_task_input = {"args": [], "kwargs": {}} - for parent in result_object.lattice.transport_graph.get_dependencies(node_id): - edge_data = result_object.lattice.transport_graph.get_edge_data(parent, node_id) + for edge in await tg_utils.get_incoming_edges(dispatch_id, node_id): + parent = edge["source"] - for _, d in edge_data.items(): - if d["edge_name"] != WAIT_EDGE_NAME: - if d["param_type"] == "arg": - abstract_task_input["args"].append((parent, d["arg_index"])) - elif d["param_type"] == "kwarg": - key = d["edge_name"] - abstract_task_input["kwargs"][key] = parent + d = edge["attrs"] + + if d["edge_name"] != WAIT_EDGE_NAME: + if d["param_type"] == "arg": + abstract_task_input["args"].append((parent, d["arg_index"])) + elif d["param_type"] == "kwarg": + key = d["edge_name"] + abstract_task_input["kwargs"][key] = parent sorted_args = sorted(abstract_task_input["args"], key=lambda x: x[1]) abstract_task_input["args"] = [x[0] for x in sorted_args] @@ -82,54 +83,40 @@ def _get_abstract_task_inputs(node_id: int, node_name: str, result_object: Resul # Domain: dispatcher -async def _handle_completed_node(result_object, node_id, pending_parents): - """ - Process the completed node in the transport graph - - Arg(s) - result_object: Result object associated with the workflow - node_id: ID of the node in the transport graph - pending_parents: Parents of this node yet to be executed - - Return(s) - List of nodes ready to be executed - """ - g = result_object.lattice.transport_graph._graph - - ready_nodes = [] +async def _handle_completed_node(dispatch_id: str, node_id: int): + next_task_groups = [] app_log.debug(f"Node {node_id} completed") - for child, edges in g.adj[node_id].items(): - for _ in edges: - pending_parents[child] -= 1 - if pending_parents[child] < 1: - app_log.debug(f"Queuing node {child} for execution") - ready_nodes.append(child) - return ready_nodes + parent_gid = (await datasvc.electron.get(dispatch_id, node_id, ["task_group_id"]))[ + "task_group_id" + ] + for child in await tg_utils.get_node_successors(dispatch_id, node_id): + node_id = child["node_id"] + gid = child["task_group_id"] + app_log.debug(f"dispatch {dispatch_id}: parent gid {parent_gid}, child gid {gid}") + if parent_gid != gid: + now_pending = await _pending_parents.decrement(dispatch_id, gid) + if now_pending < 1: + app_log.debug(f"Queuing task group {gid} for execution") + next_task_groups.append(gid) + + return next_task_groups # Domain: dispatcher -async def _handle_failed_node(result_object, node_id): - result_object._task_failed = True - result_object._end_time = datetime.now(timezone.utc) - app_log.debug(f"Node {result_object.dispatch_id}:{node_id} failed") +async def _handle_failed_node(dispatch_id: str, node_id: int): + app_log.debug(f"Node {dispatch_id}:{node_id} failed") app_log.debug("8A: Failed node upsert statement (run_planned_workflow)") - datasvc.upsert_lattice_data(result_object.dispatch_id) - await result_webhook.send_update(result_object) # Domain: dispatcher -async def _handle_cancelled_node(result_object, node_id): - result_object._task_cancelled = True - result_object._end_time = datetime.now(timezone.utc) - app_log.debug(f"Node {result_object.dispatch_id}:{node_id} cancelled") +async def _handle_cancelled_node(dispatch_id: str, node_id: int): + app_log.debug(f"Node {dispatch_id}:{node_id} cancelled") app_log.debug("9: Cancelled node upsert statement (run_planned_workflow)") - datasvc.upsert_lattice_data(result_object.dispatch_id) - await result_webhook.send_update(result_object) # Domain: dispatcher -async def _get_initial_tasks_and_deps(result_object: Result) -> Tuple[int, int, Dict]: +async def _get_initial_tasks_and_deps(dispatch_id: str) -> Tuple[int, int, Dict]: """Compute the initial batch of tasks to submit and initialize each task's dep count Returns: (num_tasks, ready_nodes, pending_parents) where num_tasks is @@ -140,192 +127,151 @@ async def _get_initial_tasks_and_deps(result_object: Result) -> Tuple[int, int, """ - num_tasks = 0 - ready_nodes = [] + # Number of pending predecessor nodes for each task group pending_parents = {} - g = result_object.lattice.transport_graph._graph - for node_id, d in g.in_degree(): - app_log.debug(f"Node {node_id} has {d} parents") + g_node_link = await tg_utils.get_nodes_links(dispatch_id) + g = nx.readwrite.node_link_graph(g_node_link) + + # Topologically sort each task group + sorted_task_groups = {} + for node_id in nx.topological_sort(g): + gid = g.nodes[node_id]["task_group_id"] + if gid not in sorted_task_groups: + sorted_task_groups[gid] = [node_id] + pending_parents[gid] = 0 + else: + sorted_task_groups[gid].append(node_id) + + for node_id in g.nodes: + parent_gid = g.nodes[node_id]["task_group_id"] + for succ, datadict in g.adj[node_id].items(): + child_gid = g.nodes[succ]["task_group_id"] - pending_parents[node_id] = d - num_tasks += 1 - if d == 0: - ready_nodes.append(node_id) + if parent_gid != child_gid: + n_edges = len(datadict.keys()) + pending_parents[child_gid] += n_edges - return num_tasks, ready_nodes, pending_parents + initial_task_groups = [gid for gid, d in pending_parents.items() if d == 0] + app_log.debug(f"Sorted task groups: {sorted_task_groups}") + return initial_task_groups, pending_parents, sorted_task_groups # Domain: dispatcher -async def _submit_task(result_object, node_id): +async def _submit_task_group(dispatch_id: str, sorted_nodes: List[int], task_group_id: int): + # Handle parameter nodes # Get name of the node for the current task - node_name = result_object.lattice.transport_graph.get_node_value(node_id, "name") - node_status = result_object.lattice.transport_graph.get_node_value(node_id, "status") + node_name = (await datasvc.electron.get(dispatch_id, sorted_nodes[0], ["name"]))["name"] + app_log.debug(f"7A: Node name: {node_name} (run_planned_workflow).") # Handle parameter nodes if node_name.startswith(parameter_prefix): - output = result_object.lattice.transport_graph.get_node_value(node_id, "value") - timestamp = datetime.now(timezone.utc) - node_result = datasvc.generate_node_result( - dispatch_id=result_object.dispatch_id, - node_id=node_id, - node_name=node_name, - start_time=timestamp, - end_time=timestamp, - status=RESULT_STATUS.COMPLETED, - output=output, - ) - await datasvc.update_node_result(result_object, node_result) - app_log.debug(f"Updated parameter node {node_id}.") - - elif node_status == RESULT_STATUS.COMPLETED: - timestamp = datetime.now(timezone.utc) - output = result_object.lattice.transport_graph.get_node_value(node_id, "output") - node_result = datasvc.generate_node_result( - dispatch_id=result_object.dispatch_id, - node_id=node_id, - node_name=node_name, - start_time=timestamp, - end_time=timestamp, - status=RESULT_STATUS.COMPLETED, - output=output, - ) - await datasvc.update_node_result(result_object, node_result) - app_log.debug(f"Skipped completed node execution {node_name}.") + if len(sorted_nodes) > 1: + raise RuntimeError("Parameter nodes cannot be packed") + + app_log.debug("7C: Encountered parameter node {node_id}.") + app_log.debug("8: Starting update node (run_planned_workflow).") + + ts = datetime.now(timezone.utc) + node_result = { + "node_id": sorted_nodes[0], + "start_time": ts, + "end_time": ts, + "status": RESULT_STATUS.COMPLETED, + } + await datasvc.update_node_result(dispatch_id, node_result) + app_log.debug("8A: Update node success (run_planned_workflow).") else: - # Gather inputs and dispatch task - app_log.debug(f"Gathering inputs for task {node_id}.") - - abs_task_input = _get_abstract_task_inputs(node_id, node_name, result_object) - executor = result_object.lattice.transport_graph.get_node_value(node_id, "metadata")[ - "executor" - ] - executor_data = result_object.lattice.transport_graph.get_node_value(node_id, "metadata")[ - "executor_data" - ] - coro = runner.run_abstract_task( - dispatch_id=result_object.dispatch_id, - node_id=node_id, - executor=[executor, executor_data], - node_name=node_name, - abstract_inputs=abs_task_input, - ) - app_log.debug(f"Creating task {node_id}.") - asyncio.create_task(coro) - - -# Domain: dispatcher -async def _run_planned_workflow(result_object: Result, status_queue: asyncio.Queue) -> Result: - """ - Run the workflow in the topological order of their position on the - transport graph. Does this in an asynchronous manner so that nodes - at the same level are executed in parallel. Also updates the status - of the whole workflow execution. - - Args: - result_object: Result object being used for current dispatch - status_queue: message queue for notifying the main loop of status updates - - Returns: - None - """ - app_log.debug("Starting _run_planned_workflow ...") - result_object._status = RESULT_STATUS.RUNNING - result_object._start_time = datetime.now(timezone.utc) - datasvc.upsert_lattice_data(result_object.dispatch_id) - app_log.debug(f"Wrote lattice status {result_object._status} to DB.") + known_nodes = [] - tasks_left, initial_nodes, pending_parents = await _get_initial_tasks_and_deps(result_object) - - unresolved_tasks = 0 - - for node_id in initial_nodes: - unresolved_tasks += 1 - await _submit_task(result_object, node_id) - - while unresolved_tasks > 0: - app_log.debug(f"{tasks_left} tasks left to complete.") - app_log.debug( - f"{result_object.dispatch_id}: Waiting to hear from {unresolved_tasks} tasks." + # Skip the group if all task outputs can be reused from a + # previous dispatch (for redispatch). + statuses = await datasvc.electron.get_bulk(dispatch_id, sorted_nodes, ["status"]) + incomplete = list( + filter(lambda record: record["status"] != RESULT_STATUS.PENDING_REUSE, statuses) ) - node_id, node_status, detail = await status_queue.get() + if incomplete: + # Gather inputs for each task and send the task spec sequence to the runner + task_specs = [] + + for node_id in sorted_nodes: + app_log.debug(f"Gathering inputs for task {node_id} (run_planned_workflow).") + + abs_task_input = await _get_abstract_task_inputs(dispatch_id, node_id, node_name) + + executor_attrs = await datasvc.electron.get( + dispatch_id, + node_id, + ["executor", "executor_data"], + ) + selected_executor = executor_attrs["executor"] + selected_executor_data = executor_attrs["executor_data"] + task_spec = { + "function_id": node_id, + "name": node_name, + "args_ids": abs_task_input["args"], + "kwargs_ids": abs_task_input["kwargs"], + } + known_nodes += abs_task_input["args"] + known_nodes += list(abs_task_input["kwargs"].values()) + task_specs.append(task_spec) - app_log.debug( - f"Status queue msg for node id {node_id}: {node_status} with detail {detail}." - ) - - if node_status == RESULT_STATUS.RUNNING: - continue - - # Note: A node status can only be 'DISPATCHING' if it is a sublattice and the corresponding graph has been built. - if node_status == RESULT_STATUS.DISPATCHING_SUBLATTICE: - sub_dispatch_id = detail["sub_dispatch_id"] - run_dispatch(sub_dispatch_id) app_log.debug( - f"Submitted sublattice (dispatch id: {sub_dispatch_id}) to run_dispatch." + f"Submitting task group {dispatch_id}:{task_group_id} ({len(sorted_nodes)} tasks) to runner" ) - continue - - unresolved_tasks -= 1 + app_log.debug(f"Using new runner for task group {task_group_id}") - if node_status == RESULT_STATUS.COMPLETED: - tasks_left -= 1 - ready_nodes = await _handle_completed_node(result_object, node_id, pending_parents) - for node_id in ready_nodes: - unresolved_tasks += 1 - await _submit_task(result_object, node_id) + known_nodes = list(set(known_nodes)) - if node_status == RESULT_STATUS.FAILED: - await _handle_failed_node(result_object, node_id) - continue + task_spec = task_specs[0] + abstract_inputs = {"args": task_spec["args_ids"], "kwargs": task_spec["kwargs_ids"]} - if node_status == RESULT_STATUS.CANCELLED: - await _handle_cancelled_node(result_object, node_id) - continue - - if result_object._task_failed or result_object._task_cancelled: - app_log.debug(f"Workflow {result_object.dispatch_id} cancelled or failed") - failed_nodes = result_object._get_failed_nodes() - failed_nodes = map(lambda x: f"{x[0]}: {x[1]}", failed_nodes) - failed_nodes_msg = "\n".join(failed_nodes) - result_object._error = "The following tasks failed:\n" + failed_nodes_msg - result_object._status = ( - RESULT_STATUS.FAILED if result_object._task_failed else RESULT_STATUS.CANCELLED - ) - return result_object - - app_log.debug( - f"Tasks for {result_object.dispatch_id} finished running. Updating result webhook ..." - ) - await result_webhook.send_update(result_object) - return result_object + # Temporarily redirect to in-memory runner (this is incompatible with task packing) + if len(task_specs) > 1: + raise RuntimeError("Task packing is not supported yet.") + coro = runner.run_abstract_task( + dispatch_id=dispatch_id, + node_id=task_group_id, + node_name=node_name, + abstract_inputs=abstract_inputs, + selected_executor=[selected_executor, selected_executor_data], + ) -def _plan_workflow(result_object: Result) -> None: + asyncio.create_task(coro) + else: + ts = datetime.now(timezone.utc) + for node_id in sorted_nodes: + app_log.debug(f"Skipping already completed node {dispatch_id}:{node_id}") + node_result = { + "node_id": node_id, + "start_time": ts, + "end_time": ts, + "status": RESULT_STATUS.COMPLETED, + } + await datasvc.update_node_result(dispatch_id, node_result) + app_log.debug("8A: Update node success (run_planned_workflow).") + + +async def _plan_workflow(dispatch_id: str) -> None: """ Function to plan a workflow according to a schedule. Planning means to decide which executors (along with their arguments) will be used by each node. Args: - result_object: Result object being used for current dispatch + dispatch_id: id of current dispatch Returns: None """ - if result_object.lattice.get_metadata("schedule"): - # Custom scheduling logic of the format: - # scheduled_executors = get_schedule(result_object) + pass - # for node_id, executor in scheduled_executors.items(): - # result_object.lattice.transport_graph.set_node_value(node_id, "executor", executor) - pass - -async def run_workflow(result_object: Result) -> Result: +async def run_workflow(dispatch_id: str, wait: bool = SYNC_DISPATCHES) -> RESULT_STATUS: """ Plan and run the workflow by loading the result object corresponding to the dispatch id and retrieving essential information from it. @@ -338,31 +284,42 @@ async def run_workflow(result_object: Result) -> Result: Returns: The result object from the workflow execution - """ - app_log.debug(f"Starting run_workflow for dispatch id {result_object.dispatch_id} ...") - if result_object.status == RESULT_STATUS.COMPLETED: - datasvc.finalize_dispatch(result_object.dispatch_id) - return result_object + + app_log.debug("Inside run_workflow.") + + # Ensure that the dispatch is run at most once + can_run = await datasvc.ensure_dispatch(dispatch_id) + + if not can_run: + result_info = await datasvc.dispatch.get(dispatch_id, ["status"]) + dispatch_status = result_info["status"] + app_log.debug(f"Cannot start dispatch {dispatch_id}: current status {dispatch_status}") + return dispatch_status try: - _plan_workflow(result_object) - status_queue = datasvc.get_status_queue(result_object.dispatch_id) - result_object = await _run_planned_workflow(result_object, status_queue) + await _plan_workflow(dispatch_id) - except Exception as ex: - app_log.error(f"Exception during _run_planned_workflow: {ex}") + if wait: + fut = asyncio.Future() + _futures[dispatch_id] = fut + + dispatch_status = await _submit_initial_tasks(dispatch_id) - error_msg = "".join(traceback.TracebackException.from_exception(ex).format()) - result_object._status = RESULT_STATUS.FAILED - result_object._error = error_msg - result_object._end_time = datetime.now(timezone.utc) + if wait: + app_log.debug(f"Waiting for dispatch {dispatch_id}") + dispatch_status = await fut + else: + app_log.debug(f"Running dispatch {dispatch_id} asynchronously") + + except Exception as ex: + dispatch_status = await _handle_dispatch_exception(dispatch_id, ex) finally: - await datasvc.persist_result(result_object.dispatch_id) - datasvc.finalize_dispatch(result_object.dispatch_id) + if dispatch_status != RESULT_STATUS.RUNNING: + datasvc.finalize_dispatch(dispatch_id) - return result_object + return dispatch_status # Domain: dispatcher @@ -376,43 +333,227 @@ async def cancel_dispatch(dispatch_id: str, task_ids: List[int] = None) -> None: Return(s) None + """ + if task_ids is None: task_ids = [] - if not dispatch_id: - return - res_object = datasvc.get_result_object(dispatch_id) - if res_object is None: + if not dispatch_id: return - tg = res_object.lattice.transport_graph if task_ids: app_log.debug(f"Cancelling tasks {task_ids} in dispatch {dispatch_id}") else: - task_ids = list(tg._graph.nodes) + task_ids = await tg_utils.get_nodes(dispatch_id) + app_log.debug(f"Cancelling dispatch {dispatch_id}") - await set_cancel_requested(dispatch_id, task_ids) - await runner.cancel_tasks(dispatch_id, task_ids) + await jbmgr.set_cancel_requested(dispatch_id, task_ids) + await cancel_tasks(dispatch_id, task_ids) # Recursively cancel running sublattice dispatches - sub_ids = list(map(lambda x: tg.get_node_value(x, "sub_dispatch_id"), task_ids)) + attrs = await datasvc.electron.get_bulk(dispatch_id, task_ids, ["sub_dispatch_id"]) + sub_ids = list(map(lambda x: x["sub_dispatch_id"], attrs)) for sub_dispatch_id in sub_ids: await cancel_dispatch(sub_dispatch_id) def run_dispatch(dispatch_id: str) -> asyncio.Future: - """ - Run the workflow and return immediately + return asyncio.create_task(run_workflow(dispatch_id)) - Arg(s) - dispatch_id: Dispatch ID of the lattice - Return(s) - asyncio.Future +async def notify_node_status( + dispatch_id: str, node_id: int, status: RESULT_STATUS, detail: Dict = None +): + if detail is None: + detail = {} - """ - app_log.debug(f"Running dispatch with dispatch_id: {dispatch_id}.") - result_object = datasvc.get_result_object(dispatch_id) - return asyncio.create_task(run_workflow(result_object)) + msg = { + "dispatch_id": dispatch_id, + "node_id": node_id, + "status": status, + "detail": detail, + } + + await _global_status_queue.put(msg) + + +async def _finalize_dispatch(dispatch_id: str): + await _clear_caches(dispatch_id) + app_log.debug(f"Removed unresolved counter for {dispatch_id}") + + incomplete_tasks = await datasvc.dispatch.get_incomplete_tasks(dispatch_id) + failed = incomplete_tasks["failed"] + cancelled = incomplete_tasks["cancelled"] + if failed or cancelled: + app_log.debug(f"Workflow {dispatch_id} cancelled or failed") + failed_nodes = failed + failed_nodes = map(lambda x: f"{x[0]}: {x[1]}", failed_nodes) + failed_nodes_msg = "\n".join(failed_nodes) + error_msg = "The following tasks failed:\n" + failed_nodes_msg + ts = datetime.now(timezone.utc) + status = RESULT_STATUS.FAILED if failed else RESULT_STATUS.CANCELLED + result_update = datasvc.generate_dispatch_result( + dispatch_id, + status=status, + error=error_msg, + end_time=ts, + ) + await datasvc.dispatch.update(dispatch_id, result_update) + + app_log.debug("8: All tasks finished running (run_planned_workflow)") + + app_log.debug("Workflow already postprocessed") + + result_info = await datasvc.dispatch.get(dispatch_id, ["status"]) + return result_info["status"] + + +async def _initialize_caches(dispatch_id, pending_parents, sorted_task_groups): + for gid, indegree in pending_parents.items(): + await _pending_parents.set_pending(dispatch_id, gid, indegree) + + for gid, sorted_nodes in sorted_task_groups.items(): + await _sorted_task_groups.set_task_group(dispatch_id, gid, sorted_nodes) + + await _unresolved_tasks.set_unresolved(dispatch_id, 0) + + +async def _submit_initial_tasks(dispatch_id: str): + app_log.debug("3: Inside run_planned_workflow (run_planned_workflow).") + dispatch_result = datasvc.generate_dispatch_result( + dispatch_id, start_time=datetime.now(timezone.utc), status=RESULT_STATUS.RUNNING + ) + await datasvc.dispatch.update(dispatch_id, dispatch_result) + + app_log.debug(f"4: Workflow status changed to running {dispatch_id} (run_planned_workflow).") + app_log.debug("5: Wrote lattice status to DB (run_planned_workflow).") + + initial_groups, pending_parents, sorted_task_groups = await _get_initial_tasks_and_deps( + dispatch_id + ) + + await _initialize_caches(dispatch_id, pending_parents, sorted_task_groups) + + for gid in initial_groups: + sorted_nodes = sorted_task_groups[gid] + app_log.debug(f"Sorted nodes group group {gid}: {sorted_nodes}") + await _unresolved_tasks.increment(dispatch_id, len(sorted_nodes)) + + for gid in initial_groups: + sorted_nodes = sorted_task_groups[gid] + await _submit_task_group(dispatch_id, sorted_nodes, gid) + + return RESULT_STATUS.RUNNING + + +async def _handle_node_status_update(dispatch_id, node_id, node_status, detail): + app_log.debug(f"Received node status update {node_id}: {node_status}") + + if node_status == RESULT_STATUS.RUNNING: + return + + if node_status == RESULT_STATUS.DISPATCHING: + sub_dispatch_id = detail["sub_dispatch_id"] + run_dispatch(sub_dispatch_id) + app_log.debug(f"Running sublattice dispatch {sub_dispatch_id}") + + return + + # Terminal node statuses + + if node_status == RESULT_STATUS.COMPLETED: + next_task_groups = await _handle_completed_node(dispatch_id, node_id) + for gid in next_task_groups: + sorted_nodes = await _sorted_task_groups.get_task_group(dispatch_id, gid) + await _unresolved_tasks.increment(dispatch_id, len(sorted_nodes)) + await _submit_task_group(dispatch_id, sorted_nodes, gid) + + if node_status == RESULT_STATUS.FAILED: + await _handle_failed_node(dispatch_id, node_id) + + if node_status == RESULT_STATUS.CANCELLED: + await _handle_cancelled_node(dispatch_id, node_id) + + # Decrement after any increments to avoid race with + # finalize_dispatch() + await _unresolved_tasks.decrement(dispatch_id) + + +async def _handle_dispatch_exception(dispatch_id: str, ex: Exception) -> RESULT_STATUS: + error_msg = "".join(traceback.TracebackException.from_exception(ex).format()) + app_log.exception(f"Exception during _run_planned_workflow: {error_msg}") + + dispatch_result = datasvc.generate_dispatch_result( + dispatch_id, + end_time=datetime.now(timezone.utc), + status=RESULT_STATUS.FAILED, + error=error_msg, + ) + + await datasvc.dispatch.update(dispatch_id, dispatch_result) + return RESULT_STATUS.FAILED + + +# msg = { +# "dispatch_id": dispatch_id, +# "node_id": node_id, +# "status": status, +# "detail": detail, +# } +async def _node_event_listener(): + app_log.debug("Starting event listener") + while True: + msg = await _global_status_queue.get() + + asyncio.create_task(_handle_event(msg)) + + +async def _handle_event(msg: Dict): + dispatch_id = msg["dispatch_id"] + node_id = msg["node_id"] + node_status = msg["status"] + detail = msg["detail"] + + try: + await _handle_node_status_update(dispatch_id, node_id, node_status, detail) + + except Exception as ex: + dispatch_status = await _handle_dispatch_exception(dispatch_id, ex) + await datasvc.persist_result(dispatch_id) + fut = _futures.get(dispatch_id) + if fut: + fut.set_result(dispatch_status) + return dispatch_status + + unresolved = await _unresolved_tasks.get_unresolved(dispatch_id) + if unresolved < 1: + app_log.debug("Finalizing dispatch") + try: + dispatch_status = await _finalize_dispatch(dispatch_id) + except Exception as ex: + dispatch_status = await _handle_dispatch_exception(dispatch_id, ex) + + finally: + await datasvc.persist_result(dispatch_id) + fut = _futures.get(dispatch_id) + if fut: + fut.set_result(dispatch_status) + + return dispatch_status + + +async def _clear_caches(dispatch_id: str): + """Clean up all keys in caches.""" + await _unresolved_tasks.remove(dispatch_id) + + g_node_link = await tg_utils.get_nodes_links(dispatch_id) + g = nx.readwrite.node_link_graph(g_node_link) + + task_groups = {g.nodes[i]["task_group_id"] for i in g.nodes} + + for gid in task_groups: + # Clean up no longer referenced keys + await _pending_parents.remove(dispatch_id, gid) + await _sorted_task_groups.remove(dispatch_id, gid) diff --git a/covalent_dispatcher/_core/dispatcher_modules/__init__.py b/covalent_dispatcher/_core/dispatcher_modules/__init__.py new file mode 100644 index 000000000..21d7eaa5c --- /dev/null +++ b/covalent_dispatcher/_core/dispatcher_modules/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/covalent_dispatcher/_core/dispatcher_modules/caches.py b/covalent_dispatcher/_core/dispatcher_modules/caches.py new file mode 100644 index 000000000..ce7b53e59 --- /dev/null +++ b/covalent_dispatcher/_core/dispatcher_modules/caches.py @@ -0,0 +1,101 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Helper classes for the dispatcher +""" + +from .store import _DictStore, _KeyValueBase + + +def _pending_parents_key(dispatch_id: str, node_id: int): + return f"pending-parents-{dispatch_id}:{node_id}" + + +def _unresolved_tasks_key(dispatch_id: str): + return f"unresolved-{dispatch_id}" + + +def _task_groups_key(dispatch_id: str, task_group_id: int): + return f"task-groups-{dispatch_id}:{task_group_id}" + + +class _UnresolvedTasksCache: + def __init__(self, store: _KeyValueBase = _DictStore()): + self._store = store + + async def get_unresolved(self, dispatch_id: str): + key = _unresolved_tasks_key(dispatch_id) + return await self._store.get(key) + + async def set_unresolved(self, dispatch_id: str, val: int): + key = _unresolved_tasks_key(dispatch_id) + await self._store.insert(key, val) + + async def increment(self, dispatch_id: str, interval: int = 1): + key = _unresolved_tasks_key(dispatch_id) + return await self._store.increment(key, interval) + + async def decrement(self, dispatch_id: str): + key = _unresolved_tasks_key(dispatch_id) + return await self._store.increment(key, -1) + + async def remove(self, dispatch_id: str): + key = _unresolved_tasks_key(dispatch_id) + await self._store.remove(key) + + +class _PendingParentsCache: + def __init__(self, store: _KeyValueBase = _DictStore()): + self._store = store + + async def get_pending(self, dispatch_id: str, task_group_id: int): + key = _pending_parents_key(dispatch_id, task_group_id) + return await self._store.get(key) + + async def set_pending(self, dispatch_id: str, task_group_id: int, val: int): + key = _pending_parents_key(dispatch_id, task_group_id) + await self._store.insert(key, val) + + async def decrement(self, dispatch_id: str, task_group_id: int): + key = _pending_parents_key(dispatch_id, task_group_id) + return await self._store.increment(key, -1) + + async def remove(self, dispatch_id: str, task_group_id: int): + key = _pending_parents_key(dispatch_id, task_group_id) + await self._store.remove(key) + + +class _SortedTaskGroups: + def __init__(self, store: _KeyValueBase = _DictStore()): + self._store = store + + async def get_task_group(self, dispatch_id: str, task_group_id: int): + key = _task_groups_key(dispatch_id, task_group_id) + return await self._store.get(key) + + async def set_task_group(self, dispatch_id: str, task_group_id: int, sorted_nodes: list): + key = _task_groups_key(dispatch_id, task_group_id) + await self._store.insert(key, sorted_nodes) + + async def remove(self, dispatch_id: str, task_group_id: int): + key = _task_groups_key(dispatch_id, task_group_id) + await self._store.remove(key) + + +_pending_parents = _PendingParentsCache() +_unresolved_tasks = _UnresolvedTasksCache() +_sorted_task_groups = _SortedTaskGroups() diff --git a/covalent_dispatcher/_core/dispatcher_modules/store.py b/covalent_dispatcher/_core/dispatcher_modules/store.py new file mode 100644 index 000000000..a27fe5872 --- /dev/null +++ b/covalent_dispatcher/_core/dispatcher_modules/store.py @@ -0,0 +1,66 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Simple Key-Value store base +""" + + +class _KeyValueBase: + async def get(self, key): + raise NotImplementedError + + async def insert(self, key, val): + raise NotImplementedError + + async def belongs(self, key): + raise NotImplementedError + + async def remove(self, key): + raise NotImplementedError + + async def increment(self, key: str, delta: int) -> int: + """Increments value for `key` by amount `delta` + + Parameters: + key: the value to change + delta: the amount to change (can be negative) + Returns: + The new value + """ + + raise NotImplementedError + + +class _DictStore(_KeyValueBase): + def __init__(self): + self._store = {} + + async def get(self, key): + return self._store[key] + + async def insert(self, key, val): + self._store[key] = val + + async def belongs(self, key): + return key in self._store + + async def remove(self, key): + del self._store[key] + + async def increment(self, key, delta: int): + self._store[key] += delta + return self._store[key] diff --git a/covalent_dispatcher/_core/execution.py b/covalent_dispatcher/_core/execution.py index 9a0b8721d..5de943fab 100644 --- a/covalent_dispatcher/_core/execution.py +++ b/covalent_dispatcher/_core/execution.py @@ -18,12 +18,13 @@ Defines the core functionality of the dispatcher """ + from covalent._results_manager import Result -from . import dispatcher, runner +from . import runner -def _get_task_inputs(node_id: int, node_name: str, result_object: Result) -> dict: +async def _get_task_inputs(node_id: int, node_name: str, result_object: Result) -> dict: """ Return the required inputs for a task execution. This makes sure that any node with child nodes isn't executed twice and fetches the @@ -40,8 +41,24 @@ def _get_task_inputs(node_id: int, node_name: str, result_object: Result) -> dic and any parent node execution results if present. """ - abstract_inputs = dispatcher._get_abstract_task_inputs(node_id, node_name, result_object) - input_values = runner._get_task_input_values(result_object, abstract_inputs) + abstract_inputs = {"args": [], "kwargs": {}} + + for parent in result_object.lattice.transport_graph.get_dependencies(node_id): + edge_data = result_object.lattice.transport_graph.get_edge_data(parent, node_id) + # value = result_object.lattice.transport_graph.get_node_value(parent, "output") + + for e_key, d in edge_data.items(): + if not d.get("wait_for"): + if d["param_type"] == "arg": + abstract_inputs["args"].append((parent, d["arg_index"])) + elif d["param_type"] == "kwarg": + key = d["edge_name"] + abstract_inputs["kwargs"][key] = parent + + sorted_args = sorted(abstract_inputs["args"], key=lambda x: x[1]) + abstract_inputs["args"] = [x[0] for x in sorted_args] + + input_values = await runner._get_task_input_values(result_object.dispatch_id, abstract_inputs) abstract_args = abstract_inputs["args"] abstract_kwargs = abstract_inputs["kwargs"] diff --git a/covalent_dispatcher/_core/runner.py b/covalent_dispatcher/_core/runner.py index 192317027..c945dfa8d 100644 --- a/covalent_dispatcher/_core/runner.py +++ b/covalent_dispatcher/_core/runner.py @@ -20,137 +20,62 @@ import asyncio import importlib -import json import traceback -from concurrent.futures import ThreadPoolExecutor from datetime import datetime, timezone from functools import partial -from typing import Any, Dict, List, Literal, Tuple, Union +from typing import Any, Dict, List, Tuple -from covalent._results_manager import Result from covalent._shared_files import logger from covalent._shared_files.config import get_config from covalent._shared_files.util_classes import RESULT_STATUS from covalent._workflow import DepsBash, DepsCall, DepsPip from covalent._workflow.transport import TransportableObject -from covalent.executor import _executor_manager -from covalent.executor.base import AsyncBaseExecutor, wrapper_fn +from covalent.executor.base import wrapper_fn from covalent.executor.utils import set_context from . import data_manager as datasvc -from .data_modules.job_manager import get_jobs_metadata, set_cancel_result from .runner_modules import executor_proxy +from .runner_modules.utils import get_executor app_log = logger.app_log log_stack_info = logger.log_stack_info debug_mode = get_config("sdk.log_level") == "debug" -_cancel_threadpool = ThreadPoolExecutor() - - -# Domain: runner -def get_executor( - executor: Union[Tuple, List], - loop: asyncio.BaseEventLoop = None, - cancel_pool: ThreadPoolExecutor = None, -) -> AsyncBaseExecutor: - """Get unpacked and initialized executor object. - - Args: - executor: Tuple containing short name and object dictionary for the executor. - loop: Running event loop. Defaults to None. - cancel_pool: Threadpool for cancelling tasks. Defaults to None. - - Returns: - Executor object. - - """ - short_name, object_dict = executor - executor = _executor_manager.get_executor(short_name) - executor.from_dict(object_dict) - executor._init_runtime(loop=loop, cancel_pool=cancel_pool) - - return executor - # Domain: runner # to be called by _run_abstract_task -def _get_task_input_values(result_object: Result, abs_task_inputs: dict) -> dict: - """ - Retrieve the input values from the result_object for the task - - Arg(s) - result_object: Result object of the workflow - abs_task_inputs: Task inputs dictionary - - Return(s) - node_values: Dictionary of task inputs - - """ +async def _get_task_input_values(dispatch_id: str, abs_task_inputs: dict) -> dict: node_values = {} args = abs_task_inputs["args"] for node_id in args: - value = result_object.lattice.transport_graph.get_node_value(node_id, "output") + value = (await datasvc.electron.get(dispatch_id, node_id, ["output"]))["output"] node_values[node_id] = value kwargs = abs_task_inputs["kwargs"] - for _, node_id in kwargs.items(): - value = result_object.lattice.transport_graph.get_node_value(node_id, "output") + for key, node_id in kwargs.items(): + value = (await datasvc.electron.get(dispatch_id, node_id, ["output"]))["output"] node_values[node_id] = value return node_values -# Domain: runner -async def run_abstract_task( - dispatch_id: str, - node_id: int, - node_name: str, - abstract_inputs: Dict, - executor: Any, -) -> None: - node_result = await _run_abstract_task( - dispatch_id=dispatch_id, - node_id=node_id, - node_name=node_name, - abstract_inputs=abstract_inputs, - executor=executor, - ) - - result_object = datasvc.get_result_object(dispatch_id) - await datasvc.update_node_result(result_object, node_result) - - # Domain: runner async def _run_abstract_task( dispatch_id: str, node_id: int, node_name: str, abstract_inputs: Dict, - executor: Any, + selected_executor: Any, ) -> None: # Resolve abstract task and inputs to their concrete (serialized) values - result_object = datasvc.get_result_object(dispatch_id) timestamp = datetime.now(timezone.utc) try: - cancel_req = await executor_proxy._get_cancel_requested(dispatch_id, node_id) - if cancel_req: - app_log.debug(f"Don't run cancelled task {dispatch_id}:{node_id}") - return datasvc.generate_node_result( - dispatch_id=dispatch_id, - node_id=node_id, - node_name=node_name, - start_time=timestamp, - end_time=timestamp, - status=RESULT_STATUS.CANCELLED, - ) - - serialized_callable = result_object.lattice.transport_graph.get_node_value( - node_id, "function" - ) + serialized_callable = (await datasvc.electron.get(dispatch_id, node_id, ["function"]))[ + "function" + ] - input_values = _get_task_input_values(result_object, abstract_inputs) + input_values = await _get_task_input_values(dispatch_id, abstract_inputs) abstract_args = abstract_inputs["args"] abstract_kwargs = abstract_inputs["kwargs"] @@ -160,14 +85,14 @@ async def _run_abstract_task( task_input = {"args": args, "kwargs": kwargs} app_log.debug(f"Collecting deps for task {node_id}") - call_before, call_after = _gather_deps(result_object, node_id) + + call_before, call_after = await _gather_deps(dispatch_id, node_id) except Exception as ex: app_log.error(f"Exception when trying to resolve inputs or deps: {ex}") node_result = datasvc.generate_node_result( dispatch_id=dispatch_id, node_id=node_id, - node_name=node_name, start_time=timestamp, end_time=timestamp, status=RESULT_STATUS.FAILED, @@ -177,20 +102,20 @@ async def _run_abstract_task( node_result = datasvc.generate_node_result( dispatch_id=dispatch_id, - node_id=node_id, node_name=node_name, + node_id=node_id, start_time=timestamp, status=RESULT_STATUS.RUNNING, ) app_log.debug(f"7: Marking node {node_id} as running (_run_abstract_task)") - await datasvc.update_node_result(result_object, node_result) + await datasvc.update_node_result(dispatch_id, node_result) return await _run_task( - result_object=result_object, + dispatch_id=dispatch_id, node_id=node_id, serialized_callable=serialized_callable, - executor=executor, + selected_executor=selected_executor, node_name=node_name, call_before=call_before, call_after=call_after, @@ -200,11 +125,11 @@ async def _run_abstract_task( # Domain: runner async def _run_task( - result_object: Result, + dispatch_id: str, node_id: int, inputs: Dict, serialized_callable: Any, - executor: Any, + selected_executor: Any, call_before: List, call_after: List, node_name: str, @@ -219,20 +144,24 @@ async def _run_task( Args: inputs: Inputs for the task. - result_object: Result object being used for current dispatch node_id: Node id of the task to be executed. Returns: None - """ - dispatch_id = result_object.dispatch_id - results_dir = result_object.results_dir + + dispatch_info = await datasvc.dispatch.get(dispatch_id, ["results_dir"]) + results_dir = dispatch_info["results_dir"] # Instantiate the executor from JSON try: - executor = get_executor(executor=executor, loop=asyncio.get_running_loop()) - + app_log.debug(f"Instantiating executor for {dispatch_id}:{node_id}") + executor = get_executor( + node_id=node_id, + selected_executor=selected_executor, + loop=asyncio.get_running_loop(), + pool=None, + ) except Exception as ex: tb = "".join(traceback.TracebackException.from_exception(ex).format()) app_log.debug("Exception when trying to instantiate executor:") @@ -241,14 +170,13 @@ async def _run_task( node_result = datasvc.generate_node_result( dispatch_id=dispatch_id, node_id=node_id, - node_name=node_name, end_time=datetime.now(timezone.utc), status=RESULT_STATUS.FAILED, error=error_msg, ) return node_result - # Run the task on the executor and register any failures. + # run the task on the executor and register any failures try: app_log.debug(f"Executing task {node_name}") @@ -271,8 +199,17 @@ def qelectron_compatible_wrapper(node_id, dispatch_id, ser_user_fn, *args, **kwa ) assembled_callable = partial(wrapper_fn, serialized_callable, call_before, call_after) + execute_callable = partial( + executor.execute, + function=assembled_callable, + args=inputs["args"], + kwargs=inputs["kwargs"], + dispatch_id=dispatch_id, + results_dir=results_dir, + node_id=node_id, + ) - # Note: Executor proxy monitors the executors instances and watches the send and receive queues of the executor. + # Start listening for messages from the plugin asyncio.create_task(executor_proxy.watch(dispatch_id, node_id, executor)) output, stdout, stderr, status = await executor._execute( @@ -308,23 +245,24 @@ def qelectron_compatible_wrapper(node_id, dispatch_id, ser_user_fn, *args, **kwa status=RESULT_STATUS.FAILED, error=error_msg, ) + app_log.debug(f"Node result: {node_result}") return node_result # Domain: runner -def _gather_deps(result_object: Result, node_id: int) -> Tuple[List, List]: +async def _gather_deps(dispatch_id: str, node_id: int) -> Tuple[List, List]: """Assemble deps for a node into the final call_before and call_after""" - deps = result_object.lattice.transport_graph.get_node_value(node_id, "metadata")["deps"] + deps_attrs = await datasvc.electron.get( + dispatch_id, node_id, ["deps", "call_before", "call_after"] + ) + + deps = deps_attrs["deps"] # Assemble call_before and call_after from all the deps - call_before_objs_json = result_object.lattice.transport_graph.get_node_value( - node_id, "metadata" - )["call_before"] - call_after_objs_json = result_object.lattice.transport_graph.get_node_value( - node_id, "metadata" - )["call_after"] + call_before_objs_json = deps_attrs["call_before"] + call_after_objs_json = deps_attrs["call_after"] call_before = [] call_after = [] @@ -353,97 +291,19 @@ def _gather_deps(result_object: Result, node_id: int) -> Tuple[List, List]: return call_before, call_after -async def _cancel_task( - dispatch_id: str, task_id: int, executor, executor_data: Dict, job_handle: str -) -> Union[Any, Literal[False]]: - """ - Cancel the task currently being executed by the executor - - Arg(s) - dispatch_id: Dispatch ID - task_id: Task ID of the electron in transport graph to be cancelled - executor: Covalent executor currently being used to execute the task - executor_data: Executor configuration arguments - job_handle: Unique identifier assigned to the task by the backend running the job - - Return(s) - cancel_job_result: Status of the job cancellation action - - """ - app_log.debug(f"Cancel task {task_id} using executor {executor}, {executor_data}") - app_log.debug(f"job_handle: {job_handle}") - - try: - executor = get_executor( - executor=executor, loop=asyncio.get_running_loop(), cancel_pool=_cancel_threadpool - ) - task_metadata = {"dispatch_id": dispatch_id, "node_id": task_id} - cancel_job_result = await executor._cancel(task_metadata, json.loads(job_handle)) - - except Exception as ex: - app_log.debug(f"Exception when cancel task {dispatch_id}:{task_id}: {ex}") - cancel_job_result = False - - await set_cancel_result(dispatch_id, task_id, cancel_job_result) - return cancel_job_result - - -def to_cancel_kwargs( - index: int, node_id: int, node_metadata: List[dict], job_metadata: List[dict] -) -> dict: - """ - Convert node_metadata for a given node `node_id` into a dictionary - - Arg(s) - index: Index into the node_metadata list - node_id: Node ID - node_metadata: List of node metadata attributes - job_metadata: List of metadata for the current job - - Return(s) - Node metadata dictionary - """ - return { - "task_id": node_id, - "executor": node_metadata[index]["executor"], - "executor_data": node_metadata[index]["executor_data"], - "job_handle": job_metadata[index]["job_handle"], - } - - -async def cancel_tasks(dispatch_id: str, task_ids: List[int]) -> None: - """ - Request all tasks with `task_ids` to be cancelled in the workflow identified by `dispatch_id` - - Arg(s) - dispatch_id: Dispatch ID of the workflow - task_ids: List of task ids to be cancelled - - Return(s) - None - """ - job_metadata = await get_jobs_metadata(dispatch_id, task_ids) - node_metadata = _get_metadata_for_nodes(dispatch_id, task_ids) - - cancel_task_kwargs = [ - to_cancel_kwargs(i, x, node_metadata, job_metadata) for i, x in enumerate(task_ids) - ] - - for kwargs in cancel_task_kwargs: - asyncio.create_task(_cancel_task(dispatch_id, **kwargs)) - - -def _get_metadata_for_nodes(dispatch_id: str, node_ids: list) -> List[Any]: - """ - Returns all the metadata associated with the node(s) for the workflow identified by `dispatch_id` - - Arg(s) - dispatch_id: Dispatch ID of the workflow - node_ids: List of node ids from the workflow to retrieve the metadata for - - Return(s) - List of node metadata for the given `node_ids` - """ - res = datasvc.get_result_object(dispatch_id) - tg = res.lattice.transport_graph - return list(map(lambda x: tg.get_node_value(x, "metadata"), node_ids)) +# Domain: runner +async def run_abstract_task( + dispatch_id: str, + node_id: int, + node_name: str, + abstract_inputs: Dict, + selected_executor: Any, +) -> None: + node_result = await _run_abstract_task( + dispatch_id=dispatch_id, + node_id=node_id, + node_name=node_name, + abstract_inputs=abstract_inputs, + selected_executor=selected_executor, + ) + await datasvc.update_node_result(dispatch_id, node_result) diff --git a/covalent_dispatcher/_core/runner_modules/cancel.py b/covalent_dispatcher/_core/runner_modules/cancel.py new file mode 100644 index 000000000..5968cb4aa --- /dev/null +++ b/covalent_dispatcher/_core/runner_modules/cancel.py @@ -0,0 +1,146 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Functions for cancelling jobs +""" + +import asyncio +import json +from concurrent.futures import ThreadPoolExecutor +from typing import Any, List + +from covalent._shared_files import logger +from covalent._shared_files.util_classes import RESULT_STATUS + +from .. import data_manager as datasvc +from ..data_modules import job_manager +from .utils import get_executor + +app_log = logger.app_log + +# Dedicated thread pool for invoking non-async Executor.cancel() +_cancel_threadpool = ThreadPoolExecutor() + +# Collects asyncio task futures +_background_tasks = set() + + +async def _cancel_task( + dispatch_id: str, task_id: int, selected_executor: List, job_handle: str +) -> None: + """ + Cancel the task currently being executed by the executor + + Arg(s) + dispatch_id: Dispatch ID + task_id: Task ID of the electron in transport graph to be cancelled + executor: Covalent executor currently being used to execute the task + executor_data: Executor configuration arguments + job_handle: Unique identifier assigned to the task by the backend running the job + + Return(s) + cancel_job_result: Status of the job cancellation action + """ + app_log.debug(f"Cancel task {task_id} using executor {selected_executor}") + app_log.debug(f"job_handle: {job_handle}") + + try: + executor = get_executor( + node_id=task_id, + selected_executor=selected_executor, + loop=asyncio.get_running_loop(), + pool=_cancel_threadpool, + ) + + task_metadata = {"dispatch_id": dispatch_id, "node_id": task_id} + + cancel_job_result = await executor._cancel(task_metadata, json.loads(job_handle)) + except Exception as ex: + app_log.debug(f"Exception when cancel task {dispatch_id}:{task_id}: {ex}") + cancel_job_result = False + + if cancel_job_result is True: + await job_manager.set_job_status(dispatch_id, task_id, str(RESULT_STATUS.CANCELLED)) + app_log.debug(f"Cancelled task {dispatch_id}:{task_id}") + + +def _to_cancel_kwargs( + index: int, node_id: int, node_metadata: List[dict], job_metadata: List[dict] +) -> dict: + """ + Convert node_metadata for a given node `node_id` into a dictionary + + Arg(s) + index: Index into the node_metadata list + node_id: Node ID + node_metadata: List of node metadata attributes + job_metadata: List of metadata for the current job + + Return(s) + Node metadata dictionary + """ + selected_executor = [node_metadata[index]["executor"], node_metadata[index]["executor_data"]] + return { + "task_id": node_id, + "selected_executor": selected_executor, + "job_handle": job_metadata[index]["job_handle"], + } + + +async def cancel_tasks(dispatch_id: str, task_ids: List[int]) -> None: + """ + Request all tasks with `task_ids` to be cancelled in the workflow identified by `dispatch_id` + + Arg(s) + dispatch_id: Dispatch ID of the workflow + task_ids: List of task ids to be cancelled + + Return(s) + None + """ + job_metadata = await job_manager.get_jobs_metadata(dispatch_id, task_ids) + node_metadata = await _get_metadata_for_nodes(dispatch_id, task_ids) + app_log.debug(f"node metadata: {node_metadata}") + app_log.debug(f"job metadata: {job_metadata}") + cancel_task_kwargs = [ + _to_cancel_kwargs(i, x, node_metadata, job_metadata) for i, x in enumerate(task_ids) + ] + + for kwargs in cancel_task_kwargs: + fut = asyncio.create_task(_cancel_task(dispatch_id, **kwargs)) + _background_tasks.add(fut) + fut.add_done_callback(_background_tasks.discard) + + +async def _get_metadata_for_nodes(dispatch_id: str, node_ids: list) -> List[Any]: + """ + Returns all the metadata associated with the node(s) for the workflow identified by `dispatch_id` + + Arg(s) + dispatch_id: Dispatch ID of the workflow + node_ids: List of node ids from the workflow to retrive the metadata for + + Return(s) + List of node metadata for the given `node_ids` + """ + + attrs = await datasvc.electron.get_bulk( + dispatch_id, + node_ids, + ["executor", "executor_data"], + ) + return attrs diff --git a/covalent_dispatcher/_core/runner_modules/executor_proxy.py b/covalent_dispatcher/_core/runner_modules/executor_proxy.py index 448b7c69b..c143066dd 100644 --- a/covalent_dispatcher/_core/runner_modules/executor_proxy.py +++ b/covalent_dispatcher/_core/runner_modules/executor_proxy.py @@ -14,7 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -""" Monitor executor instances.""" +""" Monitor executor instances """ from typing import Any @@ -23,7 +23,13 @@ from covalent.executor.base import _AbstractBaseExecutor as _ABE from covalent.executor.utils import Signals -from ..data_modules import job_manager +from .jobs import ( + get_cancel_requested, + get_job_status, + get_version_info, + put_job_handle, + put_job_status, +) app_log = logger.app_log log_stack_info = logger.log_stack_info @@ -32,50 +38,11 @@ _getters = {} -async def _get_cancel_requested(dispatch_id: str, task_id: int): - """ - Query the database for the task's cancellation status - - Arg(s) - dispatch_id: Dispatch ID of the lattice - task_id: ID of the task within the lattice - - Return(s) - Cancellation status of the task - - """ - # Don't hit the DB for post-processing task - if task_id < 0: - return False - - app_log.debug(f"Get _handle_requested for executor {dispatch_id}:{task_id}") - job_records = await job_manager.get_jobs_metadata(dispatch_id, [task_id]) - app_log.debug(f"Job record: {job_records[0]}") - return job_records[0]["cancel_requested"] - - -async def _put_job_handle(dispatch_id: str, task_id: int, job_handle: str) -> bool: - """ - Store the job handle of the task returned by the backend in the database - - Arg(s) - dispatch_id: Dispatch ID of the lattice - task_id: ID of the task within the lattice - job_handle: Unique identifier of the task returned by the execution backend - - Return(s) - True - """ - # Don't hit the DB for post-processing task - if task_id < 0: - return False - app_log.debug(f"Put job_handle for executor {dispatch_id}:{task_id}") - await job_manager.set_job_handle(dispatch_id, task_id, job_handle) - return True - - -_putters["job_handle"] = _put_job_handle -_getters["cancel_requested"] = _get_cancel_requested +_putters["job_handle"] = put_job_handle +_putters["job_status"] = put_job_status +_getters["cancel_requested"] = get_cancel_requested +_getters["job_status"] = get_job_status +_getters["version_info"] = get_version_info async def _handle_message( diff --git a/covalent_dispatcher/_core/runner_modules/jobs.py b/covalent_dispatcher/_core/runner_modules/jobs.py new file mode 100644 index 000000000..55d79eead --- /dev/null +++ b/covalent_dispatcher/_core/runner_modules/jobs.py @@ -0,0 +1,126 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" Handlers for the executor proxy """ + + +from covalent._shared_files import logger +from covalent._shared_files.util_classes import Status +from covalent_dispatcher._core.data_modules import lattice as lattice_query_module + +from .. import data_manager as datasvc +from ..data_modules import job_manager +from .utils import get_executor + +app_log = logger.app_log +log_stack_info = logger.log_stack_info + + +async def get_cancel_requested(dispatch_id: str, task_id: int): + """ + Query the database for the task's cancellation status + + Arg(s) + dispatch_id: Dispatch ID of the lattice + task_id: ID of the task within the lattice + + Return(s) + Canellation status of the task + """ + + app_log.debug(f"Get _handle_requested for task {dispatch_id}:{task_id}") + job_records = await job_manager.get_jobs_metadata(dispatch_id, [task_id]) + app_log.debug(f"Job record: {job_records[0]}") + return job_records[0]["cancel_requested"] + + +async def get_version_info(dispatch_id: str, task_id: int): + """ + Query the database for the dispatch version information + + Arg: + dispatch_id: Dispatch ID of the lattice + task_id: ID of the task within the lattice + + Returns: + {"python": python_version, "covalent": covalent_version} + """ + + data = await lattice_query_module.get(dispatch_id, ["python_version", "covalent_version"]) + + return { + "python": data["python_version"], + "covalent": data["covalent_version"], + } + + +async def get_job_status(dispatch_id: str, task_id: int) -> Status: + """ + Queries the job state for (dispatch_id, task_id) + + Arg(s) + dispatch_id: Dispatch ID of the lattice + task_id: ID of the task within the lattice + + Return(s) + Status + """ + app_log.debug(f"Get for task {dispatch_id}:{task_id}") + job_records = await job_manager.get_jobs_metadata(dispatch_id, [task_id]) + app_log.debug(f"Job record: {job_records[0]}") + return Status(job_records[0]["status"]) + + +async def put_job_handle(dispatch_id: str, task_id: int, job_handle: str) -> bool: + """ + Store the job handle of the task returned by the backend in the database + + Arg(s) + dispatch_id: Dispatch ID of the lattice + task_id: ID of the task within the lattice + job_handle: Unique identifier of the task returned by the execution backend + + Return(s) + True + """ + app_log.debug(f"Put job_handle for executor {dispatch_id}:{task_id}") + await job_manager.set_job_handle(dispatch_id, task_id, job_handle) + return True + + +async def put_job_status(dispatch_id: str, task_id: int, status: Status) -> bool: + """ + Mark the job for (dispatch_id, task_id) as cancelled + + Arg(s) + dispatch_id: Dispatch ID of the lattice + task_id: ID of the task within the lattice + job_status: A `Status` type representing the job status + + Return(s) + True + """ + app_log.debug(f"Put cancel result for task {dispatch_id}:{task_id}") + executor_attrs = await datasvc.electron.get( + dispatch_id, task_id, ["executor", "executor_data"] + ) + selected_executor = [executor_attrs["executor"], executor_attrs["executor_data"]] + executor = get_executor(task_id, selected_executor, None, None) + if executor.validate_status(status): + await job_manager.set_job_status(dispatch_id, task_id, str(status)) + return True + else: + return False diff --git a/covalent_dispatcher/_core/runner_modules/utils.py b/covalent_dispatcher/_core/runner_modules/utils.py new file mode 100644 index 000000000..8126d313f --- /dev/null +++ b/covalent_dispatcher/_core/runner_modules/utils.py @@ -0,0 +1,43 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Defines the core functionality of the runner +""" + +from covalent._shared_files import logger +from covalent._shared_files.config import get_config +from covalent.executor import _executor_manager +from covalent.executor.base import AsyncBaseExecutor + +app_log = logger.app_log +log_stack_info = logger.log_stack_info +debug_mode = get_config("sdk.log_level") == "debug" + + +def get_executor(node_id, selected_executor, loop=None, pool=None) -> AsyncBaseExecutor: + # Instantiate the executor from JSON + + short_name, object_dict = selected_executor + + app_log.debug(f"Running task {node_id} using executor {short_name}, {object_dict}") + + # the executor is determined during scheduling and provided in the execution metadata + executor = _executor_manager.get_executor(short_name) + executor.from_dict(object_dict) + executor._init_runtime(loop=loop, cancel_pool=pool) + + return executor diff --git a/covalent_dispatcher/_dal/__init__.py b/covalent_dispatcher/_dal/__init__.py index cfc23bfdf..21d7eaa5c 100644 --- a/covalent_dispatcher/_dal/__init__.py +++ b/covalent_dispatcher/_dal/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # diff --git a/covalent_dispatcher/_dal/db_interfaces/__init__.py b/covalent_dispatcher/_dal/db_interfaces/__init__.py index cfc23bfdf..21d7eaa5c 100644 --- a/covalent_dispatcher/_dal/db_interfaces/__init__.py +++ b/covalent_dispatcher/_dal/db_interfaces/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # diff --git a/covalent_dispatcher/_dal/exporters/__init__.py b/covalent_dispatcher/_dal/exporters/__init__.py index cfc23bfdf..21d7eaa5c 100644 --- a/covalent_dispatcher/_dal/exporters/__init__.py +++ b/covalent_dispatcher/_dal/exporters/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # diff --git a/covalent_dispatcher/_dal/exporters/electron.py b/covalent_dispatcher/_dal/exporters/electron.py index af2bec692..d2d3b0780 100644 --- a/covalent_dispatcher/_dal/exporters/electron.py +++ b/covalent_dispatcher/_dal/exporters/electron.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # @@ -25,8 +25,7 @@ ElectronMetadata, ElectronSchema, ) - -from ..electron import ASSET_KEYS, Electron +from covalent_dispatcher._dal.electron import ASSET_KEYS, Electron app_log = logger.app_log diff --git a/covalent_dispatcher/_dal/exporters/lattice.py b/covalent_dispatcher/_dal/exporters/lattice.py index ac75cfe3f..4b630fcd1 100644 --- a/covalent_dispatcher/_dal/exporters/lattice.py +++ b/covalent_dispatcher/_dal/exporters/lattice.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # @@ -20,16 +20,13 @@ from covalent._shared_files.schemas.asset import AssetSchema from covalent._shared_files.schemas.lattice import LatticeAssets, LatticeMetadata, LatticeSchema +from covalent_dispatcher._dal.lattice import ASSET_KEYS, METADATA_KEYS, Lattice -from ..lattice import ASSET_KEYS, METADATA_KEYS, Lattice from .tg import export_transport_graph def _export_lattice_meta(lat: Lattice) -> LatticeMetadata: - metadata_kwargs = {} - for key in METADATA_KEYS: - metadata_kwargs[key] = lat.get_value(key, None, refresh=False) - + metadata_kwargs = {key: lat.get_value(key, None, refresh=False) for key in METADATA_KEYS} return LatticeMetadata(**metadata_kwargs) diff --git a/covalent_dispatcher/_dal/exporters/result.py b/covalent_dispatcher/_dal/exporters/result.py index 095aa6254..dfc21bca5 100644 --- a/covalent_dispatcher/_dal/exporters/result.py +++ b/covalent_dispatcher/_dal/exporters/result.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # @@ -29,9 +29,9 @@ ResultSchema, ) from covalent._shared_files.utils import format_server_url +from covalent_dispatcher._dal.electron import Electron +from covalent_dispatcher._dal.result import Result, get_result_object -from ..electron import Electron -from ..result import Result, get_result_object from ..utils.uri_filters import AssetScope, URIFilterPolicy, filter_asset_uri from .lattice import export_lattice @@ -44,12 +44,11 @@ # res is assumed to represent a full db record def _export_result_meta(res: Result) -> ResultMetadata: - metadata_kwargs = {} - for key in METADATA_KEYS: - if key in METADATA_KEYS_TO_OMIT: - continue - metadata_kwargs[key] = res.get_metadata(key, None, refresh=False) - + metadata_kwargs = { + key: res.get_metadata(key, None, refresh=False) + for key in METADATA_KEYS + if key not in METADATA_KEYS_TO_OMIT + } return ResultMetadata(**metadata_kwargs) diff --git a/covalent_dispatcher/_dal/exporters/tg.py b/covalent_dispatcher/_dal/exporters/tg.py index 8aa44d187..92e856d30 100644 --- a/covalent_dispatcher/_dal/exporters/tg.py +++ b/covalent_dispatcher/_dal/exporters/tg.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # @@ -23,8 +23,8 @@ from covalent._shared_files.schemas.edge import EdgeMetadata, EdgeSchema from covalent._shared_files.schemas.electron import ElectronSchema from covalent._shared_files.schemas.transport_graph import TransportGraphSchema +from covalent_dispatcher._dal.tg import _TransportGraph -from ..tg import _TransportGraph from .electron import export_electron app_log = logger.app_log @@ -34,10 +34,7 @@ def _export_nodes(tg: _TransportGraph) -> List[ElectronSchema]: g = tg.get_internal_graph_copy() internal_nodes = tg.get_nodes(list(g.nodes), None) - export_nodes = [] - for e in internal_nodes: - export_nodes.append(export_electron(e)) - + export_nodes = [export_electron(e) for e in internal_nodes] return export_nodes diff --git a/covalent_dispatcher/_dal/importers/__init__.py b/covalent_dispatcher/_dal/importers/__init__.py index cfc23bfdf..21d7eaa5c 100644 --- a/covalent_dispatcher/_dal/importers/__init__.py +++ b/covalent_dispatcher/_dal/importers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # diff --git a/covalent_dispatcher/_dal/importers/electron.py b/covalent_dispatcher/_dal/importers/electron.py index fca4230cd..e429231c7 100644 --- a/covalent_dispatcher/_dal/importers/electron.py +++ b/covalent_dispatcher/_dal/importers/electron.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # @@ -40,13 +40,12 @@ ElectronAssets, ElectronSchema, ) - -from ..._db import models -from ..._db.write_result_to_db import get_electron_type -from ..._object_store.base import BaseProvider -from ..asset import Asset -from ..electron import ElectronMeta -from ..lattice import Lattice +from covalent_dispatcher._dal.asset import Asset +from covalent_dispatcher._dal.electron import ElectronMeta +from covalent_dispatcher._dal.lattice import Lattice +from covalent_dispatcher._db import models +from covalent_dispatcher._db.write_result_to_db import get_electron_type +from covalent_dispatcher._object_store.base import BaseProvider app_log = logger.app_log diff --git a/covalent_dispatcher/_dal/importers/lattice.py b/covalent_dispatcher/_dal/importers/lattice.py index 2282ba9c9..7c5870b30 100644 --- a/covalent_dispatcher/_dal/importers/lattice.py +++ b/covalent_dispatcher/_dal/importers/lattice.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # @@ -41,10 +41,9 @@ LatticeAssets, LatticeSchema, ) - -from ..._object_store.local import BaseProvider -from ..asset import Asset -from ..lattice import Lattice +from covalent_dispatcher._dal.asset import Asset +from covalent_dispatcher._dal.lattice import Lattice +from covalent_dispatcher._object_store.local import BaseProvider def _get_lattice_meta(lat: LatticeSchema, storage_path) -> dict: @@ -148,11 +147,9 @@ def import_lattice_assets( # Write asset records to DB session.flush() - # Link assets to lattice - lattice_asset_links = [] - for key, asset_rec in asset_ids.items(): - lattice_asset_links.append(record.associate_asset(session, key, asset_rec.id)) - + lattice_asset_links = [ + record.associate_asset(session, key, asset_rec.id) for key, asset_rec in asset_ids.items() + ] session.flush() return lat.assets diff --git a/covalent_dispatcher/_dal/importers/result.py b/covalent_dispatcher/_dal/importers/result.py index 1779e9cfc..395516b86 100644 --- a/covalent_dispatcher/_dal/importers/result.py +++ b/covalent_dispatcher/_dal/importers/result.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # @@ -28,12 +28,13 @@ from covalent._shared_files.schemas.lattice import LatticeSchema from covalent._shared_files.schemas.result import ResultAssets, ResultSchema from covalent._shared_files.utils import format_server_url +from covalent_dispatcher._dal.asset import Asset +from covalent_dispatcher._dal.electron import ElectronMeta +from covalent_dispatcher._dal.job import Job +from covalent_dispatcher._dal.result import Result, ResultMeta +from covalent_dispatcher._object_store.local import BaseProvider, local_store -from ..._object_store.local import BaseProvider, local_store -from ..asset import Asset, copy_asset_meta -from ..electron import ElectronMeta -from ..job import Job -from ..result import Result, ResultMeta +from ..asset import copy_asset_meta from ..tg_ops import TransportGraphOps from ..utils.uri_filters import AssetScope, URIFilterPolicy, filter_asset_uri from .lattice import _get_lattice_meta, import_lattice_assets @@ -374,4 +375,9 @@ def handle_redispatch( src, dest = item copy_asset_meta(session, src, dest) + # Copy asset data + # for item in assets_to_copy: + # src, dest = item + # copy_asset(src, dest) + return manifest, assets_to_copy diff --git a/covalent_dispatcher/_dal/importers/tg.py b/covalent_dispatcher/_dal/importers/tg.py index 0da92164a..280399658 100644 --- a/covalent_dispatcher/_dal/importers/tg.py +++ b/covalent_dispatcher/_dal/importers/tg.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # @@ -25,13 +25,13 @@ from covalent._shared_files import logger from covalent._shared_files.schemas.edge import EdgeSchema from covalent._shared_files.schemas.transport_graph import TransportGraphSchema +from covalent_dispatcher._dal.edge import ElectronDependency +from covalent_dispatcher._dal.electron import Electron +from covalent_dispatcher._dal.job import Job +from covalent_dispatcher._dal.lattice import Lattice +from covalent_dispatcher._db import models +from covalent_dispatcher._object_store.base import BaseProvider -from ..._db import models -from ..._object_store.base import BaseProvider -from ..edge import ElectronDependency -from ..electron import Electron -from ..job import Job -from ..lattice import Lattice from .electron import import_electron app_log = logger.app_log @@ -69,7 +69,7 @@ def import_transport_graph( # Maps node ids to asset record dictionaries electron_asset_links = {} - for gid, node_group in task_groups.items(): + for gid in task_groups: # Create a job record for each task group job_kwargs = { "cancel_requested": cancel_requested, @@ -106,7 +106,7 @@ def import_transport_graph( app_log.debug(f"Inserting {n_records} electron records took {delta} seconds") n_records = 0 - for _, asset_records_by_key in electron_asset_links.items(): + for asset_records_by_key in electron_asset_links.values(): n_records += len(asset_records_by_key) st = datetime.now() @@ -118,11 +118,10 @@ def import_transport_graph( meta_asset_associations = [] for node_id, asset_records in electron_asset_links.items(): electron_dal = Electron(session, electron_map[node_id]) - for key, asset_rec in asset_records.items(): - meta_asset_associations.append( - electron_dal.associate_asset(session, key, asset_rec.id) - ) - + meta_asset_associations.extend( + electron_dal.associate_asset(session, key, asset_rec.id) + for key, asset_rec in asset_records.items() + ) n_records = len(meta_asset_associations) st = datetime.now() diff --git a/covalent_dispatcher/_dal/tg_ops.py b/covalent_dispatcher/_dal/tg_ops.py index 1f3113b19..79f87fa4c 100644 --- a/covalent_dispatcher/_dal/tg_ops.py +++ b/covalent_dispatcher/_dal/tg_ops.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # @@ -213,7 +213,7 @@ def _max_cbms( self._flag_successors(A, A_node_status, y) continue - if y in B.adj[current_node] and B_node_status[y] == -1: + if B_node_status[y] == -1: app_log.debug(f"A: Node {y} is marked as failed in B") self._flag_successors(A, A_node_status, y) continue @@ -251,7 +251,7 @@ def _max_cbms( app_log.debug(f"B: {y} not adjacent to node {current_node} in A") self._flag_successors(B, B_node_status, y) continue - if y in A.adj[current_node] and B_node_status[y] == -1: + if B_node_status[y] == -1: app_log.debug(f"B: Node {y} is marked as failed in A") self._flag_successors(B, B_node_status, y) diff --git a/covalent_dispatcher/_dal/utils/__init__.py b/covalent_dispatcher/_dal/utils/__init__.py index cfc23bfdf..21d7eaa5c 100644 --- a/covalent_dispatcher/_dal/utils/__init__.py +++ b/covalent_dispatcher/_dal/utils/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # diff --git a/covalent_dispatcher/_dal/utils/uri_filters.py b/covalent_dispatcher/_dal/utils/uri_filters.py index 36d0ef044..4d9afdbb8 100644 --- a/covalent_dispatcher/_dal/utils/uri_filters.py +++ b/covalent_dispatcher/_dal/utils/uri_filters.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # @@ -42,13 +42,14 @@ class URIFilterPolicy(enum.Enum): def _srv_asset_uri( uri: str, attrs: dict, scope: AssetScope, dispatch_id: str, node_id: Optional[int], key: str ) -> str: - base_uri = SERVER_URL + f"/api/v1/assets/{dispatch_id}/{scope.value}" + base_uri = f"{SERVER_URL}/api/v2/dispatches/{dispatch_id}" - if scope == AssetScope.DISPATCH or scope == AssetScope.LATTICE: - uri = base_uri + f"/{key}" + if scope == AssetScope.DISPATCH: + return f"{base_uri}/assets/{key}" + elif scope == AssetScope.LATTICE: + return f"{base_uri}/lattice/assets/{key}" else: - uri = base_uri + f"/{node_id}/{key}" - return uri + return f"{base_uri}/electrons/{node_id}/assets/{key}" def _raw( diff --git a/covalent_dispatcher/_db/__init__.py b/covalent_dispatcher/_db/__init__.py index 21d7eaa5c..ab6c0fedf 100644 --- a/covalent_dispatcher/_db/__init__.py +++ b/covalent_dispatcher/_db/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Agnostiq Inc. +# Copyright 2022 Agnostiq Inc. # # This file is part of Covalent. # diff --git a/covalent_dispatcher/_db/load.py b/covalent_dispatcher/_db/load.py deleted file mode 100644 index e610a2fba..000000000 --- a/covalent_dispatcher/_db/load.py +++ /dev/null @@ -1,223 +0,0 @@ -# Copyright 2023 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the Apache License 2.0 (the "License"). A copy of the -# License may be obtained with this software package or at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Use of this file is prohibited except in compliance with the License. -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Functions to load results from the database.""" - - -from typing import Dict, Union - -from covalent import lattice -from covalent._results_manager.result import Result -from covalent._shared_files import logger -from covalent._shared_files.util_classes import Status -from covalent._workflow.transport import TransportableObject -from covalent._workflow.transport import _TransportGraph as SDKGraph - -from .._dal.electron import ASSET_KEYS as ELECTRON_ASSETS -from .._dal.electron import METADATA_KEYS as ELECTRON_META -from .._dal.result import get_result_object -from .._dal.tg import _TransportGraph as SRVGraph -from .._object_store.local import local_store -from .datastore import workflow_db -from .models import Electron, Lattice - -app_log = logger.app_log -log_stack_info = logger.log_stack_info - -NODE_ATTRIBUTES = ELECTRON_META.union(ELECTRON_ASSETS) -SDK_NODE_META_KEYS = { - "executor", - "executor_data", - "deps", - "call_before", - "call_after", -} - - -def load_file(storage_path, filename): - return local_store.load_file(storage_path, filename) - - -def _to_client_graph(srv_graph: SRVGraph) -> SDKGraph: - """Render a SDK _TransportGraph from a server-side graph""" - - sdk_graph = SDKGraph() - - sdk_graph._graph = srv_graph.get_internal_graph_copy() - for node_id in srv_graph._graph.nodes: - attrs = list(sdk_graph._graph.nodes[node_id].keys()) - for k in attrs: - del sdk_graph._graph.nodes[node_id][k] - attributes = {} - for k in NODE_ATTRIBUTES: - if k not in SDK_NODE_META_KEYS: - attributes[k] = srv_graph.get_node_value(node_id, k) - if srv_graph.get_node_value(node_id, "type") == "parameter": - attributes["value"] = srv_graph.get_node_value(node_id, "value") - attributes["output"] = srv_graph.get_node_value(node_id, "output") - - node_meta = {k: srv_graph.get_node_value(node_id, k) for k in SDK_NODE_META_KEYS} - attributes["metadata"] = node_meta - - for k, v in attributes.items(): - sdk_graph.set_node_value(node_id, k, v) - - sdk_graph.lattice_metadata = {} - - return sdk_graph - - -def _result_from(lattice_record: Lattice) -> Result: - """Re-hydrate result object from the lattice record. - - Args: - lattice_record: Lattice record to re-hydrate from. - - Returns: - Result object. - - """ - - srv_res = get_result_object(lattice_record.dispatch_id, bare=False) - - function = srv_res.lattice.get_value("workflow_function") - - function_string = srv_res.lattice.get_value("workflow_function_string") - function_docstring = srv_res.lattice.get_value("doc") - - executor_data = srv_res.lattice.get_value("executor_data") - - workflow_executor_data = srv_res.lattice.get_value("workflow_executor_data") - - inputs = srv_res.lattice.get_value("inputs") - named_args = srv_res.lattice.get_value("named_args") - named_kwargs = srv_res.lattice.get_value("named_kwargs") - error = srv_res.get_value("error") - - transport_graph = _to_client_graph(srv_res.lattice.transport_graph) - - output = srv_res.get_value("result") - deps = srv_res.lattice.get_value("deps") - call_before = srv_res.lattice.get_value("call_before") - call_after = srv_res.lattice.get_value("call_after") - cova_imports = srv_res.lattice.get_value("cova_imports") - lattice_imports = srv_res.lattice.get_value("lattice_imports") - - name = lattice_record.name - executor = lattice_record.executor - workflow_executor = lattice_record.workflow_executor - num_nodes = lattice_record.electron_num - - attributes = { - "workflow_function": function, - "workflow_function_string": function_string, - "__name__": name, - "__doc__": function_docstring, - "metadata": { - "executor": executor, - "executor_data": executor_data, - "workflow_executor": workflow_executor, - "workflow_executor_data": workflow_executor_data, - "deps": deps, - "call_before": call_before, - "call_after": call_after, - }, - "inputs": inputs, - "named_args": named_args, - "named_kwargs": named_kwargs, - "transport_graph": transport_graph, - "cova_imports": cova_imports, - "lattice_imports": lattice_imports, - "post_processing": False, - "electron_outputs": {}, - "_bound_electrons": {}, - } - - def dummy_function(x): - return x - - lat = lattice(dummy_function) - lat.__dict__ = attributes - - result = Result( - lat, - dispatch_id=lattice_record.dispatch_id, - ) - result._root_dispatch_id = lattice_record.root_dispatch_id - result._status = Status(lattice_record.status) - result._error = error or "" - result._inputs = inputs - result._start_time = lattice_record.started_at - result._end_time = lattice_record.completed_at - result._result = output if output is not None else TransportableObject(None) - result._num_nodes = num_nodes - return result - - -def get_result_object_from_storage(dispatch_id: str) -> Result: - """Get the result object from the database. - - Args: - dispatch_id: The dispatch id of the result object to load. - - Returns: - The result object. - - """ - with workflow_db.session() as session: - lattice_record = session.query(Lattice).where(Lattice.dispatch_id == dispatch_id).first() - if not lattice_record: - app_log.debug(f"No result object found for dispatch {dispatch_id}") - raise RuntimeError(f"No result object found for dispatch {dispatch_id}") - - return _result_from(lattice_record) - - -def electron_record(dispatch_id: str, node_id: str) -> Dict: - """Get electron record for a given dispatch if and node id. - - Args: - dispatch_id: Dispatch id for lattice. - node_id: Node id of the electron. - - Returns: - Electron record. - - """ - with workflow_db.session() as session: - return ( - session.query(Lattice, Electron) - .filter(Lattice.id == Electron.parent_lattice_id) - .filter(Lattice.dispatch_id == dispatch_id) - .filter(Electron.transport_graph_node_id == node_id) - .first() - .Electron.__dict__ - ) - - -def sublattice_dispatch_id(electron_id: int) -> Union[str, None]: - """Get the dispatch id of the sublattice for a given electron id. - - Args: - electron_id: Electron ID. - - Returns: - Dispatch id of sublattice. None, if the electron is not a sublattice. - - """ - with workflow_db.session() as session: - if record := (session.query(Lattice).filter(Lattice.electron_id == electron_id).first()): - return record.dispatch_id diff --git a/covalent_dispatcher/_db/models.py b/covalent_dispatcher/_db/models.py index 727286d24..b5d1fe408 100644 --- a/covalent_dispatcher/_db/models.py +++ b/covalent_dispatcher/_db/models.py @@ -200,9 +200,6 @@ class Electron(Base): # Whether qelectron data exists or not qelectron_data_exists = Column(Boolean, nullable=False, default=False) - # Cancel requested flag - cancel_requested = Column(Boolean, nullable=False, default=False) - # Name of the file containing standard error generated by the task stderr_filename = Column(Text) diff --git a/covalent_dispatcher/_db/update.py b/covalent_dispatcher/_db/update.py index 823da0ef7..ecd9a324e 100644 --- a/covalent_dispatcher/_db/update.py +++ b/covalent_dispatcher/_db/update.py @@ -15,19 +15,15 @@ # limitations under the License. import os -from datetime import datetime from pathlib import Path -from typing import Any, Union +from typing import Union from covalent._results_manager import Result from covalent._shared_files import logger from covalent._shared_files.config import get_config -from covalent._shared_files.defaults import postprocess_prefix -from covalent._shared_files.util_classes import Status from covalent._workflow.lattice import Lattice from covalent._workflow.transport import _TransportGraph -from .._dal.result import get_result_object from . import upsert app_log = logger.app_log @@ -57,93 +53,3 @@ def _initialize_results_dir(result): f"{result.dispatch_id}", ) Path(result_folder_path).mkdir(parents=True, exist_ok=True) - - -# Temporary implementation using new DAL. Will be removed in the next -# patch which transitions core covalent to the new DAL. -def _node( - result, - node_id: int, - node_name: str = None, - start_time: datetime = None, - end_time: datetime = None, - status: "Status" = None, - output: Any = None, - error: Exception = None, - stdout: str = None, - stderr: str = None, - sub_dispatch_id=None, - sublattice_result=None, - qelectron_data_exists: bool = None, -) -> bool: - """ - Update the node result in the transport graph. - Called after any change in node's execution state. - - Args: - node_id: The node id. - node_name: The name of the node. - start_time: The start time of the node execution. - end_time: The end time of the node execution. - status: The status of the node execution. - output: The output of the node unless error occured in which case None. - error: The error of the node if occured else None. - stdout: The stdout of the node execution. - stderr: The stderr of the node execution. - - Returns: - True/False indicating whether the update succeeded - """ - - # Update the in-memory result object - result._update_node( - node_id=node_id, - node_name=node_name, - start_time=start_time, - end_time=end_time, - status=status, - output=output, - error=error, - stdout=stdout, - stderr=stderr, - sub_dispatch_id=sub_dispatch_id, - sublattice_result=sublattice_result, - qelectron_data_exists=qelectron_data_exists, - ) - - # Write out update to persistent storage - srvres = get_result_object(result.dispatch_id, bare=True) - srvres._update_node( - node_id=node_id, - node_name=node_name, - start_time=start_time, - end_time=end_time, - status=status, - output=output, - error=error, - stdout=stdout, - stderr=error, - qelectron_data_exists=qelectron_data_exists, - ) - - if node_name.startswith(postprocess_prefix) and end_time is not None: - app_log.warning( - f"Persisting postprocess result {output.get_deserialized()}, node_name: {node_name}" - ) - result._result = output - result._status = status - result._end_time = end_time - lattice_data(result) - - -# Temporary implementation of upsert.lattice_data using the new DAL. -# Will be removed in the next patch which transitions core covalent to -# the new DAL. -def lattice_data(result_object: Result) -> None: - srv_res = get_result_object(result_object.dispatch_id, bare=True) - srv_res._update_dispatch( - result_object.start_time, - result_object.end_time, - result_object.status, - result_object.error, - ) diff --git a/covalent_dispatcher/_db/upsert.py b/covalent_dispatcher/_db/upsert.py index da7699d04..9e242402b 100644 --- a/covalent_dispatcher/_db/upsert.py +++ b/covalent_dispatcher/_db/upsert.py @@ -45,7 +45,6 @@ ELECTRON_FUNCTION_FILENAME = ELECTRON_FILENAMES["function"] ELECTRON_FUNCTION_STRING_FILENAME = ELECTRON_FILENAMES["function_string"] ELECTRON_VALUE_FILENAME = ELECTRON_FILENAMES["value"] -# ELECTRON_EXECUTOR_DATA_FILENAME = "executor_data.pkl" ELECTRON_STDOUT_FILENAME = ELECTRON_FILENAMES["stdout"] ELECTRON_STDERR_FILENAME = ELECTRON_FILENAMES["stderr"] ELECTRON_ERROR_FILENAME = ELECTRON_FILENAMES["error"] @@ -57,14 +56,11 @@ LATTICE_FUNCTION_FILENAME = LATTICE_FILENAMES["workflow_function"] LATTICE_FUNCTION_STRING_FILENAME = LATTICE_FILENAMES["workflow_function_string"] LATTICE_DOCSTRING_FILENAME = LATTICE_FILENAMES["doc"] -# LATTICE_EXECUTOR_DATA_FILENAME = "executor_data.pkl" -# LATTICE_WORKFLOW_EXECUTOR_DATA_FILENAME = "workflow_executor_data.pkl" LATTICE_ERROR_FILENAME = LATTICE_FILENAMES["error"] LATTICE_INPUTS_FILENAME = LATTICE_FILENAMES["inputs"] LATTICE_NAMED_ARGS_FILENAME = LATTICE_FILENAMES["named_args"] LATTICE_NAMED_KWARGS_FILENAME = LATTICE_FILENAMES["named_kwargs"] LATTICE_RESULTS_FILENAME = LATTICE_FILENAMES["result"] -# LATTICE_TRANSPORT_GRAPH_FILENAME = "transport_graph.pkl" LATTICE_DEPS_FILENAME = LATTICE_FILENAMES["deps"] LATTICE_CALL_BEFORE_FILENAME = LATTICE_FILENAMES["call_before"] LATTICE_CALL_AFTER_FILENAME = LATTICE_FILENAMES["call_after"] @@ -113,16 +109,6 @@ def _lattice_data(session: Session, result: Result, electron_id: int = None) -> ("workflow_function", LATTICE_FUNCTION_FILENAME, result.lattice.workflow_function), ("workflow_function_string", LATTICE_FUNCTION_STRING_FILENAME, workflow_func_string), ("doc", LATTICE_DOCSTRING_FILENAME, result.lattice.__doc__), - # ( - # "executor_data", - # LATTICE_EXECUTOR_DATA_FILENAME, - # result.lattice.metadata["executor_data"], - # ), - # ( - # "workflow_executor_data", - # LATTICE_WORKFLOW_EXECUTOR_DATA_FILENAME, - # result.lattice.metadata["workflow_executor_data"], - # ), ("error", LATTICE_ERROR_FILENAME, result.error), ("inputs", LATTICE_INPUTS_FILENAME, result.lattice.inputs), ("named_args", LATTICE_NAMED_ARGS_FILENAME, result.lattice.named_args), @@ -175,10 +161,8 @@ def _lattice_data(session: Session, result: Result, electron_id: int = None) -> "function_string_filename": LATTICE_FUNCTION_STRING_FILENAME, "executor": result.lattice.metadata["executor"], "executor_data": json.dumps(result.lattice.metadata["executor_data"]), - # "executor_data_filename": LATTICE_EXECUTOR_DATA_FILENAME, "workflow_executor": result.lattice.metadata["workflow_executor"], "workflow_executor_data": json.dumps(result.lattice.metadata["workflow_executor_data"]), - # "workflow_executor_data_filename": LATTICE_WORKFLOW_EXECUTOR_DATA_FILENAME, "error_filename": LATTICE_ERROR_FILENAME, "inputs_filename": LATTICE_INPUTS_FILENAME, "named_args_filename": LATTICE_NAMED_ARGS_FILENAME, @@ -201,10 +185,9 @@ def _lattice_data(session: Session, result: Result, electron_id: int = None) -> lattice_row = Lattice.meta_type.create(session, insert_kwargs=lattice_record_kwarg, flush=True) lattice_record = Lattice(session, lattice_row, bare=True, keys={"id"}, electron_keys={"id"}) - lattice_asset_links = [] - for key, asset in assets.items(): - lattice_asset_links.append(lattice_record.associate_asset(session, key, asset.id)) - + lattice_asset_links = [ + lattice_record.associate_asset(session, key, asset.id) for key, asset in assets.items() + ] session.flush() return lattice_row.id @@ -304,11 +287,6 @@ def _electron_data( ("function", ELECTRON_FUNCTION_FILENAME, tg.get_node_value(node_id, "function")), ("function_string", ELECTRON_FUNCTION_STRING_FILENAME, function_string), ("value", ELECTRON_VALUE_FILENAME, node_value), - # ( - # "executor_data", - # ELECTRON_EXECUTOR_DATA_FILENAME, - # tg.get_node_value(node_id, "metadata")["executor_data"], - # ), ("deps", ELECTRON_DEPS_FILENAME, tg.get_node_value(node_id, "metadata")["deps"]), ( "call_before", @@ -367,7 +345,6 @@ def _electron_data( "function_string_filename": ELECTRON_FUNCTION_STRING_FILENAME, "executor": executor, "executor_data": json.dumps(executor_data), - # "executor_data_filename": ELECTRON_EXECUTOR_DATA_FILENAME, "results_filename": ELECTRON_RESULTS_FILENAME, "value_filename": ELECTRON_VALUE_FILENAME, "stdout_filename": ELECTRON_STDOUT_FILENAME, @@ -377,7 +354,6 @@ def _electron_data( "call_before_filename": ELECTRON_CALL_BEFORE_FILENAME, "call_after_filename": ELECTRON_CALL_AFTER_FILENAME, "qelectron_data_exists": node_qelectron_data_exists, - "cancel_requested": cancel_requested, "job_id": job_row.id, "created_at": timestamp, "updated_at": timestamp, diff --git a/covalent_dispatcher/_db/write_result_to_db.py b/covalent_dispatcher/_db/write_result_to_db.py index a343603e2..a334e395b 100644 --- a/covalent_dispatcher/_db/write_result_to_db.py +++ b/covalent_dispatcher/_db/write_result_to_db.py @@ -253,7 +253,6 @@ def transaction_insert_electrons_data( call_after_filename: str, job_id: int, qelectron_data_exists: bool, - cancel_requested: bool, created_at: dt, updated_at: dt, started_at: dt, @@ -296,7 +295,6 @@ def transaction_insert_electrons_data( call_before_filename=call_before_filename, call_after_filename=call_after_filename, qelectron_data_exists=qelectron_data_exists, - cancel_requested=cancel_requested, is_active=True, job_id=job_id, created_at=created_at, diff --git a/covalent_dispatcher/_service/app.py b/covalent_dispatcher/_service/app.py index 417477153..924d4174c 100644 --- a/covalent_dispatcher/_service/app.py +++ b/covalent_dispatcher/_service/app.py @@ -1,150 +1,341 @@ -# Copyright 2021 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the Apache License 2.0 (the "License"). A copy of the -# License may be obtained with this software package or at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Use of this file is prohibited except in compliance with the License. -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import codecs -import json -from typing import Optional -from uuid import UUID - -import cloudpickle as pickle -from fastapi import APIRouter, HTTPException, Request -from fastapi.responses import JSONResponse - -import covalent_dispatcher as dispatcher -from covalent._results_manager.result import Result -from covalent._shared_files import logger - -from .._db.datastore import workflow_db -from .._db.load import _result_from -from .._db.models import Lattice - -app_log = logger.app_log -log_stack_info = logger.log_stack_info - -router: APIRouter = APIRouter() - - -@router.post("/submit") -async def submit(request: Request, disable_run: bool = False) -> UUID: - """ - Function to accept the submit request of - new dispatch and return the dispatch id - back to the client. - - Args: - disable_run: Whether to disable the execution of this lattice - - Returns: - dispatch_id: The dispatch id in a json format - returned as a Fast API Response object - """ - try: - data = await request.json() - data = json.dumps(data).encode("utf-8") - - return await dispatcher.run_dispatcher(data, disable_run) - except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Failed to submit workflow: {e}", - ) from e - - -@router.post("/redispatch") -async def redispatch(request: Request, is_pending: bool = False) -> str: - """Endpoint to redispatch a workflow.""" - try: - data = await request.json() - dispatch_id = data["dispatch_id"] - json_lattice = data["json_lattice"] - electron_updates = data["electron_updates"] - reuse_previous_results = data["reuse_previous_results"] - app_log.debug( - f"Unpacked redispatch request for {dispatch_id}. reuse_previous_results: {reuse_previous_results}, electron_updates: {electron_updates}" - ) - return await dispatcher.run_redispatch( - dispatch_id, json_lattice, electron_updates, reuse_previous_results, is_pending - ) - - except Exception as e: - raise HTTPException( - status_code=400, - detail=f"Failed to redispatch workflow: {e}", - ) from e - - -@router.post("/cancel") -async def cancel(request: Request) -> str: - """ - Function to accept the cancel request of - a dispatch. - - Args: - None - - Returns: - Fast API Response object confirming that the dispatch - has been cancelled. - """ - - data = await request.json() - - dispatch_id = data["dispatch_id"] - task_ids = data["task_ids"] - - await dispatcher.cancel_running_dispatch(dispatch_id, task_ids) - if task_ids: - return f"Cancelled tasks {task_ids} in dispatch {dispatch_id}." - else: - return f"Dispatch {dispatch_id} cancelled." - - -@router.get("/result/{dispatch_id}") -async def get_result( - dispatch_id: str, wait: Optional[bool] = False, status_only: Optional[bool] = False -): - with workflow_db.session() as session: - lattice_record = session.query(Lattice).where(Lattice.dispatch_id == dispatch_id).first() - status = lattice_record.status if lattice_record else None - if not lattice_record: - return JSONResponse( - status_code=404, - content={"message": f"The requested dispatch ID {dispatch_id} was not found."}, - ) - if not wait or status in [ - str(Result.COMPLETED), - str(Result.FAILED), - str(Result.CANCELLED), - str(Result.POSTPROCESSING_FAILED), - str(Result.PENDING_POSTPROCESSING), - ]: - output = { - "id": dispatch_id, - "status": lattice_record.status, - } - if not status_only: - output["result"] = codecs.encode( - pickle.dumps(_result_from(lattice_record)), "base64" - ).decode() - return output - - return JSONResponse( - status_code=503, - content={ - "message": "Result not ready to read yet. Please wait for a couple of seconds." - }, - headers={"Retry-After": "2"}, - ) +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +"""Endpoints for dispatch management""" + +import asyncio +import json +from contextlib import asynccontextmanager +from typing import List, Optional, Union +from uuid import UUID + +from fastapi import APIRouter, FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse + +import covalent_dispatcher.entry_point as dispatcher +from covalent._shared_files import logger +from covalent._shared_files.schemas.result import ResultSchema +from covalent._shared_files.util_classes import RESULT_STATUS +from covalent_dispatcher._core import dispatcher as core_dispatcher + +from .._dal.exporters.result import export_result_manifest +from .._dal.result import Result, get_result_object +from .._db.datastore import workflow_db +from .._db.dispatchdb import DispatchDB +from .heartbeat import Heartbeat +from .models import DispatchStatusSetSchema, ExportResponseSchema, TargetDispatchStatus + +# from covalent_dispatcher._core import runner_ng as core_runner + + +app_log = logger.app_log +log_stack_info = logger.log_stack_info + +router: APIRouter = APIRouter() + +_background_tasks = set() + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Initialize global variables""" + + heartbeat = Heartbeat() + fut = asyncio.create_task(heartbeat.start()) + _background_tasks.add(fut) + fut.add_done_callback(_background_tasks.discard) + + # # Runner event queue and listener + # core_runner._job_events = asyncio.Queue() + # core_runner._job_event_listener = asyncio.create_task(core_runner._listen_for_job_events()) + + # Dispatcher event queue and listener + core_dispatcher._global_status_queue = asyncio.Queue() + core_dispatcher._global_event_listener = asyncio.create_task( + core_dispatcher._node_event_listener() + ) + + yield + + # Cancel all scheduled and running dispatches + for status in [ + RESULT_STATUS.NEW_OBJECT, + RESULT_STATUS.RUNNING, + ]: + await cancel_all_with_status(status) + + core_dispatcher._global_event_listener.cancel() + # core_runner._job_event_listener.cancel() + + Heartbeat.stop() + + +async def cancel_all_with_status(status: RESULT_STATUS): + """Cancel all dispatches with the specified status.""" + + with workflow_db.session() as session: + records = Result.get_db_records( + session, + keys=["dispatch_id"], + equality_filters={"status": str(status)}, + membership_filters={}, + ) + + for record in records: + dispatch_id = record.dispatch_id + await dispatcher.cancel_running_dispatch(dispatch_id) + + +@router.post("/dispatches/submit") +async def submit(request: Request) -> UUID: + """ + Function to accept the submit request of + new dispatch and return the dispatch id + back to the client. + + Args: + None + + Returns: + dispatch_id: The dispatch id in a json format + returned as a Fast API Response object. + """ + try: + data = await request.json() + data = json.dumps(data).encode("utf-8") + return await dispatcher.make_dispatch(data) + except Exception as e: + raise HTTPException( + status_code=400, + detail=f"Failed to submit workflow: {e}", + ) from e + + +async def start(dispatch_id: str): + """Start a previously registered (re-)dispatch. + + Args: + `dispatch_id`: The dispatch's unique id. + + Returns: + `dispatch_id` + """ + fut = asyncio.create_task(dispatcher.start_dispatch(dispatch_id)) + _background_tasks.add(fut) + fut.add_done_callback(_background_tasks.discard) + + return dispatch_id + + +async def cancel(dispatch_id: str, task_ids: List[int] = None) -> str: + """ + Function to handle the cancel request of + a dispatch. + + Args: + dispatch_id: ID of the dispatch + task_ids: (Query) Optional list of specific task ids to cancel. + An empty list will cause all tasks to be cancelled. + + Returns: + Fast API Response object confirming that the dispatch + has been cancelled. + """ + + if task_ids is None: + task_ids = [] + + await dispatcher.cancel_running_dispatch(dispatch_id, task_ids) + if task_ids: + return f"Cancelled tasks {task_ids} in dispatch {dispatch_id}." + else: + return f"Dispatch {dispatch_id} cancelled." + + +@router.get("/db-path") +def db_path() -> str: + db_path = DispatchDB()._dbpath + return json.dumps(db_path) + + +@router.post("/dispatches", status_code=201) +async def register(manifest: ResultSchema) -> ResultSchema: + """Register a dispatch in the database. + + Args: + manifest: Declares all metadata and assets in the workflow + parent_dispatch_id: The parent dispatch id if registering a sublattice dispatch + + Returns: + The manifest with `dispatch_id` and remote URIs for each asset populated. + """ + try: + return await dispatcher.register_dispatch(manifest, None) + except Exception as e: + app_log.debug(f"Exception in register: {e}") + raise HTTPException( + status_code=400, + detail=f"Failed to submit workflow: {e}", + ) from e + + +@router.post("/dispatches/{dispatch_id}/subdispatches", status_code=201) +async def register_subdispatch( + manifest: ResultSchema, + dispatch_id: str, +) -> ResultSchema: + """Register a subdispatch in the database. + + Args: + manifest: Declares all metadata and assets in the workflow + dispatch_id: The parent dispatch id + + Returns: + The manifest with `dispatch_id` and remote URIs for each asset populated. + """ + try: + return await dispatcher.register_dispatch(manifest, dispatch_id) + except Exception as e: + app_log.debug(f"Exception in register: {e}") + raise HTTPException( + status_code=400, + detail=f"Failed to submit workflow: {e}", + ) from e + + +@router.post("/dispatches/{dispatch_id}/redispatches", status_code=201) +async def register_redispatch( + manifest: ResultSchema, + dispatch_id: str, + reuse_previous_results: bool = False, +): + """Register a redispatch in the database. + + Args: + manifest: Declares all metadata and assets in the workflow + dispatch_id: The original dispatch's id. + reuse_previous_results: Whether to try reusing the results of + previously completed electrons. + + Returns: + The manifest with `dispatch_id` and remote URIs for each asset populated. + """ + try: + return await dispatcher.register_redispatch( + manifest, + dispatch_id, + reuse_previous_results, + ) + except Exception as e: + app_log.debug(f"Exception in register_redispatch: {e}") + raise HTTPException( + status_code=400, + detail=f"Failed to submit workflow: {e}", + ) from e + + +@router.put("/dispatches/{dispatch_id}/status", status_code=202) +async def set_dispatch_status(dispatch_id: str, desired_status: DispatchStatusSetSchema): + """Set the status of a dispatch. + + Valid target statuses are: + - "RUNNING" to start a dispatch + - "CANCELLED" to cancel dispatch processing + + Args: + `dispatch_id`: The dispatch's unique id + `desired_status`: A `StatusSetSchema` object describing the desired status. + + """ + + if desired_status.status == TargetDispatchStatus.running: + return await start(dispatch_id) + else: + return await cancel(dispatch_id, desired_status.task_ids) + + +@router.get("/dispatches/{dispatch_id}") +async def export_result( + dispatch_id: str, wait: Optional[bool] = False, status_only: Optional[bool] = False +) -> ExportResponseSchema: + """Export all metadata about a registered dispatch + + Args: + `dispatch_id`: The dispatch's unique id. + + Returns: + { + id: `dispatch_id`, + status: status, + result_export: manifest for the result + } + + The manifest `result_export` has the same schema as that which is + submitted to `/register`. + + """ + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + _export_result_sync, + dispatch_id, + wait, + status_only, + ) + + +def _export_result_sync( + dispatch_id: str, wait: Optional[bool] = False, status_only: Optional[bool] = False +) -> ExportResponseSchema: + result_object = _try_get_result_object(dispatch_id) + if not result_object: + return JSONResponse( + status_code=404, + content={"message": f"The requested dispatch ID {dispatch_id} was not found."}, + ) + status = str(result_object.get_value("status", refresh=False)) + + if not wait or status in [ + str(RESULT_STATUS.COMPLETED), + str(RESULT_STATUS.FAILED), + str(RESULT_STATUS.CANCELLED), + ]: + output = { + "id": dispatch_id, + "status": status, + } + if not status_only: + output["result_export"] = export_result_manifest(dispatch_id) + + return output + + response = JSONResponse( + status_code=503, + content={"message": "Result not ready to read yet. Please wait for a couple of seconds."}, + headers={"Retry-After": "2"}, + ) + return response + + +def _try_get_result_object(dispatch_id: str) -> Union[Result, None]: + try: + res = get_result_object( + dispatch_id, bare=True, keys=["id", "dispatch_id", "status"], lattice_keys=["id"] + ) + except KeyError: + res = None + return res diff --git a/covalent_dispatcher/_service/assets.py b/covalent_dispatcher/_service/assets.py new file mode 100644 index 000000000..83caffd07 --- /dev/null +++ b/covalent_dispatcher/_service/assets.py @@ -0,0 +1,526 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Endpoints for uploading and downloading workflow assets""" + +import asyncio +import mmap +import os +from functools import lru_cache +from typing import Tuple, Union + +import aiofiles +import aiofiles.os +from fastapi import APIRouter, Header, HTTPException, Request +from fastapi.responses import StreamingResponse +from furl import furl + +from covalent._serialize.electron import ASSET_TYPES as ELECTRON_ASSET_TYPES +from covalent._serialize.lattice import ASSET_TYPES as LATTICE_ASSET_TYPES +from covalent._serialize.result import ASSET_TYPES as RESULT_ASSET_TYPES +from covalent._serialize.result import AssetType +from covalent._shared_files import logger +from covalent._shared_files.config import get_config +from covalent._workflow.transportable_object import TOArchiveUtils + +from .._dal.result import get_result_object +from .._db.datastore import workflow_db +from .models import ( + AssetRepresentation, + DispatchAssetKey, + ElectronAssetKey, + LatticeAssetKey, + range_pattern, + range_regex, +) + +app_log = logger.app_log +log_stack_info = logger.log_stack_info + +router: APIRouter = APIRouter() + +app_log = logger.app_log +log_stack_info = logger.log_stack_info + +router: APIRouter = APIRouter() + +_background_tasks = set() + +LRU_CACHE_SIZE = get_config("dispatcher.asset_cache_size") + + +@router.get("/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}") +def get_node_asset( + dispatch_id: str, + node_id: int, + key: ElectronAssetKey, + representation: Union[AssetRepresentation, None] = None, + Range: Union[str, None] = Header(default=None, regex=range_regex), +): + """Returns an asset for an electron. + + Args: + dispatch_id: The dispatch's unique id. + node_id: The id of the electron. + key: The name of the asset + representation: (optional) the representation ("string" or "pickle") of a `TransportableObject` + range: (optional) range request header + + If `representation` is specified, it will override the range request. + """ + start_byte = 0 + end_byte = -1 + + try: + if Range: + start_byte, end_byte = _extract_byte_range(Range) + + if end_byte >= 0 and end_byte < start_byte: + raise HTTPException( + status_code=400, + detail="Invalid byte range", + ) + app_log.debug( + f"Requested asset {key.value} ([{start_byte}:{end_byte}]) for node {dispatch_id}:{node_id}" + ) + + result_object = get_cached_result_object(dispatch_id) + + app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") + + node = result_object.lattice.transport_graph.get_node(node_id) + with workflow_db.session() as session: + asset = node.get_asset(key=key.value, session=session) + + # Explicit representation overrides the byte range + if representation is None or ELECTRON_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE: + start_byte = start_byte + end_byte = end_byte + elif representation == AssetRepresentation.string: + start_byte, end_byte = _get_tobj_string_offsets(asset.internal_uri) + else: + start_byte, end_byte = _get_tobj_pickle_offsets(asset.internal_uri) + + app_log.debug(f"Serving byte range {start_byte}:{end_byte} of {asset.internal_uri}") + generator = _generate_file_slice(asset.internal_uri, start_byte, end_byte) + return StreamingResponse(generator) + + except Exception as e: + app_log.debug(e) + raise + + +@router.get("/dispatches/{dispatch_id}/assets/{key}") +def get_dispatch_asset( + dispatch_id: str, + key: DispatchAssetKey, + representation: Union[AssetRepresentation, None] = None, + Range: Union[str, None] = Header(default=None, regex=range_regex), +): + """Returns a dynamic asset for a workflow + + Args: + dispatch_id: The dispatch's unique id. + key: The name of the asset + representation: (optional) the representation ("string" or "pickle") of a `TransportableObject` + range: (optional) range request header + + If `representation` is specified, it will override the range request. + """ + start_byte = 0 + end_byte = -1 + + try: + if Range: + start_byte, end_byte = _extract_byte_range(Range) + + if end_byte >= 0 and end_byte < start_byte: + raise HTTPException( + status_code=400, + detail="Invalid byte range", + ) + app_log.debug( + f"Requested asset {key.value} ([{start_byte}:{end_byte}]) for dispatch {dispatch_id}" + ) + + result_object = get_cached_result_object(dispatch_id) + + app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") + with workflow_db.session() as session: + asset = result_object.get_asset(key=key.value, session=session) + + # Explicit representation overrides the byte range + if representation is None or RESULT_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE: + start_byte = start_byte + end_byte = end_byte + elif representation == AssetRepresentation.string: + start_byte, end_byte = _get_tobj_string_offsets(asset.internal_uri) + else: + start_byte, end_byte = _get_tobj_pickle_offsets(asset.internal_uri) + + app_log.debug(f"Serving byte range {start_byte}:{end_byte} of {asset.internal_uri}") + generator = _generate_file_slice(asset.internal_uri, start_byte, end_byte) + return StreamingResponse(generator) + except Exception as e: + app_log.debug(e) + raise + + +@router.get("/dispatches/{dispatch_id}/lattice/assets/{key}") +def get_lattice_asset( + dispatch_id: str, + key: LatticeAssetKey, + representation: Union[AssetRepresentation, None] = None, + Range: Union[str, None] = Header(default=None, regex=range_regex), +): + """Returns a static asset for a workflow + + Args: + dispatch_id: The dispatch's unique id. + key: The name of the asset + representation: (optional) the representation ("string" or "pickle") of a `TransportableObject` + range: (optional) range request header + + If `representation` is specified, it will override the range request. + """ + start_byte = 0 + end_byte = -1 + + try: + if Range: + start_byte, end_byte = _extract_byte_range(Range) + + if end_byte >= 0 and end_byte < start_byte: + raise HTTPException( + status_code=400, + detail="Invalid byte range", + ) + app_log.debug( + f"Requested lattice asset {key.value} ([{start_byte}:{end_byte}])for dispatch {dispatch_id}" + ) + + result_object = get_cached_result_object(dispatch_id) + app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") + + with workflow_db.session() as session: + asset = result_object.lattice.get_asset(key=key.value, session=session) + + # Explicit representation overrides the byte range + if representation is None or LATTICE_ASSET_TYPES[key.value] != AssetType.TRANSPORTABLE: + start_byte = start_byte + end_byte = end_byte + elif representation == AssetRepresentation.string: + start_byte, end_byte = _get_tobj_string_offsets(asset.internal_uri) + else: + start_byte, end_byte = _get_tobj_pickle_offsets(asset.internal_uri) + + app_log.debug(f"Serving byte range {start_byte}:{end_byte} of {asset.internal_uri}") + generator = _generate_file_slice(asset.internal_uri, start_byte, end_byte) + return StreamingResponse(generator) + + except Exception as e: + app_log.debug(e) + raise e + + +@router.put("/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}") +async def upload_node_asset( + req: Request, + dispatch_id: str, + node_id: int, + key: ElectronAssetKey, + content_length: int = Header(default=0), + digest_alg: Union[str, None] = Header(default=None), + digest: Union[str, None] = Header(default=None), +): + """Upload an electron asset. + + Args: + dispatch_id: The dispatch's unique id. + node_id: The electron id. + key: The name of the asset + asset_file: (body) The file to be uploaded + content_length: (header) + digest: (header) + """ + app_log.debug(f"Requested asset {key} for node {dispatch_id}:{node_id}") + + try: + metadata = {"size": content_length, "digest_alg": digest_alg, "digest": digest} + internal_uri = await _run_in_executor( + _update_node_asset_metadata, + dispatch_id, + node_id, + key, + metadata, + ) + # Stream the request body to object store + await _transfer_data(req, internal_uri) + + return f"Uploaded file to {internal_uri}" + except Exception as e: + app_log.debug(e) + raise + + +@router.put("/dispatches/{dispatch_id}/assets/{key}") +async def upload_dispatch_asset( + req: Request, + dispatch_id: str, + key: DispatchAssetKey, + content_length: int = Header(default=0), + digest_alg: Union[str, None] = Header(default=None), + digest: Union[str, None] = Header(default=None), +): + """Upload a dispatch asset. + + Args: + dispatch_id: The dispatch's unique id. + key: The name of the asset + asset_file: (body) The file to be uploaded + content_length: (header) + digest: (header) + """ + try: + metadata = {"size": content_length, "digest_alg": digest_alg, "digest": digest} + internal_uri = await _run_in_executor( + _update_dispatch_asset_metadata, + dispatch_id, + key, + metadata, + ) + # Stream the request body to object store + await _transfer_data(req, internal_uri) + return f"Uploaded file to {internal_uri}" + except Exception as e: + app_log.debug(e) + raise + + +@router.put("/dispatches/{dispatch_id}/lattice/assets/{key}") +async def upload_lattice_asset( + req: Request, + dispatch_id: str, + key: LatticeAssetKey, + content_length: int = Header(default=0), + digest_alg: Union[str, None] = Header(default=None), + digest: Union[str, None] = Header(default=None), +): + """Upload a lattice asset. + + Args: + dispatch_id: The dispatch's unique id. + key: The name of the asset + asset_file: (body) The file to be uploaded + content_length: (header) + digest: (header) + """ + try: + metadata = {"size": content_length, "digest_alg": digest_alg, "digest": digest} + internal_uri = await _run_in_executor( + _update_lattice_asset_metadata, + dispatch_id, + key, + metadata, + ) + # Stream the request body to object store + await _transfer_data(req, internal_uri) + return f"Uploaded file to {internal_uri}" + except Exception as e: + app_log.debug(e) + raise + + +def _generate_file_slice(file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + """Generator of a byte slice from a file. + + Args: + file_url: A file:/// type URL pointing to the file + start_byte: The beginning of the byte range + end_byte: The end of the byte range, or -1 to select [start_byte:] + chunk_size: The size of each chunk + + Returns: + Yields chunks of size <= chunk_size + """ + byte_pos = start_byte + file_path = str(furl(file_url).path) + with open(file_path, "rb") as f: + f.seek(start_byte) + if end_byte < 0: + for chunk in f: + yield chunk + else: + while byte_pos + chunk_size < end_byte: + byte_pos += chunk_size + yield f.read(chunk_size) + yield f.read(end_byte - byte_pos) + + +def _extract_byte_range(byte_range_header: str) -> Tuple[int, int]: + """Extract the byte range from a range request header.""" + start_byte = 0 + end_byte = -1 + match = range_pattern.match(byte_range_header) + start = match.group(1) + end = match.group(2) + start_byte = int(start) + if end: + end_byte = int(end) + + return start_byte, end_byte + + +# Helpers for TransportableObject + + +def _get_tobj_string_offsets(file_url: str) -> Tuple[int, int]: + """Get the byte range for the str rep of a stored TObj. + + For a first implementation we just query the filesystem directly. + + Args: + file_url: A file:/// URL pointing to the TransportableObject + + Returns: + (start_byte, end_byte) + """ + + file_path = str(furl(file_url).path) + filelen = os.path.getsize(file_path) + with open(file_path, "rb+") as f: + with mmap.mmap(f.fileno(), filelen) as mm: + # TOArchiveUtils operates on byte arrays + return TOArchiveUtils.string_byte_range(mm) + + +def _get_tobj_pickle_offsets(file_url: str) -> Tuple[int, int]: + """Get the byte range for the picklebytes of a stored TObj. + + For a first implementation we just query the filesystem directly. + + Args: + file_url: A file:/// URL pointing to the TransportableObject + + Returns: + (start_byte, -1) + """ + + file_path = str(furl(file_url).path) + filelen = os.path.getsize(file_path) + with open(file_path, "rb+") as f: + with mmap.mmap(f.fileno(), filelen) as mm: + # TOArchiveUtils operates on byte arrays + return TOArchiveUtils.data_byte_range(mm) + + +# This must only be used for static data as we don't have yet any +# intelligent invalidation logic. +@lru_cache(maxsize=LRU_CACHE_SIZE) +def get_cached_result_object(dispatch_id: str): + try: + with workflow_db.session() as session: + srv_res = get_result_object(dispatch_id, bare=False, session=session) + app_log.debug(f"Caching result {dispatch_id}") + + # Prepopulate asset maps to avoid DB lookups + + srv_res.populate_asset_map(session) + srv_res.lattice.populate_asset_map(session) + + tg = srv_res.lattice.transport_graph + g = tg.get_internal_graph_copy() + for node_id in g.nodes(): + node = tg.get_node(node_id, session) + node.populate_asset_map(session) + except KeyError: + raise HTTPException( + status_code=404, + detail=f"The requested dispatch ID {dispatch_id} was not found.", + ) + + return srv_res + + +def _filter_null_metadata(metadata): + # Filter out null updates + return {k: v for k, v in metadata.items() if v is not None} + + +def _update_node_asset_metadata(dispatch_id, node_id, key, metadata) -> str: + result_object = get_cached_result_object(dispatch_id) + + app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") + node = result_object.lattice.transport_graph.get_node(node_id) + with workflow_db.session() as session: + asset = node.get_asset(key=key.value, session=session) + app_log.debug(f"Asset uri {asset.internal_uri}") + + # Update asset metadata + update = _filter_null_metadata(metadata) + node.update_assets(updates={key: update}, session=session) + app_log.debug(f"Updated node asset {dispatch_id}:{node_id}:{key}") + + return asset.internal_uri + + +def _update_lattice_asset_metadata(dispatch_id, key, metadata) -> str: + result_object = get_cached_result_object(dispatch_id) + + app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") + with workflow_db.session() as session: + asset = result_object.lattice.get_asset(key=key.value, session=session) + + # Update asset metadata + update = _filter_null_metadata(metadata) + result_object.lattice.update_assets(updates={key: update}, session=session) + app_log.debug(f"Updated size for lattice asset {dispatch_id}:{key}") + + return asset.internal_uri + + +def _update_dispatch_asset_metadata(dispatch_id, key, metadata) -> str: + result_object = get_cached_result_object(dispatch_id) + + app_log.debug(f"LRU cache info: {get_cached_result_object.cache_info()}") + with workflow_db.session() as session: + asset = result_object.get_asset(key=key.value, session=session) + + # Update asset metadata + update = _filter_null_metadata(metadata) + result_object.update_assets(updates={key: update}, session=session) + app_log.debug(f"Updated size for dispatch asset {dispatch_id}:{key}") + return asset.internal_uri + + +async def _transfer_data(req: Request, destination_url: str): + dest_url = furl(destination_url) + dest_path = str(dest_url.path) + + # Stream data to a temporary file, then replace the destination + # file atomically + tmp_path = f"{dest_path}.tmp" + + async with aiofiles.open(tmp_path, "wb") as f: + async for chunk in req.stream(): + await f.write(chunk) + + await aiofiles.os.replace(tmp_path, dest_path) + + +def _run_in_executor(function, *args) -> asyncio.Future: + loop = asyncio.get_running_loop() + return loop.run_in_executor(None, function, *args) diff --git a/covalent_ui/heartbeat.py b/covalent_dispatcher/_service/heartbeat.py similarity index 58% rename from covalent_ui/heartbeat.py rename to covalent_dispatcher/_service/heartbeat.py index 213e95877..b6b3ce3a8 100644 --- a/covalent_ui/heartbeat.py +++ b/covalent_dispatcher/_service/heartbeat.py @@ -15,16 +15,11 @@ # limitations under the License. import asyncio -from contextlib import asynccontextmanager from datetime import datetime, timezone -from typing import List import aiofiles -from fastapi import FastAPI from covalent._shared_files.config import get_config -from covalent_ui.api.v1.routes.end_points.summary_routes import get_all_dispatches -from covalent_ui.api.v1.utils.status import Status class Heartbeat: @@ -55,48 +50,3 @@ def stop(): file.write( f"DEAD {datetime.now(tz=timezone.utc).strftime(Heartbeat.TIMESTAMP_FORMAT)}" ) - - -async def cancel_all_with_status(status: Status) -> List[str]: - from covalent_dispatcher._core.dispatcher import cancel_dispatch - - dispatch_ids = [] - page = 0 - count = 100 - - while True: - dispatches = get_all_dispatches( - count=count, - offset=page * count, - status_filter=status, - ) - - dispatch_ids += [dispatch.dispatch_id for dispatch in dispatches.items] - - if dispatches.total_count == page * count + len(dispatches.items): - break - - page += 1 - - for dispatch_id in dispatch_ids: - await cancel_dispatch(dispatch_id) - - return dispatch_ids - - -@asynccontextmanager -async def lifespan(app: FastAPI): - heartbeat = Heartbeat() - asyncio.create_task(heartbeat.start()) - - yield - - for status in [ - Status.NEW_OBJECT, - Status.POSTPROCESSING, - Status.PENDING_POSTPROCESSING, - Status.RUNNING, - ]: - await cancel_all_with_status(status) - - Heartbeat.stop() diff --git a/covalent_dispatcher/_service/models.py b/covalent_dispatcher/_service/models.py new file mode 100644 index 000000000..95c71efe3 --- /dev/null +++ b/covalent_dispatcher/_service/models.py @@ -0,0 +1,126 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""FastAPI models for /api/v1/resultv2 endpoints""" + +import re +from enum import Enum +from typing import List, Optional + +from pydantic import BaseModel + +from covalent._shared_files.schemas.result import ResultSchema + +# # Copied from _dal +# RESULT_ASSET_KEYS = { +# "inputs", +# "result", +# "error", +# } + +# # Copied from _dal +# LATTICE_ASSET_KEYS = { +# "workflow_function", +# "workflow_function_string", +# "__doc__", +# "named_args", +# "named_kwargs", +# "cova_imports", +# "lattice_imports", +# # metadata +# "executor_data", +# "workflow_executor_data", +# "deps", +# "call_before", +# "call_after", +# } + +# # Copied from _dal +# ELECTRON_ASSET_KEYS = { +# "function", +# "function_string", +# "output", +# "value", +# "error", +# "stdout", +# "stderr", +# # electron metadata +# "deps", +# "call_before", +# "call_after", +# } + +range_regex = "bytes=([0-9]+)-([0-9]*)" +range_pattern = re.compile(range_regex) + +digest_regex = "(sha|sha-256)=([0-9a-f]+)" +digest_pattern = re.compile(digest_regex) + + +class DispatchAssetKey(str, Enum): + result = "result" + error = "error" + + +class LatticeAssetKey(str, Enum): + workflow_function = "workflow_function" + workflow_function_string = "workflow_function_string" + doc = "doc" + inputs = "inputs" + named_args = "named_args" + named_kwargs = "named_kwargs" + deps = "deps" + call_before = "call_before" + call_after = "call_after" + cova_imports = "cova_imports" + lattice_imports = "lattice_imports" + + +class ElectronAssetKey(str, Enum): + function = "function" + function_string = "function_string" + output = "output" + value = "value" + deps = "deps" + error = "error" + stdout = "stdout" + stderr = "stderr" + call_before = "call_before" + call_after = "call_after" + + +class ExportResponseSchema(BaseModel): + id: str + status: str + result_export: Optional[ResultSchema] + + +class AssetRepresentation(str, Enum): + string = "string" + b64pickle = "object" + + +class TargetDispatchStatus(str, Enum): + running = "RUNNING" + cancelled = "CANCELLED" + + +class DispatchStatusSetSchema(BaseModel): + # The target status + status: TargetDispatchStatus + + # For cancellation, an optional list of task ids to cancel + task_ids: Optional[List] = [] diff --git a/covalent_dispatcher/entry_point.py b/covalent_dispatcher/entry_point.py index 534ecb47b..a78242d06 100644 --- a/covalent_dispatcher/entry_point.py +++ b/covalent_dispatcher/entry_point.py @@ -18,9 +18,11 @@ Self-contained entry point for the dispatcher """ -from typing import List +import asyncio +from typing import List, Optional from covalent._shared_files import logger +from covalent._shared_files.schemas.result import ResultSchema from ._core import cancel_dispatch @@ -28,7 +30,7 @@ log_stack_info = logger.log_stack_info -async def run_dispatcher(json_lattice: str, disable_run: bool = False): +async def make_dispatch(json_lattice: str): """ Run the dispatcher from the lattice asynchronously using Dask. Assign a new dispatch id to the result object and return it. @@ -36,47 +38,67 @@ async def run_dispatcher(json_lattice: str, disable_run: bool = False): Args: json_lattice: A JSON-serialized lattice - disable_run: Whether to disable execution of this lattice Returns: dispatch_id: A string containing the dispatch id of current dispatch. """ - from ._core import make_dispatch, run_dispatch + from ._core import make_dispatch dispatch_id = await make_dispatch(json_lattice) - if not disable_run: - run_dispatch(dispatch_id) - app_log.debug(f"Submitted dispatch_id {dispatch_id} to run_workflow.") + app_log.debug(f"Created new dispatch {dispatch_id}") return dispatch_id -async def run_redispatch( - dispatch_id: str, - json_lattice: str, - electron_updates: dict, - reuse_previous_results: bool, - is_pending: bool = False, -): - from ._core import make_derived_dispatch, run_dispatch - - app_log.debug("Running redispatch ...") - if is_pending: - run_dispatch(dispatch_id) - app_log.debug(f"Submitted pending dispatch_id {dispatch_id} to run_dispatch.") - return dispatch_id - - redispatch_id = make_derived_dispatch( - dispatch_id, json_lattice, electron_updates, reuse_previous_results - ) - app_log.debug(f"Redispatch id {redispatch_id} created.") - run_dispatch(redispatch_id) +async def start_dispatch(dispatch_id: str): + """ + Run the dispatcher from the lattice asynchronously using Dask. + Assign a new dispatch id to the result object and return it. + Also save the result in this initial stage to the file mentioned in the result object. + + Args: + json_lattice: A JSON-serialized lattice + + Returns: + dispatch_id: A string containing the dispatch id of current dispatch. + """ + + from ._core import copy_futures, run_dispatch + + # Wait for any pending asset transfers + _fut = copy_futures.get(dispatch_id, None) + if _fut is not None: + # _fut is a concurrent.future.Future, so we need to wrap it in + # an asyncio.Future + app_log.debug(f"Waiting on asset transfers for dispatch {dispatch_id}") + await asyncio.wrap_future(_fut) + + # Idempotent + run_dispatch(dispatch_id) + app_log.debug(f"Running dispatch {dispatch_id}") + + +async def run_dispatcher(json_lattice: str): + """ + Run the dispatcher from the lattice asynchronously using Dask. + Assign a new dispatch id to the result object and return it. + Also save the result in this initial stage to the file mentioned in the result object. + + Args: + json_lattice: A JSON-serialized lattice + + Returns: + dispatch_id: A string containing the dispatch id of current dispatch. + """ + + dispatch_id = await make_dispatch(json_lattice) + await start_dispatch(dispatch_id) - app_log.debug(f"Re-dispatching {dispatch_id} as {redispatch_id}") + app_log.debug("Submitted result object to run_workflow.") - return redispatch_id + return dispatch_id async def cancel_running_dispatch(dispatch_id: str, task_ids: List[int] = None) -> None: @@ -94,3 +116,25 @@ async def cancel_running_dispatch(dispatch_id: str, task_ids: List[int] = None) task_ids = [] await cancel_dispatch(dispatch_id, task_ids) + + +async def register_dispatch( + manifest: ResultSchema, parent_dispatch_id: Optional[str] +) -> ResultSchema: + from ._core.data_modules.importer import import_manifest + + return await import_manifest(manifest, parent_dispatch_id, None) + + +async def register_redispatch( + manifest: ResultSchema, + parent_dispatch_id: str, + reuse_previous_results: bool, +) -> ResultSchema: + from ._core.data_modules.importer import import_derived_manifest + + return await import_derived_manifest( + manifest, + parent_dispatch_id, + reuse_previous_results, + ) diff --git a/covalent_ui/api/main.py b/covalent_ui/api/main.py index e01ad864b..a5f765199 100644 --- a/covalent_ui/api/main.py +++ b/covalent_ui/api/main.py @@ -34,8 +34,8 @@ from covalent._shared_files import logger from covalent._shared_files.config import get_config +from covalent_dispatcher._service.app import lifespan from covalent_ui.api.v1.routes import routes -from covalent_ui.heartbeat import lifespan file_descriptor = None child_process_id = None diff --git a/covalent_ui/api/v1/data_layer/electron_dal.py b/covalent_ui/api/v1/data_layer/electron_dal.py index fb9665bfb..3dbd5dc63 100644 --- a/covalent_ui/api/v1/data_layer/electron_dal.py +++ b/covalent_ui/api/v1/data_layer/electron_dal.py @@ -31,7 +31,6 @@ from covalent._shared_files.qelectron_utils import QE_DB_DIRNAME from covalent.quantum.qserver.database import Database from covalent_dispatcher._core.execution import _get_task_inputs as get_task_inputs -from covalent_dispatcher._service.app import get_result from covalent_ui.api.v1.data_layer.lattice_dal import Lattices from covalent_ui.api.v1.database.schema.electron import Electron from covalent_ui.api.v1.database.schema.lattices import Lattice @@ -254,6 +253,7 @@ def get_electrons_id(self, dispatch_id, electron_id) -> Electron: Electron.function_filename, Electron.function_string_filename, Electron.executor, + Electron.executor_data, Electron.results_filename, Electron.value_filename, Electron.stdout_filename, diff --git a/covalent_ui/api/v1/data_layer/lattice_dal.py b/covalent_ui/api/v1/data_layer/lattice_dal.py index 78121ccf4..154a9fd59 100644 --- a/covalent_ui/api/v1/data_layer/lattice_dal.py +++ b/covalent_ui/api/v1/data_layer/lattice_dal.py @@ -95,7 +95,9 @@ def get_lattices_id_storage_file(self, dispatch_id: UUID): Lattice.error_filename, Lattice.function_string_filename, Lattice.executor, + Lattice.executor_data, Lattice.workflow_executor, + Lattice.workflow_executor_data, Lattice.error_filename, Lattice.inputs_filename, Lattice.results_filename, diff --git a/covalent_ui/api/v1/database/schema/electron.py b/covalent_ui/api/v1/database/schema/electron.py index e35550da0..7baf68f85 100644 --- a/covalent_ui/api/v1/database/schema/electron.py +++ b/covalent_ui/api/v1/database/schema/electron.py @@ -93,6 +93,9 @@ class Electron(Base): # Short name describing the executor ("local", "dask", etc) executor = Column(Text) + # JSONified executor attributes + executor_data = Column(Text) + # name of the file containing the serialized output results_filename = Column(Text) @@ -123,9 +126,6 @@ class Electron(Base): # ID for circuit_info job_id = Column(Integer, ForeignKey("jobs.id", name="job_id_link"), nullable=False) - # Cancel requested flag - cancel_requested = Column(Boolean, nullable=False, default=False) - # Flag that indicates if qelectron data exists in the electron qelectron_data_exists = Column(Boolean, nullable=False, default=False) diff --git a/covalent_ui/api/v1/database/schema/lattices.py b/covalent_ui/api/v1/database/schema/lattices.py index 1c5acff30..97212a9f1 100644 --- a/covalent_ui/api/v1/database/schema/lattices.py +++ b/covalent_ui/api/v1/database/schema/lattices.py @@ -82,9 +82,15 @@ class Lattice(Base): # Short name describing the executor ("local", "dask", etc) executor = Column(Text) + # JSONified executor attributes + executor_data = Column(Text) + # Short name describing the workflow executor ("local", "dask", etc) workflow_executor = Column(Text) + # JSONified workflow executor attributes + workflow_executor_data = Column(Text) + # Name of the file containing an error message for the workflow error_filename = Column(Text) diff --git a/covalent_ui/api/v1/models/lattices_model.py b/covalent_ui/api/v1/models/lattices_model.py index 3a3c9fae0..2c5a241bb 100644 --- a/covalent_ui/api/v1/models/lattices_model.py +++ b/covalent_ui/api/v1/models/lattices_model.py @@ -123,4 +123,3 @@ class LatticeFileOutput(str, Enum): EXECUTOR = "executor" WORKFLOW_EXECUTOR = "workflow_executor" FUNCTION = "function" - TRANSPORT_GRAPH = "transport_graph" diff --git a/covalent_ui/api/v1/routes/end_points/electron_routes.py b/covalent_ui/api/v1/routes/end_points/electron_routes.py index 60afab61e..4742c5e6d 100644 --- a/covalent_ui/api/v1/routes/end_points/electron_routes.py +++ b/covalent_ui/api/v1/routes/end_points/electron_routes.py @@ -16,6 +16,7 @@ """Electrons Route""" +import json import uuid from typing import List, Optional @@ -23,8 +24,9 @@ from sqlalchemy.orm import Session import covalent_ui.api.v1.database.config.db as db -from covalent._results_manager.results_manager import get_result -from covalent_dispatcher._core.execution import _get_task_inputs as get_task_inputs +from covalent._shared_files.defaults import WAIT_EDGE_NAME +from covalent_dispatcher._core.data_modules import graph as core_graph +from covalent_dispatcher._dal.result import get_result_object from covalent_ui.api.v1.data_layer.electron_dal import Electrons from covalent_ui.api.v1.models.electrons_model import ( ElectronExecutorResponse, @@ -93,6 +95,41 @@ def get_electron_details(dispatch_id: uuid.UUID, electron_id: int): ) +def _get_abstract_task_inputs(dispatch_id: str, node_id: int) -> dict: + """Return placeholders for the required inputs for a task execution. + + Args: + dispatch_id: id of the current dispatch + node_id: Node id of this task in the transport graph. + node_name: Name of the node. + + Returns: inputs: Input dictionary to be passed to the task with + `node_id` placeholders for args, kwargs. These are to be + resolved to their values later. + """ + + abstract_task_input = {"args": [], "kwargs": {}} + + in_edges = core_graph.get_incoming_edges_sync(dispatch_id, node_id) + for edge in in_edges: + parent = edge["source"] + + d = edge["attrs"] + + if d["edge_name"] != WAIT_EDGE_NAME: + if d["param_type"] == "arg": + abstract_task_input["args"].append((parent, d["arg_index"])) + elif d["param_type"] == "kwarg": + key = d["edge_name"] + abstract_task_input["kwargs"][key] = parent + + sorted_args = sorted(abstract_task_input["args"], key=lambda x: x[1]) + abstract_task_input["args"] = [x[0] for x in sorted_args] + + return abstract_task_input + + +# Domain: data def get_electron_inputs(dispatch_id: uuid.UUID, electron_id: int) -> str: """ Get Electron Inputs @@ -103,15 +140,29 @@ def get_electron_inputs(dispatch_id: uuid.UUID, electron_id: int) -> str: Returns the inputs data from Result object """ - result_object = get_result(dispatch_id=str(dispatch_id), wait=False) + abstract_inputs = _get_abstract_task_inputs(dispatch_id=str(dispatch_id), node_id=electron_id) + + # Resolve node ids to object strings + input_assets = {"args": [], "kwargs": {}} with Session(db.engine) as session: - electron = Electrons(session) - result = electron.get_electrons_id(dispatch_id, electron_id) - inputs = get_task_inputs( - node_id=electron_id, node_name=result.name, result_object=result_object - ) - return validate_data(inputs) + result_object = get_result_object(str(dispatch_id), bare=True) + tg = result_object.lattice.transport_graph + for arg in abstract_inputs["args"]: + node = tg.get_node(node_id=arg, session=session) + asset = node.get_asset(key="output", session=session) + input_assets["args"].append(asset) + for k, v in abstract_inputs["kwargs"].items(): + node = tg.get_node(node_id=v, session=session) + asset = node.get_asset(key="output", session=session) + input_assets["kwargs"][k] = asset + + # For now we load the picklefile from the object store into memory, but once + # TransportableObjects are no longer pickled we will be + # able to load the byte range for the object string. + input_args = [asset.load_data() for asset in input_assets["args"]] + input_kwargs = {k: asset.load_data() for k, asset in input_assets["kwargs"].items()} + return validate_data({"args": input_args, "kwargs": input_kwargs}) @routes.get("/{dispatch_id}/electron/{electron_id}/details/{name}") @@ -130,52 +181,7 @@ def get_electron_file(dispatch_id: uuid.UUID, electron_id: int, name: ElectronFi with Session(db.engine) as session: electron = Electrons(session) result = electron.get_electrons_id(dispatch_id, electron_id) - if result is not None: - handler = FileHandler(result["storage_path"]) - if name == "inputs": - response, python_object = get_electron_inputs( - dispatch_id=dispatch_id, electron_id=electron_id - ) - return ElectronFileResponse(data=str(response), python_object=str(python_object)) - elif name == "function_string": - response = handler.read_from_text(result["function_string_filename"]) - return ElectronFileResponse(data=response) - elif name == "function": - response, python_object = handler.read_from_pickle(result["function_filename"]) - return ElectronFileResponse(data=response, python_object=python_object) - elif name == "executor": - executor_name = result["executor"] - return ElectronExecutorResponse( - executor_name=executor_name, - ) - elif name == "result": - response, python_object = handler.read_from_pickle(result["results_filename"]) - return ElectronFileResponse(data=str(response), python_object=python_object) - elif name == "value": - response = handler.read_from_pickle(result["value_filename"]) - return ElectronFileResponse(data=str(response)) - elif name == "stdout": - response = handler.read_from_text(result["stdout_filename"]) - return ElectronFileResponse(data=response) - elif name == "deps": - response = handler.read_from_pickle(result["deps_filename"]) - return ElectronFileResponse(data=response) - elif name == "call_before": - response = handler.read_from_pickle(result["call_before_filename"]) - return ElectronFileResponse(data=response) - elif name == "call_after": - response = handler.read_from_pickle(result["call_after_filename"]) - return ElectronFileResponse(data=response) - elif name == "error": - # Error and stderr won't be both populated if `error` - # is only used for fatal dispatcher-executor interaction errors - error_response = handler.read_from_text(result["error_filename"]) - stderr_response = handler.read_from_text(result["stderr_filename"]) - response = stderr_response + error_response - return ElectronFileResponse(data=response) - else: - return ElectronFileResponse(data=None) - else: + if result is None: raise HTTPException( status_code=400, detail=[ @@ -186,6 +192,51 @@ def get_electron_file(dispatch_id: uuid.UUID, electron_id: int, name: ElectronFi } ], ) + handler = FileHandler(result["storage_path"]) + if name == "inputs": + response, python_object = get_electron_inputs( + dispatch_id=dispatch_id, electron_id=electron_id + ) + return ElectronFileResponse(data=str(response), python_object=str(python_object)) + elif name == "function_string": + response = handler.read_from_text(result["function_string_filename"]) + return ElectronFileResponse(data=response) + elif name == "function": + response, python_object = handler.read_from_serialized(result["function_filename"]) + return ElectronFileResponse(data=response, python_object=python_object) + elif name == "executor": + executor_name = result["executor"] + executor_data = json.loads(result["executor_data"]) + return ElectronExecutorResponse( + executor_name=executor_name, executor_details=executor_data + ) + elif name == "result": + response, python_object = handler.read_from_serialized(result["results_filename"]) + return ElectronFileResponse(data=str(response), python_object=python_object) + elif name == "value": + response = handler.read_from_serialized(result["value_filename"]) + return ElectronFileResponse(data=str(response)) + elif name == "stdout": + response = handler.read_from_text(result["stdout_filename"]) + return ElectronFileResponse(data=response) + elif name == "deps": + response = handler.read_from_serialized(result["deps_filename"]) + return ElectronFileResponse(data=response) + elif name == "call_before": + response = handler.read_from_serialized(result["call_before_filename"]) + return ElectronFileResponse(data=response) + elif name == "call_after": + response = handler.read_from_serialized(result["call_after_filename"]) + return ElectronFileResponse(data=response) + elif name == "error": + # Error and stderr won't be both populated if `error` + # is only used for fatal dispatcher-executor interaction errors + error_response = handler.read_from_text(result["error_filename"]) + stderr_response = handler.read_from_text(result["stderr_filename"]) + response = stderr_response + error_response + return ElectronFileResponse(data=response) + else: + return ElectronFileResponse(data=None) @routes.get("/{dispatch_id}/electron/{electron_id}/jobs", response_model=List[Job]) diff --git a/covalent_ui/api/v1/routes/end_points/lattice_route.py b/covalent_ui/api/v1/routes/end_points/lattice_route.py index ea4f2b943..06b1cc310 100644 --- a/covalent_ui/api/v1/routes/end_points/lattice_route.py +++ b/covalent_ui/api/v1/routes/end_points/lattice_route.py @@ -16,6 +16,7 @@ """Lattice route""" +import json import uuid from typing import Optional @@ -92,38 +93,8 @@ def get_lattice_files(dispatch_id: uuid.UUID, name: LatticeFileOutput): with Session(db.engine) as session: lattice = Lattices(session) lattice_data = lattice.get_lattices_id_storage_file(dispatch_id) - if lattice_data is not None: - handler = FileHandler(lattice_data["directory"]) - if name == "result": - response, python_object = handler.read_from_pickle( - lattice_data["results_filename"] - ) - return LatticeFileResponse(data=str(response), python_object=python_object) - if name == "inputs": - response, python_object = handler.read_from_pickle(lattice_data["inputs_filename"]) - return LatticeFileResponse(data=response, python_object=python_object) - elif name == "function_string": - response = handler.read_from_text(lattice_data["function_string_filename"]) - return LatticeFileResponse(data=response) - elif name == "executor": - executor_name = lattice_data["executor"] - return LatticeExecutorResponse(executor_name=executor_name) - elif name == "workflow_executor": - executor_name = lattice_data["workflow_executor"] - return LatticeWorkflowExecutorResponse( - workflow_executor_name=executor_name, - ) - elif name == "error": - response = handler.read_from_text(lattice_data["error_filename"]) - return LatticeFileResponse(data=response) - elif name == "function": - response, python_object = handler.read_from_pickle( - lattice_data["function_filename"] - ) - return LatticeFileResponse(data=response, python_object=python_object) - elif name == "transport_graph": - return LatticeFileResponse() - else: + + if lattice_data is None: raise HTTPException( status_code=400, detail=[ @@ -135,6 +106,50 @@ def get_lattice_files(dispatch_id: uuid.UUID, name: LatticeFileOutput): ], ) + handler = FileHandler(lattice_data["directory"]) + if name == "result": + response, python_object = handler.read_from_serialized( + lattice_data["results_filename"] + ) + return LatticeFileResponse(data=str(response), python_object=python_object) + + if name == "inputs": + response, python_object = handler.read_from_serialized(lattice_data["inputs_filename"]) + return LatticeFileResponse(data=response, python_object=python_object) + + elif name == "function_string": + response = handler.read_from_text(lattice_data["function_string_filename"]) + return LatticeFileResponse(data=response) + + elif name == "executor": + executor_name = lattice_data["executor"] + executor_data = json.loads(lattice_data["executor_data"]) + + return LatticeExecutorResponse( + executor_name=executor_name, executor_details=executor_data + ) + + elif name == "workflow_executor": + executor_name = lattice_data["workflow_executor"] + executor_data = json.loads(lattice_data["workflow_executor_data"]) + + return LatticeWorkflowExecutorResponse( + workflow_executor_name=executor_name, workflow_executor_details=executor_data + ) + + elif name == "error": + response = handler.read_from_text(lattice_data["error_filename"]) + return LatticeFileResponse(data=response) + + elif name == "function": + response, python_object = handler.read_from_serialized( + lattice_data["function_filename"] + ) + return LatticeFileResponse(data=response, python_object=python_object) + + else: + return LatticeFileResponse(data=None) + @routes.get("/{dispatch_id}/sublattices", response_model=SubLatticeDetailResponse) def get_sub_lattice( diff --git a/covalent_ui/api/v1/routes/routes.py b/covalent_ui/api/v1/routes/routes.py index eedd514dc..3aa5ae706 100644 --- a/covalent_ui/api/v1/routes/routes.py +++ b/covalent_ui/api/v1/routes/routes.py @@ -18,7 +18,7 @@ from fastapi import APIRouter -from covalent_dispatcher._service import app +from covalent_dispatcher._service import app, assets from covalent_dispatcher._triggers_app.app import router as tr_router from covalent_ui.api.v1.routes.end_points import ( electron_routes, @@ -39,6 +39,8 @@ routes.include_router(electron_routes.routes, prefix=dispatch_prefix, tags=["Electrons"]) routes.include_router(settings_routes.routes, prefix="/api/v1", tags=["Settings"]) routes.include_router(logs_route.routes, prefix="/api/v1/logs", tags=["Logs"]) -routes.include_router(app.router, prefix="/api", tags=["dispatcher"]) -routes.include_router(app.router, prefix="/api", tags=["dispatcher"]) routes.include_router(tr_router, prefix="/api", tags=["Triggers"]) +routes.include_router(app.router, prefix="/api/v2", tags=["Dispatcher"]) +routes.include_router(assets.router, prefix="/api/v2", tags=["Assets"]) +# This will be enabled in the next patch +# routes.include_router(runnersvc.router, prefix="/api/v1", tags=["Runner"]) diff --git a/covalent_ui/api/v1/utils/file_handle.py b/covalent_ui/api/v1/utils/file_handle.py index 12ba66629..7a947ca49 100644 --- a/covalent_ui/api/v1/utils/file_handle.py +++ b/covalent_ui/api/v1/utils/file_handle.py @@ -22,6 +22,7 @@ import cloudpickle as pickle from covalent._workflow.transport import TransportableObject, _TransportGraph +from covalent_dispatcher._dal.asset import local_store def transportable_object(obj): @@ -103,6 +104,14 @@ def read_from_pickle(self, path): except Exception: return None + def read_from_serialized(self, path): + """Return data from serialized object""" + try: + deserialized_obj = local_store.load_file(self.location, path) + return validate_data(deserialized_obj) + except Exception as e: + return None + def read_from_text(self, path): """Return data from text file""" try: diff --git a/requirements.txt b/requirements.txt index 35a113c83..26d836bd1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ orjson>=3.8.10 pennylane>=0.31.1 psutil>=5.9.0 pydantic>=2.1.1 +python-multipart>=0.0.6 python-socketio>=5.7.1 requests>=2.24.0 rich>=12.0.0,<=13.3.5 diff --git a/tests/__init__.py b/tests/__init__.py index e69de29bb..cfc23bfdf 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/covalent_dispatcher_tests/__init__.py b/tests/covalent_dispatcher_tests/__init__.py index e69de29bb..cfc23bfdf 100644 --- a/tests/covalent_dispatcher_tests/__init__.py +++ b/tests/covalent_dispatcher_tests/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/covalent_dispatcher_tests/_cli/__init__.py b/tests/covalent_dispatcher_tests/_cli/__init__.py index e69de29bb..cfc23bfdf 100644 --- a/tests/covalent_dispatcher_tests/_cli/__init__.py +++ b/tests/covalent_dispatcher_tests/_cli/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/covalent_dispatcher_tests/_core/__init__.py b/tests/covalent_dispatcher_tests/_core/__init__.py index e69de29bb..cfc23bfdf 100644 --- a/tests/covalent_dispatcher_tests/_core/__init__.py +++ b/tests/covalent_dispatcher_tests/_core/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/covalent_dispatcher_tests/_core/data_manager_test.py b/tests/covalent_dispatcher_tests/_core/data_manager_test.py new file mode 100644 index 000000000..847c0b152 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/data_manager_test.py @@ -0,0 +1,472 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the core functionality of the dispatcher. +""" + + +from unittest.mock import MagicMock + +import pytest + +import covalent as ct +from covalent._results_manager import Result +from covalent._shared_files.util_classes import RESULT_STATUS +from covalent._workflow.lattice import Lattice +from covalent_dispatcher._core.data_manager import ( + ResultSchema, + _legacy_sublattice_dispatch_helper, + _make_sublattice_dispatch, + _redirect_lattice, + _update_parent_electron, + ensure_dispatch, + finalize_dispatch, + get_result_object, + make_dispatch, + persist_result, + update_node_result, +) +from covalent_dispatcher._db.datastore import DataStore + +TEST_RESULTS_DIR = "/tmp/results" + + +@pytest.fixture +def test_db(): + """Instantiate and return an in-memory database.""" + + return DataStore( + db_URL="sqlite+pysqlite:///:memory:", + initialize_db=True, + ) + + +def get_mock_result() -> Result: + """Construct a mock result object corresponding to a lattice.""" + + import sys + + @ct.electron(executor="local") + def task(x): + print(f"stdout: {x}") + print("Error!", file=sys.stderr) + return x + + @ct.lattice + def pipeline(x): + res1 = task(x) + res2 = task(res1) + return res2 + + pipeline.build_graph(x="absolute") + received_workflow = Lattice.deserialize_from_json(pipeline.serialize_to_json()) + result_object = Result(received_workflow, "pipeline_workflow") + + return result_object + + +@pytest.mark.parametrize( + "node_status,node_type,output_status,sub_id", + [ + (Result.COMPLETED, "function", Result.COMPLETED, ""), + (Result.FAILED, "function", Result.FAILED, ""), + (Result.CANCELLED, "function", Result.CANCELLED, ""), + (Result.COMPLETED, "sublattice", RESULT_STATUS.DISPATCHING, ""), + (Result.COMPLETED, "sublattice", RESULT_STATUS.COMPLETED, "asdf"), + (Result.FAILED, "sublattice", Result.FAILED, ""), + (Result.CANCELLED, "sublattice", Result.CANCELLED, ""), + ], +) +@pytest.mark.asyncio +async def test_update_node_result(mocker, node_status, node_type, output_status, sub_id): + """Check that update_node_result pushes the correct status updates""" + + result_object = MagicMock() + result_object.dispatch_id = "test_update_node_result" + + node_result = {"node_id": 0, "status": node_status} + mock_update_node = mocker.patch( + "covalent_dispatcher._dal.result.Result._update_node", return_value=True + ) + node_info = {"type": node_type, "sub_dispatch_id": sub_id, "status": Result.NEW_OBJ} + mocker.patch("covalent_dispatcher._core.data_manager.electron.get", return_value=node_info) + + mock_notify = mocker.patch( + "covalent_dispatcher._core.dispatcher.notify_node_status", + ) + + mock_get_result = mocker.patch( + "covalent_dispatcher._core.data_modules.electron.get_result_object", + return_value=result_object, + ) + + mock_make_dispatch = mocker.patch( + "covalent_dispatcher._core.data_manager._make_sublattice_dispatch", + return_value=sub_id, + ) + + await update_node_result(result_object.dispatch_id, node_result) + detail = {"sub_dispatch_id": sub_id} if sub_id else {} + mock_notify.assert_awaited_with(result_object.dispatch_id, 0, output_status, detail) + + if node_status == Result.COMPLETED and node_type == "sublattice" and not sub_id: + mock_make_dispatch.assert_awaited() + else: + mock_make_dispatch.assert_not_awaited() + + +@pytest.mark.parametrize( + "node_status,old_status,valid_update", + [ + (Result.COMPLETED, Result.RUNNING, True), + (Result.COMPLETED, Result.COMPLETED, False), + (Result.FAILED, Result.COMPLETED, False), + ], +) +@pytest.mark.asyncio +async def test_update_node_result_filters_illegal_updates( + mocker, node_status, old_status, valid_update +): + """Check that update_node_result pushes the correct status updates""" + + result_object = MagicMock() + result_object.dispatch_id = "test_update_node_result_filters_illegal_updates" + result_object._update_node = MagicMock(return_value=valid_update) + node_result = {"node_id": 0, "status": node_status} + node_info = {"type": "function", "sub_dispatch_id": "", "status": old_status} + mocker.patch("covalent_dispatcher._core.data_manager.electron.get", return_value=node_info) + + mock_notify = mocker.patch( + "covalent_dispatcher._core.dispatcher.notify_node_status", + ) + + mock_get_result = mocker.patch( + "covalent_dispatcher._core.data_modules.electron.get_result_object", + return_value=result_object, + ) + + mocker.patch( + "covalent_dispatcher._core.data_manager._make_sublattice_dispatch", + ) + + await update_node_result(result_object.dispatch_id, node_result) + + if not valid_update: + mock_notify.assert_not_awaited() + else: + mock_notify.assert_awaited() + + +@pytest.mark.asyncio +async def test_update_node_result_handles_keyerrors(mocker): + """Check that update_node_result handles invalid dispatch id or node id""" + + result_object = MagicMock() + result_object.dispatch_id = "test_update_node_result_handles_keyerrors" + node_result = {"node_id": -5, "status": RESULT_STATUS.COMPLETED} + mock_update_node = mocker.patch("covalent_dispatcher._dal.result.Result._update_node") + node_info = {"type": "function", "sub_dispatch_id": "", "status": RESULT_STATUS.RUNNING} + mocker.patch("covalent_dispatcher._core.data_manager.electron.get", side_effect=KeyError()) + + mock_notify = mocker.patch( + "covalent_dispatcher._core.dispatcher.notify_node_status", + ) + + await update_node_result(result_object.dispatch_id, node_result) + + mock_notify.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_update_node_result_handles_subl_exceptions(mocker): + """Check that update_node_result pushes the correct status updates""" + + result_object = MagicMock() + result_object.dispatch_id = "test_update_node_result_handles_subl_exception" + + node_type = "sublattice" + sub_id = "" + node_result = {"node_id": 0, "status": Result.COMPLETED} + mock_update_node = mocker.patch("covalent_dispatcher._dal.result.Result._update_node") + node_info = {"type": node_type, "sub_dispatch_id": sub_id, "status": Result.NEW_OBJ} + mocker.patch("covalent_dispatcher._core.data_manager.electron.get", return_value=node_info) + mock_notify = mocker.patch( + "covalent_dispatcher._core.dispatcher.notify_node_status", + ) + + mock_get_result = mocker.patch( + "covalent_dispatcher._core.data_modules.electron.get_result_object", + return_value=result_object, + ) + + mock_make_dispatch = mocker.patch( + "covalent_dispatcher._core.data_manager._make_sublattice_dispatch", + side_effect=RuntimeError(), + ) + + mocker.patch("traceback.TracebackException.from_exception", return_value="error") + + await update_node_result(result_object.dispatch_id, node_result) + output_status = Result.FAILED + mock_notify.assert_awaited_with(result_object.dispatch_id, 0, output_status, {}) + mock_make_dispatch.assert_awaited() + + +@pytest.mark.asyncio +async def test_update_node_result_handles_db_exceptions(mocker): + """Check that update_node_result handles db write failures""" + + result_object = MagicMock() + result_object.dispatch_id = "test_update_node_result_handles_db_exceptions" + result_object._update_node = MagicMock(side_effect=RuntimeError()) + mock_get_result = mocker.patch( + "covalent_dispatcher._core.data_modules.electron.get_result_object", + return_value=result_object, + ) + mock_notify = mocker.patch( + "covalent_dispatcher._core.dispatcher.notify_node_status", + ) + + node_result = {"node_id": 0, "status": Result.COMPLETED} + await update_node_result(result_object.dispatch_id, node_result) + + mock_notify.assert_awaited_with(result_object.dispatch_id, 0, Result.FAILED, {}) + + +@pytest.mark.asyncio +async def test_make_dispatch(mocker): + res = MagicMock() + dispatch_id = "test_make_dispatch" + mock_resubmit_lattice = mocker.patch( + "covalent_dispatcher._core.data_manager._redirect_lattice", return_value=dispatch_id + ) + json_lattice = '{"workflow_function": "asdf"}' + assert dispatch_id == await make_dispatch(json_lattice) + + +def test_get_result_object(mocker): + result_object = MagicMock() + result_object.dispatch_id = "dispatch_1" + mocker.patch( + "covalent_dispatcher._core.data_manager.get_result_object_from_db", + return_value=result_object, + ) + + dispatch_id = result_object.dispatch_id + assert get_result_object(dispatch_id) is result_object + + +@pytest.mark.parametrize("stateless", [False, True]) +def test_unregister_result_object(mocker, stateless): + dispatch_id = "test_unregister_result_object" + finalize_dispatch(dispatch_id) + + +@pytest.mark.asyncio +async def test_persist_result(mocker): + dispatch_id = "test_persist_result" + mock_update_parent = mocker.patch( + "covalent_dispatcher._core.data_manager._update_parent_electron" + ) + + await persist_result(dispatch_id) + mock_update_parent.assert_awaited_with(dispatch_id) + + +@pytest.mark.parametrize( + "sub_status,mapped_status", + [(Result.COMPLETED, Result.COMPLETED), (Result.POSTPROCESSING_FAILED, Result.FAILED)], +) +@pytest.mark.asyncio +async def test_update_parent_electron(mocker, sub_status, mapped_status): + import datetime + + mock_res = MagicMock() + mock_res.dispatch_id = "test_update_parent_electron" + parent_result_obj = MagicMock() + sub_result_obj = MagicMock() + eid = 5 + + parent_result_obj.dispatch_id = mock_res.dispatch_id + + parent_dispatch_id = (parent_result_obj.dispatch_id,) + parent_node_id = 2 + sub_result_obj._electron_id = eid + sub_result_obj.status = sub_status + sub_result_obj._result = 42 + sub_result_obj._error = "" + sub_result_obj._end_time = datetime.datetime.now() + + mock_node_result = { + "node_id": parent_node_id, + "end_time": sub_result_obj._end_time, + "status": mapped_status, + "output": sub_result_obj._result, + "error": sub_result_obj._error, + } + + mock_gen_node_result = mocker.patch( + "covalent_dispatcher._core.data_manager.generate_node_result", + return_value=mock_node_result, + ) + + mock_update_node = mocker.patch("covalent_dispatcher._core.data_manager.update_node_result") + mock_resolve_eid = mocker.patch( + "covalent_dispatcher._core.data_manager.resolve_electron_id", + return_value=(parent_dispatch_id, parent_node_id), + ) + mock_get_res = mocker.patch( + "covalent_dispatcher._core.data_modules.dispatch.get_result_object", + return_value=parent_result_obj, + ) + + mock_get_res = mocker.patch( + "covalent_dispatcher._core.data_manager.get_result_object", + return_value=parent_result_obj, + ) + + await _update_parent_electron(sub_result_obj) + + mock_get_res.assert_called_with(parent_dispatch_id) + mock_update_node.assert_awaited_with(parent_result_obj.dispatch_id, mock_node_result) + + +@pytest.mark.asyncio +async def test_make_sublattice_dispatch(mocker): + node_result = {"node_id": 0, "status": Result.COMPLETED} + output_json = "lattice_json" + + mock_node = MagicMock() + mock_node._electron_id = 5 + + mock_bg_output = MagicMock() + mock_bg_output.object_string = output_json + + mock_node.get_value = MagicMock(return_value=mock_bg_output) + + mock_manifest = MagicMock() + mock_manifest.metadata.dispatch_id = "mock_sublattice_dispatch" + + result_object = MagicMock() + result_object.dispatch_id = "dispatch" + result_object.lattice.transport_graph.get_node = MagicMock(return_value=mock_node) + mocker.patch( + "covalent_dispatcher._core.data_manager.get_result_object", + return_value=result_object, + ) + mocker.patch("covalent._shared_files.schemas.result.ResultSchema.parse_raw") + mocker.patch( + "covalent_dispatcher._core.data_manager.manifest_importer.import_manifest", + return_value=mock_manifest, + ) + + mock_make_dispatch = mocker.patch("covalent_dispatcher._core.data_manager.make_dispatch") + sub_dispatch_id = await _make_sublattice_dispatch(result_object.dispatch_id, node_result) + + assert sub_dispatch_id == mock_manifest.metadata.dispatch_id + + +@pytest.mark.asyncio +async def test_make_monolithic_sublattice_dispatch(mocker): + """Check that JSON sublattices are handled correctly""" + + dispatch_id = "test_make_monolithic_sublattice_dispatch" + + def _mock_helper(dispatch_id, node_result): + return ResultSchema.parse_raw("invalid_input") + + mocker.patch( + "covalent_dispatcher._core.data_manager._make_sublattice_dispatch_helper", _mock_helper + ) + + json_lattice = "json_lattice" + parent_electron_id = 5 + mock_legacy_subl_helper = mocker.patch( + "covalent_dispatcher._core.data_manager._legacy_sublattice_dispatch_helper", + return_value=(json_lattice, parent_electron_id), + ) + sub_dispatch_id = "sub_dispatch" + mock_make_dispatch = mocker.patch( + "covalent_dispatcher._core.data_manager.make_dispatch", return_value=sub_dispatch_id + ) + + assert sub_dispatch_id == await _make_sublattice_dispatch(dispatch_id, {}) + + mock_make_dispatch.assert_awaited_with(json_lattice, dispatch_id, parent_electron_id) + + +def test_legacy_sublattice_dispatch_helper(mocker): + dispatch_id = "test_legacy_sublattice_dispatch_helper" + res_obj = MagicMock() + bg_output = MagicMock() + bg_output.object_string = "json_sublattice" + parent_node = MagicMock() + parent_node._electron_id = 2 + parent_node.get_value = MagicMock(return_value=bg_output) + res_obj.lattice.transport_graph.get_node = MagicMock(return_value=parent_node) + node_result = {"node_id": 0} + + mocker.patch("covalent_dispatcher._core.data_manager.get_result_object", return_value=res_obj) + + assert _legacy_sublattice_dispatch_helper(dispatch_id, node_result) == ("json_sublattice", 2) + + +def test_redirect_lattice(mocker): + """Test redirecting JSON lattices to new DAL.""" + + dispatch_id = "test_redirect_lattice" + mock_manifest = MagicMock() + mock_manifest.metadata.dispatch_id = dispatch_id + mock_prepare_manifest = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.prepare_manifest", + return_value=mock_manifest, + ) + mock_import_manifest = mocker.patch( + "covalent_dispatcher._core.data_manager.manifest_importer._import_manifest", + return_value=mock_manifest, + ) + + mock_pull = mocker.patch( + "covalent_dispatcher._core.data_manager.manifest_importer._pull_assets", + ) + + mock_lat_deserialize = mocker.patch( + "covalent_dispatcher._core.data_manager.Lattice.deserialize_from_json" + ) + + json_lattice = "json_lattice" + + parent_dispatch_id = "parent_dispatch" + parent_electron_id = 3 + + assert ( + _redirect_lattice(json_lattice, parent_dispatch_id, parent_electron_id, None) + == dispatch_id + ) + + mock_import_manifest.assert_called_with(mock_manifest, parent_dispatch_id, parent_electron_id) + mock_pull.assert_called_with(mock_manifest) + + +@pytest.mark.asyncio +async def test_ensure_dispatch(mocker): + mock_ensure_run_once = mocker.patch( + "covalent_dispatcher._core.data_manager.SRVResult.ensure_run_once", + return_value=True, + ) + assert await ensure_dispatch("test_ensure_dispatch") is True + mock_ensure_run_once.assert_called_with("test_ensure_dispatch") diff --git a/tests/covalent_dispatcher_tests/_core/data_modules/asset_manager_db_integration_test.py b/tests/covalent_dispatcher_tests/_core/data_modules/asset_manager_db_integration_test.py new file mode 100644 index 000000000..efe75dddd --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/data_modules/asset_manager_db_integration_test.py @@ -0,0 +1,165 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for DB-backed Result""" + + +import os +import tempfile + +import pytest + +import covalent as ct +from covalent._results_manager import Result as SDKResult +from covalent._shared_files.schemas.asset import AssetUpdate +from covalent._workflow.lattice import Lattice as SDKLattice +from covalent_dispatcher._core.data_modules import asset_manager as am +from covalent_dispatcher._dal.result import Result, get_result_object +from covalent_dispatcher._db import update +from covalent_dispatcher._db.datastore import DataStore + +TEMP_RESULTS_DIR = os.environ.get("COVALENT_DATA_DIR") or ct.get_config("dispatcher.results_dir") + + +@pytest.fixture +def test_db(): + """Instantiate and return an in-memory database.""" + + return DataStore( + db_URL="sqlite+pysqlite:///:memory:", + initialize_db=True, + ) + + +def get_mock_result() -> SDKResult: + """Construct a mock result object corresponding to a lattice.""" + + @ct.electron(executor="local") + def task(x): + return x + + @ct.lattice(deps_bash=ct.DepsBash(["ls"])) + def workflow(x): + res1 = task(x) + res2 = task(res1) + return res2 + + workflow.build_graph(x=1) + received_workflow = SDKLattice.deserialize_from_json(workflow.serialize_to_json()) + result_object = SDKResult(received_workflow, "mock_dispatch") + + return result_object + + +def get_mock_srvresult(sdkres, test_db) -> Result: + sdkres._initialize_nodes() + + update.persist(sdkres) + + return get_result_object(sdkres.dispatch_id) + + +@pytest.mark.asyncio +async def test_upload_asset_for_nodes(test_db, mocker): + sdkres = get_mock_result() + sdkres._initialize_nodes() + + mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + + srvres = get_mock_srvresult(sdkres, test_db) + + srvres.lattice.transport_graph.set_node_value(0, "stdout", "Hello!\n") + srvres.lattice.transport_graph.set_node_value(2, "stdout", "Bye!\n") + + with tempfile.NamedTemporaryFile("w", delete=True, suffix=".txt") as temp: + dest_path_0 = temp.name + + with tempfile.NamedTemporaryFile("w", delete=True, suffix=".txt") as temp: + dest_path_2 = temp.name + + dest_uri_0 = os.path.join("file://", dest_path_0) + dest_uri_2 = os.path.join("file://", dest_path_2) + + await am.upload_asset_for_nodes(srvres.dispatch_id, "stdout", {0: dest_uri_0, 2: dest_uri_2}) + + with open(dest_path_0, "r") as f: + assert f.read() == "Hello!\n" + + with open(dest_path_2, "r") as f: + assert f.read() == "Bye!\n" + + os.unlink(dest_path_0) + os.unlink(dest_path_2) + + +@pytest.mark.asyncio +async def test_download_assets_for_node(test_db, mocker): + sdkres = get_mock_result() + sdkres._initialize_nodes() + + mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + + mock_update_assets = mocker.patch("covalent_dispatcher._dal.electron.Electron.update_assets") + + srvres = get_mock_srvresult(sdkres, test_db) + + with tempfile.NamedTemporaryFile("w", delete=False, suffix=".txt") as temp: + src_path_stdout = temp.name + temp.write("Hello!\n") + + with tempfile.NamedTemporaryFile("w", delete=False, suffix=".txt") as temp: + src_path_stderr = temp.name + temp.write("Bye!\n") + + src_uri_stdout = os.path.join("file://", src_path_stdout) + src_uri_stderr = os.path.join("file://", src_path_stderr) + + assets = { + "output": { + "remote_uri": "", + }, + "stdout": {"remote_uri": src_uri_stdout, "size": None, "digest": "0af23"}, + "stderr": { + "remote_uri": src_uri_stderr, + }, + } + assets = {k: AssetUpdate(**v) for k, v in assets.items()} + + expected_update = { + "output": { + "remote_uri": "", + }, + "stdout": { + "remote_uri": src_uri_stdout, + "digest": "0af23", + }, + "stderr": { + "remote_uri": src_uri_stderr, + }, + } + await am.download_assets_for_node( + srvres.dispatch_id, + 0, + assets, + ) + + mock_update_assets.assert_called_with(expected_update) + assert srvres.lattice.transport_graph.get_node_value(0, "stdout") == "Hello!\n" + assert srvres.lattice.transport_graph.get_node_value(0, "stderr") == "Bye!\n" diff --git a/tests/covalent_dispatcher_tests/_core/data_modules/dispatch_test.py b/tests/covalent_dispatcher_tests/_core/data_modules/dispatch_test.py new file mode 100644 index 000000000..8caf15e4a --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/data_modules/dispatch_test.py @@ -0,0 +1,69 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the querying and updating dispatches +""" + + +from unittest.mock import MagicMock + +import pytest + +from covalent_dispatcher._core.data_modules import dispatch + + +@pytest.mark.asyncio +async def test_get(mocker): + dispatch_id = "test_get_incoming_edges" + + mock_retval = MagicMock() + mock_result_obj = MagicMock() + mock_result_obj.get_values = MagicMock(return_value=mock_retval) + mocker.patch( + "covalent_dispatcher._core.data_modules.dispatch.get_result_object", + return_value=mock_result_obj, + ) + + assert mock_retval == await dispatch.get(dispatch_id, keys=["status"]) + + +@pytest.mark.asyncio +async def test_get_incomplete_tasks(mocker): + dispatch_id = "test_get_node_successors" + mock_retval = MagicMock() + mock_result_obj = MagicMock() + mock_result_obj._get_incomplete_nodes = MagicMock(return_value=mock_retval) + mocker.patch( + "covalent_dispatcher._core.data_modules.dispatch.get_result_object", + return_value=mock_result_obj, + ) + + assert mock_retval == await dispatch.get_incomplete_tasks(dispatch_id) + + +@pytest.mark.asyncio +async def test_update(mocker): + dispatch_id = "test_update_dispatch" + mock_result_obj = MagicMock() + mocker.patch( + "covalent_dispatcher._core.data_modules.dispatch.get_result_object", + return_value=mock_result_obj, + ) + + await dispatch.update(dispatch_id, {"status": "COMPLETED"}) + + mock_result_obj._update_dispatch.assert_called() diff --git a/tests/covalent_dispatcher_tests/_core/data_modules/graph_test.py b/tests/covalent_dispatcher_tests/_core/data_modules/graph_test.py new file mode 100644 index 000000000..a4cbc1253 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/data_modules/graph_test.py @@ -0,0 +1,94 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the graph querying functions +""" + + +from unittest.mock import MagicMock + +import pytest + +from covalent_dispatcher._core.data_modules import graph + + +@pytest.mark.asyncio +async def test_get_incoming_edges(mocker): + dispatch_id = "test_get_incoming_edges" + node_id = 0 + + mock_result_obj = MagicMock() + mock_return_val = [{"source": 1, "target": 0, "attrs": {"param_type": "arg"}}] + mock_result_obj.lattice.transport_graph.get_incoming_edges = MagicMock( + return_value=mock_return_val + ) + + mocker.patch( + "covalent_dispatcher._core.data_modules.graph.get_result_object", + return_value=mock_result_obj, + ) + + assert mock_return_val == await graph.get_incoming_edges(dispatch_id, node_id) + + +@pytest.mark.asyncio +async def test_get_node_successors(mocker): + dispatch_id = "test_get_node_successors" + node_id = 0 + + mock_result_obj = MagicMock() + mock_return_val = {"node_id": 0, "status": "NEW_OBJECT"} + mock_result_obj.lattice.transport_graph.get_successors = MagicMock( + return_value=mock_return_val + ) + mocker.patch( + "covalent_dispatcher._core.data_modules.graph.get_result_object", + return_value=mock_result_obj, + ) + assert mock_return_val == await graph.get_node_successors(dispatch_id, node_id) + + +@pytest.mark.asyncio +async def test_get_node_links(mocker): + dispatch_id = "test_get_node_links" + + mock_result_obj = MagicMock() + + mock_return_val = {"nodes": [0, 1], "links": [(1, 0, 0)]} + mocker.patch("networkx.readwrite.node_link_data", return_value=mock_return_val) + mocker.patch( + "covalent_dispatcher._core.data_modules.graph.get_result_object", + return_value=mock_result_obj, + ) + + assert mock_return_val == await graph.get_nodes_links(dispatch_id) + + +@pytest.mark.asyncio +async def test_get_nodes(mocker): + dispatch_id = "test_get_nodes" + mock_result_obj = MagicMock() + + g = MagicMock() + mock_result_obj.lattice.transport_graph.get_internal_graph_copy = MagicMock(return_value=g) + g.nodes = [1, 2, 3] + mocker.patch( + "covalent_dispatcher._core.data_modules.graph.get_result_object", + return_value=mock_result_obj, + ) + + assert [1, 2, 3] == await graph.get_nodes(dispatch_id) diff --git a/tests/covalent_dispatcher_tests/_core/data_modules/importer_test.py b/tests/covalent_dispatcher_tests/_core/data_modules/importer_test.py new file mode 100644 index 000000000..21b410a53 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/data_modules/importer_test.py @@ -0,0 +1,120 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the importer entry point""" + +from unittest.mock import MagicMock + +import pytest + +from covalent_dispatcher._core.data_modules.importer import ( + _copy_assets, + import_derived_manifest, + import_manifest, +) + + +@pytest.mark.asyncio +async def test_import_manifest(mocker): + mock_manifest = MagicMock() + mock_manifest.metadata.dispatch_id = None + + mock_srvres = MagicMock() + mocker.patch( + "covalent_dispatcher._dal.result.Result.from_dispatch_id", return_value=mock_srvres + ) + + mock_asset = MagicMock() + mock_asset.remote_uri = "s3://mybucket/object.pkl" + + mocker.patch( + "covalent_dispatcher._core.data_modules.importer.import_result", return_value=mock_manifest + ) + + mock_assets = {"lattice": [mock_asset], "nodes": [mock_asset]} + mocker.patch( + "covalent_dispatcher._core.data_modules.importer._get_all_assets", return_value=mock_assets + ) + + return_manifest = await import_manifest(mock_manifest, None, None) + + assert return_manifest.metadata.dispatch_id is not None + + +@pytest.mark.asyncio +async def test_import_sublattice_manifest(mocker): + mock_manifest = MagicMock() + mock_manifest.metadata.dispatch_id = None + + mock_parent_res = MagicMock() + mock_parent_res.root_dispatch_id = "parent_dispatch_id" + + mock_asset = MagicMock() + mock_asset.remote_uri = "s3://mybucket/object.pkl" + + mock_srvres = MagicMock() + mocker.patch( + "covalent_dispatcher._dal.result.Result.from_dispatch_id", return_value=mock_parent_res + ) + + mocker.patch( + "covalent_dispatcher._core.data_modules.importer.import_result", return_value=mock_manifest + ) + + mock_assets = {"lattice": [MagicMock()], "nodes": [MagicMock()]} + + return_manifest = await import_manifest(mock_manifest, "parent_dispatch_id", None) + + assert return_manifest.metadata.dispatch_id is not None + assert return_manifest.metadata.root_dispatch_id == "parent_dispatch_id" + + +@pytest.mark.asyncio +async def test_import_derived_manifest(mocker): + mock_manifest = MagicMock() + mock_manifest.metadata.dispatch_id = "test_import_derived_manifest" + + mock_import_manifest = mocker.patch( + "covalent_dispatcher._core.data_modules.importer._import_manifest", + ) + + mock_copy = mocker.patch( + "covalent_dispatcher._core.data_modules.importer._copy_assets", + ) + + mock_handle_redispatch = mocker.patch( + "covalent_dispatcher._core.data_modules.importer.handle_redispatch", + return_value=(mock_manifest, []), + ) + + mock_pull = mocker.patch( + "covalent_dispatcher._core.data_modules.importer._pull_assets", + ) + + mock_manifest = {} + await import_derived_manifest(mock_manifest, "parent_dispatch", True) + + mock_import_manifest.assert_called() + mock_pull.assert_called() + mock_handle_redispatch.assert_called() + mock_copy.assert_called_with([]) + + +def test_copy_assets(mocker): + mock_copy = mocker.patch("covalent_dispatcher._core.data_modules.importer.copy_asset") + + _copy_assets([("src", "dest")]) + mock_copy.assert_called_with("src", "dest") diff --git a/tests/covalent_dispatcher_tests/_core/data_modules/job_manager_test.py b/tests/covalent_dispatcher_tests/_core/data_modules/job_manager_test.py index 77ae5bf85..05809a5b4 100644 --- a/tests/covalent_dispatcher_tests/_core/data_modules/job_manager_test.py +++ b/tests/covalent_dispatcher_tests/_core/data_modules/job_manager_test.py @@ -21,8 +21,8 @@ from covalent_dispatcher._core.data_modules.job_manager import ( get_jobs_metadata, set_cancel_requested, - set_cancel_result, set_job_handle, + set_job_status, ) @@ -100,9 +100,7 @@ async def test_set_job_handle(mocker): mock_update.assert_called_with([{"job_id": 1, "job_handle": "12356"}]) -@pytest.mark.asyncio -@pytest.mark.parametrize("cancel_requested", [True, False]) -async def test_set_cancel_result(cancel_requested, mocker): +async def test_set_job_status(mocker): """ Test requesting a task to be cancelled """ @@ -111,5 +109,5 @@ async def test_set_cancel_result(cancel_requested, mocker): mock_update = mocker.patch( "covalent_dispatcher._core.data_modules.job_manager.update_job_records" ) - await set_cancel_result("dispatch", 0, cancel_status=cancel_requested) - mock_update.assert_called_with([{"job_id": 1, "cancel_successful": cancel_requested}]) + await set_job_status("dispatch", 0, status="COMPLETED") + mock_update.assert_called_with([{"job_id": 1, "status": "COMPLEtED"}]) diff --git a/tests/covalent_dispatcher_tests/_core/data_modules/lattice_query_test.py b/tests/covalent_dispatcher_tests/_core/data_modules/lattice_query_test.py new file mode 100644 index 000000000..ab4add428 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/data_modules/lattice_query_test.py @@ -0,0 +1,41 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the querying lattices +""" + + +from unittest.mock import MagicMock + +import pytest + +from covalent_dispatcher._core.data_modules import lattice + + +@pytest.mark.asyncio +async def test_get(mocker): + dispatch_id = "test_get" + + mock_retval = MagicMock() + mock_result_obj = MagicMock() + mock_result_obj.lattice.get_values = MagicMock(return_value=mock_retval) + mocker.patch( + "covalent_dispatcher._core.data_modules.lattice.get_result_object", + return_value=mock_result_obj, + ) + + assert mock_retval == await lattice.get(dispatch_id, keys=["executor"]) diff --git a/tests/covalent_dispatcher_tests/_core/dispatcher_db_integration_test.py b/tests/covalent_dispatcher_tests/_core/dispatcher_db_integration_test.py new file mode 100644 index 000000000..53444a7b6 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/dispatcher_db_integration_test.py @@ -0,0 +1,322 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the core functionality of the dispatcher. +""" + + +from typing import Dict, List + +import pytest + +import covalent as ct +from covalent._results_manager import Result +from covalent._workflow.lattice import Lattice +from covalent_dispatcher._core.dispatcher import ( + _get_abstract_task_inputs, + _get_initial_tasks_and_deps, + _handle_completed_node, +) +from covalent_dispatcher._dal.result import Result as SRVResult +from covalent_dispatcher._dal.result import get_result_object +from covalent_dispatcher._db import models, update +from covalent_dispatcher._db.datastore import DataStore + + +@pytest.fixture +def test_db(): + """Instantiate and return an in-memory database.""" + + return DataStore( + db_URL="sqlite+pysqlite:///:memory:", + initialize_db=True, + ) + + +def get_mock_result() -> Result: + """Construct a mock result object corresponding to a lattice.""" + + import sys + + @ct.electron(executor="local") + def task(x): + print(f"stdout: {x}") + print("Error!", file=sys.stderr) + return x + + @ct.lattice + def pipeline(x): + res1 = task(x) + res2 = task(res1) + return res2 + + pipeline.build_graph(x="absolute") + received_workflow = Lattice.deserialize_from_json(pipeline.serialize_to_json()) + result_object = Result(received_workflow, "pipeline_workflow") + + return result_object + + +def get_mock_srvresult(sdkres, test_db) -> SRVResult: + sdkres._initialize_nodes() + + with test_db.session() as session: + record = session.query(models.Lattice).where(models.Lattice.id == 1).first() + + update.persist(sdkres) + + return get_result_object(sdkres.dispatch_id, bare=False) + + +@pytest.mark.asyncio +async def test_get_abstract_task_inputs(mocker, test_db): + """Test _get_abstract_task_inputs for both dicts and list parameter types""" + + @ct.electron + def list_task(arg: List): + return len(arg) + + @ct.electron + def dict_task(arg: Dict): + return len(arg) + + @ct.electron + def multivariable_task(x, y): + return x, y + + @ct.lattice + def list_workflow(arg): + return list_task(arg) + + @ct.lattice + def dict_workflow(arg): + return dict_task(arg) + + # 1 2 + # \ \ + # 0 3 + # / /\/ + # 4 5 + + @ct.electron + def identity(x): + return x + + @ct.lattice + def multivar_workflow(x, y): + electron_x = identity(x) + electron_y = identity(y) + res1 = multivariable_task(electron_x, electron_y) + res2 = multivariable_task(electron_y, electron_x) + res3 = multivariable_task(electron_y, electron_x) + res4 = multivariable_task(electron_x, electron_y) + return 1 + + mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + + # list-type inputs + + # Nodes 0=task, 1=:electron_list:, 2=1, 3=2, 4=3 + list_workflow.build_graph([1, 2, 3]) + abstract_args = [2, 3, 4] + tg = list_workflow.transport_graph + + sdkres = Result(lattice=list_workflow, dispatch_id="list_input_dispatch") + result_object = get_mock_srvresult(sdkres, test_db) + dispatch_id = result_object.dispatch_id + + async def mock_get_incoming_edges(dispatch_id, node_id): + return result_object.lattice.transport_graph.get_incoming_edges(node_id) + + mocker.patch( + "covalent_dispatcher._core.dispatcher.tg_utils.get_incoming_edges", + mock_get_incoming_edges, + ) + + abs_task_inputs = await _get_abstract_task_inputs( + result_object.dispatch_id, 1, tg.get_node_value(1, "name") + ) + + expected_inputs = {"args": abstract_args, "kwargs": {}} + + assert abs_task_inputs == expected_inputs + + # dict-type inputs + + # Nodes 0=task, 1=:electron_dict:, 2=1, 3=2 + dict_workflow.build_graph({"a": 1, "b": 2}) + abstract_args = {"a": 2, "b": 3} + tg = dict_workflow.transport_graph + + sdkres = Result(lattice=dict_workflow, dispatch_id="dict_input_dispatch") + result_object = get_mock_srvresult(sdkres, test_db) + + mocker.patch( + "covalent_dispatcher._core.dispatcher.tg_utils.get_incoming_edges", + mock_get_incoming_edges, + ) + + task_inputs = await _get_abstract_task_inputs( + result_object.dispatch_id, 1, tg.get_node_value(1, "name") + ) + expected_inputs = {"args": [], "kwargs": abstract_args} + + assert task_inputs == expected_inputs + + # Check arg order + multivar_workflow.build_graph(1, 2) + received_lattice = Lattice.deserialize_from_json(multivar_workflow.serialize_to_json()) + sdkres = Result(lattice=received_lattice, dispatch_id="arg_order_dispatch") + result_object = get_mock_srvresult(sdkres, test_db) + tg = received_lattice.transport_graph + + # Account for injected postprocess electron + assert list(tg._graph.nodes) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + tg.set_node_value(0, "output", ct.TransportableObject(1)) + tg.set_node_value(2, "output", ct.TransportableObject(2)) + + mocker.patch( + "covalent_dispatcher._core.dispatcher.tg_utils.get_incoming_edges", + mock_get_incoming_edges, + ) + + task_inputs = await _get_abstract_task_inputs( + result_object.dispatch_id, 4, tg.get_node_value(4, "name") + ) + assert task_inputs["args"] == [0, 2] + + mocker.patch( + "covalent_dispatcher._core.dispatcher.tg_utils.get_incoming_edges", + mock_get_incoming_edges, + ) + + task_inputs = await _get_abstract_task_inputs( + result_object.dispatch_id, 5, tg.get_node_value(5, "name") + ) + assert task_inputs["args"] == [2, 0] + + mocker.patch( + "covalent_dispatcher._core.dispatcher.tg_utils.get_incoming_edges", + mock_get_incoming_edges, + ) + + task_inputs = await _get_abstract_task_inputs( + result_object.dispatch_id, 6, tg.get_node_value(6, "name") + ) + assert task_inputs["args"] == [2, 0] + mocker.patch( + "covalent_dispatcher._core.dispatcher.tg_utils.get_incoming_edges", + mock_get_incoming_edges, + ) + + task_inputs = await _get_abstract_task_inputs( + result_object.dispatch_id, 7, tg.get_node_value(7, "name") + ) + assert task_inputs["args"] == [0, 2] + + +@pytest.mark.asyncio +async def test_handle_completed_node(mocker, test_db): + """Unit test for completed node handler""" + + from covalent_dispatcher._core.dispatcher import _initialize_caches, _pending_parents + + mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + + pending_parents = {} + sorted_task_groups = {} + sdkres = get_mock_result() + result_object = get_mock_srvresult(sdkres, test_db) + + async def get_node_successors(dispatch_id: str, node_id: int): + return result_object.lattice.transport_graph.get_successors(node_id, ["task_group_id"]) + + async def electron_get(dispatch_id, node_id, keys): + return {keys[0]: node_id} + + mocker.patch( + "covalent_dispatcher._core.dispatcher.tg_utils.get_node_successors", + get_node_successors, + ) + + mocker.patch( + "covalent_dispatcher._core.data_manager.electron.get", + electron_get, + ) + + # tg edges are (1, 0), (0, 2) + pending_parents[0] = 1 + pending_parents[1] = 0 + pending_parents[2] = 1 + sorted_task_groups[0] = [0] + sorted_task_groups[1] = [1] + sorted_task_groups[2] = [2] + + await _initialize_caches(result_object.dispatch_id, pending_parents, sorted_task_groups) + + node_result = {"node_id": 1, "status": Result.COMPLETED} + assert await _pending_parents.get_pending(result_object.dispatch_id, 0) == 1 + assert await _pending_parents.get_pending(result_object.dispatch_id, 1) == 0 + assert await _pending_parents.get_pending(result_object.dispatch_id, 2) == 1 + + next_nodes = await _handle_completed_node(result_object.dispatch_id, 1) + assert next_nodes == [0] + + assert await _pending_parents.get_pending(result_object.dispatch_id, 0) == 0 + assert await _pending_parents.get_pending(result_object.dispatch_id, 1) == 0 + assert await _pending_parents.get_pending(result_object.dispatch_id, 2) == 1 + + +@pytest.mark.asyncio +async def test_get_initial_tasks_and_deps(mocker, test_db): + """Test internal function for initializing status_queue and pending_parents""" + + mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + + pending_parents = {} + + sdkres = get_mock_result() + result_object = get_mock_srvresult(sdkres, test_db) + dispatch_id = result_object.dispatch_id + + async def get_graph_nodes_links(dispatch_id: str) -> dict: + import networkx as nx + + """Return the internal transport graph in NX node-link form""" + g = result_object.lattice.transport_graph.get_internal_graph_copy() + return nx.readwrite.node_link_data(g) + + mocker.patch( + "covalent_dispatcher._core.dispatcher.tg_utils.get_nodes_links", + side_effect=get_graph_nodes_links, + ) + + initial_nodes, pending_parents, sorted_task_groups = await _get_initial_tasks_and_deps( + dispatch_id + ) + + assert initial_nodes == [1] + + # Account for injected postprocess electron + assert pending_parents == {0: 1, 1: 0, 2: 1, 3: 3} + assert sorted_task_groups == {0: [0], 1: [1], 2: [2], 3: [3]} diff --git a/tests/covalent_dispatcher_tests/_core/dispatcher_test.py b/tests/covalent_dispatcher_tests/_core/dispatcher_test.py new file mode 100644 index 000000000..5b5b79414 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/dispatcher_test.py @@ -0,0 +1,849 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the core functionality of the dispatcher. + +This will be replaced in the next patch. +""" + + +from unittest.mock import call + +import pytest + +import covalent as ct +from covalent._results_manager import Result +from covalent._workflow.lattice import Lattice +from covalent_dispatcher._core.dispatcher import ( + _clear_caches, + _finalize_dispatch, + _handle_cancelled_node, + _handle_event, + _handle_failed_node, + _handle_node_status_update, + _submit_initial_tasks, + _submit_task_group, + cancel_dispatch, + run_dispatch, + run_workflow, +) +from covalent_dispatcher._db.datastore import DataStore + +TEST_RESULTS_DIR = "/tmp/results" + + +@pytest.fixture +def test_db(): + """Instantiate and return an in-memory database.""" + + return DataStore( + db_URL="sqlite+pysqlite:///:memory:", + initialize_db=True, + ) + + +def get_mock_result() -> Result: + """Construct a mock result object corresponding to a lattice.""" + + import sys + + @ct.electron(executor="local") + def task(x): + print(f"stdout: {x}") + print("Error!", file=sys.stderr) + return x + + @ct.lattice + def pipeline(x): + res1 = task(x) + res2 = task(res1) + return res2 + + pipeline.build_graph(x="absolute") + received_workflow = Lattice.deserialize_from_json(pipeline.serialize_to_json()) + result_object = Result(received_workflow, "pipeline_workflow") + + return result_object + + +@pytest.mark.asyncio +async def test_handle_failed_node(mocker): + """Unit test for failed node handler""" + dispatch_id = "failed_dispatch" + await _handle_failed_node(dispatch_id, 1) + + +@pytest.mark.asyncio +async def test_handle_cancelled_node(mocker, test_db): + """Unit test for cancelled node handler""" + dispatch_id = "cancelled_dispatch" + + await _handle_cancelled_node(dispatch_id, 1) + + +@pytest.mark.parametrize( + "wait,expected_status", [(True, Result.COMPLETED), (False, Result.RUNNING)] +) +@pytest.mark.asyncio +async def test_run_workflow_normal(mocker, wait, expected_status): + import asyncio + + dispatch_id = "mock_dispatch" + + mock_unregister = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" + ) + mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.ensure_dispatch", return_value=True) + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.get", + return_value={"status": Result.NEW_OBJ}, + ) + _futures = {dispatch_id: asyncio.Future()} + mocker.patch("covalent_dispatcher._core.dispatcher._futures", _futures) + + async def mark_future_done(dispatch_id): + _futures[dispatch_id].set_result(Result.COMPLETED) + return Result.RUNNING + + mocker.patch( + "covalent_dispatcher._core.dispatcher._submit_initial_tasks", + return_value=Result.RUNNING, + side_effect=mark_future_done, + ) + + dispatch_status = await run_workflow(dispatch_id, wait) + assert dispatch_status == expected_status + if wait: + mock_unregister.assert_called_with(dispatch_id) + + +@pytest.mark.parametrize("wait", [True, False]) +@pytest.mark.asyncio +async def test_run_completed_workflow(mocker, wait): + import asyncio + + dispatch_id = "completed_dispatch" + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.ensure_dispatch", return_value=False + ) + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.get", + return_value={"status": Result.COMPLETED}, + ) + + mock_unregister = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" + ) + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.get", + return_value={"status": Result.COMPLETED}, + ) + mock_plan = mocker.patch("covalent_dispatcher._core.dispatcher._plan_workflow") + dispatch_status = await run_workflow(dispatch_id, wait) + + mock_unregister.assert_not_called() + assert dispatch_status == Result.COMPLETED + + +@pytest.mark.parametrize("wait", [True, False]) +@pytest.mark.asyncio +async def test_run_workflow_exception(mocker, wait): + import asyncio + + dispatch_id = "mock_dispatch" + + mock_unregister = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" + ) + mocker.patch("covalent_dispatcher._core.dispatcher._plan_workflow") + mocker.patch( + "covalent_dispatcher._core.dispatcher._submit_initial_tasks", + side_effect=RuntimeError("Error"), + ) + + mock_dispatch_update = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.update", + ) + mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.ensure_dispatch", return_value=True) + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.get", + return_value={"status": Result.NEW_OBJ}, + ) + + status = await run_workflow(dispatch_id, wait) + + assert status == Result.FAILED + mock_unregister.assert_called_with(dispatch_id) + + +@pytest.mark.asyncio +async def test_run_dispatch(mocker): + dispatch_id = "test_dispatch" + mock_run = mocker.patch("covalent_dispatcher._core.dispatcher.run_workflow") + run_dispatch(dispatch_id) + mock_run.assert_called_with(dispatch_id) + + +@pytest.mark.asyncio +async def test_handle_completed_node_update(mocker): + import asyncio + + dispatch_id = "mock_dispatch" + node_id = 2 + status = Result.COMPLETED + detail = {} + next_groups = [0, 1] + + mock_handle_cancelled = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_completed_node", return_value=next_groups + ) + mock_decrement = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.decrement" + ) + + mock_increment = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.increment" + ) + + async def get_task_group(dispatch_id, gid): + return [gid] + + mock_get_sorted_task_groups = mocker.patch( + "covalent_dispatcher._core.dispatcher._sorted_task_groups.get_task_group", + get_task_group, + ) + mock_submit_task_group = mocker.patch( + "covalent_dispatcher._core.dispatcher._submit_task_group" + ) + + await _handle_node_status_update(dispatch_id, node_id, status, detail) + mock_decrement.assert_awaited() + assert mock_increment.await_count == 2 + assert mock_submit_task_group.await_count == 2 + + +@pytest.mark.asyncio +async def test_handle_cancelled_node_update(mocker): + import asyncio + + dispatch_id = "mock_dispatch" + node_id = 0 + status = Result.CANCELLED + detail = {} + mock_handle_cancelled = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_cancelled_node", + ) + mock_decrement = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.decrement" + ) + + await _handle_node_status_update(dispatch_id, node_id, status, detail) + mock_handle_cancelled.assert_awaited_with(dispatch_id, 0) + mock_decrement.assert_awaited() + + +@pytest.mark.asyncio +async def test_run_handle_failed_node_update(mocker): + import asyncio + + dispatch_id = "mock_dispatch" + node_id = 0 + status = Result.FAILED + detail = {} + mock_handle_failed = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_failed_node", + ) + mock_decrement = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.decrement" + ) + + await _handle_node_status_update(dispatch_id, node_id, status, detail) + mock_handle_failed.assert_awaited_with(dispatch_id, 0) + mock_decrement.assert_awaited() + + +@pytest.mark.asyncio +async def test_run_handle_sublattice_node_update(mocker): + import asyncio + + from covalent._shared_files.util_classes import RESULT_STATUS + + dispatch_id = "mock_dispatch" + node_id = 0 + status = RESULT_STATUS.DISPATCHING + detail = {"sub_dispatch_id": "sub_dispatch"} + mock_run_dispatch = mocker.patch( + "covalent_dispatcher._core.dispatcher.run_dispatch", + ) + mock_decrement = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.decrement" + ) + await _handle_node_status_update(dispatch_id, node_id, status, detail) + mock_run_dispatch.assert_called_with("sub_dispatch") + mock_decrement.assert_not_awaited() + + +@pytest.mark.parametrize("unresolved_count", [1, 0]) +@pytest.mark.asyncio +async def test_handle_event(mocker, unresolved_count): + mock_handle_status_update = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_node_status_update", + ) + mock_handle_dispatch_exception = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_dispatch_exception", + ) + + mock_persist = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.persist_result", + ) + + mock_finalize = mocker.patch( + "covalent_dispatcher._core.dispatcher._finalize_dispatch", + return_value=Result.COMPLETED, + ) + + mock_get_unresolved = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.get_unresolved", + return_value=unresolved_count, + ) + + dispatch_id = "mock_dispatch" + node_id = 2 + status = Result.COMPLETED + msg = {"dispatch_id": dispatch_id, "node_id": node_id, "status": status, "detail": {}} + + await _handle_event(msg) + + if unresolved_count < 1: + mock_finalize.assert_awaited() + mock_persist.assert_awaited() + else: + mock_finalize.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_handle_event_exception(mocker): + import asyncio + + mock_handle_status_update = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_node_status_update", + side_effect=RuntimeError(), + ) + mock_handle_dispatch_exception = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_dispatch_exception", + return_value=Result.FAILED, + ) + + mock_persist = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.persist_result", + ) + + mock_finalize = mocker.patch( + "covalent_dispatcher._core.dispatcher._finalize_dispatch", + return_value=Result.COMPLETED, + ) + + mock_get_unresolved = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.get_unresolved", + return_value=2, + ) + + dispatch_id = "mock_dispatch" + node_id = 2 + status = Result.COMPLETED + msg = {"dispatch_id": dispatch_id, "node_id": node_id, "status": status, "detail": {}} + + _futures = {dispatch_id: asyncio.Future()} + + mocker.patch( + "covalent_dispatcher._core.dispatcher._futures", + _futures, + ) + + assert await _handle_event(msg) == Result.FAILED + + assert _futures[dispatch_id].result() == Result.FAILED + + mock_persist.assert_awaited() + mock_finalize.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_handle_event_finalize_exception(mocker): + import asyncio + + mock_handle_status_update = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_node_status_update", + ) + mock_handle_dispatch_exception = mocker.patch( + "covalent_dispatcher._core.dispatcher._handle_dispatch_exception", + return_value=Result.FAILED, + ) + + mock_persist = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.persist_result", + ) + + mock_finalize = mocker.patch( + "covalent_dispatcher._core.dispatcher._finalize_dispatch", + side_effect=RuntimeError(), + ) + + mock_get_unresolved = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.get_unresolved", + return_value=0, + ) + + dispatch_id = "mock_dispatch" + node_id = 2 + status = Result.COMPLETED + msg = {"dispatch_id": dispatch_id, "node_id": node_id, "status": status, "detail": {}} + + _futures = {dispatch_id: asyncio.Future()} + + mocker.patch( + "covalent_dispatcher._core.dispatcher._futures", + _futures, + ) + + assert await _handle_event(msg) == Result.FAILED + + assert _futures[dispatch_id].result() == Result.FAILED + + mock_persist.assert_awaited() + + +@pytest.mark.parametrize( + "failed,cancelled,final_status", + [ + (False, False, Result.COMPLETED), + (False, True, Result.CANCELLED), + (True, False, Result.FAILED), + (True, True, Result.FAILED), + ], +) +@pytest.mark.asyncio +async def test_finalize_dispatch(mocker, failed, cancelled, final_status): + mock_clear = mocker.patch("covalent_dispatcher._core.dispatcher._clear_caches") + failed_tasks = [(0, "task_0")] if failed else [] + cancelled_tasks = [(1, "task_1")] if cancelled else [] + + query_result = {"failed": failed_tasks, "cancelled": cancelled_tasks} + mock_incomplete = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.get_incomplete_tasks", + return_value=query_result, + ) + + mock_dispatch_info = {"status": Result.COMPLETED} + + def mock_gen_dispatch_result(dispatch_id, **kwargs): + return {"status": kwargs["status"]} + + async def mock_dispatch_update(dispatch_id, dispatch_result): + mock_dispatch_info["status"] = dispatch_result["status"] + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.generate_dispatch_result", + mock_gen_dispatch_result, + ) + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.update", + mock_dispatch_update, + ) + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.get", + return_value=mock_dispatch_info, + ) + + dispatch_id = "dispatch_1" + + assert await _finalize_dispatch(dispatch_id) == final_status + + +@pytest.mark.asyncio +async def test_submit_initial_tasks(mocker): + dispatch_id = "dispatch_1" + + initial_groups = [1, 2] + sorted_groups = {1: [1], 2: [2]} + + mocker.patch( + "covalent_dispatcher._core.dispatcher._get_initial_tasks_and_deps", + return_value=(initial_groups, {1: 0, 2: 0}, sorted_groups), + ) + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.generate_dispatch_result", + ) + mocker.patch( + "covalent_dispatcher._core.dispatcher._initialize_caches", + ) + + mock_inc = mocker.patch("covalent_dispatcher._core.dispatcher._unresolved_tasks.increment") + mock_submit_task_group = mocker.patch( + "covalent_dispatcher._core.dispatcher._submit_task_group", + ) + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.dispatch.update", + ) + + assert await _submit_initial_tasks(dispatch_id) == Result.RUNNING + + assert mock_submit_task_group.await_count == 2 + assert mock_inc.await_count == 2 + + +@pytest.mark.asyncio +async def test_submit_task_group_single(mocker): + """Test submitting a singleton task groups""" + dispatch_id = "dispatch_1" + gid = 2 + nodes = [2] + + mock_get_abs_input = mocker.patch( + "covalent_dispatcher._core.dispatcher._get_abstract_task_inputs", + return_value={"args": [], "kwargs": {}}, + ) + + mock_attrs = { + "name": "task", + "value": 5, + "executor": "local", + "executor_data": {}, + } + + mock_statuses = [ + {"status": Result.NEW_OBJ}, + {"status": Result.NEW_OBJ}, + {"status": Result.NEW_OBJ}, + ] + + async def get_electron_attrs(dispatch_id, node_id, keys): + return {key: mock_attrs[key] for key in keys} + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.electron.get", + get_electron_attrs, + ) + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.electron.get_bulk", + return_value=mock_statuses, + ) + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.update_node_result", + ) + + # This will be removed in the next patch + mock_run_abs_task = mocker.patch( + "covalent_dispatcher._core.dispatcher.runner.run_abstract_task", + ) + + await _submit_task_group(dispatch_id, nodes, gid) + mock_run_abs_task.assert_called() + assert mock_get_abs_input.await_count == len(nodes) + + +# Temporary only because the current runner does not support +# nontrivial task groups. +@pytest.mark.asyncio +async def test_submit_task_group_multiple(mocker): + """Check that submitting multiple tasks errors out""" + dispatch_id = "dispatch_1" + gid = 2 + nodes = [4, 3, 2] + + mock_get_abs_input = mocker.patch( + "covalent_dispatcher._core.dispatcher._get_abstract_task_inputs", + return_value={"args": [], "kwargs": {}}, + ) + + mock_attrs = { + "name": "task", + "value": 5, + "executor": "local", + "executor_data": {}, + } + + mock_statuses = [ + {"status": Result.NEW_OBJ}, + {"status": Result.NEW_OBJ}, + {"status": Result.NEW_OBJ}, + ] + + async def get_electron_attrs(dispatch_id, node_id, keys): + return {key: mock_attrs[key] for key in keys} + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.electron.get", + get_electron_attrs, + ) + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.electron.get_bulk", + return_value=mock_statuses, + ) + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.update_node_result", + ) + + # This will be removed in the next patch + mock_run_abs_task = mocker.patch( + "covalent_dispatcher._core.dispatcher.runner.run_abstract_task", + ) + + with pytest.raises(RuntimeError): + await _submit_task_group(dispatch_id, nodes, gid) + + +@pytest.mark.asyncio +async def test_submit_task_group_skips_reusable(mocker): + """Check that submit_task_group skips reusable groups""" + dispatch_id = "dispatch_1" + gid = 2 + nodes = [4, 3, 2] + + mock_get_abs_input = mocker.patch( + "covalent_dispatcher._core.dispatcher._get_abstract_task_inputs", + return_value={"args": [], "kwargs": {}}, + ) + + mock_attrs = { + "name": "task", + "value": 5, + "executor": "local", + "executor_data": {}, + } + + mock_statuses = [ + {"status": Result.PENDING_REUSE}, + {"status": Result.PENDING_REUSE}, + {"status": Result.PENDING_REUSE}, + ] + + async def get_electron_attrs(dispatch_id, node_id, keys): + return {key: mock_attrs[key] for key in keys} + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.electron.get", + get_electron_attrs, + ) + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.electron.get_bulk", + return_value=mock_statuses, + ) + + mock_update = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.update_node_result", + ) + + # Will be removed next patch + mock_run_abs_task = mocker.patch( + "covalent_dispatcher._core.dispatcher.runner.run_abstract_task", + ) + + await _submit_task_group(dispatch_id, nodes, gid) + mock_run_abs_task.assert_not_called() + mock_get_abs_input.assert_not_awaited() + assert mock_update.await_count == len(nodes) + + +@pytest.mark.asyncio +async def test_submit_parameter(mocker): + from covalent._shared_files.defaults import parameter_prefix + + dispatch_id = "dispatch_1" + node_id = 2 + + mock_attrs = { + "name": parameter_prefix, + "value": 5, + "executor": "local", + "executor_data": {}, + } + + async def get_electron_attrs(dispatch_id, node_id, keys): + return {key: mock_attrs[key] for key in keys} + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.electron.get", + get_electron_attrs, + ) + + mock_update = mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.update_node_result", + ) + + # Will be removed next patch + mock_run_abs_task = mocker.patch( + "covalent_dispatcher._core.dispatcher.runner.run_abstract_task", + ) + + await _submit_task_group(dispatch_id, [node_id], node_id) + + mock_run_abs_task.assert_not_called() + mock_update.assert_awaited() + + +@pytest.mark.asyncio +async def test_clear_caches(mocker): + import networkx as nx + + g = nx.MultiDiGraph() + g.add_node(0, task_group_id=0) + g.add_node(1, task_group_id=0) + g.add_node(2, task_group_id=0) + g.add_node(3, task_group_id=3) + + mocker.patch("covalent_dispatcher._core.dispatcher.tg_utils.get_nodes_links") + mocker.patch("networkx.readwrite.node_link_graph", return_value=g) + mock_unresolved_remove = mocker.patch( + "covalent_dispatcher._core.dispatcher._unresolved_tasks.remove" + ) + mock_pending_remove = mocker.patch( + "covalent_dispatcher._core.dispatcher._pending_parents.remove" + ) + + mock_groups_remove = mocker.patch( + "covalent_dispatcher._core.dispatcher._sorted_task_groups.remove" + ) + + await _clear_caches("dispatch") + + assert mock_unresolved_remove.await_count == 1 + assert mock_pending_remove.await_count == 2 + assert mock_groups_remove.await_count == 2 + + +@pytest.mark.asyncio +async def test_cancel_dispatch(mocker): + """Test cancelling a dispatch, including sub-lattices""" + res = get_mock_result() + sub_res = get_mock_result() + + sub_dispatch_id = "sub_pipeline_workflow" + sub_res._dispatch_id = sub_dispatch_id + + mock_data_cancel = mocker.patch( + "covalent_dispatcher._core.dispatcher.jbmgr.set_cancel_requested" + ) + + mock_cancel_tasks = mocker.patch("covalent_dispatcher._core.dispatcher.cancel_tasks") + + res._initialize_nodes() + sub_res._initialize_nodes() + + tg = res.lattice.transport_graph + tg.set_node_value(2, "sub_dispatch_id", sub_dispatch_id) + sub_tg = sub_res.lattice.transport_graph + + async def mock_get_nodes(dispatch_id): + if dispatch_id == res.dispatch_id: + return list(tg._graph.nodes) + else: + return list(sub_tg._graph.nodes) + + mocker.patch("covalent_dispatcher._core.dispatcher.tg_utils.get_nodes", mock_get_nodes) + + node_attrs = [ + {"sub_dispatch_id": tg.get_node_value(i, "sub_dispatch_id")} for i in tg._graph.nodes + ] + sub_node_attrs = [ + {"sub_dispatch_id": sub_tg.get_node_value(i, "sub_dispatch_id")} + for i in sub_tg._graph.nodes + ] + + async def mock_get(dispatch_id, task_ids, keys): + return node_attrs if dispatch_id == res.dispatch_id else sub_node_attrs + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.electron.get_bulk", + mock_get, + ) + + await cancel_dispatch("pipeline_workflow") + + task_ids = list(tg._graph.nodes) + sub_task_ids = list(sub_tg._graph.nodes) + + calls = [call("pipeline_workflow", task_ids), call(sub_dispatch_id, sub_task_ids)] + mock_data_cancel.assert_has_awaits(calls) + mock_cancel_tasks.assert_has_awaits(calls) + + +@pytest.mark.asyncio +async def test_cancel_dispatch_with_task_ids(mocker): + """Test cancelling a dispatch, including sub-lattices and with task ids""" + res = get_mock_result() + sub_res = get_mock_result() + + res._initialize_nodes() + sub_res._initialize_nodes() + + sub_dispatch_id = "sub_pipeline_workflow" + sub_res._dispatch_id = sub_dispatch_id + tg = res.lattice.transport_graph + tg.set_node_value(2, "sub_dispatch_id", sub_dispatch_id) + sub_tg = sub_res.lattice.transport_graph + + mock_data_cancel = mocker.patch( + "covalent_dispatcher._core.dispatcher.jbmgr.set_cancel_requested" + ) + + mock_cancel_tasks = mocker.patch("covalent_dispatcher._core.dispatcher.cancel_tasks") + + async def mock_get_nodes(dispatch_id): + if dispatch_id == res.dispatch_id: + return list(tg._graph.nodes) + else: + return list(sub_tg._graph.nodes) + + mocker.patch("covalent_dispatcher._core.dispatcher.tg_utils.get_nodes", mock_get_nodes) + + node_attrs = [ + {"sub_dispatch_id": tg.get_node_value(i, "sub_dispatch_id")} for i in tg._graph.nodes + ] + sub_node_attrs = [ + {"sub_dispatch_id": sub_tg.get_node_value(i, "sub_dispatch_id")} + for i in sub_tg._graph.nodes + ] + + async def mock_get(dispatch_id, task_ids, keys): + return node_attrs if dispatch_id == res.dispatch_id else sub_node_attrs + + mocker.patch( + "covalent_dispatcher._core.dispatcher.datasvc.electron.get_bulk", + mock_get, + ) + + mock_app_log = mocker.patch("covalent_dispatcher._core.dispatcher.app_log.debug") + task_ids = [2] + sub_task_ids = list(sub_tg._graph.nodes) + + await cancel_dispatch("pipeline_workflow", task_ids) + + calls = [call("pipeline_workflow", task_ids), call(sub_dispatch_id, sub_task_ids)] + mock_data_cancel.assert_has_awaits(calls) + mock_cancel_tasks.assert_has_awaits(calls) + assert mock_app_log.call_count == 2 diff --git a/tests/covalent_dispatcher_tests/_core/execution_test.py b/tests/covalent_dispatcher_tests/_core/execution_test.py new file mode 100644 index 000000000..5dc5712fe --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/execution_test.py @@ -0,0 +1,279 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Integration tests for the dispatcher, runner, and result modules +""" + +import asyncio +import uuid +from typing import Dict, List + +import pytest +import pytest_asyncio +from sqlalchemy.pool import StaticPool + +import covalent as ct +from covalent._results_manager import Result +from covalent._workflow.lattice import Lattice +from covalent_dispatcher._core.dispatcher import run_workflow +from covalent_dispatcher._core.execution import _get_task_inputs +from covalent_dispatcher._dal.result import Result as SRVResult +from covalent_dispatcher._dal.result import get_result_object +from covalent_dispatcher._db import models, update +from covalent_dispatcher._db.datastore import DataStore + +TEST_RESULTS_DIR = "/tmp/results" + + +@pytest.fixture +def test_db(): + """Instantiate and return an in-memory database.""" + + return DataStore( + db_URL="sqlite+pysqlite:///:memory:", + initialize_db=True, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + + +@pytest_asyncio.fixture(scope="session") +def event_loop(request): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +def get_mock_result() -> Result: + """Construct a mock result object corresponding to a lattice.""" + + import sys + + @ct.electron(executor="local") + def task(x): + print(f"stdout: {x}") + print("Error!", file=sys.stderr) + return x + + @ct.lattice(deps_bash=ct.DepsBash(["ls"])) + def pipeline(x): + res1 = task(x) + res2 = task(res1) + return res2 + + pipeline.build_graph(x="absolute") + received_workflow = Lattice.deserialize_from_json(pipeline.serialize_to_json()) + result_object = Result(received_workflow, "pipeline_workflow") + + return result_object + + +def get_mock_srvresult(sdkres, test_db) -> SRVResult: + sdkres._initialize_nodes() + + with test_db.session() as session: + record = session.query(models.Lattice).where(models.Lattice.id == 1).first() + + update.persist(sdkres) + + return get_result_object(sdkres.dispatch_id, bare=False) + + +@pytest.mark.asyncio +async def test_get_task_inputs(mocker, test_db): + """Test _get_task_inputs for both dicts and list parameter types""" + + @ct.electron + def list_task(arg: List): + return len(arg) + + @ct.electron + def dict_task(arg: Dict): + return len(arg) + + @ct.electron + def multivariable_task(x, y): + return x, y + + @ct.lattice + def list_workflow(arg): + return list_task(arg) + + @ct.lattice + def dict_workflow(arg): + return dict_task(arg) + + # 1 2 + # \ \ + # 0 3 + # / /\/ + # 4 5 + + @ct.electron + def identity(x): + return x + + @ct.lattice + def multivar_workflow(x, y): + electron_x = identity(x) + electron_y = identity(y) + res1 = multivariable_task(electron_x, electron_y) + res2 = multivariable_task(electron_y, electron_x) + res3 = multivariable_task(electron_y, electron_x) + res4 = multivariable_task(electron_x, electron_y) + return 1 + + # list-type inputs + + list_workflow.build_graph([1, 2, 3]) + serialized_args = [ct.TransportableObject(i) for i in [1, 2, 3]] + + # Nodes 0=task, 1=:electron_list:, 2=1, 3=2, 4=3 + sdkres = Result(lattice=list_workflow, dispatch_id="asdf") + mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + + result_object = get_mock_srvresult(sdkres, test_db) + tg = result_object.lattice.transport_graph + tg.set_node_value(2, "output", ct.TransportableObject(1)) + tg.set_node_value(3, "output", ct.TransportableObject(2)) + tg.set_node_value(4, "output", ct.TransportableObject(3)) + + mock_get_result = mocker.patch( + "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object + ) + task_inputs = await _get_task_inputs(1, tg.get_node_value(1, "name"), result_object) + + expected_inputs = {"args": serialized_args, "kwargs": {}} + + assert task_inputs == expected_inputs + + # dict-type inputs + + dict_workflow.build_graph({"a": 1, "b": 2}) + serialized_args = {"a": ct.TransportableObject(1), "b": ct.TransportableObject(2)} + + # Nodes 0=task, 1=:electron_dict:, 2=1, 3=2 + sdkres = Result(lattice=dict_workflow, dispatch_id="asdf_dict_workflow") + result_object = get_mock_srvresult(sdkres, test_db) + tg = result_object.lattice.transport_graph + tg.set_node_value(2, "output", ct.TransportableObject(1)) + tg.set_node_value(3, "output", ct.TransportableObject(2)) + + mock_get_result = mocker.patch( + "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object + ) + task_inputs = await _get_task_inputs(1, tg.get_node_value(1, "name"), result_object) + expected_inputs = {"args": [], "kwargs": serialized_args} + + assert task_inputs == expected_inputs + + # Check arg order + multivar_workflow.build_graph(1, 2) + received_lattice = Lattice.deserialize_from_json(multivar_workflow.serialize_to_json()) + sdkres = Result(lattice=received_lattice, dispatch_id="asdf_multivar_workflow") + result_object = get_mock_srvresult(sdkres, test_db) + tg = result_object.lattice.transport_graph + + # Account for injected postprocess electron + assert list(tg._graph.nodes) == [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + tg.set_node_value(0, "output", ct.TransportableObject(1)) + tg.set_node_value(2, "output", ct.TransportableObject(2)) + + mock_get_result = mocker.patch( + "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object + ) + + task_inputs = await _get_task_inputs(4, tg.get_node_value(4, "name"), result_object) + + input_args = [arg.get_deserialized() for arg in task_inputs["args"]] + assert input_args == [1, 2] + + mock_get_result = mocker.patch( + "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object + ) + + task_inputs = await _get_task_inputs(5, tg.get_node_value(5, "name"), result_object) + input_args = [arg.get_deserialized() for arg in task_inputs["args"]] + assert input_args == [2, 1] + + mock_get_result = mocker.patch( + "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object + ) + + task_inputs = await _get_task_inputs(6, tg.get_node_value(6, "name"), result_object) + input_args = [arg.get_deserialized() for arg in task_inputs["args"]] + assert input_args == [2, 1] + + mock_get_result = mocker.patch( + "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object + ) + + task_inputs = await _get_task_inputs(7, tg.get_node_value(7, "name"), result_object) + input_args = [arg.get_deserialized() for arg in task_inputs["args"]] + assert input_args == [1, 2] + + +@pytest.mark.asyncio +async def test_run_workflow_does_not_deserialize(test_db, mocker): + """Check that dispatcher does not deserialize user data when using + out-of-process `workflow_executor`""" + + from dask.distributed import LocalCluster + + from covalent._workflow.lattice import Lattice + from covalent.executor import DaskExecutor + + lc = LocalCluster() + dask_exec = DaskExecutor(lc.scheduler_address) + + @ct.electron(executor=dask_exec) + def task(x): + return x + + @ct.lattice(executor=dask_exec, workflow_executor=dask_exec) + def workflow(x): + # Exercise both sublatticing and postprocessing + sublattice_task = ct.lattice(task, workflow_executor=dask_exec) + res1 = ct.electron(sublattice_task(x), executor=dask_exec) + return res1 + + workflow.build_graph(5) + + json_lattice = workflow.serialize_to_json() + dispatch_id = str(uuid.uuid4()) + lattice = Lattice.deserialize_from_json(json_lattice) + result_object = Result(lattice) + result_object._dispatch_id = dispatch_id + result_object._initialize_nodes() + + mocker.patch("covalent_dispatcher._db.datastore.DataStore.factory", return_value=test_db) + mocker.patch( + "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object + ) + + mocker.patch("covalent_dispatcher._core.dispatcher._global_status_queue", asyncio.Queue()) + + update.persist(result_object) + + mock_to_deserialize = mocker.patch("covalent.TransportableObject.get_deserialized") + + await run_workflow(result_object.dispatch_id) + + mock_to_deserialize.assert_not_called() diff --git a/tests/covalent_dispatcher_tests/_core/runner_db_integration_test.py b/tests/covalent_dispatcher_tests/_core/runner_db_integration_test.py new file mode 100644 index 000000000..3366b8aad --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/runner_db_integration_test.py @@ -0,0 +1,123 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the core functionality of the runner. +""" + + +import pytest +from sqlalchemy.pool import StaticPool + +import covalent as ct +from covalent._results_manager import Result +from covalent._workflow.lattice import Lattice +from covalent_dispatcher._core.runner import _gather_deps +from covalent_dispatcher._dal.result import Result as SRVResult +from covalent_dispatcher._dal.result import get_result_object +from covalent_dispatcher._db import update +from covalent_dispatcher._db.datastore import DataStore + +TEST_RESULTS_DIR = "/tmp/results" + + +@pytest.fixture +def test_db(): + """Instantiate and return an in-memory database.""" + + return DataStore( + db_URL="sqlite+pysqlite:///:memory:", + initialize_db=True, + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + + +def get_mock_result() -> Result: + """Construct a mock result object corresponding to a lattice.""" + + import sys + + @ct.electron(executor="local") + def task(x): + print(f"stdout: {x}") + print("Error!", file=sys.stderr) + return x + + @ct.lattice(deps_bash=ct.DepsBash(["ls"])) + def pipeline(x): + res1 = task(x) + res2 = task(res1) + return res2 + + pipeline.build_graph(x="absolute") + received_workflow = Lattice.deserialize_from_json(pipeline.serialize_to_json()) + result_object = Result(received_workflow, "pipeline_workflow") + + return result_object + + +def get_mock_srvresult(sdkres, test_db) -> SRVResult: + sdkres._initialize_nodes() + + update.persist(sdkres) + + return get_result_object(sdkres.dispatch_id) + + +@pytest.mark.asyncio +async def test_gather_deps(mocker, test_db): + """Test internal _gather_deps for assembling deps into call_before and + call_after""" + + def square(x): + return x * x + + @ct.electron( + deps_bash=ct.DepsBash("ls -l"), + deps_pip=ct.DepsPip(["pandas"]), + call_before=[ct.DepsCall(square, [5])], + call_after=[ct.DepsCall(square, [3])], + ) + def task(x): + return x + + @ct.lattice + def workflow(x): + return task(x) + + mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) + mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) + workflow.build_graph(5) + + received_workflow = Lattice.deserialize_from_json(workflow.serialize_to_json()) + sdkres = Result(received_workflow, "test_gather_deps") + result_object = get_mock_srvresult(sdkres, test_db) + + async def get_electron_attrs(dispatch_id, node_id, keys): + return { + key: result_object.lattice.transport_graph.get_node_value(node_id, key) for key in keys + } + + mocker.patch( + "covalent_dispatcher._core.data_manager.electron.get", + get_electron_attrs, + ) + + before, after = await _gather_deps(result_object.dispatch_id, 0) + assert len(before) == 3 + assert len(after) == 1 diff --git a/tests/covalent_dispatcher_tests/_core/runner_modules/cancel_test.py b/tests/covalent_dispatcher_tests/_core/runner_modules/cancel_test.py new file mode 100644 index 000000000..239102a98 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/runner_modules/cancel_test.py @@ -0,0 +1,102 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the cancellation module +""" + +import json +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from covalent._shared_files.util_classes import RESULT_STATUS +from covalent_dispatcher._core.runner_modules import cancel + + +@pytest.mark.asyncio +async def test_cancel_tasks(mocker): + """Test the public `cancel_tasks` function""" + dispatch_id = "test_cancel_tasks" + node_id = 0 + mock_node_metadata = [{"executor": "dask", "executor_data": {}}] + mock_job_metadata = [{"job_handle": 42}] + mock_cancel_priv = mocker.patch("covalent_dispatcher._core.runner_modules.cancel._cancel_task") + + mocker.patch( + "covalent_dispatcher._core.runner_modules.cancel._get_metadata_for_nodes", + return_value=mock_node_metadata, + ) + mocker.patch( + "covalent_dispatcher._core.data_modules.job_manager.get_jobs_metadata", + return_value=mock_job_metadata, + ) + + await cancel.cancel_tasks(dispatch_id, [node_id]) + + assert mock_cancel_priv.call_count == 1 + + +@pytest.mark.asyncio +async def test_cancel_task_priv(mocker): + """Test the internal `_cancel_task` function""" + mock_executor = MagicMock() + mock_executor._cancel = AsyncMock(return_value=True) + mock_set_status = mocker.patch( + "covalent_dispatcher._core.data_modules.job_manager.set_job_status" + ) + + mocker.patch( + "covalent_dispatcher._core.runner_modules.cancel.get_executor", return_value=mock_executor + ) + + dispatch_id = "test_cancel_task_priv" + job_handle = json.dumps(42) + task_id = 0 + + await cancel._cancel_task(dispatch_id, task_id, ["dask", {}], job_handle) + + task_meta = {"dispatch_id": dispatch_id, "node_id": task_id} + + mock_executor._cancel.assert_awaited_with(task_meta, 42) + + mock_set_status.assert_awaited_with(dispatch_id, task_id, str(RESULT_STATUS.CANCELLED)) + + +@pytest.mark.asyncio +async def test_cancel_task_priv_exception(mocker): + """Test the internal `_cancel_task` function""" + mock_executor = MagicMock() + mock_executor._cancel = AsyncMock(side_effect=RuntimeError()) + mock_set_status = mocker.patch( + "covalent_dispatcher._core.data_modules.job_manager.set_job_status" + ) + + mocker.patch( + "covalent_dispatcher._core.runner_modules.cancel.get_executor", return_value=mock_executor + ) + + dispatch_id = "test_cancel_task_priv" + job_handle = json.dumps(42) + task_id = 0 + + await cancel._cancel_task(dispatch_id, task_id, ["dask", {}], job_handle) + + task_meta = {"dispatch_id": dispatch_id, "node_id": task_id} + + mock_executor._cancel.assert_awaited_with(task_meta, 42) + + mock_set_status.assert_not_awaited() diff --git a/tests/covalent_dispatcher_tests/_core/runner_modules/jobs_test.py b/tests/covalent_dispatcher_tests/_core/runner_modules/jobs_test.py new file mode 100644 index 000000000..12d06443d --- /dev/null +++ b/tests/covalent_dispatcher_tests/_core/runner_modules/jobs_test.py @@ -0,0 +1,99 @@ +# Copyright 2023 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for the executor proxy handlers to get/set job info +""" + +from unittest.mock import MagicMock + +import pytest + +from covalent._shared_files.util_classes import RESULT_STATUS +from covalent_dispatcher._core.runner_modules import jobs + + +@pytest.mark.asyncio +async def test_get_cancel_requested(mocker): + dispatch_id = "test_get_cancel_requested" + mock_job_records = [{"cancel_requested": True}] + + mocker.patch( + "covalent_dispatcher._core.data_modules.job_manager.get_jobs_metadata", + return_value=mock_job_records, + ) + + assert await jobs.get_cancel_requested(dispatch_id, 0) is True + + +@pytest.mark.asyncio +async def test_get_version_info(mocker): + dispatch_id = "test_get_version_info" + mock_ver_info = {"python_version": "3.10", "covalent_version": "0.220"} + + mocker.patch( + "covalent_dispatcher._core.runner_modules.jobs.lattice_query_module.get", + return_value=mock_ver_info, + ) + assert await jobs.get_version_info(dispatch_id, 0) == {"python": "3.10", "covalent": "0.220"} + + +@pytest.mark.asyncio +async def test_get_job_status(mocker): + dispatch_id = "test_job_status" + mock_job_records = [{"status": str(RESULT_STATUS.RUNNING)}] + + mocker.patch( + "covalent_dispatcher._core.data_modules.job_manager.get_jobs_metadata", + return_value=mock_job_records, + ) + + assert await jobs.get_job_status(dispatch_id, 0) == RESULT_STATUS.RUNNING + + +@pytest.mark.asyncio +async def test_put_job_handle(mocker): + dispatch_id = "test_put_job_handle" + task_id = 0 + job_handle = "jobArn" + + mock_set = mocker.patch("covalent_dispatcher._core.data_modules.job_manager.set_job_handle") + + assert await jobs.put_job_handle(dispatch_id, task_id, job_handle) is True + mock_set.assert_awaited_with(dispatch_id, task_id, job_handle) + + +@pytest.mark.asyncio +async def test_put_job_status(mocker): + dispatch_id = "test_put_job_handle" + task_id = 0 + status = RESULT_STATUS.RUNNING + + mock_exec_attrs = {"executor": "dask", "executor_data": {}} + executor = MagicMock() + executor.validate_status = MagicMock(return_value=True) + + mocker.patch( + "covalent_dispatcher._core.data_modules.electron.get", return_value=mock_exec_attrs + ) + + mocker.patch( + "covalent_dispatcher._core.runner_modules.jobs.get_executor", return_value=executor + ) + mock_set = mocker.patch("covalent_dispatcher._core.data_modules.job_manager.set_job_status") + + assert await jobs.put_job_status(dispatch_id, task_id, status) is True + mock_set.assert_awaited_with(dispatch_id, task_id, str(status)) diff --git a/tests/covalent_dispatcher_tests/_core/runner_test.py b/tests/covalent_dispatcher_tests/_core/runner_test.py index 644f6a8c1..7f8b6de52 100644 --- a/tests/covalent_dispatcher_tests/_core/runner_test.py +++ b/tests/covalent_dispatcher_tests/_core/runner_test.py @@ -19,25 +19,14 @@ """ -import json from unittest.mock import AsyncMock, MagicMock import pytest -from mock import call import covalent as ct from covalent._results_manager import Result from covalent._workflow.lattice import Lattice -from covalent_dispatcher._core.runner import ( - _cancel_task, - _gather_deps, - _get_metadata_for_nodes, - _run_abstract_task, - _run_task, - cancel_tasks, - get_executor, -) -from covalent_dispatcher._core.runner_modules.executor_proxy import _get_cancel_requested +from covalent_dispatcher._core.runner import _run_abstract_task, _run_task from covalent_dispatcher._db.datastore import DataStore TEST_RESULTS_DIR = "/tmp/results" @@ -73,161 +62,91 @@ def pipeline(x): pipeline.build_graph(x="absolute") received_workflow = Lattice.deserialize_from_json(pipeline.serialize_to_json()) result_object = Result(received_workflow, "pipeline_workflow") - result_object._initialize_nodes() return result_object -def test_get_executor(mocker): - """Test that get_executor returns the correct executor""" - - executor_manager_mock = mocker.patch("covalent_dispatcher._core.runner._executor_manager") - executor = get_executor(["local", {"mock-key": "mock-value"}], "mock-loop", "mock-pool") - assert executor_manager_mock.get_executor.mock_calls == [ - call("local"), - call().from_dict({"mock-key": "mock-value"}), - call()._init_runtime(loop="mock-loop", cancel_pool="mock-pool"), - ] - assert executor == executor_manager_mock.get_executor() - - -def test_gather_deps(): - """Test internal _gather_deps for assembling deps into call_before and - call_after""" - - def square(x): - return x * x - - @ct.electron( - deps_bash=ct.DepsBash("ls -l"), - deps_pip=ct.DepsPip(["pandas"]), - call_before=[ct.DepsCall(square, [5])], - call_after=[ct.DepsCall(square, [3])], - ) - def task(x): - return x - - @ct.lattice - def workflow(x): - return task(x) - - workflow.build_graph(5) - - received_workflow = Lattice.deserialize_from_json(workflow.serialize_to_json()) - result_object = Result(received_workflow, "asdf") - - before, after = _gather_deps(result_object, 0) - assert len(before) == 3 - assert len(after) == 1 - - @pytest.mark.asyncio async def test_run_abstract_task_exception_handling(mocker): """Test that exceptions from resolving abstract inputs are handled""" - result_object = get_mock_result() + dispatch_id = "mock_dispatch" + inputs = {"args": [], "kwargs": {}} - mock_get_result = mocker.patch( - "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object - ) + mocker.patch("covalent_dispatcher._core.runner._gather_deps", side_effect=RuntimeError()) mocker.patch( - "covalent_dispatcher._core.runner._get_task_input_values", - side_effect=RuntimeError(), + "covalent_dispatcher._core.data_manager.electron.get", + return_value={"function": "function"}, ) node_result = await _run_abstract_task( - dispatch_id=result_object.dispatch_id, + dispatch_id=dispatch_id, node_id=0, node_name="test_node", abstract_inputs=inputs, - executor=["local", {}], + selected_executor=["local", {}], ) assert node_result["status"] == Result.FAILED @pytest.mark.asyncio -async def test_run_abstract_task_get_cancel_requested(mocker): - """Test that get_cancel_requested is properly handled""" - mock_result = MagicMock() - - result_object = get_mock_result() +async def test_run_task_runtime_exception_handling(mocker): inputs = {"args": [], "kwargs": {}} - mock_app_log = mocker.patch("covalent_dispatcher._core.runner.app_log.debug") - mock_get_result = mocker.patch( - "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object - ) - mock_get_task_input_values = mocker.patch( - "covalent_dispatcher._core.runner._get_task_input_values", - side_effect=RuntimeError(), - ) - mock_get_cancel_requested = mocker.patch( - "covalent_dispatcher._core.runner_modules.executor_proxy._get_cancel_requested", - return_value=AsyncMock(return_value=True), - ) - mock_generate_node_result = mocker.patch( - "covalent_dispatcher._core.runner.datasvc.generate_node_result", - return_value=mock_result, - ) - - node_result = await _run_abstract_task( - dispatch_id=result_object.dispatch_id, - node_id=0, - node_name="test_node", - abstract_inputs=inputs, - executor=["local", {}], + mock_executor = MagicMock() + mock_executor._execute = AsyncMock(return_value=("", "", "error", Result.FAILED)) + mock_get_executor = mocker.patch( + "covalent_dispatcher._core.runner.get_executor", + return_value=mock_executor, ) - mock_get_result.assert_called_with(result_object.dispatch_id) - mock_get_cancel_requested.assert_awaited_once_with(result_object.dispatch_id, 0) - mock_generate_node_result.assert_called() - mock_app_log.assert_called_with(f"Don't run cancelled task {result_object.dispatch_id}:0") - assert node_result == mock_result - - -@pytest.mark.asyncio -async def test_run_task_executor_exception_handling(mocker): - """Test that exceptions from initializing executors are caught""" - - result_object = get_mock_result() - inputs = {"args": [], "kwargs": {}} - mock_get_executor = mocker.patch( - "covalent_dispatcher._core.runner._executor_manager.get_executor", - side_effect=Exception(), + dispatch_id = "mock_dispatch" + mocker.patch( + "covalent_dispatcher._core.data_manager.dispatch.get", + return_value={"results_dir": "/tmp/result"}, ) node_result = await _run_task( - result_object=result_object, + dispatch_id=dispatch_id, node_id=1, inputs=inputs, serialized_callable=None, - executor=["nonexistent", {}], + selected_executor=["local", {}], call_before=[], call_after=[], - node_name="test_node", + node_name="task", ) + mock_executor._execute.assert_awaited_once() + assert node_result["status"] == Result.FAILED + assert node_result["stderr"] == "error" @pytest.mark.asyncio -async def test_run_task_runtime_exception_handling(mocker): - result_object = get_mock_result() +async def test_run_task_exception_handling(mocker): + dispatch_id = "mock_dispatch" inputs = {"args": [], "kwargs": {}} mock_executor = MagicMock() - mock_executor._execute = AsyncMock(return_value=("", "", "error", True)) + mock_executor._execute = AsyncMock(side_effect=RuntimeError("error")) + mock_get_executor = mocker.patch( - "covalent_dispatcher._core.runner._executor_manager.get_executor", + "covalent_dispatcher._core.runner.get_executor", return_value=mock_executor, ) + mocker.patch( + "covalent_dispatcher._core.data_manager.dispatch.get", + return_value={"results_dir": "/tmp/result"}, + ) + mocker.patch("traceback.TracebackException.from_exception", return_value="error") node_result = await _run_task( - result_object=result_object, + dispatch_id=dispatch_id, node_id=1, inputs=inputs, serialized_callable=None, - executor=["local", {}], + selected_executor=["local", {}], call_before=[], call_after=[], node_name="task", @@ -235,119 +154,35 @@ async def test_run_task_runtime_exception_handling(mocker): mock_executor._execute.assert_awaited_once() - assert node_result["stderr"] == "error" - - -@pytest.mark.asyncio -async def test__cancel_task(mocker): - """ - Test module private _cancel_task method - """ - mock_executor = AsyncMock() - mock_executor.from_dict = MagicMock() - mock_executor._init_runtime = MagicMock() - mock_executor._cancel = AsyncMock() - - mock_app_log = mocker.patch("covalent_dispatcher._core.runner.app_log.debug") - get_executor_mock = mocker.patch( - "covalent_dispatcher._core.runner.get_executor", return_value=mock_executor - ) - mock_set_cancel_result = mocker.patch("covalent_dispatcher._core.runner.set_cancel_result") - - dispatch_id = "abcd" - task_id = 0 - executor = "mock_executor" - executor_data = {} - job_handle = "42" - - task_metadata = {"dispatch_id": dispatch_id, "node_id": task_id} - - await _cancel_task(dispatch_id, task_id, executor, executor_data, job_handle) - - assert mock_app_log.call_count == 2 - get_executor_mock.assert_called_once() - mock_executor._cancel.assert_called_with(task_metadata, json.loads(job_handle)) - mock_set_cancel_result.assert_called() + assert node_result["status"] == Result.FAILED + assert node_result["error"] == "error" @pytest.mark.asyncio -async def test__cancel_task_exception(mocker): - """ - Test exception raised in module private _cancel task exception - """ - mock_executor = AsyncMock() - mock_executor.from_dict = MagicMock() - mock_executor._init_runtime = MagicMock() - mock_executor._cancel = AsyncMock(side_effect=Exception("cancel")) - - mock_app_log = mocker.patch("covalent_dispatcher._core.runner.app_log.debug") - get_executor_mock = mocker.patch( - "covalent_dispatcher._core.runner.get_executor", return_value=mock_executor - ) - mocker.patch("covalent_dispatcher._core.runner.set_cancel_result") - - dispatch_id = "abcd" - task_id = 0 - executor = "mock_executor" - executor_data = {} - job_handle = "42" - - task_metadata = {"dispatch_id": dispatch_id, "node_id": task_id} - - cancel_result = await _cancel_task(dispatch_id, task_id, executor, executor_data, job_handle) - assert mock_app_log.call_count == 3 - get_executor_mock.assert_called_once() - mock_executor._cancel.assert_called_with(task_metadata, json.loads(job_handle)) - assert cancel_result is False - +async def test_run_task_executor_exception_handling(mocker): + """Test that exceptions from initializing executors are caught""" -@pytest.mark.asyncio -async def test_cancel_tasks(mocker): - """ - Test cancelling multiple tasks - """ - mock_get_jobs_metadata = mocker.patch( - "covalent_dispatcher._core.runner.get_jobs_metadata", return_value=AsyncMock() - ) - mock_get_metadata_for_nodes = mocker.patch( - "covalent_dispatcher._core.runner._get_metadata_for_nodes", return_value=MagicMock() + dispatch_id = "mock_dispatch" + inputs = {"args": [], "kwargs": {}} + mock_get_executor = mocker.patch( + "covalent_dispatcher._core.runner.get_executor", + side_effect=Exception(), ) - dispatch_id = "abcd" - task_ids = [0, 1] - - await cancel_tasks(dispatch_id, task_ids) - - mock_get_jobs_metadata.assert_awaited_with(dispatch_id, task_ids) - mock_get_metadata_for_nodes.assert_called_with(dispatch_id, task_ids) - - -def test__get_metadata_for_nodes(mocker): - """ - Test module private method for getting nodes metadata - """ - dispatch_id = "abcd" - node_ids = [0, 1] - - mock_get_result_object = mocker.patch( - "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=MagicMock() + mocker.patch( + "covalent_dispatcher._core.data_manager.dispatch.get", + return_value={"results_dir": "/tmp/result"}, ) - _get_metadata_for_nodes(dispatch_id, node_ids) - mock_get_result_object.assert_called_with(dispatch_id) - -@pytest.mark.asyncio -async def test__get_cancel_requested(mocker): - """ - Test module private method for querying if a task was requested to be cancelled - """ - dispatch_id = "abcd" - task_id = 0 - mock_get_jobs_metadata = mocker.patch( - "covalent_dispatcher._core.runner_modules.executor_proxy.job_manager.get_jobs_metadata", - return_value=AsyncMock(), + node_result = await _run_task( + dispatch_id=dispatch_id, + node_id=1, + inputs=inputs, + serialized_callable=None, + selected_executor=["nonexistent", {}], + call_before=[], + call_after=[], + node_name="test_node", ) - await _get_cancel_requested(dispatch_id, task_id) - - mock_get_jobs_metadata.assert_awaited_with(dispatch_id, [task_id]) + assert node_result["status"] == Result.FAILED diff --git a/tests/covalent_dispatcher_tests/_core/tmp_data_manager_test.py b/tests/covalent_dispatcher_tests/_core/tmp_data_manager_test.py deleted file mode 100644 index d3538579d..000000000 --- a/tests/covalent_dispatcher_tests/_core/tmp_data_manager_test.py +++ /dev/null @@ -1,543 +0,0 @@ -# Copyright 2021 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the Apache License 2.0 (the "License"). A copy of the -# License may be obtained with this software package or at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Use of this file is prohibited except in compliance with the License. -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Tests for the core functionality of the dispatcher. -""" - - -from unittest.mock import AsyncMock, MagicMock - -import pytest - -import covalent as ct -from covalent._results_manager import Result -from covalent._shared_files.defaults import sublattice_prefix -from covalent._shared_files.util_classes import RESULT_STATUS -from covalent._workflow.lattice import Lattice -from covalent_dispatcher._core.data_manager import ( - _dispatch_status_queues, - _get_result_object_from_new_lattice, - _get_result_object_from_old_result, - _handle_built_sublattice, - _register_result_object, - _registered_dispatches, - _update_parent_electron, - finalize_dispatch, - generate_node_result, - get_result_object, - get_status_queue, - initialize_result_object, - make_derived_dispatch, - make_dispatch, - make_sublattice_dispatch, - persist_result, - update_node_result, - upsert_lattice_data, -) -from covalent_dispatcher._db.datastore import DataStore - -TEST_RESULTS_DIR = "/tmp/results" - - -@pytest.fixture -def test_db(): - """Instantiate and return an in-memory database.""" - - return DataStore( - db_URL="sqlite+pysqlite:///:memory:", - initialize_db=True, - ) - - -def get_mock_result() -> Result: - """Construct a mock result object corresponding to a lattice.""" - - import sys - - @ct.electron(executor="local") - def task(x): - print(f"stdout: {x}") - print("Error!", file=sys.stderr) - return x - - @ct.lattice - def pipeline(x): - res1 = task(x) - res2 = task(res1) - return res2 - - pipeline.build_graph(x="absolute") - received_workflow = Lattice.deserialize_from_json(pipeline.serialize_to_json()) - result_object = Result(received_workflow, "pipeline_workflow") - - return result_object - - -@pytest.mark.asyncio -async def test_handle_built_sublattice(mocker): - """Test the handle_built_sublattice function.""" - - get_result_object_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.get_result_object", return_value="mock-result" - ) - make_sublattice_dispatch_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.make_sublattice_dispatch", - return_value="mock-sub-dispatch-id", - ) - mock_node_result = generate_node_result( - dispatch_id="mock-dispatch-id", - node_id=0, - node_name="mock_node_name", - status=RESULT_STATUS.COMPLETED, - ) - - await _handle_built_sublattice("mock-dispatch-id", mock_node_result) - get_result_object_mock.assert_called_with("mock-dispatch-id") - make_sublattice_dispatch_mock.assert_called_with("mock-result", mock_node_result) - assert mock_node_result["status"] == RESULT_STATUS.DISPATCHING_SUBLATTICE - assert mock_node_result["start_time"] is not None - assert mock_node_result["end_time"] is None - assert mock_node_result["sub_dispatch_id"] == "mock-sub-dispatch-id" - - -@pytest.mark.asyncio -async def test_handle_built_sublattice_exception(mocker): - """Test the handle_built_sublattice function exception case.""" - - get_result_object_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.get_result_object", side_effect=Exception - ) - make_sublattice_dispatch_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.make_sublattice_dispatch", - return_value="mock-sub-dispatch-id", - ) - mock_node_result = generate_node_result( - dispatch_id="mock-dispatch-id", - node_id=0, - node_name="mock_node_name", - status=RESULT_STATUS.COMPLETED, - ) - - await _handle_built_sublattice("mock-dispatch-id", mock_node_result) - mock_node_result["error"] - get_result_object_mock.assert_called_with("mock-dispatch-id") - make_sublattice_dispatch_mock.assert_not_called() - assert mock_node_result["status"] == RESULT_STATUS.FAILED - assert "exception" in mock_node_result["error"].lower() - - -def test_initialize_result_object(mocker, test_db): - """Test the `initialize_result_object` function""" - - @ct.electron - def task(x): - return x - - @ct.lattice - def workflow(x): - return task(x) - - workflow.build_graph(1) - json_lattice = workflow.serialize_to_json() - mocker.patch("covalent_dispatcher._db.upsert.workflow_db", return_value=test_db) - mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", return_value=test_db) - result_object = get_mock_result() - - mock_persist = mocker.patch("covalent_dispatcher._db.update.persist") - - sub_result_object = initialize_result_object( - json_lattice=json_lattice, parent_result_object=result_object, parent_electron_id=5 - ) - - mock_persist.assert_called_with(sub_result_object, electron_id=5) - assert sub_result_object._root_dispatch_id == result_object.dispatch_id - - -@pytest.mark.parametrize( - "node_name, node_status, sub_dispatch_id, detail", - [ - ( - f"{sublattice_prefix}workflow", - RESULT_STATUS.COMPLETED, - "mock-sub-dispatch-id", - {"sub_dispatch_id": "mock-sub-dispatch-id"}, - ), - (f"{sublattice_prefix}workflow", RESULT_STATUS.COMPLETED, None, {}), - ("mock-node-name", RESULT_STATUS.COMPLETED, None, {}), - ("mock-node-name", RESULT_STATUS.FAILED, None, {}), - ("mock-node-name", RESULT_STATUS.CANCELLED, None, {}), - ], -) -@pytest.mark.asyncio -async def test_update_node_result(mocker, node_name, node_status, sub_dispatch_id, detail): - """Check that update_node_result pushes the correct status updates""" - - status_queue = AsyncMock() - - result_object = get_mock_result() - mock_update_node = mocker.patch("covalent_dispatcher._db.update._node") - mocker.patch( - "covalent_dispatcher._core.data_manager.get_status_queue", return_value=status_queue - ) - handle_built_sublattice_mock = mocker.patch( - "covalent_dispatcher._core.data_manager._handle_built_sublattice" - ) - - node_result = { - "node_id": 0, - "node_name": node_name, - "status": node_status, - "sub_dispatch_id": sub_dispatch_id, - } - await update_node_result(result_object, node_result) - - status_queue.put.assert_awaited_with((0, node_status, detail)) - mock_update_node.assert_called_with(result_object, **node_result) - - if ( - node_status == RESULT_STATUS.COMPLETED - and sub_dispatch_id is None - and node_name.startswith(sublattice_prefix) - ): - handle_built_sublattice_mock.assert_called_with(result_object.dispatch_id, node_result) - else: - handle_built_sublattice_mock.assert_not_called() - - -@pytest.mark.asyncio -async def test_update_node_result_handles_db_exceptions(mocker): - """Check that update_node_result handles db write failures""" - - status_queue = AsyncMock() - - result_object = get_mock_result() - mock_update_node = mocker.patch( - "covalent_dispatcher._db.update._node", side_effect=RuntimeError() - ) - mocker.patch( - "covalent_dispatcher._core.data_manager.get_status_queue", return_value=status_queue - ) - node_result = { - "node_id": 0, - "node_name": "mock_node_name", - "status": RESULT_STATUS.COMPLETED, - "sub_dispatch_id": None, - } - await update_node_result(result_object, node_result) - - status_queue.put.assert_awaited_with((0, RESULT_STATUS.FAILED, {})) - - -@pytest.mark.asyncio -async def test_make_dispatch(mocker): - res = get_mock_result() - mock_init_result = mocker.patch( - "covalent_dispatcher._core.data_manager.initialize_result_object", return_value=res - ) - mock_register = mocker.patch( - "covalent_dispatcher._core.data_manager._register_result_object", return_value=res - ) - json_lattice = '{"workflow_function": "asdf"}' - dispatch_id = await make_dispatch(json_lattice) - assert dispatch_id == res.dispatch_id - mock_register.assert_called_with(res) - - -@pytest.mark.asyncio -async def test_make_sublattice_dispatch(mocker): - """Test the make sublattice dispatch method.""" - - mock_result_object = get_mock_result() - output_mock = MagicMock() - mock_node_result = {"node_id": 0, "output": output_mock} - load_electron_record_mock = mocker.patch( - "covalent_dispatcher._db.load.electron_record", return_value={"id": "mock-electron-id"} - ) - make_dispatch_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.make_dispatch", return_value="mock-dispatch-id" - ) - - res = await make_sublattice_dispatch(mock_result_object, mock_node_result) - assert res == "mock-dispatch-id" - load_electron_record_mock.assert_called_with( - mock_result_object.dispatch_id, mock_node_result["node_id"] - ) - make_dispatch_mock.assert_called_with( - output_mock.object_string, mock_result_object, "mock-electron-id" - ) - - -@pytest.mark.parametrize("reuse", [True, False]) -def test_get_result_object_from_new_lattice(mocker, reuse): - """Test the get result object from new lattice json function.""" - lattice_mock = mocker.patch("covalent_dispatcher._core.data_manager.Lattice") - result_object_mock = mocker.patch("covalent_dispatcher._core.data_manager.Result") - transport_graph_ops_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.TransportGraphOps" - ) - old_result_mock = MagicMock() - res = _get_result_object_from_new_lattice( - json_lattice="mock-lattice", - old_result_object=old_result_mock, - reuse_previous_results=reuse, - ) - assert res == result_object_mock.return_value - lattice_mock.deserialize_from_json.assert_called_with("mock-lattice") - result_object_mock()._initialize_nodes.assert_called_with() - - if reuse: - transport_graph_ops_mock().get_reusable_nodes.assert_called_with( - result_object_mock().lattice.transport_graph - ) - transport_graph_ops_mock().copy_nodes_from.assert_called_once_with( - old_result_mock.lattice.transport_graph, - transport_graph_ops_mock().get_reusable_nodes.return_value, - ) - - else: - transport_graph_ops_mock().get_reusable_nodes.assert_not_called() - transport_graph_ops_mock().copy_nodes_from.assert_not_called() - - -@pytest.mark.parametrize("reuse", [True, False]) -def test_get_result_object_from_old_result(mocker, reuse): - """Test the get result object from old result function.""" - result_object_mock = mocker.patch("covalent_dispatcher._core.data_manager.Result") - old_result_mock = MagicMock() - res = _get_result_object_from_old_result( - old_result_object=old_result_mock, - reuse_previous_results=reuse, - ) - assert res == result_object_mock.return_value - - if reuse: - result_object_mock()._initialize_nodes.assert_not_called() - else: - result_object_mock()._initialize_nodes.assert_called_with() - - assert res._num_nodes == old_result_mock._num_nodes - - -@pytest.mark.parametrize("reuse", [True, False]) -def test_make_derived_dispatch_from_lattice(mocker, reuse): - """Test the make derived dispatch function.""" - - def mock_func(): - pass - - mock_old_result = MagicMock() - mock_new_result = MagicMock() - mock_new_result.dispatch_id = "mock-redispatch-id" - mock_new_result.lattice.transport_graph._graph.nodes = ["mock-nodes"] - load_get_result_object_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.load", return_value=mock_old_result - ) - get_result_object_from_new_lattice_mock = mocker.patch( - "covalent_dispatcher._core.data_manager._get_result_object_from_new_lattice", - return_value=mock_new_result, - ) - get_result_object_from_old_result_mock = mocker.patch( - "covalent_dispatcher._core.data_manager._get_result_object_from_old_result" - ) - update_mock = mocker.patch("covalent_dispatcher._core.data_manager.update") - register_result_object_mock = mocker.patch( - "covalent_dispatcher._core.data_manager._register_result_object" - ) - mock_electron_updates = {"mock-electron-id": mock_func} - redispatch_id = make_derived_dispatch( - parent_dispatch_id="mock-dispatch-id", - json_lattice="mock-json-lattice", - electron_updates=mock_electron_updates, - reuse_previous_results=reuse, - ) - load_get_result_object_mock.called_once_with("mock-dispatch-id", wait=reuse) - get_result_object_from_new_lattice_mock.called_once_with( - "mock-json-lattice", mock_old_result, reuse - ) - get_result_object_from_old_result_mock.assert_not_called() - mock_new_result.lattice.transport_graph.apply_electron_updates.assert_called_once_with( - mock_electron_updates - ) - update_mock().persist.called_once_with(mock_new_result) - register_result_object_mock.assert_called_once_with(mock_new_result) - assert redispatch_id == "mock-redispatch-id" - assert mock_new_result.lattice.transport_graph.dirty_nodes == ["mock-nodes"] - - -@pytest.mark.parametrize("reuse", [True, False]) -def test_make_derived_dispatch_from_old_result(mocker, reuse): - """Test the make derived dispatch function.""" - mock_old_result = MagicMock() - mock_new_result = MagicMock() - mock_new_result.dispatch_id = "mock-redispatch-id" - mock_new_result.lattice.transport_graph._graph.nodes = ["mock-nodes"] - load_get_result_object_mock = mocker.patch( - "covalent_dispatcher._core.data_manager.load", return_value=mock_old_result - ) - get_result_object_from_new_lattice_mock = mocker.patch( - "covalent_dispatcher._core.data_manager._get_result_object_from_new_lattice", - ) - get_result_object_from_old_result_mock = mocker.patch( - "covalent_dispatcher._core.data_manager._get_result_object_from_old_result", - return_value=mock_new_result, - ) - update_mock = mocker.patch("covalent_dispatcher._core.data_manager.update") - register_result_object_mock = mocker.patch( - "covalent_dispatcher._core.data_manager._register_result_object" - ) - redispatch_id = make_derived_dispatch( - parent_dispatch_id="mock-dispatch-id", - reuse_previous_results=reuse, - ) - load_get_result_object_mock.called_once_with("mock-dispatch-id", wait=reuse) - get_result_object_from_new_lattice_mock.assert_not_called() - get_result_object_from_old_result_mock.called_once_with(mock_old_result, reuse) - mock_new_result.lattice.transport_graph.apply_electron_updates.assert_called_once_with({}) - update_mock().persist.called_once_with(mock_new_result) - register_result_object_mock.assert_called_once_with(mock_new_result) - assert redispatch_id == "mock-redispatch-id" - assert mock_new_result.lattice.transport_graph.dirty_nodes == ["mock-nodes"] - - -def test_get_result_object(mocker): - """ - Test get result object - """ - result_object = get_mock_result() - dispatch_id = result_object.dispatch_id - _registered_dispatches[dispatch_id] = result_object - assert get_result_object(dispatch_id) is result_object - del _registered_dispatches[dispatch_id] - - -def test_register_result_object(mocker): - """ - Test registering a result object - """ - result_object = get_mock_result() - dispatch_id = result_object.dispatch_id - _register_result_object(result_object) - assert _registered_dispatches[dispatch_id] is result_object - del _registered_dispatches[dispatch_id] - - -def test_unregister_result_object(mocker): - """ - Test unregistering a result object from lattice - """ - result_object = get_mock_result() - dispatch_id = result_object.dispatch_id - _registered_dispatches[dispatch_id] = result_object - finalize_dispatch(dispatch_id) - assert dispatch_id not in _registered_dispatches - - -def test_get_status_queue(): - """ - Test querying the dispatch status from the queue - """ - import asyncio - - dispatch_id = "dispatch" - q = asyncio.Queue() - _dispatch_status_queues[dispatch_id] = q - assert get_status_queue(dispatch_id) is q - - -@pytest.mark.asyncio -async def test_persist_result(mocker): - """ - Test persisting the result object - """ - result_object = get_mock_result() - - mock_get_result = mocker.patch( - "covalent_dispatcher._core.data_manager.get_result_object", return_value=result_object - ) - mock_update_parent = mocker.patch( - "covalent_dispatcher._core.data_manager._update_parent_electron" - ) - mock_update_lattice = mocker.patch( - "covalent_dispatcher._core.data_manager.update.lattice_data" - ) - - await persist_result(result_object.dispatch_id) - mock_update_parent.assert_awaited_with(result_object) - mock_update_lattice.assert_called_with(result_object) - - -@pytest.mark.parametrize( - "sub_status,mapped_status", - [ - (RESULT_STATUS.COMPLETED, RESULT_STATUS.COMPLETED), - (RESULT_STATUS.POSTPROCESSING_FAILED, RESULT_STATUS.FAILED), - ], -) -@pytest.mark.asyncio -async def test_update_parent_electron(mocker, sub_status, mapped_status): - """ - Test updating parent electron data - """ - parent_result_obj = get_mock_result() - sub_result_obj = get_mock_result() - eid = 5 - parent_dispatch_id = (parent_result_obj.dispatch_id,) - parent_node_id = 2 - sub_result_obj._electron_id = eid - sub_result_obj._status = sub_status - sub_result_obj._result = 42 - - mock_node_result = { - "node_id": parent_node_id, - "end_time": sub_result_obj._end_time, - "status": mapped_status, - "output": sub_result_obj._result, - "error": sub_result_obj._error, - } - - mocker.patch( - "covalent_dispatcher._core.data_manager.generate_node_result", - return_value=mock_node_result, - ) - - mock_update_node = mocker.patch("covalent_dispatcher._core.data_manager.update_node_result") - mocker.patch( - "covalent_dispatcher._core.data_manager.resolve_electron_id", - return_value=(parent_dispatch_id, parent_node_id), - ) - mock_get_res = mocker.patch( - "covalent_dispatcher._core.data_manager.get_result_object", return_value=parent_result_obj - ) - load_mock = mocker.patch("covalent_dispatcher._core.data_manager.load") - load_mock.sublattice_dispatch_id.return_value = "mock-sub-dispatch-id" - await _update_parent_electron(sub_result_obj) - - mock_get_res.assert_called_with(parent_dispatch_id) - mock_update_node.assert_awaited_with(parent_result_obj, mock_node_result) - - -def test_upsert_lattice_data(mocker): - """ - Test updating lattice data in database - """ - result_object = get_mock_result() - mocker.patch( - "covalent_dispatcher._core.data_manager.get_result_object", return_value=result_object - ) - mock_update_lattice = mocker.patch("covalent_dispatcher._db.update.lattice_data") - upsert_lattice_data(result_object.dispatch_id) - mock_update_lattice.assert_called_with(result_object) diff --git a/tests/covalent_dispatcher_tests/_core/tmp_dispatcher_test.py b/tests/covalent_dispatcher_tests/_core/tmp_dispatcher_test.py deleted file mode 100644 index ce659dfa1..000000000 --- a/tests/covalent_dispatcher_tests/_core/tmp_dispatcher_test.py +++ /dev/null @@ -1,592 +0,0 @@ -# Copyright 2021 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the Apache License 2.0 (the "License"). A copy of the -# License may be obtained with this software package or at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Use of this file is prohibited except in compliance with the License. -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Tests for the core functionality of the dispatcher. - -This will be replaced in the next patch. -""" - - -from typing import Dict, List -from unittest.mock import AsyncMock, call - -import cloudpickle as pickle -import pytest -from mock import MagicMock - -import covalent as ct -from covalent._results_manager import Result -from covalent._shared_files.util_classes import RESULT_STATUS -from covalent._workflow.lattice import Lattice -from covalent_dispatcher._core.dispatcher import ( - _get_abstract_task_inputs, - _get_initial_tasks_and_deps, - _handle_cancelled_node, - _handle_completed_node, - _handle_failed_node, - _plan_workflow, - _run_planned_workflow, - _submit_task, - cancel_dispatch, - run_dispatch, - run_workflow, -) -from covalent_dispatcher._db.datastore import DataStore - -TEST_RESULTS_DIR = "/tmp/results" - - -@pytest.fixture -def test_db(): - """Instantiate and return an in-memory database.""" - - return DataStore( - db_URL="sqlite+pysqlite:///:memory:", - initialize_db=True, - ) - - -def get_mock_result() -> Result: - """Construct a mock result object corresponding to a lattice.""" - - import sys - - @ct.electron(executor="local") - def task(x): - print(f"stdout: {x}") - print("Error!", file=sys.stderr) - return x - - @ct.lattice - def pipeline(x): - res1 = task(x) - res2 = task(res1) - return res2 - - pipeline.build_graph(x="absolute") - received_workflow = Lattice.deserialize_from_json(pipeline.serialize_to_json()) - result_object = Result(received_workflow, "pipeline_workflow") - - return result_object - - -def test_plan_workflow(): - """Test workflow planning method.""" - - @ct.electron - def task(x): - return x - - @ct.lattice - def workflow(x): - return task(x) - - workflow.build_graph(1) - workflow.metadata["schedule"] = True - received_workflow = Lattice.deserialize_from_json(workflow.serialize_to_json()) - result_object = Result(received_workflow, "asdf") - _plan_workflow(result_object=result_object) - - # Updated transport graph post planning - updated_tg = pickle.loads(result_object.lattice.transport_graph.serialize(metadata_only=True)) - - assert updated_tg["lattice_metadata"]["schedule"] - - -def test_get_abstract_task_inputs(): - """Test _get_abstract_task_inputs for both dicts and list parameter types""" - - @ct.electron - def list_task(arg: List): - return len(arg) - - @ct.electron - def dict_task(arg: Dict): - return len(arg) - - @ct.electron - def multivariable_task(x, y): - return x, y - - @ct.lattice - def list_workflow(arg): - return list_task(arg) - - @ct.lattice - def dict_workflow(arg): - return dict_task(arg) - - # 1 2 - # \ \ - # 0 3 - # / /\/ - # 4 5 - - @ct.electron - def identity(x): - return x - - @ct.lattice - def multivar_workflow(x, y): - electron_x = identity(x) - electron_y = identity(y) - res1 = multivariable_task(electron_x, electron_y) - res2 = multivariable_task(electron_y, electron_x) - res3 = multivariable_task(electron_y, electron_x) - res4 = multivariable_task(electron_x, electron_y) - return 1 - - # list-type inputs - - # Nodes 0=task, 1=:electron_list:, 2=1, 3=2, 4=3 - list_workflow.build_graph([1, 2, 3]) - abstract_args = [2, 3, 4] - tg = list_workflow.transport_graph - - result_object = Result(lattice=list_workflow, dispatch_id="asdf") - abs_task_inputs = _get_abstract_task_inputs(1, tg.get_node_value(1, "name"), result_object) - - expected_inputs = {"args": abstract_args, "kwargs": {}} - - assert abs_task_inputs == expected_inputs - - # dict-type inputs - - # Nodes 0=task, 1=:electron_dict:, 2=1, 3=2 - dict_workflow.build_graph({"a": 1, "b": 2}) - abstract_args = {"a": 2, "b": 3} - tg = dict_workflow.transport_graph - - result_object = Result(lattice=dict_workflow, dispatch_id="asdf") - task_inputs = _get_abstract_task_inputs(1, tg.get_node_value(1, "name"), result_object) - expected_inputs = {"args": [], "kwargs": abstract_args} - - assert task_inputs == expected_inputs - - # Check arg order - multivar_workflow.build_graph(1, 2) - received_lattice = Lattice.deserialize_from_json(multivar_workflow.serialize_to_json()) - result_object = Result(lattice=received_lattice, dispatch_id="asdf") - tg = received_lattice.transport_graph - - assert list(tg._graph.nodes) == list(range(10)) - tg.set_node_value(0, "output", ct.TransportableObject(1)) - tg.set_node_value(2, "output", ct.TransportableObject(2)) - - task_inputs = _get_abstract_task_inputs(4, tg.get_node_value(4, "name"), result_object) - assert task_inputs["args"] == [0, 2] - - task_inputs = _get_abstract_task_inputs(5, tg.get_node_value(5, "name"), result_object) - assert task_inputs["args"] == [2, 0] - - task_inputs = _get_abstract_task_inputs(6, tg.get_node_value(6, "name"), result_object) - assert task_inputs["args"] == [2, 0] - - task_inputs = _get_abstract_task_inputs(7, tg.get_node_value(7, "name"), result_object) - assert task_inputs["args"] == [0, 2] - - -@pytest.mark.asyncio -async def test_handle_completed_node(mocker): - """Unit test for completed node handler""" - pending_parents = {} - - result_object = get_mock_result() - - # tg edges are (1, 0), (0, 2) - pending_parents[0] = 1 - pending_parents[1] = 0 - pending_parents[2] = 1 - - mock_upsert_lattice = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data" - ) - - node_result = {"node_id": 1, "status": Result.COMPLETED} - - next_nodes = await _handle_completed_node(result_object, 1, pending_parents) - assert next_nodes == [0] - assert pending_parents == {0: 0, 1: 0, 2: 1} - - -@pytest.mark.asyncio -async def test_handle_failed_node(mocker): - """Unit test for failed node handler""" - pending_parents = {} - - result_object = get_mock_result() - # tg edges are (1, 0), (0, 2) - - mock_upsert_lattice = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data" - ) - await _handle_failed_node(result_object, 1) - - mock_upsert_lattice.assert_called() - - -@pytest.mark.asyncio -async def test_handle_cancelled_node(mocker): - """Unit test for cancelled node handler""" - pending_parents = {} - - result_object = get_mock_result() - # tg edges are (1, 0), (0, 2) - - mock_upsert_lattice = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data" - ) - - node_result = {"node_id": 1, "status": Result.CANCELLED} - - await _handle_cancelled_node(result_object, 1) - assert result_object._task_cancelled is True - mock_upsert_lattice.assert_called() - - -@pytest.mark.asyncio -async def test_get_initial_tasks_and_deps(mocker): - """Test internal function for initializing status_queue and pending_parents""" - pending_parents = {} - - result_object = get_mock_result() - num_tasks, initial_nodes, pending_parents = await _get_initial_tasks_and_deps(result_object) - - assert initial_nodes == [1] - assert pending_parents == {0: 1, 1: 0, 2: 1, 3: 3} - assert num_tasks == len(result_object.lattice.transport_graph._graph.nodes) - - -@pytest.mark.asyncio -async def test_run_dispatch(mocker): - """ - Test running a mock dispatch - """ - res = get_mock_result() - mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.get_result_object", return_value=res - ) - mock_run = mocker.patch("covalent_dispatcher._core.dispatcher.run_workflow") - run_dispatch(res.dispatch_id) - mock_run.assert_called_with(res) - - -@pytest.mark.asyncio -async def test_run_workflow_normal(mocker): - """ - Test a normal workflow execution - """ - import asyncio - - result_object = get_mock_result() - msg_queue = asyncio.Queue() - mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.get_status_queue", return_value=msg_queue - ) - mocker.patch("covalent_dispatcher._core.dispatcher._plan_workflow") - mocker.patch( - "covalent_dispatcher._core.dispatcher._run_planned_workflow", return_value=result_object - ) - mock_get_result_object = mocker.patch( - "covalent_dispatcher._core.data_manager.get_result_object", return_value=result_object - ) - mock_upsert = mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data") - mock_unregister = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" - ) - await run_workflow(result_object) - - mock_upsert.assert_called_with(result_object.dispatch_id) - mock_unregister.assert_called_with(result_object.dispatch_id) - - -@pytest.mark.asyncio -async def test_run_completed_workflow(mocker): - """ - Test run completed workflow - """ - import asyncio - - result_object = get_mock_result() - result_object._status = Result.COMPLETED - msg_queue = asyncio.Queue() - mock_get_status_queue = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.get_status_queue", return_value=msg_queue - ) - mock_unregister = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" - ) - mock_plan = mocker.patch("covalent_dispatcher._core.dispatcher._plan_workflow") - mocker.patch( - "covalent_dispatcher._core.dispatcher._run_planned_workflow", return_value=result_object - ) - mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data") - - await run_workflow(result_object) - - mock_plan.assert_not_called() - mock_get_status_queue.assert_not_called() - mock_unregister.assert_called_with(result_object.dispatch_id) - - -@pytest.mark.asyncio -async def test_run_workflow_exception(mocker): - """ - Test any exception raised when running workflow - """ - import asyncio - - result_object = get_mock_result() - msg_queue = asyncio.Queue() - - mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.get_status_queue", return_value=msg_queue - ) - mock_unregister = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" - ) - mocker.patch("covalent_dispatcher._core.dispatcher._plan_workflow") - mocker.patch( - "covalent_dispatcher._core.dispatcher._run_planned_workflow", - return_value=result_object, - side_effect=RuntimeError("Error"), - ) - mock_get_result_object = mocker.patch( - "covalent_dispatcher._core.data_manager.get_result_object", return_value=result_object - ) - mock_upsert = mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data") - - result = await run_workflow(result_object) - - assert result.status == Result.FAILED - mock_upsert.assert_called_with(result_object.dispatch_id) - mock_unregister.assert_called_with(result_object.dispatch_id) - - -@pytest.mark.asyncio -async def test_run_planned_workflow_cancelled_update(mocker): - """ - Test run planned workflow with cancelled update - """ - import asyncio - - result_object = get_mock_result() - - mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data") - tasks_left = 1 - initial_nodes = [0] - pending_deps = {0: 0} - - mocker.patch( - "covalent_dispatcher._core.dispatcher._get_initial_tasks_and_deps", - return_value=(tasks_left, initial_nodes, pending_deps), - ) - - mock_submit_task = mocker.patch("covalent_dispatcher._core.dispatcher._submit_task") - - def side_effect(result_object, node_id): - result_object._task_cancelled = True - - mock_handle_cancelled = mocker.patch( - "covalent_dispatcher._core.dispatcher._handle_cancelled_node", side_effect=side_effect - ) - status_queue = asyncio.Queue() - status_queue.put_nowait((0, Result.CANCELLED, {})) - await _run_planned_workflow(result_object, status_queue) - assert mock_submit_task.await_count == 1 - mock_handle_cancelled.assert_awaited_with(result_object, 0) - - -@pytest.mark.asyncio -async def test_run_planned_workflow_failed_update(mocker): - """ - Test run planned workflow with mocking a failed job update - """ - import asyncio - - result_object = get_mock_result() - - mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data") - tasks_left = 1 - initial_nodes = [0] - pending_deps = {0: 0} - - mocker.patch( - "covalent_dispatcher._core.dispatcher._get_initial_tasks_and_deps", - return_value=(tasks_left, initial_nodes, pending_deps), - ) - - mock_submit_task = mocker.patch("covalent_dispatcher._core.dispatcher._submit_task") - - def side_effect(result_object, node_id): - result_object._task_failed = True - - mock_handle_failed = mocker.patch( - "covalent_dispatcher._core.dispatcher._handle_failed_node", side_effect=side_effect - ) - status_queue = asyncio.Queue() - status_queue.put_nowait((0, Result.FAILED, {})) - await _run_planned_workflow(result_object, status_queue) - assert mock_submit_task.await_count == 1 - mock_handle_failed.assert_awaited_with(result_object, 0) - - -@pytest.mark.asyncio -async def test_run_planned_workflow_dispatching(mocker): - """Test the run planned workflow for a dispatching node.""" - import asyncio - - result_object = get_mock_result() - - mocker.patch("covalent_dispatcher._core.dispatcher.datasvc.upsert_lattice_data") - tasks_left = 1 - initial_nodes = [0] - pending_deps = {0: 0} - - mocker.patch( - "covalent_dispatcher._core.dispatcher._get_initial_tasks_and_deps", - return_value=(tasks_left, initial_nodes, pending_deps), - ) - - mock_submit_task = mocker.patch("covalent_dispatcher._core.dispatcher._submit_task") - - def side_effect(result_object, node_id): - result_object._task_failed = True - - mock_handle_failed = mocker.patch( - "covalent_dispatcher._core.dispatcher._handle_failed_node", side_effect=side_effect - ) - mock_run_dispatch = mocker.patch("covalent_dispatcher._core.dispatcher.run_dispatch") - status_queue = asyncio.Queue() - status_queue.put_nowait( - (0, RESULT_STATUS.DISPATCHING_SUBLATTICE, {"sub_dispatch_id": "mock_sub_dispatch_id"}) - ) - status_queue.put_nowait((0, RESULT_STATUS.FAILED, {})) # This ensures that the loop is exited. - await _run_planned_workflow(result_object, status_queue) - assert mock_submit_task.await_count == 1 - mock_handle_failed.assert_awaited_with(result_object, 0) - mock_run_dispatch.assert_called_once_with("mock_sub_dispatch_id") - - -@pytest.mark.asyncio -async def test_cancel_dispatch(mocker): - """Test cancelling a dispatch, including sub-lattices""" - res = get_mock_result() - sub_res = get_mock_result() - - sub_dispatch_id = "sub_pipeline_workflow" - sub_res._dispatch_id = sub_dispatch_id - - def mock_get_result_object(dispatch_id): - objs = {res._dispatch_id: res, sub_res._dispatch_id: sub_res} - return objs[dispatch_id] - - mock_data_cancel = mocker.patch("covalent_dispatcher._core.dispatcher.set_cancel_requested") - - mock_runner = mocker.patch("covalent_dispatcher._core.dispatcher.runner") - mock_runner.cancel_tasks = AsyncMock() - - mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.get_result_object", mock_get_result_object - ) - - res._initialize_nodes() - sub_res._initialize_nodes() - - tg = res.lattice.transport_graph - tg.set_node_value(2, "sub_dispatch_id", sub_dispatch_id) - sub_tg = sub_res.lattice.transport_graph - - await cancel_dispatch("pipeline_workflow") - - task_ids = list(tg._graph.nodes) - sub_task_ids = list(sub_tg._graph.nodes) - - calls = [call("pipeline_workflow", task_ids), call(sub_dispatch_id, sub_task_ids)] - mock_data_cancel.assert_has_awaits(calls) - mock_runner.cancel_tasks.assert_has_awaits(calls) - - -@pytest.mark.asyncio -async def test_cancel_dispatch_with_task_ids(mocker): - """Test cancelling a dispatch, including sub-lattices and with task ids""" - res = get_mock_result() - sub_res = get_mock_result() - - sub_dispatch_id = "sub_pipeline_workflow" - sub_res._dispatch_id = sub_dispatch_id - - def mock_get_result_object(dispatch_id): - objs = {res._dispatch_id: res, sub_res._dispatch_id: sub_res} - return objs[dispatch_id] - - mock_data_cancel = mocker.patch("covalent_dispatcher._core.dispatcher.set_cancel_requested") - - mock_runner = mocker.patch("covalent_dispatcher._core.dispatcher.runner") - mock_runner.cancel_tasks = AsyncMock() - - mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.get_result_object", mock_get_result_object - ) - mock_app_log = mocker.patch("covalent_dispatcher._core.dispatcher.app_log.debug") - - res._initialize_nodes() - sub_res._initialize_nodes() - - tg = res.lattice.transport_graph - tg.set_node_value(2, "sub_dispatch_id", sub_dispatch_id) - sub_tg = sub_res.lattice.transport_graph - task_ids = list(tg._graph.nodes) - sub_task_ids = list(sub_tg._graph.nodes) - - await cancel_dispatch("pipeline_workflow", task_ids) - - calls = [call("pipeline_workflow", task_ids), call(sub_dispatch_id, sub_task_ids)] - mock_data_cancel.assert_has_awaits(calls) - mock_runner.cancel_tasks.assert_has_awaits(calls) - assert mock_app_log.call_count == 2 - - -@pytest.mark.asyncio -async def test_submit_task(mocker): - """Test the submit task function.""" - - def transport_graph_get_value_side_effect(node_id, key): - if key == "name": - return "mock-name" - if key == "status": - return RESULT_STATUS.COMPLETED - - mock_result = MagicMock() - mock_result.lattice.transport_graph.get_node_value.side_effect = ( - transport_graph_get_value_side_effect - ) - - generate_node_result_mock = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.generate_node_result" - ) - update_node_result_mock = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.update_node_result" - ) - await _submit_task(mock_result, 0) - assert mock_result.lattice.transport_graph.get_node_value.mock_calls == [ - call(0, "name"), - call(0, "status"), - call(0, "output"), - ] - update_node_result_mock.assert_called_with(mock_result, generate_node_result_mock.return_value) - generate_node_result_mock.assert_called_once() diff --git a/tests/covalent_dispatcher_tests/_core/tmp_execution_test.py b/tests/covalent_dispatcher_tests/_core/tmp_execution_test.py deleted file mode 100644 index 37edd198f..000000000 --- a/tests/covalent_dispatcher_tests/_core/tmp_execution_test.py +++ /dev/null @@ -1,391 +0,0 @@ -# Copyright 2021 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the Apache License 2.0 (the "License"). A copy of the -# License may be obtained with this software package or at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Use of this file is prohibited except in compliance with the License. -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""" -Integration tests for the dispatcher, runner, and result modules -""" - -import asyncio -from typing import Dict, List - -import pytest - -import covalent as ct -from covalent._results_manager import Result -from covalent._workflow.lattice import Lattice -from covalent_dispatcher._core.dispatcher import run_workflow -from covalent_dispatcher._core.execution import _get_task_inputs -from covalent_dispatcher._db import update -from covalent_dispatcher._db.datastore import DataStore - -TEST_RESULTS_DIR = "/tmp/results" - - -@pytest.fixture -def test_db(): - """Instantiate and return an in-memory database.""" - - return DataStore( - db_URL="sqlite+pysqlite:///:memory:", - initialize_db=True, - ) - - -def get_mock_result() -> Result: - """Construct a mock result object corresponding to a lattice.""" - - import sys - - @ct.electron(executor="local") - def task(x): - print(f"stdout: {x}") - print("Error!", file=sys.stderr) - return x - - @ct.lattice(deps_bash=ct.DepsBash(["ls"])) - def pipeline(x): - res1 = task(x) - res2 = task(res1) - return res2 - - pipeline.build_graph(x="absolute") - received_workflow = Lattice.deserialize_from_json(pipeline.serialize_to_json()) - result_object = Result(received_workflow, "pipeline_workflow") - - return result_object - - -def test_get_task_inputs(): - """Test _get_task_inputs for both dicts and list parameter types""" - - @ct.electron - def list_task(arg: List): - return len(arg) - - @ct.electron - def dict_task(arg: Dict): - return len(arg) - - @ct.electron - def multivariable_task(x, y): - return x, y - - @ct.lattice - def list_workflow(arg): - return list_task(arg) - - @ct.lattice - def dict_workflow(arg): - return dict_task(arg) - - # 1 2 - # \ \ - # 0 3 - # / /\/ - # 4 5 - - @ct.electron - def identity(x): - return x - - @ct.lattice - def multivar_workflow(x, y): - electron_x = identity(x) - electron_y = identity(y) - res1 = multivariable_task(electron_x, electron_y) - res2 = multivariable_task(electron_y, electron_x) - res3 = multivariable_task(electron_y, electron_x) - res4 = multivariable_task(electron_x, electron_y) - return 1 - - # list-type inputs - - list_workflow.build_graph([1, 2, 3]) - serialized_args = [ct.TransportableObject(i) for i in [1, 2, 3]] - tg = list_workflow.transport_graph - # Nodes 0=task, 1=:electron_list:, 2=1, 3=2, 4=3 - tg.set_node_value(2, "output", ct.TransportableObject(1)) - tg.set_node_value(3, "output", ct.TransportableObject(2)) - tg.set_node_value(4, "output", ct.TransportableObject(3)) - - result_object = Result(lattice=list_workflow, dispatch_id="asdf") - task_inputs = _get_task_inputs(1, tg.get_node_value(1, "name"), result_object) - - expected_inputs = {"args": serialized_args, "kwargs": {}} - - assert task_inputs == expected_inputs - - # dict-type inputs - - dict_workflow.build_graph({"a": 1, "b": 2}) - serialized_args = {"a": ct.TransportableObject(1), "b": ct.TransportableObject(2)} - tg = dict_workflow.transport_graph - # Nodes 0=task, 1=:electron_dict:, 2=1, 3=2 - tg.set_node_value(2, "output", ct.TransportableObject(1)) - tg.set_node_value(3, "output", ct.TransportableObject(2)) - - result_object = Result(lattice=dict_workflow, dispatch_id="asdf") - task_inputs = _get_task_inputs(1, tg.get_node_value(1, "name"), result_object) - expected_inputs = {"args": [], "kwargs": serialized_args} - - assert task_inputs == expected_inputs - - # Check arg order - multivar_workflow.build_graph(1, 2) - received_lattice = Lattice.deserialize_from_json(multivar_workflow.serialize_to_json()) - result_object = Result(lattice=received_lattice, dispatch_id="asdf") - tg = received_lattice.transport_graph - - assert list(tg._graph.nodes) == list(range(10)) - tg.set_node_value(0, "output", ct.TransportableObject(1)) - tg.set_node_value(2, "output", ct.TransportableObject(2)) - - task_inputs = _get_task_inputs(4, tg.get_node_value(4, "name"), result_object) - - input_args = [arg.get_deserialized() for arg in task_inputs["args"]] - assert input_args == [1, 2] - - task_inputs = _get_task_inputs(5, tg.get_node_value(5, "name"), result_object) - input_args = [arg.get_deserialized() for arg in task_inputs["args"]] - assert input_args == [2, 1] - - task_inputs = _get_task_inputs(6, tg.get_node_value(6, "name"), result_object) - input_args = [arg.get_deserialized() for arg in task_inputs["args"]] - assert input_args == [2, 1] - - task_inputs = _get_task_inputs(7, tg.get_node_value(7, "name"), result_object) - input_args = [arg.get_deserialized() for arg in task_inputs["args"]] - assert input_args == [1, 2] - - -@pytest.mark.asyncio -async def test_run_workflow_with_failing_nonleaf(mocker, test_db): - """Test running workflow with a failing intermediate node""" - - @ct.electron - def failing_task(x): - assert False - - @ct.lattice - def workflow(x): - res1 = failing_task(x) - res2 = failing_task(res1) - return res2 - - from covalent._workflow.lattice import Lattice - - workflow.build_graph(5) - - json_lattice = workflow.serialize_to_json() - dispatch_id = "asdf" - lattice = Lattice.deserialize_from_json(json_lattice) - result_object = Result(lattice) - result_object._dispatch_id = dispatch_id - result_object._root_dispatch_id = dispatch_id - result_object._initialize_nodes() - - mocker.patch("covalent_dispatcher._db.datastore.workflow_db", test_db) - mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) - mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) - - mocker.patch( - "covalent._results_manager.result.Result._get_node_name", return_value="failing_task" - ) - mocker.patch( - "covalent._results_manager.result.Result._get_node_error", return_value="AssertionError" - ) - mock_unregister = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" - ) - mocker.patch( - "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object - ) - status_queue = asyncio.Queue() - mocker.patch( - "covalent_dispatcher._core.data_manager.get_status_queue", return_value=status_queue - ) - mock_get_failed_nodes = mocker.patch( - "covalent._results_manager.result.Result._get_failed_nodes", - return_value=[(0, "failing_task")], - ) - - update.persist(result_object) - result_object = await run_workflow(result_object) - mock_unregister.assert_called_with(result_object.dispatch_id) - assert result_object.status == Result.FAILED - - mock_get_failed_nodes.assert_called() - assert result_object._error == "The following tasks failed:\n0: failing_task" - - -@pytest.mark.asyncio -async def test_run_workflow_with_failing_leaf(mocker, test_db): - """Test running workflow with a failing leaf node""" - - @ct.electron - def failing_task(x): - assert False - return x - - @ct.lattice - def workflow(x): - res1 = failing_task(x) - return res1 - - from covalent._workflow.lattice import Lattice - - workflow.build_graph(5) - - json_lattice = workflow.serialize_to_json() - dispatch_id = "asdf" - lattice = Lattice.deserialize_from_json(json_lattice) - result_object = Result(lattice) - result_object._dispatch_id = dispatch_id - result_object._root_dispatch_id = dispatch_id - result_object._initialize_nodes() - - mocker.patch("covalent_dispatcher._db.datastore.workflow_db", test_db) - mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) - mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) - - mocker.patch( - "covalent._results_manager.result.Result._get_node_name", return_value="failing_task" - ) - mocker.patch( - "covalent._results_manager.result.Result._get_node_error", return_value="AssertionError" - ) - mock_unregister = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" - ) - mocker.patch( - "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object - ) - - status_queue = asyncio.Queue() - mocker.patch( - "covalent_dispatcher._core.data_manager.get_status_queue", return_value=status_queue - ) - mock_get_failed_nodes = mocker.patch( - "covalent._results_manager.result.Result._get_failed_nodes", - return_value=[(0, "failing_task")], - ) - - update.persist(result_object) - - result_object = await run_workflow(result_object) - mock_unregister.assert_called_with(result_object.dispatch_id) - assert result_object.status == Result.FAILED - assert result_object._error == "The following tasks failed:\n0: failing_task" - - -@pytest.mark.asyncio -async def test_run_workflow_does_not_deserialize(test_db, mocker): - """Check that dispatcher does not deserialize user data when using - out-of-process `workflow_executor`""" - - @ct.electron(executor="local") - def task(x): - return x - - @ct.lattice(executor="local", workflow_executor="local") - def workflow(x): - # Exercise both sublatticing and postprocessing - sublattice_task = ct.lattice(task, workflow_executor="local") - res1 = ct.electron(sublattice_task(x), executor="local") - return res1 - - dispatch_id = "asdf" - workflow.build_graph(5) - - json_lattice = workflow.serialize_to_json() - lattice = Lattice.deserialize_from_json(json_lattice) - result_object = Result(lattice, dispatch_id=dispatch_id) - result_object._initialize_nodes() - - mocker.patch("covalent_dispatcher._db.datastore.workflow_db", test_db) - mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) - mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) - - mock_unregister = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" - ) - mock_run_abstract_task = mocker.patch("covalent_dispatcher._core.runner._run_abstract_task") - mocker.patch( - "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object - ) - - status_queue = asyncio.Queue() - mocker.patch( - "covalent_dispatcher._core.data_manager.get_status_queue", return_value=status_queue - ) - - update.persist(result_object) - - mock_to_deserialize = mocker.patch("covalent.TransportableObject.get_deserialized") - - result_object = await run_workflow(result_object) - mock_unregister.assert_called_with(result_object.dispatch_id) - - mock_to_deserialize.assert_not_called() - assert result_object.status == Result.RUNNING - assert mock_run_abstract_task.call_count == 1 - - -@pytest.mark.asyncio -async def test_run_workflow_with_failed_postprocess(test_db, mocker): - """Check that run_workflow handles postprocessing failures""" - - dispatch_id = "test_run_workflow_with_failed_postprocess" - result_object = get_mock_result() - result_object._dispatch_id = dispatch_id - result_object._initialize_nodes() - - mocker.patch("covalent_dispatcher._db.datastore.workflow_db", test_db) - mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) - mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) - mock_unregister = mocker.patch( - "covalent_dispatcher._core.dispatcher.datasvc.finalize_dispatch" - ) - mocker.patch( - "covalent_dispatcher._core.runner.datasvc.get_result_object", return_value=result_object - ) - mocker.patch("covalent_dispatcher._core.runner._run_abstract_task") - - update.persist(result_object) - - status_queue = asyncio.Queue() - mocker.patch( - "covalent_dispatcher._core.data_manager.get_status_queue", return_value=status_queue - ) - mock_run_abstract_task = mocker.patch("covalent_dispatcher._core.runner._run_abstract_task") - - def failing_workflow(x): - assert False - - result_object.lattice.set_metadata("workflow_executor", "bogus") - result_object = await run_workflow(result_object) - mock_unregister.assert_called_with(result_object.dispatch_id) - - assert result_object.status == Result.RUNNING - - result_object.lattice.workflow_function = ct.TransportableObject(failing_workflow) - result_object.lattice.set_metadata("workflow_executor", "local") - - result_object = await run_workflow(result_object) - mock_unregister.assert_called_with(result_object.dispatch_id) - - assert result_object.status == Result.RUNNING - assert mock_run_abstract_task.call_count == 2 diff --git a/tests/covalent_dispatcher_tests/_dal/exporters/result_export_test.py b/tests/covalent_dispatcher_tests/_dal/exporters/result_export_test.py index 09caada73..f6136d9b0 100644 --- a/tests/covalent_dispatcher_tests/_dal/exporters/result_export_test.py +++ b/tests/covalent_dispatcher_tests/_dal/exporters/result_export_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # diff --git a/tests/covalent_dispatcher_tests/_dal/importers/result_import_test.py b/tests/covalent_dispatcher_tests/_dal/importers/result_import_test.py index a6e6eb70d..b612e8ac8 100644 --- a/tests/covalent_dispatcher_tests/_dal/importers/result_import_test.py +++ b/tests/covalent_dispatcher_tests/_dal/importers/result_import_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # diff --git a/tests/covalent_dispatcher_tests/_dal/tg_ops_test.py b/tests/covalent_dispatcher_tests/_dal/tg_ops_test.py index 892934b0e..3740b3c78 100644 --- a/tests/covalent_dispatcher_tests/_dal/tg_ops_test.py +++ b/tests/covalent_dispatcher_tests/_dal/tg_ops_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # diff --git a/tests/covalent_dispatcher_tests/_db/load_test.py b/tests/covalent_dispatcher_tests/_db/load_test.py deleted file mode 100644 index efbe02071..000000000 --- a/tests/covalent_dispatcher_tests/_db/load_test.py +++ /dev/null @@ -1,159 +0,0 @@ -# Copyright 2023 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the Apache License 2.0 (the "License"). A copy of the -# License may be obtained with this software package or at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Use of this file is prohibited except in compliance with the License. -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -"""Unit tests for result loading (from database) module.""" - -from unittest.mock import call - -import pytest -from sqlalchemy import select - -import covalent as ct -from covalent._results_manager.result import Result as SDKResult -from covalent._shared_files.util_classes import Status -from covalent._workflow.lattice import Lattice as SDKLattice -from covalent_dispatcher._db import models, update -from covalent_dispatcher._db.datastore import DataStore -from covalent_dispatcher._db.load import ( - _result_from, - electron_record, - get_result_object_from_storage, - sublattice_dispatch_id, -) - - -@pytest.fixture -def test_db(): - """Instantiate and return an in-memory database.""" - - return DataStore( - db_URL="sqlite+pysqlite:///:memory:", - initialize_db=True, - ) - - -def get_mock_result(dispatch_id) -> SDKResult: - """Construct a mock result object corresponding to a lattice.""" - - @ct.electron - def task(x): - return x - - @ct.lattice - def workflow(x): - res1 = task(x) - return res1 - - workflow.build_graph(x=1) - received_workflow = SDKLattice.deserialize_from_json(workflow.serialize_to_json()) - result_object = SDKResult(received_workflow, dispatch_id) - - return result_object - - -def test_result_from(mocker, test_db): - """Test the result from function in the load module.""" - - dispatch_id = "test_result_from" - res = get_mock_result(dispatch_id) - res._initialize_nodes() - - mocker.patch("covalent_dispatcher._db.write_result_to_db.workflow_db", test_db) - mocker.patch("covalent_dispatcher._db.upsert.workflow_db", test_db) - mocker.patch("covalent_dispatcher._dal.base.workflow_db", test_db) - - update.persist(res) - - with test_db.session() as session: - mock_lattice_record = session.scalars( - select(models.Lattice).where(models.Lattice.dispatch_id == dispatch_id) - ).first() - - result_object = _result_from(mock_lattice_record) - - assert result_object._root_dispatch_id == mock_lattice_record.root_dispatch_id - assert result_object._status == Status(mock_lattice_record.status) - assert result_object._error == "" - assert result_object.inputs == res.inputs - assert result_object._start_time == mock_lattice_record.started_at - assert result_object._end_time == mock_lattice_record.completed_at - assert result_object.result == res.result - - -def test_get_result_object_from_storage(mocker): - """Test the get_result_object_from_storage method.""" - from covalent_dispatcher._db.load import Lattice - - result_from_mock = mocker.patch("covalent_dispatcher._db.load._result_from") - - workflow_db_mock = mocker.patch("covalent_dispatcher._db.load.workflow_db") - session_mock = workflow_db_mock.session.return_value.__enter__.return_value - - result_object = get_result_object_from_storage("mock-dispatch-id") - - assert call(Lattice) in session_mock.query.mock_calls - session_mock.query().where().first.assert_called_once() - - assert result_object == result_from_mock.return_value - result_from_mock.assert_called_once_with(session_mock.query().where().first.return_value) - - -def test_get_result_object_from_storage_exception(mocker): - """Test the get_result_object_from_storage method.""" - from covalent_dispatcher._db.load import Lattice - - result_from_mock = mocker.patch("covalent_dispatcher._db.load._result_from") - - workflow_db_mock = mocker.patch("covalent_dispatcher._db.load.workflow_db") - session_mock = workflow_db_mock.session.return_value.__enter__.return_value - session_mock.query().where().first.return_value = None - - with pytest.raises(RuntimeError): - get_result_object_from_storage("mock-dispatch-id") - - assert call(Lattice) in session_mock.query.mock_calls - session_mock.query().where().first.assert_called_once() - - result_from_mock.assert_not_called() - - -def test_electron_record(mocker): - """Test the electron_record method.""" - - workflow_db_mock = mocker.patch("covalent_dispatcher._db.load.workflow_db") - session_mock = workflow_db_mock.session.return_value.__enter__.return_value - - electron_record("mock-dispatch-id", "mock-node-id") - session_mock.query().filter().filter().filter().first.assert_called_once() - - -def test_sublattice_dispatch_id(mocker): - """Test the sublattice_dispatch_id method.""" - - class MockObject: - dispatch_id = "mock-dispatch-id" - - workflow_db_mock = mocker.patch("covalent_dispatcher._db.load.workflow_db") - session_mock = workflow_db_mock.session.return_value.__enter__.return_value - - session_mock.query().filter().first.return_value = MockObject() - res = sublattice_dispatch_id("mock-electron-id") - assert res == "mock-dispatch-id" - - session_mock.query().filter().first.return_value = [] - res = sublattice_dispatch_id("mock-electron-id") - assert res is None diff --git a/tests/covalent_dispatcher_tests/_db/update_test.py b/tests/covalent_dispatcher_tests/_db/update_test.py index 72f949b06..1d9230f67 100644 --- a/tests/covalent_dispatcher_tests/_db/update_test.py +++ b/tests/covalent_dispatcher_tests/_db/update_test.py @@ -200,9 +200,6 @@ def test_result_persist_workflow_1(test_db, result_1, mocker): ) if electron.transport_graph_node_id == 3: executor_data = json.loads(electron.executor_data) - # executor_data = local_store.load_file( - # storage_path=electron.storage_path, filename=electron.executor_data_filename - # ) assert executor_data["short_name"] == le.short_name() assert executor_data["attributes"] == le.__dict__ @@ -210,54 +207,6 @@ def test_result_persist_workflow_1(test_db, result_1, mocker): # Check that there are the appropriate amount of electron dependency records assert len(electron_dependency_rows) == 7 - # # Update some node / lattice statuses - # cur_time = dt.now(timezone.utc) - # result_1._end_time = cur_time - # result_1._status = "COMPLETED" - # result_1._result = ct.TransportableObject({"helo": 1, "world": 2}) - - # for node_id in range(6): - # result_1._update_node( - # node_id=node_id, - # start_time=cur_time, - # end_time=cur_time, - # status="COMPLETED", - # # output={"test_data": "test_data"}, # TODO - Put back in later - # # sublattice_result=None, # TODO - Add a test where this is not None - # ) - - # Call Result.persist - # update.persist(result_1) - - # Query lattice / electron / electron dependency - # with test_db.session() as session: - # lattice_row = session.query(Lattice).first() - # electron_rows = session.query(Electron).all() - # electron_dependency_rows = session.query(ElectronDependency).all() - - # # Check that the lattice records are as expected - # assert lattice_row.completed_at.strftime("%Y-%m-%d %H:%M") == cur_time.strftime( - # "%Y-%m-%d %H:%M" - # ) - # assert lattice_row.status == "COMPLETED" - # result = local_store.load_file( - # storage_path=lattice_storage_path, filename=lattice_row.results_filename - # ) - # assert result_1.result == result.get_deserialized() - - # # Check that the electron records are as expected - # for electron in electron_rows: - # assert electron.status == "COMPLETED" - # assert electron.parent_lattice_id == 1 - # assert ( - # electron.started_at.strftime("%Y-%m-%d %H:%M") - # == electron.completed_at.strftime("%Y-%m-%d %H:%M") - # == cur_time.strftime("%Y-%m-%d %H:%M") - # ) - # assert Path(electron.storage_path) == Path( - # f"{TEMP_RESULTS_DIR}/dispatch_1/node_{electron.transport_graph_node_id}" - # ) - # Tear down temporary results directory teardown_temp_results_dir(dispatch_id="dispatch_1") @@ -368,7 +317,7 @@ def workflow(arr): update.persist(result) tg = workflow.transport_graph - task_groups = set([tg.get_node_value(node_id, "task_group_id") for node_id in tg._graph.nodes]) + task_groups = {tg.get_node_value(node_id, "task_group_id") for node_id in tg._graph.nodes} with test_db.session() as session: job_records = session.query(Job).all() @@ -400,43 +349,3 @@ def workflow(arr): with pytest.raises(RuntimeError): update.persist(result) - - -# @pytest.mark.parametrize("node_name", [None, "mock_node_name", postprocess_prefix]) -# def test_node(mocker, node_name): -# """Test the _node method.""" -# electron_data_mock = mocker.patch("covalent_dispatcher._db.upsert.electron_data") -# lattice_data_mock = mocker.patch("covalent_dispatcher._db.upsert.lattice_data") -# mock_result = mocker.MagicMock() -# update._node( -# mock_result, -# node_id=0, -# node_name=node_name, -# start_time="mock_time", -# end_time="mock_time", -# status="COMPLETED", -# output="mock_output", -# qelectron_data_exists=False, -# ) -# if node_name is None: -# node_name = mock_result.lattice.transport_graph.get_node_value() -# mock_result._update_node.assert_called_once_with( -# node_id=0, -# node_name=node_name, -# start_time="mock_time", -# end_time="mock_time", -# status="COMPLETED", -# output="mock_output", -# qelectron_data_exists=False, -# error=None, -# sub_dispatch_id=None, -# sublattice_result=None, -# stdout=None, -# stderr=None, -# ) -# if node_name.startswith(postprocess_prefix): -# assert mock_result._result == "mock_output" -# assert mock_result._status == "COMPLETED" -# else: -# assert mock_result._result != "mock_output" -# assert mock_result._status != "COMPLETED" diff --git a/tests/covalent_dispatcher_tests/_db/write_result_to_db_test.py b/tests/covalent_dispatcher_tests/_db/write_result_to_db_test.py index ea4571d31..84d647db1 100644 --- a/tests/covalent_dispatcher_tests/_db/write_result_to_db_test.py +++ b/tests/covalent_dispatcher_tests/_db/write_result_to_db_test.py @@ -124,10 +124,8 @@ def get_lattice_kwargs( function_string_filename=FUNCTION_STRING_FILENAME, executor="dask", executor_data=json.dumps({}), - # executor_data_filename=EXECUTOR_DATA_FILENAME, workflow_executor="dask", workflow_executor_data=json.dumps({}), - # workflow_executor_data_filename=WORKFLOW_EXECUTOR_DATA_FILENAME, error_filename=ERROR_FILENAME, inputs_filename=INPUTS_FILENAME, named_args_filename=NAMED_ARGS_FILENAME, @@ -161,10 +159,8 @@ def get_lattice_kwargs( "function_string_filename": function_string_filename, "executor": executor, "executor_data": executor_data, - # "executor_data_filename": executor_data_filename, "workflow_executor": workflow_executor, "workflow_executor_data": workflow_executor_data, - # "workflow_executor_data_filename": workflow_executor_data_filename, "error_filename": error_filename, "inputs_filename": inputs_filename, "named_args_filename": named_args_filename, @@ -197,7 +193,6 @@ def get_electron_kwargs( function_string_filename=FUNCTION_STRING_FILENAME, executor="dask", executor_data=json.dumps({}), - # executor_data_filename=EXECUTOR_DATA_FILENAME, results_filename=RESULTS_FILENAME, value_filename=VALUE_FILENAME, stdout_filename=STDOUT_FILENAME, @@ -208,7 +203,6 @@ def get_electron_kwargs( call_after_filename=CALL_AFTER_FILENAME, job_id=1, qelectron_data_exists=False, - cancel_requested=False, created_at=None, updated_at=None, started_at=None, @@ -229,7 +223,6 @@ def get_electron_kwargs( "function_string_filename": function_string_filename, "executor": executor, "executor_data": executor_data, - # "executor_data_filename": executor_data_filename, "results_filename": results_filename, "value_filename": value_filename, "stdout_filename": stdout_filename, @@ -240,7 +233,6 @@ def get_electron_kwargs( "call_after_filename": call_after_filename, "job_id": job_id, "qelectron_data_exists": qelectron_data_exists, - "cancel_requested": cancel_requested, "created_at": created_at, "updated_at": updated_at, "started_at": started_at, @@ -301,9 +293,7 @@ def test_insert_lattices_data(test_db, mocker): assert lattice.function_filename == FUNCTION_FILENAME assert lattice.function_string_filename == FUNCTION_STRING_FILENAME assert lattice.executor == "dask" - # assert lattice.executor_data_filename == EXECUTOR_DATA_FILENAME assert lattice.workflow_executor == "dask" - # assert lattice.workflow_executor_data_filename == WORKFLOW_EXECUTOR_DATA_FILENAME assert lattice.error_filename == ERROR_FILENAME assert lattice.inputs_filename == INPUTS_FILENAME assert lattice.named_args_filename == NAMED_ARGS_FILENAME diff --git a/tests/covalent_dispatcher_tests/_object_store/__init__.py b/tests/covalent_dispatcher_tests/_object_store/__init__.py index 883ec0eda..55e011b94 100644 --- a/tests/covalent_dispatcher_tests/_object_store/__init__.py +++ b/tests/covalent_dispatcher_tests/_object_store/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # diff --git a/tests/covalent_dispatcher_tests/_object_store/local_test.py b/tests/covalent_dispatcher_tests/_object_store/local_test.py index 52e0509f3..9f869c6df 100644 --- a/tests/covalent_dispatcher_tests/_object_store/local_test.py +++ b/tests/covalent_dispatcher_tests/_object_store/local_test.py @@ -1,4 +1,4 @@ -# Copyright 2021 Agnostiq Inc. +# Copyright 2023 Agnostiq Inc. # # This file is part of Covalent. # diff --git a/tests/covalent_dispatcher_tests/_service/app_test.py b/tests/covalent_dispatcher_tests/_service/app_test.py index f05c492ca..7877fe673 100644 --- a/tests/covalent_dispatcher_tests/_service/app_test.py +++ b/tests/covalent_dispatcher_tests/_service/app_test.py @@ -17,17 +17,21 @@ """Unit tests for the FastAPI app.""" import json -import os +import tempfile from contextlib import contextmanager from typing import Generator +from unittest.mock import MagicMock import pytest from fastapi.testclient import TestClient from sqlalchemy import Column, Integer, String, create_engine from sqlalchemy.orm import Session, declarative_base, sessionmaker -from covalent._results_manager.result import Result +import covalent as ct +from covalent._dispatcher_plugins.local import LocalDispatcher +from covalent._shared_files.util_classes import RESULT_STATUS from covalent_dispatcher._db.dispatchdb import DispatchDB +from covalent_dispatcher._service.app import _try_get_result_object, cancel_all_with_status from covalent_ui.app import fastapi_app as fast_app DISPATCH_ID = "f34671d1-48f2-41ce-89d9-9a8cb5c60e5d" @@ -74,6 +78,25 @@ def test_db(): return MockDataStore(db_URL="sqlite+pysqlite:///:memory:") +@pytest.fixture +def mock_manifest(): + """Create a mock workflow manifest""" + + @ct.electron + def task(x): + return x**2 + + @ct.lattice + def workflow(x): + return task(x) + + workflow.build_graph(3) + + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + return manifest + + @pytest.fixture def test_db_file(): """Instantiate and return a database.""" @@ -81,61 +104,38 @@ def test_db_file(): @pytest.mark.asyncio -@pytest.mark.parametrize("disable_run", [True, False]) -async def test_submit(mocker, client, disable_run): +async def test_submit(mocker, client): """Test the submit endpoint.""" mock_data = json.dumps({}).encode("utf-8") run_dispatcher_mock = mocker.patch( - "covalent_dispatcher.run_dispatcher", return_value=DISPATCH_ID + "covalent_dispatcher.entry_point.make_dispatch", return_value=DISPATCH_ID ) - response = client.post("/api/submit", data=mock_data, params={"disable_run": disable_run}) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + response = client.post("/api/v2/dispatches/submit", data=mock_data) assert response.json() == DISPATCH_ID - run_dispatcher_mock.assert_called_once_with(mock_data, disable_run) + run_dispatcher_mock.assert_called_once_with(mock_data) @pytest.mark.asyncio async def test_submit_exception(mocker, client): """Test the submit endpoint.""" mock_data = json.dumps({}).encode("utf-8") - mocker.patch("covalent_dispatcher.run_dispatcher", side_effect=Exception("mock")) - response = client.post("/api/submit", data=mock_data) + mocker.patch("covalent_dispatcher.entry_point.make_dispatch", side_effect=Exception("mock")) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + response = client.post("/api/v2/dispatches/submit", data=mock_data) assert response.status_code == 400 assert response.json()["detail"] == "Failed to submit workflow: mock" -@pytest.mark.asyncio -@pytest.mark.parametrize("is_pending", [True, False]) -async def test_redispatch(mocker, client, is_pending): - """Test the redispatch endpoint.""" - json_lattice = None - electron_updates = None - reuse_previous_results = False - mock_data = json.dumps( - { - "dispatch_id": DISPATCH_ID, - "json_lattice": json_lattice, - "electron_updates": electron_updates, - "reuse_previous_results": reuse_previous_results, - } - ).encode("utf-8") - run_redispatch_mock = mocker.patch( - "covalent_dispatcher.run_redispatch", return_value=DISPATCH_ID - ) - - response = client.post("/api/redispatch", data=mock_data, params={"is_pending": is_pending}) - assert response.json() == DISPATCH_ID - run_redispatch_mock.assert_called_once_with( - DISPATCH_ID, json_lattice, electron_updates, reuse_previous_results, is_pending - ) - - def test_cancel_dispatch(mocker, app, client): """ Test cancelling dispatch """ - mocker.patch("covalent_dispatcher.cancel_running_dispatch") - response = client.post( - "/api/cancel", data=json.dumps({"dispatch_id": DISPATCH_ID, "task_ids": []}) + mocker.patch("covalent_dispatcher.entry_point.cancel_running_dispatch") + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + response = client.put( + f"/api/v2/dispatches/{DISPATCH_ID}/status", + json={"status": "CANCELLED"}, ) assert response.json() == f"Dispatch {DISPATCH_ID} cancelled." @@ -144,104 +144,196 @@ def test_cancel_tasks(mocker, app, client): """ Test cancelling tasks within a lattice after dispatch """ - mocker.patch("covalent_dispatcher.cancel_running_dispatch") - response = client.post( - "/api/cancel", data=json.dumps({"dispatch_id": DISPATCH_ID, "task_ids": [0, 1]}) + mocker.patch("covalent_dispatcher.entry_point.cancel_running_dispatch") + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + response = client.put( + f"/api/v2/dispatches/{DISPATCH_ID}/status", + json={"status": "CANCELLED", "task_ids": [0, 1]}, ) assert response.json() == f"Cancelled tasks [0, 1] in dispatch {DISPATCH_ID}." -@pytest.mark.asyncio -async def test_redispatch_exception(mocker, client): - """Test the redispatch endpoint.""" - response = client.post("/api/redispatch", data="bad data") - assert response.status_code == 400 - assert ( - response.json()["detail"] - == "Failed to redispatch workflow: Expecting value: line 1 column 1 (char 0)" - ) - - -@pytest.mark.asyncio -async def test_cancel(mocker, client): - """Test the cancel endpoint.""" - cancel_running_dispatch_mock = mocker.patch( - "covalent_dispatcher.cancel_running_dispatch", return_value=DISPATCH_ID - ) - response = client.post( - "/api/cancel", data=json.dumps({"dispatch_id": DISPATCH_ID, "task_ids": []}) - ) - assert response.json() == f"Dispatch {DISPATCH_ID} cancelled." - cancel_running_dispatch_mock.assert_called_once_with(DISPATCH_ID, []) - - @pytest.mark.asyncio async def test_cancel_exception(mocker, client): """Test the cancel endpoint.""" cancel_running_dispatch_mock = mocker.patch( - "covalent_dispatcher.cancel_running_dispatch", side_effect=Exception("mock") + "covalent_dispatcher.entry_point.cancel_running_dispatch", side_effect=Exception("mock") ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") with pytest.raises(Exception): - response = client.post( - "/api/cancel", data=json.dumps({"dispatch_id": DISPATCH_ID, "task_ids": []}) + response = client.put( + f"/api/v2/dispatches/{DISPATCH_ID}/status", + json={"status": "CANCELLED", "task_ids": []}, ) assert response.status_code == 400 assert response.json()["detail"] == "Failed to cancel workflow: mock" cancel_running_dispatch_mock.assert_called_once_with(DISPATCH_ID, []) -def test_get_result(mocker, client, test_db_file): - """Test the get-result endpoint.""" - lattice = MockLattice( - status=str(Result.COMPLETED), - dispatch_id=DISPATCH_ID, +def test_db_path_get_config(mocker): + """Test that the db path is retrieved from the config.""" "" + get_config_mock = mocker.patch("covalent_dispatcher._db.dispatchdb.get_config") + + DispatchDB() + + get_config_mock.assert_called_once() + + +def test_register(mocker, app, client, mock_manifest): + mock_register_dispatch = mocker.patch( + "covalent_dispatcher._service.app.dispatcher.register_dispatch", return_value=mock_manifest ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + resp = client.post("/api/v2/dispatches", data=mock_manifest.json()) - with test_db_file.session() as session: - session.add(lattice) - session.commit() + assert resp.json() == json.loads(mock_manifest.json()) + mock_register_dispatch.assert_awaited_with(mock_manifest, None) - mocker.patch("covalent_dispatcher._service.app._result_from", return_value={}) - mocker.patch("covalent_dispatcher._service.app.workflow_db", test_db_file) - mocker.patch("covalent_dispatcher._service.app.Lattice", MockLattice) - response = client.get(f"/api/result/{DISPATCH_ID}") - result = response.json() - assert result["id"] == DISPATCH_ID - assert result["status"] == Result.COMPLETED - os.remove("/tmp/testdb.sqlite") + +def test_register_exception(mocker, app, client, mock_manifest): + mock_register_dispatch = mocker.patch( + "covalent_dispatcher._service.app.dispatcher.register_dispatch", side_effect=RuntimeError() + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + resp = client.post("/api/v2/dispatches", data=mock_manifest.json()) + assert resp.status_code == 400 -def test_get_result_503(mocker, client, test_db_file): - """Test the get-result endpoint.""" - lattice = MockLattice( - status=str(Result.NEW_OBJ), - dispatch_id=DISPATCH_ID, +def test_register_sublattice(mocker, app, client, mock_manifest): + mock_register_dispatch = mocker.patch( + "covalent_dispatcher._service.app.dispatcher.register_dispatch", return_value=mock_manifest + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + resp = client.post( + "/api/v2/dispatches/parent_dispatch/subdispatches", + data=mock_manifest.json(), ) - with test_db_file.session() as session: - session.add(lattice) - session.commit() - mocker.patch("covalent_dispatcher._service.app._result_from", side_effect=FileNotFoundError()) - mocker.patch("covalent_dispatcher._service.app.workflow_db", test_db_file) - mocker.patch("covalent_dispatcher._service.app.Lattice", MockLattice) - response = client.get(f"/api/result/{DISPATCH_ID}?wait=True&status_only=True") - assert response.status_code == 503 - os.remove("/tmp/testdb.sqlite") + assert resp.json() == json.loads(mock_manifest.json()) + mock_register_dispatch.assert_awaited_with(mock_manifest, "parent_dispatch") -def test_get_result_dispatch_id_not_found(mocker, test_db_file, client): - """Test the get-result endpoint and that 404 is returned if the dispatch ID is not found in the database.""" - mocker.patch("covalent_dispatcher._service.app._result_from", return_value={}) - mocker.patch("covalent_dispatcher._service.app.workflow_db", test_db_file) - mocker.patch("covalent_dispatcher._service.app.Lattice", MockLattice) - response = client.get(f"/api/result/{DISPATCH_ID}") - assert response.status_code == 404 +def test_register_redispatch(mocker, app, client, mock_manifest): + dispatch_id = "test_register_redispatch" + mock_register_redispatch = mocker.patch( + "covalent_dispatcher._service.app.dispatcher.register_redispatch", + return_value=mock_manifest, + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + resp = client.post(f"/api/v2/dispatches/{dispatch_id}/redispatches", data=mock_manifest.json()) + mock_register_redispatch.assert_awaited_with(mock_manifest, dispatch_id, False) + assert resp.json() == json.loads(mock_manifest.json()) -def test_db_path_get_config(mocker): - """Test that the db path is retrieved from the config.""" "" - get_config_mock = mocker.patch("covalent_dispatcher._db.dispatchdb.get_config") - DispatchDB() +def test_register_redispatch_reuse(mocker, app, client, mock_manifest): + dispatch_id = "test_register_redispatch" + mock_register_redispatch = mocker.patch( + "covalent_dispatcher._service.app.dispatcher.register_redispatch", + return_value=mock_manifest, + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + resp = client.post( + f"/api/v2/dispatches/{dispatch_id}/redispatches", + data=mock_manifest.json(), + params={"reuse_previous_results": True}, + ) + mock_register_redispatch.assert_awaited_with(mock_manifest, dispatch_id, True) + assert resp.json() == json.loads(mock_manifest.json()) - get_config_mock.assert_called_once() + +def test_register_redispatch_exception(mocker, app, client, mock_manifest): + dispatch_id = "test_register_redispatch" + mock_register_redispatch = mocker.patch( + "covalent_dispatcher._service.app.dispatcher.register_redispatch", + side_effect=RuntimeError(), + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + resp = client.post(f"/api/v2/dispatches/{dispatch_id}/redispatches", data=mock_manifest.json()) + assert resp.status_code == 400 + + +def test_start(mocker, app, client): + dispatch_id = "test_start" + mock_start = mocker.patch("covalent_dispatcher._service.app.dispatcher.start_dispatch") + mock_create_task = mocker.patch("asyncio.create_task") + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + resp = client.put(f"/api/v2/dispatches/{dispatch_id}/status", json={"status": "RUNNING"}) + assert resp.json() == dispatch_id + + +def test_export_result_nowait(mocker, app, client, mock_manifest): + dispatch_id = "test_export_result" + mock_result_object = MagicMock() + mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.NEW_OBJECT)) + mocker.patch( + "covalent_dispatcher._service.app._try_get_result_object", return_value=mock_result_object + ) + mock_export = mocker.patch( + "covalent_dispatcher._service.app.export_result_manifest", return_value=mock_manifest + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + resp = client.get(f"/api/v2/dispatches/{dispatch_id}") + assert resp.status_code == 200 + assert resp.json()["id"] == dispatch_id + assert resp.json()["status"] == str(RESULT_STATUS.NEW_OBJECT) + assert resp.json()["result_export"] == json.loads(mock_manifest.json()) + + +def test_export_result_wait_not_ready(mocker, app, client, mock_manifest): + dispatch_id = "test_export_result" + mock_result_object = MagicMock() + mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.RUNNING)) + mocker.patch( + "covalent_dispatcher._service.app._try_get_result_object", return_value=mock_result_object + ) + mock_export = mocker.patch( + "covalent_dispatcher._service.app.export_result_manifest", return_value=mock_manifest + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + resp = client.get(f"/api/v2/dispatches/{dispatch_id}", params={"wait": True}) + assert resp.status_code == 503 + + +def test_export_result_bad_dispatch_id(mocker, app, client, mock_manifest): + dispatch_id = "test_export_result" + mock_result_object = MagicMock() + mock_result_object.get_value = MagicMock(return_value=str(RESULT_STATUS.NEW_OBJECT)) + mocker.patch("covalent_dispatcher._service.app._try_get_result_object", return_value=None) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + resp = client.get(f"/api/v2/dispatches/{dispatch_id}") + assert resp.status_code == 404 + + +def test_try_get_result_object(mocker, app, client, mock_manifest): + dispatch_id = "test_try_get_result_object" + mock_result_object = MagicMock() + mocker.patch( + "covalent_dispatcher._service.app.get_result_object", return_value=mock_result_object + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + assert _try_get_result_object(dispatch_id) == mock_result_object + + +def test_try_get_result_object_not_found(mocker, app, client, mock_manifest): + dispatch_id = "test_try_get_result_object" + mock_result_object = MagicMock() + mocker.patch("covalent_dispatcher._service.app.get_result_object", side_effect=KeyError()) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + assert _try_get_result_object(dispatch_id) is None + + +@pytest.mark.asyncio +async def test_cancel_all_with_status(mocker, test_db): + mock_rec = MagicMock() + mock_rec.dispatch_id = "mock_dispatch" + + mocker.patch("covalent_dispatcher._service.app.workflow_db", test_db) + mocker.patch("covalent_dispatcher._dal.result.Result.get_db_records", return_value=[mock_rec]) + mock_cancel = mocker.patch( + "covalent_dispatcher._service.app.dispatcher.cancel_running_dispatch" + ) + + await cancel_all_with_status(RESULT_STATUS.RUNNING) + + mock_cancel.assert_awaited_with("mock_dispatch") diff --git a/tests/covalent_dispatcher_tests/_service/assets_test.py b/tests/covalent_dispatcher_tests/_service/assets_test.py new file mode 100644 index 000000000..5f704ca43 --- /dev/null +++ b/tests/covalent_dispatcher_tests/_service/assets_test.py @@ -0,0 +1,739 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unit tests for the FastAPI asset endpoints""" + +import tempfile +from contextlib import contextmanager +from typing import Generator +from unittest.mock import MagicMock + +import pytest +from fastapi import HTTPException +from fastapi.testclient import TestClient +from sqlalchemy import Column, Integer, String, create_engine +from sqlalchemy.orm import Session, declarative_base, sessionmaker + +from covalent._workflow.transportable_object import TransportableObject +from covalent_dispatcher._service.assets import ( + _generate_file_slice, + _get_tobj_pickle_offsets, + _get_tobj_string_offsets, + get_cached_result_object, +) +from covalent_ui.app import fastapi_app as fast_app + +DISPATCH_ID = "f34671d1-48f2-41ce-89d9-9a8cb5c60e5d" + +INTERNAL_URI = "file:///tmp/object.pkl" + +# Mock SqlAlchemy models +MockBase = declarative_base() + + +class MockLattice(MockBase): + __tablename__ = "lattices" + id = Column(Integer, primary_key=True) + dispatch_id = Column(String(64), nullable=False) + status = Column(String(24), nullable=False) + + +class MockDataStore: + def __init__(self, db_URL): + self.db_URL = db_URL + self.engine = create_engine(self.db_URL) + self.Session = sessionmaker(self.engine) + + MockBase.metadata.create_all(self.engine) + + @contextmanager + def session(self) -> Generator[Session, None, None]: + with self.Session.begin() as session: + yield session + + +@pytest.fixture +def app(): + yield fast_app + + +@pytest.fixture +def client(): + with TestClient(fast_app) as c: + yield c + + +@pytest.fixture +def test_db(): + """Instantiate and return an in-memory database.""" + return MockDataStore(db_URL="sqlite+pysqlite:///:memory:") + + +@pytest.fixture +def mock_result_object(): + res_obj = MagicMock() + mock_node = MagicMock() + mock_asset = MagicMock() + mock_asset.internal_uri = INTERNAL_URI + + res_obj.get_asset = MagicMock(return_value=mock_asset) + res_obj.update_assets = MagicMock() + res_obj.lattice.get_asset = MagicMock(return_value=mock_asset) + res_obj.lattice.update_assets = MagicMock() + + res_obj.lattice.transport_graph.get_node = MagicMock(return_value=mock_node) + + mock_node.get_asset = MagicMock(return_value=mock_asset) + mock_node.update_assets = MagicMock() + + return res_obj + + +def test_get_node_asset(mocker, client, test_db, mock_result_object): + """ + Test get node asset + """ + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + yield "Hi" + + key = "output" + node_id = 0 + dispatch_id = "test_get_node_asset_no_dispatch_id" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + resp = client.get(f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}") + + assert resp.text == "Hi" + assert (INTERNAL_URI, 0, -1, 65536) == mock_generator.calls[0] + + +def test_get_node_asset_byte_range(mocker, client, test_db, mock_result_object): + """ + Test get node asset + """ + + test_str = "test_get_node_asset_string_rep" + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + if end_byte >= 0: + yield test_str[start_byte:end_byte] + else: + yield test_str[start_byte:] + + key = "output" + node_id = 0 + dispatch_id = "test_get_node_asset_no_dispatch_id" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + + headers = {"Range": "bytes=0-6"} + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + resp = client.get( + f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}", headers=headers + ) + + assert resp.text == test_str[0:6] + assert (INTERNAL_URI, 0, 6, 65536) == mock_generator.calls[0] + + +@pytest.mark.parametrize("rep,start_byte,end_byte", [("string", 0, 6), ("object", 6, 12)]) +def test_get_node_asset_rep( + mocker, client, test_db, mock_result_object, rep, start_byte, end_byte +): + """ + Test get node asset + """ + + test_str = "test_get_node_asset_rep" + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + if end_byte >= 0: + yield test_str[start_byte:end_byte] + else: + yield test_str[start_byte:] + + key = "output" + node_id = 0 + dispatch_id = "test_get_node_asset_no_dispatch_id" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + mocker.patch( + "covalent_dispatcher._service.assets._get_tobj_string_offsets", return_value=(0, 6) + ) + mocker.patch( + "covalent_dispatcher._service.assets._get_tobj_pickle_offsets", return_value=(6, 12) + ) + + params = {"representation": rep} + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + resp = client.get( + f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}", params=params + ) + + assert resp.text == test_str[start_byte:end_byte] + assert (INTERNAL_URI, start_byte, end_byte, 65536) == mock_generator.calls[0] + + +def test_get_node_asset_bad_dispatch_id(mocker, client): + """ + Test get node asset + """ + key = "output" + node_id = 0 + dispatch_id = "test_get_node_asset" + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + mocker.patch( + "covalent_dispatcher._service.assets.get_cached_result_object", + side_effect=HTTPException(status_code=400), + ) + resp = client.get(f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}") + assert resp.status_code == 400 + + +def test_get_lattice_asset(mocker, client, test_db, mock_result_object): + """ + Test get lattice asset + """ + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + yield "Hi" + + key = "workflow_function" + dispatch_id = "test_get_lattice_asset_no_dispatch_id" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + resp = client.get(f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}") + + assert resp.text == "Hi" + assert (INTERNAL_URI, 0, -1, 65536) == mock_generator.calls[0] + + +def test_get_lattice_asset_byte_range(mocker, client, test_db, mock_result_object): + """ + Test get lattice asset + """ + + test_str = "test_lattice_asset_byte_range" + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + if end_byte >= 0: + yield test_str[start_byte:end_byte] + else: + yield test_str[start_byte:] + + key = "workflow_function" + dispatch_id = "test_get_lattice_asset_no_dispatch_id" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + headers = {"Range": "bytes=0-6"} + resp = client.get(f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}", headers=headers) + + assert resp.text == test_str[0:6] + assert (INTERNAL_URI, 0, 6, 65536) == mock_generator.calls[0] + + +@pytest.mark.parametrize("rep,start_byte,end_byte", [("string", 0, 6), ("object", 6, 12)]) +def test_get_lattice_asset_rep( + mocker, client, test_db, mock_result_object, rep, start_byte, end_byte +): + """ + Test get lattice asset + """ + + test_str = "test_get_lattice_asset_rep" + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + if end_byte >= 0: + yield test_str[start_byte:end_byte] + else: + yield test_str[start_byte:] + + key = "workflow_function" + dispatch_id = "test_get_lattice_asset_rep" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + mocker.patch( + "covalent_dispatcher._service.assets._get_tobj_string_offsets", return_value=(0, 6) + ) + mocker.patch( + "covalent_dispatcher._service.assets._get_tobj_pickle_offsets", return_value=(6, 12) + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + params = {"representation": rep} + + resp = client.get(f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}", params=params) + + assert resp.text == test_str[start_byte:end_byte] + assert (INTERNAL_URI, start_byte, end_byte, 65536) == mock_generator.calls[0] + + +def test_get_lattice_asset_bad_dispatch_id(mocker, client): + """ + Test get lattice asset + """ + + key = "workflow_function" + dispatch_id = "test_get_lattice_asset_no_dispatch_id" + + mocker.patch( + "covalent_dispatcher._service.assets.get_cached_result_object", + side_effect=HTTPException(status_code=400), + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + resp = client.get(f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}") + assert resp.status_code == 400 + + +def test_get_dispatch_asset(mocker, client, test_db, mock_result_object): + """ + Test get dispatch asset + """ + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + yield "Hi" + + key = "result" + dispatch_id = "test_get_dispatch_asset" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + resp = client.get(f"/api/v2/dispatches/{dispatch_id}/assets/{key}") + + assert resp.text == "Hi" + assert (INTERNAL_URI, 0, -1, 65536) == mock_generator.calls[0] + + +def test_get_dispatch_asset_byte_range(mocker, client, test_db, mock_result_object): + """ + Test get dispatch asset + """ + + test_str = "test_dispatch_asset_byte_range" + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + if end_byte >= 0: + yield test_str[start_byte:end_byte] + else: + yield test_str[start_byte:] + + key = "result" + dispatch_id = "test_get_dispatch_asset_byte_range" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + + headers = {"Range": "bytes=0-6"} + resp = client.get(f"/api/v2/dispatches/{dispatch_id}/assets/{key}", headers=headers) + + assert resp.text == test_str[0:6] + assert (INTERNAL_URI, 0, 6, 65536) == mock_generator.calls[0] + + +@pytest.mark.parametrize("rep,start_byte,end_byte", [("string", 0, 6), ("object", 6, 12)]) +def test_get_dispatch_asset_rep( + mocker, client, test_db, mock_result_object, rep, start_byte, end_byte +): + """ + Test get dispatch asset + """ + + test_str = "test_get_dispatch_asset_rep" + + class MockGenerateFileSlice: + def __init__(self): + self.calls = [] + + def __call__(self, file_url: str, start_byte: int, end_byte: int, chunk_size: int = 65536): + self.calls.append((file_url, start_byte, end_byte, chunk_size)) + if end_byte >= 0: + yield test_str[start_byte:end_byte] + else: + yield test_str[start_byte:] + + key = "result" + dispatch_id = "test_get_dispatch_asset_rep" + mock_generator = MockGenerateFileSlice() + + mocker.patch("fastapi.responses.StreamingResponse") + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mock_generate_file_slice = mocker.patch( + "covalent_dispatcher._service.assets._generate_file_slice", mock_generator + ) + mocker.patch( + "covalent_dispatcher._service.assets._get_tobj_string_offsets", return_value=(0, 6) + ) + mocker.patch( + "covalent_dispatcher._service.assets._get_tobj_pickle_offsets", return_value=(6, 12) + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + params = {"representation": rep} + + resp = client.get(f"/api/v2/dispatches/{dispatch_id}/assets/{key}", params=params) + + assert resp.text == test_str[start_byte:end_byte] + assert (INTERNAL_URI, start_byte, end_byte, 65536) == mock_generator.calls[0] + + +def test_get_dispatch_asset_bad_dispatch_id(mocker, client): + """ + Test get dispatch asset + """ + + key = "result" + dispatch_id = "test_get_dispatch_asset_no_dispatch_id" + + mocker.patch( + "covalent_dispatcher._service.assets.get_cached_result_object", + side_effect=HTTPException(status_code=400), + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + resp = client.get(f"/api/v2/dispatches/{dispatch_id}/assets/{key}") + assert resp.status_code == 400 + + +def test_put_node_asset(test_db, mocker, client, mock_result_object): + """ + Test put node asset + """ + + key = "function" + node_id = 0 + dispatch_id = "test_put_node_asset" + + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + with tempfile.NamedTemporaryFile("w") as writer: + writer.write(f"{dispatch_id}") + writer.flush() + + headers = {"Digest-alg": "sha", "Digest": "0bf"} + with open(writer.name, "rb") as reader: + resp = client.put( + f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}", + data=reader, + headers=headers, + ) + mock_node = mock_result_object.lattice.transport_graph.get_node(node_id) + mock_node.update_assets.assert_called() + assert resp.status_code == 200 + + +def test_put_node_asset_bad_dispatch_id(mocker, client): + """ + Test put node asset + """ + key = "function" + node_id = 0 + dispatch_id = "test_put_node_asset_no_dispatch_id" + + mocker.patch( + "covalent_dispatcher._service.assets.get_cached_result_object", + side_effect=HTTPException(status_code=400), + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + with tempfile.NamedTemporaryFile("w") as writer: + writer.write(f"{dispatch_id}") + writer.flush() + + with open(writer.name, "rb") as reader: + resp = client.put( + f"/api/v2/dispatches/{dispatch_id}/electrons/{node_id}/assets/{key}", data=reader + ) + + assert resp.status_code == 400 + + +def test_put_lattice_asset(mocker, client, test_db, mock_result_object): + """ + Test put lattice asset + """ + key = "workflow_function" + dispatch_id = "test_put_lattice_asset" + + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + with tempfile.NamedTemporaryFile("w") as writer: + writer.write(f"{dispatch_id}") + writer.flush() + + with open(writer.name, "rb") as reader: + resp = client.put( + f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}", data=reader + ) + mock_lattice = mock_result_object.lattice + mock_lattice.update_assets.assert_called() + assert resp.status_code == 200 + + +def test_put_lattice_asset_bad_dispatch_id(mocker, client): + """ + Test put lattice asset + """ + key = "workflow_function" + dispatch_id = "test_put_lattice_asset_no_dispatch_id" + + mocker.patch( + "covalent_dispatcher._service.assets.get_cached_result_object", + side_effect=HTTPException(status_code=404), + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + with tempfile.NamedTemporaryFile("w") as writer: + writer.write(f"{dispatch_id}") + writer.flush() + + with open(writer.name, "rb") as reader: + resp = client.put( + f"/api/v2/dispatches/{dispatch_id}/lattice/assets/{key}", data=reader + ) + + assert resp.status_code == 400 + + +def test_put_dispatch_asset(mocker, client, test_db, mock_result_object): + """ + Test put dispatch asset + """ + key = "result" + dispatch_id = "test_put_dispatch_asset" + + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch( + "covalent_dispatcher._service.assets.get_result_object", return_value=mock_result_object + ) + + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + with tempfile.NamedTemporaryFile("w") as writer: + writer.write(f"{dispatch_id}") + writer.flush() + + with open(writer.name, "rb") as reader: + resp = client.put(f"/api/v2/dispatches/{dispatch_id}/assets/{key}", data=reader) + mock_result_object.update_assets.assert_called() + assert resp.status_code == 200 + + +def test_put_dispatch_asset_bad_dispatch_id(mocker, client): + """ + Test put dispatch asset + """ + key = "result" + dispatch_id = "test_put_dispatch_asset_no_dispatch_id" + + mocker.patch( + "covalent_dispatcher._service.assets.get_cached_result_object", + side_effect=HTTPException(status_code=400), + ) + mocker.patch("covalent_dispatcher._service.app.cancel_all_with_status") + + with tempfile.NamedTemporaryFile("w") as writer: + writer.write(f"{dispatch_id}") + writer.flush() + + with open(writer.name, "rb") as reader: + resp = client.put(f"/api/v2/dispatches/{dispatch_id}/assets/{key}", data=reader) + + assert resp.status_code == 400 + + +def test_get_string_offsets(): + tobj = TransportableObject("test_get_string_offsets") + + data = tobj.serialize() + with tempfile.NamedTemporaryFile("wb") as write_file: + write_file.write(data) + write_file.flush() + + start, end = _get_tobj_string_offsets(f"file://{write_file.name}") + + assert data[start:end].decode("utf-8") == tobj.object_string + + +def test_get_pickle_offsets(): + tobj = TransportableObject("test_get_pickle_offsets") + + data = tobj.serialize() + with tempfile.NamedTemporaryFile("wb") as write_file: + write_file.write(data) + write_file.flush() + + start, end = _get_tobj_pickle_offsets(f"file://{write_file.name}") + + assert data[start:].decode("utf-8") == tobj.get_serialized() + + +def test_generate_partial_file_slice(): + """Test generating slices of files.""" + + data = "test_generate_file_slice".encode("utf-8") + with tempfile.NamedTemporaryFile("wb") as write_file: + write_file.write(data) + write_file.flush() + gen = _generate_file_slice(f"file://{write_file.name}", 1, 5, 2) + assert next(gen) == data[1:3] + assert next(gen) == data[3:5] + with pytest.raises(StopIteration): + next(gen) + + +def test_generate_whole_file_slice(): + """Test generating slices of files.""" + + data = "test_generate_file_slice".encode("utf-8") + with tempfile.NamedTemporaryFile("wb") as write_file: + write_file.write(data) + write_file.flush() + gen = _generate_file_slice(f"file://{write_file.name}", 0, -1) + assert next(gen) == data + + +def test_get_cached_result_obj(mocker, test_db): + mocker.patch("covalent_dispatcher._service.assets.workflow_db", test_db) + mocker.patch("covalent_dispatcher._service.assets.get_result_object", side_effect=KeyError()) + with pytest.raises(HTTPException): + get_cached_result_object("test_get_cached_result_obj") diff --git a/tests/covalent_dispatcher_tests/entry_point_test.py b/tests/covalent_dispatcher_tests/entry_point_test.py index e4c55c352..53f92fece 100644 --- a/tests/covalent_dispatcher_tests/entry_point_test.py +++ b/tests/covalent_dispatcher_tests/entry_point_test.py @@ -17,69 +17,98 @@ """Unit tests for the FastAPI app.""" +from concurrent.futures import ThreadPoolExecutor +from datetime import datetime +from unittest.mock import MagicMock + import pytest -from covalent_dispatcher.entry_point import cancel_running_dispatch, run_dispatcher, run_redispatch +from covalent_dispatcher.entry_point import ( + cancel_running_dispatch, + register_dispatch, + register_redispatch, + run_dispatcher, + start_dispatch, +) DISPATCH_ID = "f34671d1-48f2-41ce-89d9-9a8cb5c60e5d" -class MockObject: - pass +@pytest.mark.asyncio +async def test_run_dispatcher(mocker): + mock_run_dispatch = mocker.patch("covalent_dispatcher._core.run_dispatch") + mock_make_dispatch = mocker.patch( + "covalent_dispatcher._core.make_dispatch", return_value=DISPATCH_ID + ) + json_lattice = '{"workflow_function": "asdf"}' + dispatch_id = await run_dispatcher(json_lattice) + assert dispatch_id == DISPATCH_ID + mock_make_dispatch.assert_awaited_with(json_lattice) + mock_run_dispatch.assert_called_with(dispatch_id) -def mock_initialize_result_object(lattice): - result = MockObject() - result.dispatch_id = lattice["dispatch_id"] - return result +@pytest.mark.asyncio +async def test_cancel_running_dispatch(mocker): + mock_cancel_workflow = mocker.patch("covalent_dispatcher.entry_point.cancel_dispatch") + await cancel_running_dispatch(DISPATCH_ID) + mock_cancel_workflow.assert_awaited_once_with(DISPATCH_ID, []) @pytest.mark.asyncio -@pytest.mark.parametrize("disable_run", [True, False]) -async def test_run_dispatcher(mocker, disable_run): - """ - Test run_dispatcher is called with the - right arguments in different conditions - """ +async def test_start_dispatch_waits(mocker): + """Check that start_dispatch waits for any assets to be copied.""" + + dispatch_id = "test_start_dispatch_waits" + + def mock_copy(): + import time + + time.sleep(3) + + mock_futures = {} + ex = ThreadPoolExecutor(max_workers=1) + mocker.patch("covalent_dispatcher._core.copy_futures", mock_futures) mock_run_dispatch = mocker.patch("covalent_dispatcher._core.run_dispatch") - mock_make_dispatch = mocker.patch( - "covalent_dispatcher._core.make_dispatch", return_value=DISPATCH_ID - ) - json_lattice = '{"workflow_function": "asdf"}' - dispatch_id = await run_dispatcher(json_lattice, disable_run) - assert dispatch_id == DISPATCH_ID + fut = ex.submit(mock_copy) + mock_futures[dispatch_id] = fut + fut.add_done_callback(lambda x: mock_futures.pop(dispatch_id)) + + start_time = datetime.now() + await start_dispatch(dispatch_id) + end_time = datetime.now() - mock_make_dispatch.assert_called_with(json_lattice) - if not disable_run: - mock_run_dispatch.assert_called_with(dispatch_id) + assert (end_time - start_time).total_seconds() > 2 + + mock_run_dispatch.assert_called() @pytest.mark.asyncio -@pytest.mark.parametrize("is_pending", [True, False]) -async def test_run_redispatch(mocker, is_pending): - """ - Test the run_redispatch function is called - with the right arguments in differnet conditions - """ - - make_derived_dispatch_mock = mocker.patch( - "covalent_dispatcher._core.make_derived_dispatch", return_value="mock-redispatch-id" - ) - run_dispatch_mock = mocker.patch("covalent_dispatcher._core.run_dispatch") - redispatch_id = await run_redispatch(DISPATCH_ID, "mock-json-lattice", {}, False, is_pending) +async def test_register_dispatch(mocker): + """Check register_dispatch""" - if not is_pending: - make_derived_dispatch_mock.assert_called_once_with( - DISPATCH_ID, "mock-json-lattice", {}, False - ) + mock_manifest = MagicMock() + + mock_importer = mocker.patch( + "covalent_dispatcher._core.data_modules.importer.import_manifest", + return_value=mock_manifest, + ) - run_dispatch_mock.assert_called_once_with(redispatch_id) + assert await register_dispatch("manifest", "parent_dispatch_id") is mock_manifest + mock_importer.assert_awaited_with("manifest", "parent_dispatch_id", None) @pytest.mark.asyncio -async def test_cancel_running_dispatch(mocker): - mock_cancel_workflow = mocker.patch("covalent_dispatcher.entry_point.cancel_dispatch") - await cancel_running_dispatch(DISPATCH_ID) - mock_cancel_workflow.assert_awaited_once_with(DISPATCH_ID, []) +async def test_register_redispatch(mocker): + """Check register_dispatch""" + + mock_manifest = MagicMock() + + mock_importer = mocker.patch( + "covalent_dispatcher._core.data_modules.importer.import_derived_manifest", + return_value=mock_manifest, + ) + + assert await register_redispatch("manifest", "parent_dispatch_id", True) is mock_manifest + mock_importer.assert_awaited_with("manifest", "parent_dispatch_id", True) diff --git a/tests/covalent_tests/__init__.py b/tests/covalent_tests/__init__.py index e69de29bb..cfc23bfdf 100644 --- a/tests/covalent_tests/__init__.py +++ b/tests/covalent_tests/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/covalent_tests/dispatcher_plugins/__init__.py b/tests/covalent_tests/dispatcher_plugins/__init__.py index e69de29bb..cfc23bfdf 100644 --- a/tests/covalent_tests/dispatcher_plugins/__init__.py +++ b/tests/covalent_tests/dispatcher_plugins/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/covalent_tests/dispatcher_plugins/local_test.py b/tests/covalent_tests/dispatcher_plugins/local_test.py index da756a152..d34592ed7 100644 --- a/tests/covalent_tests/dispatcher_plugins/local_test.py +++ b/tests/covalent_tests/dispatcher_plugins/local_test.py @@ -17,132 +17,142 @@ """Unit tests for local module in dispatcher_plugins.""" +import tempfile from unittest.mock import MagicMock import pytest from requests import Response -from requests.exceptions import HTTPError +from requests.exceptions import ConnectionError, HTTPError import covalent as ct -from covalent._dispatcher_plugins.local import LocalDispatcher, get_redispatch_request_body +from covalent._dispatcher_plugins.local import LocalDispatcher, get_redispatch_request_body_v2 +from covalent._results_manager.result import Result +from covalent._shared_files.utils import format_server_url -def test_get_redispatch_request_body_null_arguments(): - """Test the get request body function with null arguments.""" +def test_dispatching_a_non_lattice(): + """test dispatching a non-lattice""" @ct.electron - def identity(a): - return a + def task(a, b, c): + return a + b + c @ct.electron - def add(a, b): - return a + b + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) - response = get_redispatch_request_body( - "mock-dispatch-id", - ) - assert response == { - "json_lattice": None, - "dispatch_id": "mock-dispatch-id", - "electron_updates": {}, - "reuse_previous_results": False, - } + with pytest.raises( + TypeError, match="Dispatcher expected a Lattice, received instead." + ): + LocalDispatcher.dispatch(workflow)(1, 2) -def test_get_redispatch_request_body_args_kwargs(mocker): - """Test the get request body function when args/kwargs is not null.""" - mock_electron = MagicMock() - get_result_mock = mocker.patch("covalent._dispatcher_plugins.local.get_result") - get_result_mock().lattice.serialize_to_json.return_value = "mock-json-lattice" +def test_dispatch_when_no_server_is_running(mocker): + """test dispatching a lattice when no server is running""" - response = get_redispatch_request_body( - "mock-dispatch-id", - new_args=[1, 2], - new_kwargs={"a": 1, "b": 2}, - replace_electrons={"mock-task-id": mock_electron}, - ) - assert response == { - "json_lattice": "mock-json-lattice", - "dispatch_id": "mock-dispatch-id", - "electron_updates": {"mock-task-id": mock_electron.electron_object.as_transportable_dict}, - "reuse_previous_results": False, - } - get_result_mock().lattice.build_graph.assert_called_once_with(*[1, 2], **{"a": 1, "b": 2}) + # the test suite is using another port, thus, with the dummy address below + # the covalent server is not running in some sense. + dummy_dispatcher_addr = "http://localhost:12345" + endpoint = "/api/v2/dispatches" + url = dummy_dispatcher_addr + endpoint + message = f"The Covalent server cannot be reached at {url}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." + @ct.electron + def task(a, b, c): + return a + b + c -@pytest.mark.parametrize("is_pending", [True, False]) -@pytest.mark.parametrize( - "replace_electrons, expected_arg", - [(None, {}), ({"mock-electron-1": "mock-electron-2"}, {"mock-electron-1": "mock-electron-2"})], -) -def test_redispatch(mocker, replace_electrons, expected_arg, is_pending): - """Test the local re-dispatch function.""" + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) - mocker.patch("covalent._dispatcher_plugins.local.get_config", return_value="mock-config") - requests_mock = mocker.patch("covalent._dispatcher_plugins.local.requests") - get_request_body_mock = mocker.patch( - "covalent._dispatcher_plugins.local.get_redispatch_request_body", - return_value={"mock-request-body"}, - ) + mock_print = mocker.patch("covalent._api.apiclient.print") - local_dispatcher = LocalDispatcher() - func = local_dispatcher.redispatch( - "mock-dispatch-id", replace_electrons=replace_electrons, is_pending=is_pending - ) - func() - requests_mock.post.assert_called_once_with( - "http://mock-config:mock-config/api/redispatch", - json={"mock-request-body"}, - params={"is_pending": is_pending}, - timeout=5, - ) - requests_mock.post().raise_for_status.assert_called_once() - requests_mock.post().content.decode().strip().replace.assert_called_once_with('"', "") + with pytest.raises(ConnectionError): + LocalDispatcher.dispatch(workflow, dispatcher_addr=dummy_dispatcher_addr)(1, 2) - get_request_body_mock.assert_called_once_with("mock-dispatch-id", (), {}, expected_arg, False) + mock_print.assert_called_once_with(message) -def test_redispatch_unreachable(mocker): - """Test the local re-dispatch function when the server is unreachable.""" +def test_dispatcher_dispatch_single(mocker): + """test dispatching a lattice with submit api""" - mock_dispatch_id = "mock-dispatch-id" - dummy_dispatcher_addr = "http://localhost:12345" + @ct.electron + def task(a, b, c): + return a + b + c - message = f"The Covalent server cannot be reached at {dummy_dispatcher_addr}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) - mock_print = mocker.patch("covalent._dispatcher_plugins.local.print") + # test when api raises an implicit error - LocalDispatcher.redispatch(mock_dispatch_id, dispatcher_addr=dummy_dispatcher_addr)() + dispatch_id = "test_dispatcher_dispatch_single" + # multistage = False + mocker.patch("covalent._dispatcher_plugins.local.get_config", return_value=False) - mock_print.assert_called_once_with(message) + mock_submit_callable = MagicMock(return_value=dispatch_id) + mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.submit", + return_value=mock_submit_callable, + ) + mock_reg_tr = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.register_triggers" + ) + mock_start = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.start", return_value=dispatch_id + ) + + assert dispatch_id == LocalDispatcher.dispatch(workflow)(1, 2) + + mock_submit_callable.assert_called() + mock_start.assert_called() -def test_dispatching_a_non_lattice(): - """test dispatching a non-lattice""" + +def test_dispatcher_dispatch_multi(mocker): + """test dispatching a lattice with multistage api""" @ct.electron def task(a, b, c): return a + b + c - @ct.electron @ct.lattice def workflow(a, b): return task(a, b, c=4) - with pytest.raises( - TypeError, match="Dispatcher expected a Lattice, received instead." - ): - LocalDispatcher.dispatch(workflow)(1, 2) + dispatch_id = "test_dispatcher_dispatch_multi" + # multistage = True + mocker.patch("covalent._shared_files.config.get_config", return_value=True) + mock_register_callable = MagicMock(return_value=dispatch_id) + mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.register", + return_value=mock_register_callable, + ) -def test_dispatch_when_no_server_is_running(mocker): - """test dispatching a lattice when no server is running""" + mock_submit_callable = MagicMock(return_value=dispatch_id) + mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.submit", + return_value=mock_submit_callable, + ) - # the test suite is using another port, thus, with the dummy address below - # the covalent server is not running in some sense. - dummy_dispatcher_addr = "http://localhost:12345" + mock_reg_tr = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.register_triggers" + ) + mock_start = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.start", return_value=dispatch_id + ) - message = f"The Covalent server cannot be reached at {dummy_dispatcher_addr}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." + assert dispatch_id == LocalDispatcher.dispatch(workflow)(1, 2) + + mock_submit_callable.assert_not_called() + mock_register_callable.assert_called() + mock_start.assert_called() + + +def test_dispatcher_dispatch_with_triggers(mocker): + """test dispatching a lattice with triggers""" @ct.electron def task(a, b, c): @@ -152,11 +162,32 @@ def task(a, b, c): def workflow(a, b): return task(a, b, c=4) - mock_print = mocker.patch("covalent._dispatcher_plugins.local.print") + dispatch_id = "test_dispatcher_dispatch_with_triggers" - LocalDispatcher.dispatch(workflow, dispatcher_addr=dummy_dispatcher_addr)(1, 2) + workflow.metadata["triggers"] = {"dir_trigger": {}} - mock_print.assert_called_once_with(message) + mock_register_callable = MagicMock(return_value=dispatch_id) + mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.register", + return_value=mock_register_callable, + ) + + mock_submit_callable = MagicMock(return_value=dispatch_id) + mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.submit", + return_value=mock_submit_callable, + ) + + mock_reg_tr = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.register_triggers" + ) + mock_start = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.start", return_value=dispatch_id + ) + + assert dispatch_id == LocalDispatcher.dispatch(workflow)(1, 2) + mock_reg_tr.assert_called() + mock_start.assert_not_called() def test_dispatcher_submit_api(mocker): @@ -176,10 +207,10 @@ def workflow(a, b): r.url = "http://dummy" r.reason = "dummy reason" - mocker.patch("covalent._dispatcher_plugins.local.requests.post", return_value=r) + mocker.patch("covalent._api.apiclient.requests.Session.post", return_value=r) with pytest.raises(HTTPError, match="404 Client Error: dummy reason for url: http://dummy"): - dispatch_id = LocalDispatcher.dispatch(workflow)(1, 2) + dispatch_id = LocalDispatcher.submit(workflow)(1, 2) assert dispatch_id is None # test when api doesn't raise an implicit error @@ -188,7 +219,366 @@ def workflow(a, b): r.url = "http://dummy" r._content = b"abcde" - mocker.patch("covalent._dispatcher_plugins.local.requests.post", return_value=r) + mocker.patch("covalent._api.apiclient.requests.Session.post", return_value=r) - dispatch_id = LocalDispatcher.dispatch(workflow)(1, 2) + dispatch_id = LocalDispatcher.submit(workflow)(1, 2) assert dispatch_id == "abcde" + + +def test_dispatcher_start(mocker): + """Test starting a dispatch""" + + dispatch_id = "test_dispatcher_start" + r = Response() + r.status_code = 404 + r.url = "http://dummy" + r.reason = "dummy reason" + + mocker.patch("covalent._api.apiclient.requests.Session.put", return_value=r) + + with pytest.raises(HTTPError, match="404 Client Error: dummy reason for url: http://dummy"): + LocalDispatcher.start(dispatch_id) + + # test when api doesn't raise an implicit error + r = Response() + r.status_code = 202 + r.url = "http://dummy" + r._content = dispatch_id.encode("utf-8") + + mocker.patch("covalent._api.apiclient.requests.Session.put", return_value=r) + + assert LocalDispatcher.start(dispatch_id) == dispatch_id + + +def test_register(mocker): + """test dispatching a lattice with register api""" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + + manifest.metadata.dispatch_id = "test_register" + + mock_upload = mocker.patch("covalent._dispatcher_plugins.local.LocalDispatcher.upload_assets") + mock_prepare_manifest = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.prepare_manifest", + return_value=manifest, + ) + mock_register_manifest = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.register_manifest" + ) + + dispatch_id = LocalDispatcher.register(workflow)(1, 2) + assert dispatch_id == "test_register" + mock_upload.assert_called() + + +def test_redispatch(mocker): + """test redispatching a lattice with register api""" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + + dispatch_id = "test_register_redispatch" + manifest.metadata.dispatch_id = dispatch_id + parent_id = "parent_dispatch_id" + + mock_upload = mocker.patch("covalent._dispatcher_plugins.local.LocalDispatcher.upload_assets") + mock_get_redispatch_manifest = mocker.patch( + "covalent._dispatcher_plugins.local.get_redispatch_request_body_v2", return_value=manifest + ) + mock_register_derived_manifest = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.register_derived_manifest" + ) + mock_start = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.start", + return_value="test_register_redispatch", + ) + + new_args = (1, 2) + new_kwargs = {} + redispatch_id = LocalDispatcher.redispatch( + dispatch_id=parent_id, replace_electrons={"f": "callable"}, reuse_previous_results=False + )(*new_args, **new_kwargs) + + assert dispatch_id == redispatch_id + mock_upload.assert_called() + + mock_start.assert_called_with(dispatch_id, format_server_url()) + + +def test_register_manifest(mocker): + """Test registering a dispatch manifest.""" + + dispatch_id = "test_register_manifest" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + + manifest.metadata.dispatch_id = dispatch_id + + r = Response() + r.status_code = 201 + r.json = MagicMock(return_value=manifest.dict()) + + mocker.patch("covalent._api.apiclient.requests.Session.post", return_value=r) + + mock_merge = mocker.patch( + "covalent._dispatcher_plugins.local.merge_response_manifest", return_value=manifest + ) + + return_manifest = LocalDispatcher.register_manifest(manifest) + assert return_manifest.metadata.dispatch_id == dispatch_id + mock_merge.assert_called_with(manifest, manifest) + + +def test_register_derived_manifest(mocker): + """Test registering a redispatch manifest.""" + + dispatch_id = "test_register_derived_manifest" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + + manifest.metadata.dispatch_id = dispatch_id + + r = Response() + r.status_code = 201 + r.json = MagicMock(return_value=manifest.dict()) + + mocker.patch("covalent._api.apiclient.requests.Session.post", return_value=r) + + mock_merge = mocker.patch( + "covalent._dispatcher_plugins.local.merge_response_manifest", return_value=manifest + ) + + return_manifest = LocalDispatcher.register_derived_manifest(manifest, "original_dispatch") + assert return_manifest.metadata.dispatch_id == dispatch_id + mock_merge.assert_called_with(manifest, manifest) + + +def test_upload_assets(mocker): + """Test uploading assets to HTTP endpoints""" + + dispatch_id = "test_upload_assets_http" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + + num_assets = 0 + # Populate the lattice asset schemas with dummy URLs + for key, asset in manifest.lattice.assets: + num_assets += 1 + asset.remote_uri = ( + f"http://localhost:48008/api/v2/dispatches/{dispatch_id}/lattice/assets/dummy" + ) + + endpoint = f"/api/v2/dispatches/{dispatch_id}/lattice/assets/dummy" + r = Response() + r.status_code = 200 + mock_post = mocker.patch("covalent._api.apiclient.requests.Session.put", return_value=r) + + LocalDispatcher.upload_assets(manifest) + + assert mock_post.call_count == num_assets + + +def test_get_redispatch_request_body_norebuild(mocker): + """Test constructing the request body for redispatch""" + + # Consider the case where the dispatch is to be retried with no + # changes to inputs or electrons. + + dispatch_id = "test_get_redispatch_request_body_norebuild" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + + # "Old" result object + res_obj = Result(workflow) + + # Mock result manager + mock_resmgr = MagicMock() + + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + mock_resmgr._manifest = manifest + mock_resmgr.result_object = res_obj + + mock_serialize = mocker.patch( + "covalent._dispatcher_plugins.local.serialize_result", return_value=manifest + ) + mocker.patch( + "covalent._dispatcher_plugins.local.ResultSchema.parse_obj", return_value=manifest + ) + mocker.patch( + "covalent._dispatcher_plugins.local.get_result_manager", return_value=mock_resmgr + ) + + with tempfile.TemporaryDirectory() as redispatch_dir: + redispatch_manifest = get_redispatch_request_body_v2( + dispatch_id, redispatch_dir, [], {}, replace_electrons={} + ) + + assert redispatch_manifest is manifest + + +def test_get_redispatch_request_body_replace_electrons(mocker): + """Test constructing the request body for redispatch""" + + # Consider the case where electrons are to be replaced but lattice + # inputs stay the same. + + dispatch_id = "test_get_redispatch_request_body_replace_electrons" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.electron + def new_task(a, b, c): + return a * b * c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + + # "Old" result object + res_obj = Result(workflow) + + # Mock result manager + mock_resmgr = MagicMock() + + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + mock_resmgr._manifest = manifest + mock_resmgr.result_object = res_obj + + mock_serialize = mocker.patch( + "covalent._dispatcher_plugins.local.serialize_result", return_value=manifest + ) + mocker.patch( + "covalent._dispatcher_plugins.local.ResultSchema.parse_obj", return_value=manifest + ) + mocker.patch( + "covalent._dispatcher_plugins.local.get_result_manager", return_value=mock_resmgr + ) + + with tempfile.TemporaryDirectory() as redispatch_dir: + redispatch_manifest = get_redispatch_request_body_v2( + dispatch_id, redispatch_dir, [], {}, replace_electrons={"task": new_task} + ) + + assert redispatch_manifest is manifest + mock_resmgr.download_lattice_asset.assert_any_call("workflow_function") + mock_resmgr.download_lattice_asset.assert_any_call("workflow_function_string") + mock_resmgr.download_lattice_asset.assert_any_call("inputs") + + mock_resmgr.load_lattice_asset.assert_any_call("workflow_function") + mock_resmgr.load_lattice_asset.assert_any_call("workflow_function_string") + mock_resmgr.load_lattice_asset.assert_any_call("inputs") + + +def test_get_redispatch_request_body_replace_inputs(mocker): + """Test constructing the request body for redispatch""" + + # Consider the case where only lattice + # inputs are changed. + + dispatch_id = "test_get_redispatch_request_body_replace_inputs" + + @ct.electron + def task(a, b, c): + return a + b + c + + @ct.lattice + def workflow(a, b): + return task(a, b, c=4) + + workflow.build_graph(1, 2) + + # "Old" result object + res_obj = Result(workflow) + + # Mock result manager + mock_resmgr = MagicMock() + + with tempfile.TemporaryDirectory() as staging_dir: + manifest = LocalDispatcher.prepare_manifest(workflow, staging_dir) + mock_resmgr._manifest = manifest + mock_resmgr.result_object = res_obj + + mock_serialize = mocker.patch( + "covalent._dispatcher_plugins.local.serialize_result", return_value=manifest + ) + mocker.patch( + "covalent._dispatcher_plugins.local.ResultSchema.parse_obj", return_value=manifest + ) + mocker.patch( + "covalent._dispatcher_plugins.local.get_result_manager", return_value=mock_resmgr + ) + + with tempfile.TemporaryDirectory() as redispatch_dir: + redispatch_manifest = get_redispatch_request_body_v2( + dispatch_id, redispatch_dir, [3, 4], {}, replace_electrons=None + ) + + assert redispatch_manifest is manifest + mock_resmgr.download_lattice_asset.assert_any_call("workflow_function") + mock_resmgr.download_lattice_asset.assert_any_call("workflow_function_string") + + mock_resmgr.load_lattice_asset.assert_any_call("workflow_function") + mock_resmgr.load_lattice_asset.assert_any_call("workflow_function_string") diff --git a/tests/covalent_tests/file_transfer/__init__.py b/tests/covalent_tests/file_transfer/__init__.py index 21d7eaa5c..ab6c0fedf 100644 --- a/tests/covalent_tests/file_transfer/__init__.py +++ b/tests/covalent_tests/file_transfer/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Agnostiq Inc. +# Copyright 2022 Agnostiq Inc. # # This file is part of Covalent. # diff --git a/tests/covalent_tests/results_manager_tests/__init__.py b/tests/covalent_tests/results_manager_tests/__init__.py index e69de29bb..cfc23bfdf 100644 --- a/tests/covalent_tests/results_manager_tests/__init__.py +++ b/tests/covalent_tests/results_manager_tests/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/covalent_tests/results_manager_tests/results_manager_test.py b/tests/covalent_tests/results_manager_tests/results_manager_test.py index 203dea5b2..f7bb829fb 100644 --- a/tests/covalent_tests/results_manager_tests/results_manager_test.py +++ b/tests/covalent_tests/results_manager_tests/results_manager_test.py @@ -16,123 +16,281 @@ """Tests for results manager.""" -from http.client import HTTPMessage -from unittest.mock import ANY, MagicMock, Mock, call +import os +import tempfile +from datetime import datetime, timezone +from unittest.mock import MagicMock import pytest -import requests +from requests import Response -from covalent._results_manager import wait +import covalent as ct from covalent._results_manager.results_manager import ( - _get_result_from_dispatcher, + MissingLatticeRecordError, + Result, + ResultManager, + _get_result_export_from_dispatcher, cancel, + download_asset, get_result, ) -from covalent._shared_files.config import get_config +from covalent._serialize.result import serialize_result +from covalent._workflow.transportable_object import TransportableObject -def test_get_result_unreachable_dispatcher(mocker): - """ - Test that get_result returns None when - the dispatcher server is unreachable. - """ - mock_dispatch_id = "mock_dispatch_id" +def get_test_manifest(staging_dir): + @ct.electron + def identity(x): + return x - mocker.patch( - "covalent._results_manager.results_manager._get_result_from_dispatcher", - side_effect=requests.exceptions.ConnectionError, + @ct.electron + def add(x, y): + return x + y + + @ct.lattice + def workflow(x, y): + res1 = identity(x) + res2 = identity(y) + return add(res1, res2) + + workflow.build_graph(2, 3) + result_object = Result(workflow) + ts = datetime.now(timezone.utc) + result_object._start_time = ts + result_object._end_time = ts + result_object._result = TransportableObject(42) + result_object.lattice.transport_graph.set_node_value(0, "status", Result.COMPLETED) + result_object.lattice.transport_graph.set_node_value(0, "output", TransportableObject(2)) + manifest = serialize_result(result_object, staging_dir) + + # Swap asset uri and remote_uri to simulate an exported manifest + for key, asset in manifest.assets: + asset.remote_uri = asset.uri + asset.uri = None + + for key, asset in manifest.lattice.assets: + asset.remote_uri = asset.uri + asset.uri = None + + for node in manifest.lattice.transport_graph.nodes: + for key, asset in node.assets: + asset.remote_uri = asset.uri + asset.uri = None + + return manifest + + +def test_cancel_with_single_task_id(mocker): + mock_request_put = mocker.patch( + "covalent._api.apiclient.requests.Session.put", ) - assert get_result(mock_dispatch_id) is None + cancel(dispatch_id="dispatch", task_ids=1) + mock_request_put.assert_called_once() + mock_request_put.return_value.raise_for_status.assert_called_once() -@pytest.mark.parametrize( - "dispatcher_addr", - [ - "http://" + get_config("dispatcher.address") + ":" + str(get_config("dispatcher.port")), - "http://localhost:48008", - ], -) -def test_get_result_from_dispatcher(mocker, dispatcher_addr): - retries = 10 - getconn_mock = mocker.patch("urllib3.connectionpool.HTTPConnectionPool._get_conn") - mocker.patch("requests.Response.json", return_value=True) - headers = HTTPMessage() - headers.add_header("Retry-After", "2") - - mock_response = [Mock(status=503, msg=headers)] * (retries - 1) - mock_response.append(Mock(status=200, msg=HTTPMessage())) - getconn_mock.return_value.getresponse.side_effect = mock_response - dispatch_id = "9d1b308b-4763-4990-ae7f-6a6e36d35893" - _get_result_from_dispatcher( - dispatch_id, wait=wait.LONG, dispatcher_addr=dispatcher_addr, status_only=False + +def test_cancel_with_multiple_task_ids(mocker): + mock_task_ids = [0, 1] + + mock_request_put = mocker.patch( + "covalent._api.apiclient.requests.Session.put", ) - assert ( - getconn_mock.return_value.request.mock_calls - == [ - call( - "GET", - f"/api/result/{dispatch_id}?wait=True&status_only=False", - body=None, - headers=ANY, - ), - ] - * retries + + cancel(dispatch_id="dispatch", task_ids=[1, 2, 3]) + + mock_request_put.assert_called_once() + mock_request_put.return_value.raise_for_status.assert_called_once() + + +def test_result_export(mocker): + with tempfile.TemporaryDirectory() as staging_dir: + test_manifest = get_test_manifest(staging_dir) + + dispatch_id = "test_result_export" + + mock_body = {"id": "test_result_export", "status": "COMPLETED"} + + mock_client = MagicMock() + mock_response = Response() + mock_response.status_code = 200 + mock_response.json = MagicMock(return_value=mock_body) + + mocker.patch("covalent._api.apiclient.requests.Session.get", return_value=mock_response) + + endpoint = f"/api/v2/dispatches/{dispatch_id}" + assert mock_body == _get_result_export_from_dispatcher( + dispatch_id, wait=False, status_only=True ) -def test_get_result_from_dispatcher_unreachable(mocker): - """ - Test that _get_result_from_dispatcher raises an exception when - the dispatcher server is unreachable. - """ +def test_result_manager_assets_local_copies(): + """Test downloading and loading assets using local asset uris.""" + dispatch_id = "test_result_manager" + with tempfile.TemporaryDirectory() as server_dir: + # This will have uri and remote_uri swapped so as to simulate + # a manifest exported from the server. All "downloads" will be + # local file copies from server_dir to results_dir. + manifest = get_test_manifest(server_dir) + with tempfile.TemporaryDirectory() as results_dir: + rm = ResultManager(manifest, results_dir) + rm.download_lattice_asset("workflow_function") + rm.load_lattice_asset("workflow_function") + rm.download_result_asset("result") + rm.load_result_asset("result") + os.makedirs(f"{results_dir}/node_0") + rm.download_node_asset(0, "output") + rm.load_node_asset(0, "output") + + res_obj = rm.result_object + assert res_obj.lattice(3, 5) == 8 + assert res_obj.result == 42 - # TODO: Will need to edit this once `_get_result_from_dispatcher` is fixed - # to actually throw an exception when the dispatcher server is unreachable - # instead of just hanging. + output = res_obj.lattice.transport_graph.get_node_value(0, "output") + assert output.get_deserialized() == 2 - mock_dispatcher_addr = "mock_dispatcher_addr" - mock_dispatch_id = "mock_dispatch_id" - message = f"The Covalent server cannot be reached at {mock_dispatcher_addr}. Local servers can be started using `covalent start` in the terminal. If you are using a remote Covalent server, contact your systems administrator to report an outage." +def test_result_manager_save_manifest(): + """Test saving and loading manifests""" + dispatch_id = "test_result_manager_save_load" + with tempfile.TemporaryDirectory() as server_dir: + # This will have uri and remote_uri swapped so as to simulate + # a manifest exported from the server. All "downloads" will be + # local file copies from server_dir to results_dir. + manifest = get_test_manifest(server_dir) + with tempfile.TemporaryDirectory() as results_dir: + rm = ResultManager(manifest, results_dir) + rm.save() + path = os.path.join(results_dir, "manifest.json") + rm2 = ResultManager.load(path, results_dir) + assert rm2._results_dir == results_dir + assert rm2._manifest == rm._manifest - mocker.patch("covalent._results_manager.results_manager.HTTPAdapter") - mock_session = mocker.patch("covalent._results_manager.results_manager.requests.Session") - mock_session.return_value.get.side_effect = requests.exceptions.ConnectionError - mock_print = mocker.patch("covalent._results_manager.results_manager.print") +def test_get_result(mocker): + dispatch_id = "test_result_manager" + with tempfile.TemporaryDirectory() as server_dir: + # This will have uri and remote_uri swapped so as to simulate + # a manifest exported from the server. All "downloads" will be + # local file copies from server_dir to results_dir. + manifest = get_test_manifest(server_dir) - with pytest.raises(requests.exceptions.ConnectionError): - _get_result_from_dispatcher( - mock_dispatch_id, wait=wait.LONG, dispatcher_addr=mock_dispatcher_addr + mock_result_export = { + "id": dispatch_id, + "status": "COMPLETED", + "result_export": manifest.dict(), + } + mocker.patch( + "covalent._results_manager.results_manager._get_result_export_from_dispatcher", + return_value=mock_result_export, ) + with tempfile.TemporaryDirectory() as results_dir: + res_obj = get_result(dispatch_id, results_dir=results_dir) - mock_print.assert_called_once_with(message) + assert res_obj.result == 42 -def test_cancel_with_single_task_id(mocker): - mock_get_config = mocker.patch("covalent._results_manager.results_manager.get_config") - mock_request_post = mocker.patch( - "covalent._results_manager.results_manager.requests.post", MagicMock() +def test_get_result_sublattice(mocker): + dispatch_id = "test_result_manager_sublattice" + sub_dispatch_id = "test_result_manager_sublattice_sub" + + with tempfile.TemporaryDirectory() as server_dir: + # This will have uri and remote_uri swapped so as to simulate + # a manifest exported from the server. All "downloads" will be + # local file copies from server_dir to results_dir. + manifest = get_test_manifest(server_dir) + + node = manifest.lattice.transport_graph.nodes[0] + node.metadata.sub_dispatch_id = sub_dispatch_id + + with tempfile.TemporaryDirectory() as server_dir_sub: + # Sublattice manifest + sub_manifest = get_test_manifest(server_dir_sub) + + mock_result_export = { + "id": dispatch_id, + "status": "COMPLETED", + "result_export": manifest.dict(), + } + + mock_subresult_export = { + "id": sub_dispatch_id, + "status": "COMPLETED", + "result_export": sub_manifest.dict(), + } + + exports = {dispatch_id: mock_result_export, sub_dispatch_id: mock_subresult_export} + + def mock_get_export(dispatch_id, *args, **kwargs): + return exports[dispatch_id] + + mocker.patch( + "covalent._results_manager.results_manager._get_result_export_from_dispatcher", + mock_get_export, + ) + with tempfile.TemporaryDirectory() as results_dir: + res_obj = get_result(dispatch_id, results_dir=results_dir) + + assert res_obj.result == 42 + tg = res_obj.lattice.transport_graph + for node_id in tg._graph.nodes: + if node_id == 0: + assert tg.get_node_value(node_id, "sub_dispatch_id") == sub_dispatch_id + assert tg.get_node_value(node_id, "sublattice_result") is not None + + else: + assert tg.get_node_value(1, "sublattice_result") is None + + +def test_get_result_404(mocker): + """Check exception handing for invalid dispatch ids.""" + + dispatch_id = "test_get_result_404" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 404 + + mock_client.get = MagicMock(return_value=mock_response) + + mocker.patch( + "covalent._results_manager.results_manager.CovalentAPIClient", return_value=mock_client ) - cancel(dispatch_id="dispatch", task_ids=1) + with pytest.raises(MissingLatticeRecordError): + get_result(dispatch_id) - assert mock_get_config.call_count == 2 - mock_request_post.assert_called_once() - mock_request_post.return_value.raise_for_status.assert_called_once() +def test_get_status_only(mocker): + """Check get_result when status_only=True""" -def test_cancel_with_multiple_task_ids(mocker): - mock_get_config = mocker.patch("covalent._results_manager.results_manager.get_config") - mock_task_ids = [0, 1] + dispatch_id = "test_get_result_st" + mock_get_result_export = mocker.patch( + "covalent._results_manager.results_manager._get_result_export_from_dispatcher", + return_value={"id": dispatch_id, "status": "RUNNING"}, + ) - mock_request_post = mocker.patch( - "covalent._results_manager.results_manager.requests.post", MagicMock() + status_report = get_result(dispatch_id, status_only=True) + assert status_report["status"] == "RUNNING" + + +def test_download_asset(mocker): + dispatch_id = "test_download_asset" + remote_uri = f"http://localhost:48008/api/v2/dispatches/{dispatch_id}/assets/result" + mock_client = MagicMock() + mock_response = MagicMock() + mock_response.status_code = 200 + + mock_client.get = MagicMock(return_value=mock_response) + mocker.patch( + "covalent._results_manager.results_manager.CovalentAPIClient", return_value=mock_client ) - cancel(dispatch_id="dispatch", task_ids=[1, 2, 3]) + def mock_generator(): + yield "Hello".encode("utf-8") + + mock_response.iter_content = MagicMock(return_value=mock_generator()) - assert mock_get_config.call_count == 2 - mock_request_post.assert_called_once() - mock_request_post.return_value.raise_for_status.assert_called_once() + with tempfile.NamedTemporaryFile() as local_file: + download_asset(remote_uri, local_file.name) + assert local_file.read().decode("utf-8") == "Hello" diff --git a/tests/covalent_tests/triggers/base_test.py b/tests/covalent_tests/triggers/base_test.py index fedca845f..0ca0c8d7c 100644 --- a/tests/covalent_tests/triggers/base_test.py +++ b/tests/covalent_tests/triggers/base_test.py @@ -61,7 +61,7 @@ def test_get_status(mocker, use_internal_func, mock_status): base_trigger.use_internal_funcs = use_internal_func if use_internal_func: - mocker.patch("covalent_dispatcher._service.app.get_result") + mocker.patch("covalent_dispatcher._service.app.export_result") mock_fut_res = mock.Mock() mock_fut_res.result.return_value = mock_status @@ -98,29 +98,25 @@ def test_do_redispatch(mocker, use_internal_func, is_pending): with the right arguments in different conditions """ - base_trigger = BaseTrigger() - base_trigger.use_internal_funcs = use_internal_func - mock_redispatch_id = "test_dispatch_id" + mock_wrapper = mock.MagicMock(return_value=mock_redispatch_id) + mock_redispatch = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.redispatch", return_value=mock_wrapper + ) + mock_start = mocker.patch( + "covalent._dispatcher_plugins.local.LocalDispatcher.start", return_value=mock_redispatch_id + ) - if use_internal_func: - mocker.patch("covalent_dispatcher.run_redispatch") - mock_fut_res = mock.Mock() - mock_fut_res.result.return_value = mock_redispatch_id - mock_run_coro = mocker.patch( - "covalent.triggers.base.asyncio.run_coroutine_threadsafe", return_value=mock_fut_res - ) + base_trigger = BaseTrigger() + base_trigger.use_internal_funcs = use_internal_func - redispatch_id = base_trigger._do_redispatch(is_pending) + redispatch_id = base_trigger._do_redispatch(is_pending) - mock_run_coro.assert_called_once() - mock_fut_res.result.assert_called_once() + if is_pending: + mock_start.assert_called_once() + mock_wrapper.assert_not_called() else: - mock_redispatch = mocker.patch("covalent.redispatch")() - mock_redispatch.return_value = mock_redispatch_id - redispatch_id = base_trigger._do_redispatch(is_pending) - - mock_redispatch.assert_called_once() + mock_redispatch.assert_called() assert redispatch_id == mock_redispatch_id diff --git a/tests/covalent_tests/workflow/electron_metadata_test.py b/tests/covalent_tests/workflow/electron_metadata_test.py index 85a42da30..a2b45f1a8 100644 --- a/tests/covalent_tests/workflow/electron_metadata_test.py +++ b/tests/covalent_tests/workflow/electron_metadata_test.py @@ -44,7 +44,6 @@ def hello_world(x): hello_world.build_graph(1) data = hello_world.transport_graph.serialize_to_json() - # data = json.loads(data) tg = _TransportGraph() tg.deserialize_from_json(data) diff --git a/tests/covalent_tests/workflow/electron_test.py b/tests/covalent_tests/workflow/electron_test.py index 00d79ab61..18d6cedc5 100644 --- a/tests/covalent_tests/workflow/electron_test.py +++ b/tests/covalent_tests/workflow/electron_test.py @@ -75,7 +75,6 @@ def workflow_2(): return res_3 -@pytest.mark.skip(reason="Will be re-enabled next PR") def test_build_sublattice_graph(mocker): """ Test building a sublattice graph @@ -99,6 +98,7 @@ def workflow(x): "call_before": [], "call_after": [], "triggers": "mock-trigger", + "qelectron_data_exists": False, "results_dir": None, } mock_environ = { @@ -138,10 +138,7 @@ def mock_register(manifest, *args, **kwargs): assert lat.metadata.workflow_executor == parent_metadata["workflow_executor"] assert lat.metadata.workflow_executor_data == parent_metadata["workflow_executor_data"] - # lattice = Lattice.deserialize_from_json(json_lattice) - -@pytest.mark.skip(reason="Will be re-enabled next PR") def test_build_sublattice_graph_fallback(mocker): """ Test falling back to monolithic sublattice dispatch @@ -165,6 +162,7 @@ def workflow(x): "call_before": [], "call_after": [], "triggers": "mock-trigger", + "qelectron_data_exists": False, "results_dir": None, } @@ -289,7 +287,7 @@ def workflow(x): # Account for postprocessing node assert list(g.nodes) == [0, 1, 2] - assert set(g.edges) == set([(1, 0, 0), (0, 2, 0), (0, 2, 1)]) + assert set(g.edges) == {(1, 0, 0), (0, 2, 0), (0, 2, 1)} def test_metadata_in_electron_list(): @@ -313,8 +311,8 @@ def workflow(x): task_metadata = workflow.transport_graph.get_node_value(0, "metadata") e_list_metadata = workflow.transport_graph.get_node_value(1, "metadata") - assert list(e_list_metadata["call_before"]) == [] - assert list(e_list_metadata["call_after"]) == [] + assert not list(e_list_metadata["call_before"]) + assert not list(e_list_metadata["call_after"]) assert e_list_metadata["executor"] == task_metadata["executor"] @@ -368,7 +366,14 @@ def workflow(x): assert g.nodes[2]["value"].get_deserialized() == 5 assert g.nodes[3]["value"].get_deserialized() == 7 - assert set(g.edges) == set([(1, 0, 0), (2, 1, 0), (3, 1, 0), (0, 4, 0), (0, 4, 1), (1, 4, 0)]) + assert set(g.edges) == { + (1, 0, 0), + (2, 1, 0), + (3, 1, 0), + (0, 4, 0), + (0, 4, 1), + (1, 4, 0), + } def test_autogen_dict_electrons(): @@ -390,7 +395,14 @@ def workflow(x): assert fn(x=2, y=5, z=7) == {"x": 2, "y": 5, "z": 7} assert g.nodes[2]["value"].get_deserialized() == 5 assert g.nodes[3]["value"].get_deserialized() == 7 - assert set(g.edges) == set([(1, 0, 0), (2, 1, 0), (3, 1, 0), (0, 4, 0), (0, 4, 1), (1, 4, 0)]) + assert set(g.edges) == { + (1, 0, 0), + (2, 1, 0), + (3, 1, 0), + (0, 4, 0), + (0, 4, 1), + (1, 4, 0), + } def test_as_transportable_dict(): @@ -462,7 +474,7 @@ def workflow(x): assert all(tg.get_node_value(i, "task_group_id") == 0 for i in [0, 3, 4]) assert all(tg.get_node_value(i, "task_group_id") == i for i in [1, 2, 5, 6, 7, 8]) else: - assert all(tg.get_node_value(i, "task_group_id") == i for i in range(0, 9)) + assert all(tg.get_node_value(i, "task_group_id") == i for i in range(9)) @pytest.mark.parametrize("task_packing", ["true", "false"]) @@ -507,7 +519,7 @@ def workflow(): assert getitem_y_gid == point_electron_gid assert all(tg.get_node_value(i, "task_group_id") == i for i in [2, 4, 5, 6]) else: - assert all(tg.get_node_value(i, "task_group_id") == i for i in range(0, 7)) + assert all(tg.get_node_value(i, "task_group_id") == i for i in range(7)) @pytest.mark.parametrize("task_packing", ["true", "false"]) @@ -549,7 +561,7 @@ def workflow(): assert getitem_y_gid == arr_electron_gid assert all(tg.get_node_value(i, "task_group_id") == i for i in [2, 4, 5, 6]) else: - assert all(tg.get_node_value(i, "task_group_id") == i for i in range(0, 7)) + assert all(tg.get_node_value(i, "task_group_id") == i for i in range(7)) @pytest.mark.parametrize("task_packing", ["true", "false"]) @@ -592,7 +604,7 @@ def workflow(): assert getitem_y_gid == tup_electron_gid assert all(tg.get_node_value(i, "task_group_id") == i for i in [2, 4, 5, 6]) else: - assert all(tg.get_node_value(i, "task_group_id") == i for i in range(0, 7)) + assert all(tg.get_node_value(i, "task_group_id") == i for i in range(7)) def test_electron_executor_property(): diff --git a/tests/covalent_tests/workflow/transport_graph_ops_test.py b/tests/covalent_tests/workflow/transport_graph_ops_test.py deleted file mode 100644 index e12a5e8eb..000000000 --- a/tests/covalent_tests/workflow/transport_graph_ops_test.py +++ /dev/null @@ -1,248 +0,0 @@ -# Copyright 2023 Agnostiq Inc. -# -# This file is part of Covalent. -# -# Licensed under the Apache License 2.0 (the "License"). A copy of the -# License may be obtained with this software package or at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Use of this file is prohibited except in compliance with the License. -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""Unit tests for transport graph operations module.""" - -import pytest - -from covalent._workflow.transport import _TransportGraph -from covalent._workflow.transport_graph_ops import TransportGraphOps - - -def add(x, y): - return x + y - - -def multiply(x, y): - return x * y - - -def identity(x): - return x - - -@pytest.fixture -def tg(): - """Transport graph operations fixture.""" - tg = _TransportGraph() - tg.add_node(name="add", function=add, metadata={"0-mock-key": "0-mock-value"}) - tg.add_node(name="multiply", function=multiply, metadata={"1-mock-key": "1-mock-value"}) - tg.add_node(name="identity", function=identity, metadata={"2-mock-key": "2-mock-value"}) - return tg - - -@pytest.fixture -def tg_2(): - """Transport graph operations fixture - different from tg.""" - tg_2 = _TransportGraph() - tg_2.add_node(name="not-add", function=add, metadata={"0- mock-key": "0-mock-value"}) - tg_2.add_node(name="multiply", function=multiply, metadata={"1- mock-key": "1-mock-value"}) - tg_2.add_node(name="identity", function=identity, metadata={"2- mock-key": "2-mock-value"}) - return tg_2 - - -@pytest.fixture -def tg_ops(tg): - """Transport graph operations fixture.""" - return TransportGraphOps(tg) - - -def test_init(tg): - """Test initialization of transport graph operations.""" - tg_ops = TransportGraphOps(tg) - assert tg_ops.tg == tg - assert tg_ops._status_map == {1: True, -1: False} - - -def test_flag_successors_no_successors(tg, tg_ops): - """Test flagging successors of a node.""" - node_statuses = {0: 1, 1: 1, 2: 1} - tg_ops._flag_successors(tg._graph, node_statuses=node_statuses, starting_node=0) - assert node_statuses == {0: -1, 1: 1, 2: 1} - - -@pytest.mark.parametrize( - "n_1,n_2,n_start,label,new_statuses", - [ - (0, 1, 0, "01", {0: -1, 1: -1, 2: 1}), - (1, 2, 0, "12", {0: -1, 1: 1, 2: 1}), - (1, 2, 1, "12", {0: 1, 1: -1, 2: -1}), - (1, 2, 2, "12", {0: 1, 1: 1, 2: -1}), - ], -) -def test_flag_successors_with_one_successors(tg, tg_ops, n_1, n_2, n_start, label, new_statuses): - """Test flagging successors of a node.""" - tg.add_edge(n_1, n_2, label) - node_statuses = {0: 1, 1: 1, 2: 1} - tg_ops._flag_successors(tg._graph, node_statuses=node_statuses, starting_node=n_start) - assert node_statuses == new_statuses - - -@pytest.mark.parametrize( - "n_1,n_2,n_3,n_4,label_1,label_2,n_start,new_statuses", - [ - (0, 1, 1, 2, "01", "12", 0, {0: -1, 1: -1, 2: -1}), - (0, 1, 0, 2, "01", "02", 0, {0: -1, 1: -1, 2: -1}), - (0, 1, 0, 2, "01", "12", 1, {0: 1, 1: -1, 2: 1}), - ], -) -def test_flag_successors_with_successors_3( - tg, tg_ops, n_1, n_2, n_3, n_4, label_1, n_start, label_2, new_statuses -): - """Test flagging successors of a node.""" - tg.add_edge(n_1, n_2, label_1) - tg.add_edge(n_3, n_4, label_2) - node_statuses = {0: 1, 1: 1, 2: 1} - tg_ops._flag_successors(tg._graph, node_statuses=node_statuses, starting_node=n_start) - assert node_statuses == new_statuses - - -def test_is_same_node_true(tg, tg_ops): - """Test the is same node method.""" - assert tg_ops.is_same_node(tg._graph, tg._graph, 0) is True - assert tg_ops.is_same_node(tg._graph, tg._graph, 1) is True - - -def test_is_same_node_false(tg, tg_ops): - """Test the is same node method.""" - tg_2 = _TransportGraph() - tg_2.add_node(name="multiply", function=add, metadata={"0- mock-key": "0-mock-value"}) - assert tg_ops.is_same_node(tg._graph, tg_2._graph, 0) is False - - -def test_is_same_edge_attributes_true(tg, tg_ops): - """Test the is same edge attributes method.""" - tg.add_edge(0, 1, edge_name="01", kwargs={"x": 1, "y": 2}) - assert tg_ops.is_same_edge_attributes(tg._graph, tg._graph, 0, 1) is True - - -def test_is_same_edge_attributes_false(tg, tg_ops): - """Test the is same edge attributes method.""" - tg.add_edge(0, 1, edge_name="01", kwargs={"x": 1, "y": 2}) - - tg_2 = _TransportGraph() - tg_2.add_node(name="add", function=add, metadata={"0- mock-key": "0-mock-value"}) - tg_2.add_node(name="multiply", function=multiply, metadata={"1- mock-key": "1-mock-value"}) - tg_2.add_node(name="identity", function=identity, metadata={"2- mock-key": "2-mock-value"}) - tg_2.add_edge(0, 1, edge_name="01", kwargs={"x": 1}) - - assert tg_ops.is_same_edge_attributes(tg._graph, tg_2._graph, 0, 1) is False - - -def test_copy_nodes_from(tg_ops): - """Test the node copying method.""" - - def replacement(x): - return x + 1 - - tg_new = _TransportGraph() - tg_new.add_node( - name="replacement", function=replacement, metadata={"0-mock-key": "0-mock-value"} - ) - tg_new.add_node(name="multiply", function=multiply, metadata={"1-mock-key": "1-mock-value"}) - tg_new.add_node( - name="replacement", function=replacement, metadata={"2-mock-key": "2-mock-value"} - ) - - tg_ops.copy_nodes_from(tg_new, [0, 2]) - tg_ops.tg._graph.nodes(data=True)[0]["name"] == tg_ops.tg._graph.nodes(data=True)[2][ - "name" - ] == "replacement" - tg_ops.tg._graph.nodes(data=True)[2]["name"] == "multiply" - - -def test_max_cbms(tg_ops): - """Test method for determining a largest cbms""" - import networkx as nx - - A = nx.MultiDiGraph() - B = nx.MultiDiGraph() - C = nx.MultiDiGraph() - D = nx.MultiDiGraph() - - # 0 5 6 - # / \ - # 1 2 - A.add_edge(0, 1) - A.add_edge(0, 2) - A.nodes[1]["color"] = "red" - A.add_node(5) - A.add_node(6) - - # 0 5 - # / \\ - # 1 2 - B.add_edge(0, 1) - B.add_edge(0, 2) - B.add_edge(0, 2) - B.nodes[1]["color"] = "black" - B.add_node(5) - - # 0 3 - # / \ / - # 1 2 - C.add_edge(0, 1) - C.add_edge(0, 2) - C.add_edge(3, 2) - - # 0 3 - # / \ / - # 1 2 - # / - # 4 - D.add_edge(0, 1) - D.add_edge(0, 2) - D.add_edge(3, 2) - D.add_edge(2, 4) - - A_node_status, B_node_status = tg_ops._max_cbms(A, B) - assert A_node_status == {0: True, 1: False, 2: False, 5: True, 6: False} - assert B_node_status == {0: True, 1: False, 2: False, 5: True} - - A_node_status, C_node_status = tg_ops._max_cbms(A, C) - assert A_node_status == {0: True, 1: False, 2: False, 5: False, 6: False} - assert C_node_status == {0: True, 1: False, 2: False, 3: False} - - C_node_status, D_node_status = tg_ops._max_cbms(C, D) - assert C_node_status == {0: True, 1: True, 2: True, 3: True} - assert D_node_status == {0: True, 1: True, 2: True, 3: True, 4: False} - - -def test_cmp_name_and_pval_true(tg, tg_ops): - """Test the name and parameter value comparison method.""" - assert tg_ops._cmp_name_and_pval(tg._graph, tg._graph, 0) is True - - -def test_cmp_name_and_pval_false(tg, tg_2, tg_ops): - """Test the name and parameter value comparison method.""" - assert tg_ops._cmp_name_and_pval(tg._graph, tg_2._graph, 0) is False - - -def test_get_reusable_nodes(mocker, tg, tg_2, tg_ops): - """Test the get reusable nodes method.""" - max_cbms_mock = mocker.patch( - "covalent._workflow.transport_graph_ops.TransportGraphOps._max_cbms", - return_value=({"mock-key-A": "mock-value-A"}, {"mock-key-B": "mock-value-B"}), - ) - reusable_nodes = tg_ops.get_reusable_nodes(tg_2) - assert reusable_nodes == ["mock-key-A"] - max_cbms_mock.assert_called_once() - - -def test_get_diff_nodes_integration_test(tg_2, tg_ops): - """Test the get reusable nodes method.""" - reusable_nodes = tg_ops.get_reusable_nodes(tg_2) - assert reusable_nodes == [1, 2] diff --git a/tests/covalent_ui_backend_tests/utils/assert_data/electrons.py b/tests/covalent_ui_backend_tests/utils/assert_data/electrons.py index 61de28089..3930a22c5 100644 --- a/tests/covalent_ui_backend_tests/utils/assert_data/electrons.py +++ b/tests/covalent_ui_backend_tests/utils/assert_data/electrons.py @@ -95,7 +95,7 @@ def seed_electron_data(): "electron_id": VALID_NODE_ID, "name": "executor", }, - "response_data": {"executor_name": "dask", "executor_details": None}, + "response_data": {"executor_name": "dask", "executor_details": {}}, }, "case_result_1": { "status_code": 200, diff --git a/tests/covalent_ui_backend_tests/utils/assert_data/lattices.py b/tests/covalent_ui_backend_tests/utils/assert_data/lattices.py index 1c04f490a..7ba09f7b2 100644 --- a/tests/covalent_ui_backend_tests/utils/assert_data/lattices.py +++ b/tests/covalent_ui_backend_tests/utils/assert_data/lattices.py @@ -108,7 +108,7 @@ def seed_lattice_data(): "dispatch_id": VALID_DISPATCH_ID, "name": "executor", }, - "response_data": {"executor_name": "dask", "executor_details": None}, + "response_data": {"executor_name": "dask", "executor_details": {}}, }, "case_workflow_executor_1": { "status_code": 200, @@ -118,7 +118,7 @@ def seed_lattice_data(): }, "response_data": { "workflow_executor_name": "dask", - "workflow_executor_details": None, + "workflow_executor_details": {}, }, }, "case_transport_graph_1": { diff --git a/tests/covalent_ui_backend_tests/utils/data/electrons.json b/tests/covalent_ui_backend_tests/utils/data/electrons.json index 999ed832f..7fe8b7670 100644 --- a/tests/covalent_ui_backend_tests/utils/data/electrons.json +++ b/tests/covalent_ui_backend_tests/utils/data/electrons.json @@ -6,6 +6,7 @@ "created_at": "2022-09-23 10:01:11.062647", "deps_filename": "deps.pkl", "executor": "dask", + "executor_data": {}, "function_filename": "function.pkl", "function_string_filename": "function_string.txt", @@ -37,6 +38,7 @@ "created_at": "2022-09-23 10:01:11.075465", "deps_filename": "deps.pkl", "executor": "dask", + "executor_data": {}, "function_filename": "function.pkl", "function_string_filename": "function_string.txt", @@ -68,6 +70,7 @@ "created_at": "2022-09-23 10:01:11.085971", "deps_filename": "deps.pkl", "executor": "dask", + "executor_data": {}, "function_filename": "function.pkl", "function_string_filename": "function_string.txt", @@ -99,6 +102,7 @@ "created_at": "2022-09-23 10:01:11.098325", "deps_filename": "deps.pkl", "executor": "dask", + "executor_data": {}, "function_filename": "function.pkl", "function_string_filename": "function_string.txt", @@ -130,6 +134,7 @@ "created_at": "2022-09-23 10:01:11.109305", "deps_filename": "deps.pkl", "executor": "dask", + "executor_data": {}, "function_filename": "function.pkl", "function_string_filename": "function_string.txt", @@ -161,6 +166,7 @@ "created_at": "2022-09-23 10:01:11.121100", "deps_filename": "deps.pkl", "executor": "dask", + "executor_data": {}, "function_filename": "function.pkl", "function_string_filename": "function_string.txt", @@ -210,6 +216,7 @@ "started_at": "2022-10-27 10:08:33.861837", "completed_at": "2022-10-27 10:08:33.933100", "executor": "dask", + "executor_data": {}, "job_id": 5, "qelectron_data_exists": false, @@ -241,6 +248,7 @@ "started_at": "2022-10-27 10:08:33.827366", "completed_at": "2022-10-27 10:08:33.827372", "executor": "dask", + "executor_data": {}, "job_id": 6, "qelectron_data_exists": false, @@ -272,6 +280,7 @@ "started_at": "2022-10-27 10:08:33.967565", "completed_at": "2022-10-27 10:08:36.028194", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -303,6 +312,7 @@ "started_at": "2022-10-27 10:08:36.065830", "completed_at": "2022-10-27 10:08:43.905519", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -334,6 +344,7 @@ "started_at": "2022-10-27 10:08:34.939603", "completed_at": "2022-10-27 10:08:35.159092", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -365,6 +376,7 @@ "started_at": "2022-10-27 10:08:34.523975", "completed_at": "2022-10-27 10:08:34.523987", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -396,6 +408,7 @@ "started_at": "2022-10-27 10:08:34.968949", "completed_at": "2022-10-27 10:08:35.238154", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -427,6 +440,7 @@ "started_at": "2022-10-27 10:08:34.576178", "completed_at": "2022-10-27 10:08:34.576181", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -458,6 +472,7 @@ "started_at": "2022-10-27 10:08:35.011096", "completed_at": "2022-10-27 10:08:35.324016", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -489,6 +504,7 @@ "started_at": "2022-10-27 10:08:34.614049", "completed_at": "2022-10-27 10:08:34.614051", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -520,6 +536,7 @@ "started_at": "2022-10-27 10:08:35.058712", "completed_at": "2022-10-27 10:08:35.408068", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -551,6 +568,7 @@ "started_at": "2022-10-27 10:08:34.647511", "completed_at": "2022-10-27 10:08:34.647516", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -582,6 +600,7 @@ "started_at": "2022-10-27 10:08:35.109570", "completed_at": "2022-10-27 10:08:35.558177", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -613,6 +632,7 @@ "started_at": "2022-10-27 10:08:34.689654", "completed_at": "2022-10-27 10:08:34.689659", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -644,6 +664,7 @@ "started_at": "2022-10-27 10:08:35.193220", "completed_at": "2022-10-27 10:08:35.514718", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -675,6 +696,7 @@ "started_at": "2022-10-27 10:08:34.729529", "completed_at": "2022-10-27 10:08:34.729533", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -706,6 +728,7 @@ "started_at": "2022-10-27 10:08:35.287892", "completed_at": "2022-10-27 10:08:35.673424", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -737,6 +760,7 @@ "started_at": "2022-10-27 10:08:34.769243", "completed_at": "2022-10-27 10:08:34.769254", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -768,6 +792,7 @@ "started_at": "2022-10-27 10:08:35.354986", "completed_at": "2022-10-27 10:08:35.763519", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -799,6 +824,7 @@ "started_at": "2022-10-27 10:08:34.810072", "completed_at": "2022-10-27 10:08:34.810076", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -830,6 +856,7 @@ "started_at": "2022-10-27 10:08:35.447829", "completed_at": "2022-10-27 10:08:35.809561", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -861,6 +888,7 @@ "started_at": "2022-10-27 10:08:34.853785", "completed_at": "2022-10-27 10:08:34.853790", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -892,6 +920,7 @@ "started_at": "2022-10-27 10:08:35.596881", "completed_at": "2022-10-27 10:08:35.861501", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -923,6 +952,7 @@ "started_at": "2022-10-27 10:08:34.895470", "completed_at": "2022-10-27 10:08:34.895475", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -954,6 +984,7 @@ "started_at": "2022-10-27 10:08:42.917693", "completed_at": "2022-10-27 10:08:43.165874", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -985,6 +1016,7 @@ "started_at": "2022-10-27 10:08:42.299329", "completed_at": "2022-10-27 10:08:42.506677", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1016,6 +1048,7 @@ "started_at": "2022-10-27 10:08:38.780790", "completed_at": "2022-10-27 10:08:38.780796", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1047,6 +1080,7 @@ "started_at": "2022-10-27 10:08:38.820257", "completed_at": "2022-10-27 10:08:38.820263", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1078,6 +1112,7 @@ "started_at": "2022-10-27 10:08:38.854643", "completed_at": "2022-10-27 10:08:38.854648", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1109,6 +1144,7 @@ "started_at": "2022-10-27 10:08:38.887957", "completed_at": "2022-10-27 10:08:38.887961", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1140,6 +1176,7 @@ "started_at": "2022-10-27 10:08:38.922953", "completed_at": "2022-10-27 10:08:38.922958", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1171,6 +1208,7 @@ "started_at": "2022-10-27 10:08:38.964715", "completed_at": "2022-10-27 10:08:38.964719", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1202,6 +1240,7 @@ "started_at": "2022-10-27 10:08:39.035460", "completed_at": "2022-10-27 10:08:39.035471", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1233,6 +1272,7 @@ "started_at": "2022-10-27 10:08:39.090917", "completed_at": "2022-10-27 10:08:39.090918", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1264,6 +1304,7 @@ "started_at": "2022-10-27 10:08:39.115263", "completed_at": "2022-10-27 10:08:39.115265", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1295,6 +1336,7 @@ "started_at": "2022-10-27 10:08:39.140423", "completed_at": "2022-10-27 10:08:39.140424", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1326,6 +1368,7 @@ "started_at": "2022-10-27 10:08:42.962785", "completed_at": "2022-10-27 10:08:43.253146", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1357,6 +1400,7 @@ "started_at": "2022-10-27 10:08:42.322474", "completed_at": "2022-10-27 10:08:42.587122", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1388,6 +1432,7 @@ "started_at": "2022-10-27 10:08:39.170609", "completed_at": "2022-10-27 10:08:39.170613", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1419,6 +1464,7 @@ "started_at": "2022-10-27 10:08:39.205854", "completed_at": "2022-10-27 10:08:39.205857", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1450,6 +1496,7 @@ "started_at": "2022-10-27 10:08:39.248043", "completed_at": "2022-10-27 10:08:39.248047", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1481,6 +1528,7 @@ "started_at": "2022-10-27 10:08:39.277848", "completed_at": "2022-10-27 10:08:39.277851", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1512,6 +1560,7 @@ "started_at": "2022-10-27 10:08:39.306032", "completed_at": "2022-10-27 10:08:39.306035", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1543,6 +1592,7 @@ "started_at": "2022-10-27 10:08:39.339834", "completed_at": "2022-10-27 10:08:39.339839", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1574,6 +1624,7 @@ "started_at": "2022-10-27 10:08:39.372707", "completed_at": "2022-10-27 10:08:39.372710", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1605,6 +1656,7 @@ "started_at": "2022-10-27 10:08:39.406279", "completed_at": "2022-10-27 10:08:39.406281", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1636,6 +1688,7 @@ "started_at": "2022-10-27 10:08:39.438271", "completed_at": "2022-10-27 10:08:39.438274", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1667,6 +1720,7 @@ "started_at": "2022-10-27 10:08:39.471663", "completed_at": "2022-10-27 10:08:39.471666", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1698,6 +1752,7 @@ "started_at": "2022-10-27 10:08:43.012521", "completed_at": "2022-10-27 10:08:43.340161", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1729,6 +1784,7 @@ "started_at": "2022-10-27 10:08:42.364016", "completed_at": "2022-10-27 10:08:42.650965", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1760,6 +1816,7 @@ "started_at": "2022-10-27 10:08:39.504710", "completed_at": "2022-10-27 10:08:39.504713", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1791,6 +1848,7 @@ "started_at": "2022-10-27 10:08:39.536280", "completed_at": "2022-10-27 10:08:39.536282", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1822,6 +1880,7 @@ "started_at": "2022-10-27 10:08:39.567594", "completed_at": "2022-10-27 10:08:39.567598", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1853,6 +1912,7 @@ "started_at": "2022-10-27 10:08:39.602978", "completed_at": "2022-10-27 10:08:39.602981", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1884,6 +1944,7 @@ "started_at": "2022-10-27 10:08:39.636935", "completed_at": "2022-10-27 10:08:39.636938", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1915,6 +1976,7 @@ "started_at": "2022-10-27 10:08:39.676238", "completed_at": "2022-10-27 10:08:39.676241", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1946,6 +2008,7 @@ "started_at": "2022-10-27 10:08:39.710482", "completed_at": "2022-10-27 10:08:39.710485", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -1977,6 +2040,7 @@ "started_at": "2022-10-27 10:08:39.743820", "completed_at": "2022-10-27 10:08:39.743822", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2008,6 +2072,7 @@ "started_at": "2022-10-27 10:08:39.776311", "completed_at": "2022-10-27 10:08:39.776316", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2039,6 +2104,7 @@ "started_at": "2022-10-27 10:08:39.813526", "completed_at": "2022-10-27 10:08:39.813527", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2070,6 +2136,7 @@ "started_at": "2022-10-27 10:08:43.124746", "completed_at": "2022-10-27 10:08:43.502348", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2101,6 +2168,7 @@ "started_at": "2022-10-27 10:08:42.410990", "completed_at": "2022-10-27 10:08:42.767106", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2132,6 +2200,7 @@ "started_at": "2022-10-27 10:08:39.840526", "completed_at": "2022-10-27 10:08:39.840528", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2163,6 +2232,7 @@ "started_at": "2022-10-27 10:08:39.872532", "completed_at": "2022-10-27 10:08:39.872534", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2194,6 +2264,7 @@ "started_at": "2022-10-27 10:08:39.905887", "completed_at": "2022-10-27 10:08:39.905891", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2225,6 +2296,7 @@ "started_at": "2022-10-27 10:08:39.937792", "completed_at": "2022-10-27 10:08:39.937795", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2256,6 +2328,7 @@ "started_at": "2022-10-27 10:08:39.973990", "completed_at": "2022-10-27 10:08:39.973995", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2287,6 +2360,7 @@ "started_at": "2022-10-27 10:08:40.004694", "completed_at": "2022-10-27 10:08:40.004696", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2318,6 +2392,7 @@ "started_at": "2022-10-27 10:08:40.038140", "completed_at": "2022-10-27 10:08:40.038143", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2349,6 +2424,7 @@ "started_at": "2022-10-27 10:08:40.071383", "completed_at": "2022-10-27 10:08:40.071386", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2380,6 +2456,7 @@ "started_at": "2022-10-27 10:08:40.105934", "completed_at": "2022-10-27 10:08:40.105939", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2411,6 +2488,7 @@ "started_at": "2022-10-27 10:08:40.142740", "completed_at": "2022-10-27 10:08:40.142743", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2442,6 +2520,7 @@ "started_at": "2022-10-27 10:08:43.041103", "completed_at": "2022-10-27 10:08:43.419548", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2473,6 +2552,7 @@ "started_at": "2022-10-27 10:08:42.460853", "completed_at": "2022-10-27 10:08:42.796610", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2504,6 +2584,7 @@ "started_at": "2022-10-27 10:08:40.181654", "completed_at": "2022-10-27 10:08:40.181662", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2535,6 +2616,7 @@ "started_at": "2022-10-27 10:08:40.220720", "completed_at": "2022-10-27 10:08:40.220725", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2566,6 +2648,7 @@ "started_at": "2022-10-27 10:08:40.254308", "completed_at": "2022-10-27 10:08:40.254312", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2597,6 +2680,7 @@ "started_at": "2022-10-27 10:08:40.293041", "completed_at": "2022-10-27 10:08:40.293044", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2628,6 +2712,7 @@ "started_at": "2022-10-27 10:08:40.339703", "completed_at": "2022-10-27 10:08:40.339706", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2659,6 +2744,7 @@ "started_at": "2022-10-27 10:08:40.388397", "completed_at": "2022-10-27 10:08:40.388409", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2690,6 +2776,7 @@ "started_at": "2022-10-27 10:08:40.427023", "completed_at": "2022-10-27 10:08:40.427026", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2721,6 +2808,7 @@ "started_at": "2022-10-27 10:08:40.476868", "completed_at": "2022-10-27 10:08:40.476872", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2752,6 +2840,7 @@ "started_at": "2022-10-27 10:08:40.524104", "completed_at": "2022-10-27 10:08:40.524107", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2783,6 +2872,7 @@ "started_at": "2022-10-27 10:08:40.568992", "completed_at": "2022-10-27 10:08:40.569005", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2814,6 +2904,7 @@ "started_at": "2022-10-27 10:08:43.207640", "completed_at": "2022-10-27 10:08:43.589035", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2845,6 +2936,7 @@ "started_at": "2022-10-27 10:08:42.540112", "completed_at": "2022-10-27 10:08:42.874105", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2876,6 +2968,7 @@ "started_at": "2022-10-27 10:08:40.606339", "completed_at": "2022-10-27 10:08:40.606344", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2907,6 +3000,7 @@ "started_at": "2022-10-27 10:08:40.641995", "completed_at": "2022-10-27 10:08:40.641996", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2938,6 +3032,7 @@ "started_at": "2022-10-27 10:08:40.672991", "completed_at": "2022-10-27 10:08:40.672994", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -2969,6 +3064,7 @@ "started_at": "2022-10-27 10:08:40.696115", "completed_at": "2022-10-27 10:08:40.696117", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3000,6 +3096,7 @@ "started_at": "2022-10-27 10:08:40.722953", "completed_at": "2022-10-27 10:08:40.722956", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3031,6 +3128,7 @@ "started_at": "2022-10-27 10:08:40.755865", "completed_at": "2022-10-27 10:08:40.755868", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3062,6 +3160,7 @@ "started_at": "2022-10-27 10:08:40.793381", "completed_at": "2022-10-27 10:08:40.793384", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3093,6 +3192,7 @@ "started_at": "2022-10-27 10:08:40.828280", "completed_at": "2022-10-27 10:08:40.828284", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3124,6 +3224,7 @@ "started_at": "2022-10-27 10:08:40.865834", "completed_at": "2022-10-27 10:08:40.865850", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3155,6 +3256,7 @@ "started_at": "2022-10-27 10:08:40.908278", "completed_at": "2022-10-27 10:08:40.908311", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3186,6 +3288,7 @@ "started_at": "2022-10-27 10:08:43.290300", "completed_at": "2022-10-27 10:08:43.649360", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3217,6 +3320,7 @@ "started_at": "2022-10-27 10:08:42.622934", "completed_at": "2022-10-27 10:08:42.939996", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3248,6 +3352,7 @@ "started_at": "2022-10-27 10:08:40.941095", "completed_at": "2022-10-27 10:08:40.941099", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3279,6 +3384,7 @@ "started_at": "2022-10-27 10:08:40.976058", "completed_at": "2022-10-27 10:08:40.976061", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3310,6 +3416,7 @@ "started_at": "2022-10-27 10:08:41.009802", "completed_at": "2022-10-27 10:08:41.009805", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3341,6 +3448,7 @@ "started_at": "2022-10-27 10:08:41.048453", "completed_at": "2022-10-27 10:08:41.048457", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3372,6 +3480,7 @@ "started_at": "2022-10-27 10:08:41.086414", "completed_at": "2022-10-27 10:08:41.086417", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3403,6 +3512,7 @@ "started_at": "2022-10-27 10:08:41.116755", "completed_at": "2022-10-27 10:08:41.116759", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3434,6 +3544,7 @@ "started_at": "2022-10-27 10:08:41.149923", "completed_at": "2022-10-27 10:08:41.149926", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3465,6 +3576,7 @@ "started_at": "2022-10-27 10:08:41.186466", "completed_at": "2022-10-27 10:08:41.186469", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3496,6 +3608,7 @@ "started_at": "2022-10-27 10:08:41.218531", "completed_at": "2022-10-27 10:08:41.218536", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3527,6 +3640,7 @@ "started_at": "2022-10-27 10:08:41.250702", "completed_at": "2022-10-27 10:08:41.250704", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3558,6 +3672,7 @@ "started_at": "2022-10-27 10:08:43.375045", "completed_at": "2022-10-27 10:08:43.687405", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3589,6 +3704,7 @@ "started_at": "2022-10-27 10:08:42.677025", "completed_at": "2022-10-27 10:08:42.989021", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3620,6 +3736,7 @@ "started_at": "2022-10-27 10:08:41.288911", "completed_at": "2022-10-27 10:08:41.288914", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3651,6 +3768,7 @@ "started_at": "2022-10-27 10:08:41.322042", "completed_at": "2022-10-27 10:08:41.322045", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3682,6 +3800,7 @@ "started_at": "2022-10-27 10:08:41.356030", "completed_at": "2022-10-27 10:08:41.356033", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3713,6 +3832,7 @@ "started_at": "2022-10-27 10:08:41.389522", "completed_at": "2022-10-27 10:08:41.389525", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3744,6 +3864,7 @@ "started_at": "2022-10-27 10:08:41.421862", "completed_at": "2022-10-27 10:08:41.421865", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3775,6 +3896,7 @@ "started_at": "2022-10-27 10:08:41.457205", "completed_at": "2022-10-27 10:08:41.457208", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3806,6 +3928,7 @@ "started_at": "2022-10-27 10:08:41.491181", "completed_at": "2022-10-27 10:08:41.491184", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3837,6 +3960,7 @@ "started_at": "2022-10-27 10:08:41.522585", "completed_at": "2022-10-27 10:08:41.522588", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3868,6 +3992,7 @@ "started_at": "2022-10-27 10:08:41.556600", "completed_at": "2022-10-27 10:08:41.556602", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3899,6 +4024,7 @@ "started_at": "2022-10-27 10:08:41.590899", "completed_at": "2022-10-27 10:08:41.590902", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3930,6 +4056,7 @@ "started_at": "2022-10-27 10:08:43.541105", "completed_at": "2022-10-27 10:08:43.761303", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3961,6 +4088,7 @@ "started_at": "2022-10-27 10:08:42.724995", "completed_at": "2022-10-27 10:08:43.086071", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -3992,6 +4120,7 @@ "started_at": "2022-10-27 10:08:41.622560", "completed_at": "2022-10-27 10:08:41.622563", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4023,6 +4152,7 @@ "started_at": "2022-10-27 10:08:41.656441", "completed_at": "2022-10-27 10:08:41.656443", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4054,6 +4184,7 @@ "started_at": "2022-10-27 10:08:41.691678", "completed_at": "2022-10-27 10:08:41.691681", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4085,6 +4216,7 @@ "started_at": "2022-10-27 10:08:41.725404", "completed_at": "2022-10-27 10:08:41.725407", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4116,6 +4248,7 @@ "started_at": "2022-10-27 10:08:41.759767", "completed_at": "2022-10-27 10:08:41.759771", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4147,6 +4280,7 @@ "started_at": "2022-10-27 10:08:41.799118", "completed_at": "2022-10-27 10:08:41.799121", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4178,6 +4312,7 @@ "started_at": "2022-10-27 10:08:41.830166", "completed_at": "2022-10-27 10:08:41.830169", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4209,6 +4344,7 @@ "started_at": "2022-10-27 10:08:41.864510", "completed_at": "2022-10-27 10:08:41.864513", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4240,6 +4376,7 @@ "started_at": "2022-10-27 10:08:41.900656", "completed_at": "2022-10-27 10:08:41.900659", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4271,6 +4408,7 @@ "started_at": "2022-10-27 10:08:41.935954", "completed_at": "2022-10-27 10:08:41.935957", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4302,6 +4440,7 @@ "started_at": "2022-10-27 10:08:43.459156", "completed_at": "2022-10-27 10:08:43.724659", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4333,6 +4472,7 @@ "started_at": "2022-10-27 10:08:42.835169", "completed_at": "2022-10-27 10:08:43.060465", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4364,6 +4504,7 @@ "started_at": "2022-10-27 10:08:41.970618", "completed_at": "2022-10-27 10:08:41.970621", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4395,6 +4536,7 @@ "started_at": "2022-10-27 10:08:41.998870", "completed_at": "2022-10-27 10:08:41.998874", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4426,6 +4568,7 @@ "started_at": "2022-10-27 10:08:42.032233", "completed_at": "2022-10-27 10:08:42.032236", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4457,6 +4600,7 @@ "started_at": "2022-10-27 10:08:42.065366", "completed_at": "2022-10-27 10:08:42.065369", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4488,6 +4632,7 @@ "started_at": "2022-10-27 10:08:42.099840", "completed_at": "2022-10-27 10:08:42.099844", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4519,6 +4664,7 @@ "started_at": "2022-10-27 10:08:42.134805", "completed_at": "2022-10-27 10:08:42.134812", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4550,6 +4696,7 @@ "started_at": "2022-10-27 10:08:42.170433", "completed_at": "2022-10-27 10:08:42.170438", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4581,6 +4728,7 @@ "started_at": "2022-10-27 10:08:42.202646", "completed_at": "2022-10-27 10:08:42.202650", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4612,6 +4760,7 @@ "started_at": "2022-10-27 10:08:42.233364", "completed_at": "2022-10-27 10:08:42.233366", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4643,6 +4792,7 @@ "started_at": "2022-10-27 10:08:42.265656", "completed_at": "2022-10-27 10:08:42.265658", "executor": "dask", + "executor_data": {}, "job_id": 1, "qelectron_data_exists": false, @@ -4657,6 +4807,7 @@ "deps_filename": "deps.pkl", "error_filename": "error.log", "executor": "dask", + "executor_data": {}, "function_filename": "function.pkl", "function_string_filename": "function_string.txt", @@ -4688,6 +4839,7 @@ "deps_filename": "deps.pkl", "error_filename": "error.log", "executor": "dask", + "executor_data": {}, "function_filename": "function.pkl", "function_string_filename": "function_string.txt", diff --git a/tests/covalent_ui_backend_tests/utils/data/lattices.json b/tests/covalent_ui_backend_tests/utils/data/lattices.json index 1209c6286..c166bad06 100644 --- a/tests/covalent_ui_backend_tests/utils/data/lattices.json +++ b/tests/covalent_ui_backend_tests/utils/data/lattices.json @@ -13,6 +13,7 @@ "electron_num": 6, "error_filename": "error.log", "executor": "dask", + "executor_data": {}, "function_filename": "function.pkl", "function_string_filename": "function_string.txt", @@ -32,7 +33,8 @@ "storage_type": "local", "updated_at": "2022-09-23 10:01:11.720140", - "workflow_executor": "dask" + "workflow_executor": "dask", + "workflow_executor_data": {} }, { "call_after_filename": "call_after.pkl", @@ -48,6 +50,7 @@ "electron_num": 4, "error_filename": "error.log", "executor": "dask", + "executor_data": {}, "function_filename": "function.pkl", "function_string_filename": "function_string.txt", @@ -67,7 +70,8 @@ "storage_type": "local", "updated_at": "2022-10-27 10:08:43.997619", - "workflow_executor": "dask" + "workflow_executor": "dask", + "workflow_executor_data": {} }, { "call_after_filename": "call_after.pkl", @@ -83,6 +87,7 @@ "electron_num": 20, "error_filename": "error.log", "executor": "dask", + "executor_data": {}, "function_filename": "function.pkl", "function_string_filename": "function_string.txt", @@ -102,7 +107,8 @@ "storage_type": "local", "updated_at": "2022-10-27 10:08:36.004030", - "workflow_executor": "dask" + "workflow_executor": "dask", + "workflow_executor_data": {} }, { "call_after_filename": "call_after.pkl", @@ -118,6 +124,7 @@ "electron_num": 120, "error_filename": "error.log", "executor": "dask", + "executor_data": {}, "function_filename": "function.pkl", "function_string_filename": "function_string.txt", @@ -137,7 +144,8 @@ "storage_type": "local", "updated_at": "2022-10-27 10:08:43.890454", - "workflow_executor": "dask" + "workflow_executor": "dask", + "workflow_executor_data": {} }, { "call_after_filename": "call_after.pkl", @@ -153,6 +161,7 @@ "electron_num": 2, "error_filename": "error.log", "executor": "dask", + "executor_data": {}, "function_filename": "function.pkl", "function_string_filename": "function_string.txt", @@ -172,6 +181,7 @@ "storage_type": "local", "updated_at": "2023-08-10 10:08:55.946668", - "workflow_executor": "dask" + "workflow_executor": "dask", + "workflow_executor_data": {} } ] diff --git a/tests/covalent_ui_backend_tests/utils/seed_script.py b/tests/covalent_ui_backend_tests/utils/seed_script.py index e73d05d9b..da24b03e4 100644 --- a/tests/covalent_ui_backend_tests/utils/seed_script.py +++ b/tests/covalent_ui_backend_tests/utils/seed_script.py @@ -61,7 +61,9 @@ def seed(engine): function_filename=item["function_filename"], function_string_filename=item["function_string_filename"], executor=item["executor"], + executor_data=json.dumps(item["executor_data"]), workflow_executor=item["workflow_executor"], + workflow_executor_data=json.dumps(item["workflow_executor_data"]), error_filename=item["error_filename"], inputs_filename=item["inputs_filename"], named_args_filename=item["named_args_filename"], @@ -100,6 +102,7 @@ def seed(engine): function_filename=item["function_filename"], function_string_filename=item["function_string_filename"], executor=item["executor"], + executor_data=json.dumps(item["executor_data"]), results_filename=item["results_filename"], value_filename=item["value_filename"], stdout_filename=item["stdout_filename"], @@ -115,7 +118,6 @@ def seed(engine): completed_at=convert_to_date(item["completed_at"]), job_id=item["job_id"], qelectron_data_exists=item["qelectron_data_exists"], - cancel_requested=item["cancel_requested"], ) ) diff --git a/tests/functional_tests/__init__.py b/tests/functional_tests/__init__.py index e69de29bb..cfc23bfdf 100644 --- a/tests/functional_tests/__init__.py +++ b/tests/functional_tests/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/functional_tests/file_transfer_test.py b/tests/functional_tests/file_transfer_test.py index 66b8b0790..ea5f4029d 100644 --- a/tests/functional_tests/file_transfer_test.py +++ b/tests/functional_tests/file_transfer_test.py @@ -202,7 +202,7 @@ def test_local_file_transfer_transfer_from(tmp_path: Path, mocker): Popen.returncode = 0 mocker.patch("covalent._file_transfer.strategies.rsync_strategy.Popen", return_value=Popen) - ft = ct.fs.TransferFromRemote(str(source_file)) + ft = ct.fs.TransferFromRemote(str(source_file), strategy=ct.fs_strategies.Rsync()) @ct.electron(files=[ft]) def test_transfer(files=[]): @@ -239,7 +239,7 @@ def test_local_file_transfer_transfer_to(tmp_path: Path, mocker): Popen.returncode = 0 mocker.patch("covalent._file_transfer.strategies.rsync_strategy.Popen", return_value=Popen) - ft = ct.fs.TransferToRemote(str(dest_file)) + ft = ct.fs.TransferToRemote(str(dest_file), strategy=ct.fs_strategies.Rsync()) @ct.electron(files=[ft]) def test_transfer(files=[]): diff --git a/tests/functional_tests/local_executor_test.py b/tests/functional_tests/local_executor_test.py index aa1a9beb6..27db56c11 100644 --- a/tests/functional_tests/local_executor_test.py +++ b/tests/functional_tests/local_executor_test.py @@ -16,6 +16,8 @@ import covalent as ct +import covalent._results_manager.results_manager as rm +from covalent._results_manager.result import Result def test_local_executor_returns_stdout_stderr(): @@ -41,3 +43,33 @@ def workflow(x): assert tg.get_node_value(0, "stdout") == "Hello\n" assert tg.get_node_value(0, "stderr") == "Error\n" assert tg.get_node_value(0, "output").get_deserialized() == 5 + + +def test_local_executor_build_sublattice_graph(): + """ + Check using local executor to build_sublattice_graph. + + This will exercise the /register endpoint for sublattices. + """ + + def add(a, b): + return a + b + + @ct.electron(executor="local") + def identity(a): + return a + + sublattice_add = ct.lattice(add) + + @ct.lattice(executor="local", workflow_executor="local") + def workflow(a, b): + res_1 = ct.electron(sublattice_add, executor="local")(a=a, b=b) + return identity(a=res_1) + + dispatch_id = ct.dispatch(workflow)(a=1, b=2) + workflow_result = rm.get_result(dispatch_id, wait=True) + + assert workflow_result.error == "" + assert workflow_result.status == Result.COMPLETED + assert workflow_result.result == 3 + assert workflow_result.get_node_result(node_id=0)["sublattice_result"].result == 3 diff --git a/tests/functional_tests/results_manager_test.py b/tests/functional_tests/results_manager_test.py new file mode 100644 index 000000000..16fa01b90 --- /dev/null +++ b/tests/functional_tests/results_manager_test.py @@ -0,0 +1,77 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Testing methods to retrieve workflow artifacts""" + +import pytest + +import covalent as ct +from covalent._shared_files.exceptions import MissingLatticeRecordError + + +def test_granular_get_result(): + def add(a, b): + return a + b + + @ct.electron + def identity(a): + return a + + sublattice_add = ct.lattice(add) + + @ct.lattice + def workflow(a, b): + res_1 = ct.electron(sublattice_add)(a=a, b=b) + return identity(a=res_1) + + dispatch_id = ct.dispatch(workflow)(a=1, b=2) + res_obj = ct.get_result( + dispatch_id, + wait=True, + workflow_output=False, + intermediate_outputs=False, + sublattice_results=False, + ) + + assert res_obj.result is None + + res_obj = ct.get_result( + dispatch_id, workflow_output=True, intermediate_outputs=False, sublattice_results=False + ) + assert res_obj.result == 3 + + assert res_obj.get_node_result(0)["sublattice_result"] is None + + res_obj = ct.get_result( + dispatch_id, workflow_output=True, intermediate_outputs=False, sublattice_results=True + ) + assert res_obj.result == 3 + + assert res_obj.get_node_result(0)["sublattice_result"].result == 3 + assert res_obj.get_node_result(0)["output"] is None + + res_obj = ct.get_result( + dispatch_id, workflow_output=True, intermediate_outputs=True, sublattice_results=False + ) + assert res_obj.result == 3 + + assert res_obj.get_node_result(0)["sublattice_result"] is None + assert res_obj.get_node_result(0)["output"].get_deserialized() == 3 + + +def test_get_result_nonexistent(): + with pytest.raises(MissingLatticeRecordError): + result_object = ct.get_result("nonexistent", wait=False) diff --git a/tests/functional_tests/triggers_test.py b/tests/functional_tests/triggers_test.py index 7490a04d9..d43af2a5c 100644 --- a/tests/functional_tests/triggers_test.py +++ b/tests/functional_tests/triggers_test.py @@ -63,7 +63,7 @@ def dir_workflow(): with open(read_file_path, "a") as f: f.write(f"{i}\n") - time.sleep(2) + time.sleep(5) with open(write_file_path, "r") as f: actual_sums = f.readlines() diff --git a/tests/functional_tests/workflow_cancellation_test.py b/tests/functional_tests/workflow_cancellation_test.py index 9ff2c8ae4..6a1ec3c26 100644 --- a/tests/functional_tests/workflow_cancellation_test.py +++ b/tests/functional_tests/workflow_cancellation_test.py @@ -48,6 +48,7 @@ def workflow(x): ct.cancel(dispatch_id) result = ct.get_result(dispatch_id, wait=True) + assert result.status == ct.status.CANCELLED rm._delete_result(dispatch_id) @@ -108,7 +109,7 @@ def workflow(x): return sub_workflow(3) dispatch_id = ct.dispatch(workflow)(3) - time.sleep(0.5) + time.sleep(1) ct.cancel(dispatch_id, task_ids=[0]) @@ -116,7 +117,6 @@ def workflow(x): tg = result.lattice.transport_graph sub_dispatch_id = tg.get_node_value(0, "sub_dispatch_id") - - print("Sublattice dispatch id:", sub_dispatch_id) - sub_res = ct.get_result(sub_dispatch_id) - assert sub_res.status == ct.status.CANCELLED + if sub_dispatch_id: + sub_res = ct.get_result(sub_dispatch_id) + assert sub_res.status == ct.status.CANCELLED diff --git a/tests/functional_tests/workflow_stack_test.py b/tests/functional_tests/workflow_stack_test.py index dc7a72d96..8d2a8f90e 100644 --- a/tests/functional_tests/workflow_stack_test.py +++ b/tests/functional_tests/workflow_stack_test.py @@ -17,13 +17,14 @@ """Workflow stack testing of TransportGraph, Lattice and Electron classes.""" import os +import tempfile import pytest import covalent as ct +import covalent._dispatcher_plugins.local as local import covalent._results_manager.results_manager as rm from covalent._results_manager.result import Result -from covalent_dispatcher._db import update def construct_temp_cache_dir(): @@ -116,7 +117,7 @@ def workflow(a, b): dispatch_id = ct.dispatch(workflow)(a=1, b=2) workflow_result = rm.get_result(dispatch_id, wait=True) - assert workflow_result.error is None + assert workflow_result.error == "" assert workflow_result.status == Result.COMPLETED assert workflow_result.result == 3 assert workflow_result.get_node_result(node_id=0)["sublattice_result"].result == 3 @@ -172,11 +173,11 @@ def test_parallelization(): def heavy_function(a): import time - time.sleep(1) + time.sleep(10) return a @ct.lattice -def workflow(x=10): +def workflow(x=2): for i in range(x): heavy_function(a=i) return x @@ -259,7 +260,7 @@ def workflow(file_path): dispatch_id = ct.dispatch(workflow)(file_path=tmp_path) res = ct.get_result(dispatch_id, wait=True) - assert res.error is None + assert res.error == "" assert res.result == (True, "Hello") @@ -605,8 +606,9 @@ def workflow(a, /, b, *args, c, **kwargs): result = rm.get_result(dispatch_id, wait=True) rm._delete_result(dispatch_id) - assert ct.TransportableObject.deserialize_list(result.inputs["args"]) == [1, 2, 3, 4] - assert ct.TransportableObject.deserialize_dict(result.inputs["kwargs"]) == { + workflow_inputs = result.inputs.get_deserialized() + assert workflow_inputs["args"] == (1, 2, 3, 4) + assert workflow_inputs["kwargs"] == { "c": 5, "d": 6, "e": 7, @@ -677,7 +679,6 @@ def workflow(): dispatch_id = ct.dispatch(workflow)() result = ct.get_result(dispatch_id, wait=True) - update.persist(result) assert result.status == Result.COMPLETED assert ( @@ -925,3 +926,26 @@ def failing_workflow(x, y): assert int(result.result) == 1 assert result.status == "COMPLETED" assert result.get_node_result(0)["start_time"] == result.get_node_result(0)["end_time"] + + +def test_multistage_dispatch_with_pull_assets(): + """Test submitting a dispatch with assets to be pulled.""" + + @ct.electron + def task(x): + return x**3 + + @ct.lattice + def workflow(x): + return task(x) + + workflow.build_graph(5) + with tempfile.TemporaryDirectory() as staging_dir: + manifest = local.LocalDispatcher.prepare_manifest(workflow, staging_dir) + return_manifest = local.LocalDispatcher.register_manifest(manifest, push_assets=False) + dispatch_id = return_manifest.metadata.dispatch_id + + local.LocalDispatcher.start(dispatch_id) + + res = rm.get_result(dispatch_id, wait=True) + assert res.result == 125 diff --git a/tests/load_tests/locustfiles/basic.py b/tests/load_tests/locustfiles/basic.py index 2071072f5..ddd80051c 100644 --- a/tests/load_tests/locustfiles/basic.py +++ b/tests/load_tests/locustfiles/basic.py @@ -44,18 +44,20 @@ def serialize_workflow(workflow, lattice_args): @task def submit_identity_workflow(self): - self.client.post("/api/submit", data=self.serialize_workflow(identity_workflow, [1])) + self.client.post( + "/api/v2/dispatches/submit", data=self.serialize_workflow(identity_workflow, [1]) + ) @task def submit_horizontal_workflow(self): self.client.post( - "/api/submit", + "/api/v2/dispatches/submit", data=self.serialize_workflow(horizontal_workflow, [random.randint(5, 10)]), ) @task def submit_add_multiply_workflow(self): self.client.post( - "/api/submit", + "/api/v2/dispatches/submit", data=self.serialize_workflow(add_multiply_workflow, [1, 2]), ) diff --git a/tests/load_tests/workflows/horizontal.py b/tests/load_tests/workflows/horizontal.py index 9bdad4135..ee525e209 100644 --- a/tests/load_tests/workflows/horizontal.py +++ b/tests/load_tests/workflows/horizontal.py @@ -1,3 +1,19 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import covalent as ct diff --git a/tests/stress_tests/benchmarks/__init__.py b/tests/stress_tests/benchmarks/__init__.py index e69de29bb..cfc23bfdf 100644 --- a/tests/stress_tests/benchmarks/__init__.py +++ b/tests/stress_tests/benchmarks/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2021 Agnostiq Inc. +# +# This file is part of Covalent. +# +# Licensed under the Apache License 2.0 (the "License"). A copy of the +# License may be obtained with this software package or at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Use of this file is prohibited except in compliance with the License. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From 95402676f2cc9c76950979f79258c1be6f0a8309 Mon Sep 17 00:00:00 2001 From: Kirill Pushkarev <71515921+kirill-push@users.noreply.github.com> Date: Thu, 19 Oct 2023 23:31:43 +0300 Subject: [PATCH 2/2] fixes #1835 (#1837) * fixes #1835 * Add Electron __pow__ method * add tests to Electron __pow__ method * current version of test_electron_pow_method * fixing test_electron_pow after a change request * Empty commit --------- Co-authored-by: Sankalp Sanand --- CHANGELOG.md | 1 + covalent/_workflow/electron.py | 4 ++++ .../covalent_tests/workflow/electron_test.py | 21 ++++++++++++++++++- 3 files changed, 25 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9dc84c08b..756d14cfb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -35,6 +35,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Documentation and test cases for database triggers. +- Added the `__pow__` method to the `Electron` class ### Docs diff --git a/covalent/_workflow/electron.py b/covalent/_workflow/electron.py index 05e4ee6d8..d8697d83b 100644 --- a/covalent/_workflow/electron.py +++ b/covalent/_workflow/electron.py @@ -155,6 +155,7 @@ def get_op_function( "-": operator.sub, "*": operator.mul, "/": operator.truediv, + "**": operator.pow, } def rename(op1: Any, op: str, op2: Any) -> Callable: @@ -243,6 +244,9 @@ def __truediv__(self, other): def __rtruediv__(self, other): return self.get_op_function(other, self, "/") + def __pow__(self, other): + return self.get_op_function(self, other, "**") + def __int__(self): return int() diff --git a/tests/covalent_tests/workflow/electron_test.py b/tests/covalent_tests/workflow/electron_test.py index 18d6cedc5..5b9f45a7c 100644 --- a/tests/covalent_tests/workflow/electron_test.py +++ b/tests/covalent_tests/workflow/electron_test.py @@ -17,7 +17,7 @@ """Unit tests for electron""" import json -from unittest.mock import MagicMock +from unittest.mock import ANY, MagicMock import pytest @@ -655,3 +655,22 @@ def workflow(x): assert ( workflow.transport_graph.get_node_value(0, "status") == RESULT_STATUS.PENDING_REPLACEMENT ) + + +def test_electron_pow_method(mocker): + mock_electron_get_op_function = mocker.patch.object( + Electron, "get_op_function", return_value=Electron + ) + + @ct.electron + def g(x): + return 42 * x + + @ct.lattice + def workflow(x): + res = g(x) + return res**2 + + workflow.build_graph(2) + + mock_electron_get_op_function.assert_called_with(ANY, 2, "**")