diff --git a/dioptra.toml b/dioptra.toml new file mode 100644 index 000000000..ce3497057 --- /dev/null +++ b/dioptra.toml @@ -0,0 +1,37 @@ +[[plugins]] +path = "plugins/hello_world" +description = "A simple plugin used for testing and demonstration purposes." + + + [[plugins.tasks]] + filename = "tasks.py" + name = "hello" + input_params = [ { name = "name", type = "string", required = true} ] + output_params = [ { name = "greeting", type = "message" } ] + + [[plugins.tasks]] + filename = "tasks.py" + name = "greet" + input_params = [ + { name = "greeting", type = "string", required = true }, + { name = "name", type = "string", required = true }, + ] + output_params = [ { name = "greeting", type = "message" } ] + + [[plugins.tasks]] + filename = "tasks.py" + name = "shout" + input_params = [ { name = "greeting", type = "message", required = true} ] + output_params = [ { name = "loud_greeting", type = "message" } ] + +[[plugin_param_types]] +name = "message" + +[[entrypoints]] +path = "examples/hello-world.yaml" +name = "Hello World" +description = "A simple example using the hello_world plugin." +params = [ + { name = "name", type = "string", default_value = "World" } +] +plugins = [ "hello_world" ] diff --git a/docs/source/index.rst b/docs/source/index.rst index 35e692e60..3b2f55d96 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -65,6 +65,7 @@ Email us: dioptra@nist.gov :caption: Reference reference/task-engine-reference + reference/resource-import-reference reference/api-reference-restapi .. reference/api-reference-sdk reference/api-reference-client diff --git a/docs/source/reference/resource-import-reference.rst b/docs/source/reference/resource-import-reference.rst new file mode 100644 index 000000000..ece031dc9 --- /dev/null +++ b/docs/source/reference/resource-import-reference.rst @@ -0,0 +1,207 @@ +=========================== + Resource Import Reference +=========================== + +This document describes the contract for importing resources into a Dioptra instance. + +.. contents:: + +Overview +======== + +Dioptra provides functionality for importing Plugins, Entrypoints, and +PluginParameterTypes. This allows users to easily publish and share +resources across Dioptra instances. It is also the mechanism used to +distribute core plugins and examples developed by the Dioptra maintainers. + +Resources are described via a combination of a TOML configuration file, +Python source code (for plugins), and YAML task graphs (for entrypoints). +The TOML file is the central configuration that references required sources +and includes metadata for fully registering the resources in Dioptra (such +as plugin task specifications, and entrypoint parameters). + +Collections of resources can be distributed via git repositories or by sharing +files manually. The resourceImport workflow can import from a repository or +archive file. See the API Reference for details. + +TOML Configuration Format +========================= + +The TOML format consists of three arrays of tables, one for each of the +importable resource types: Plugins, PluginParameterTypes, and Entrypoints. +This allows for importing of zero or more of each of these resources. + +Example +------- + +The following example illustrates how to configure a collection of resources +including a Plugin, PluginParameterType, and Entrypoint. + +.. code:: TOML + + # Plugins point to a python package and include metadata for registering them in Dioptra + [[plugins]] + # the path to the Python plugin relative to the root directory + path = "plugins/hello_world" + # an optional description + description = "A simple plugin used for testing and demonstration purposes." + + # an array of plugin task definitions + [[plugins.tasks]] + # the name of the file relative to the root plugin directory + filename = "tasks.py" + # the name must match the name of the function + name = "hello" + # input parameter names must match the Python function definition + # types are plugin parameter types and are matched by name + input_params = [ { name = "name", type = "string", required = true} ] + output_params = [ { name = "message", type = "message" } ] + + [[plugins.tasks]] + filename = "tasks.py" + name = "greet" + input_params = [ + { name = "greeting", type = "string", required = true }, + { name = "name", type = "string", required = true }, + ] + output_params = [ { name = "message", type = "message" } ] + + [[plugins.tasks]] + filename = "tasks.py" + name = "shout" + input_params = [ { name = "message", type = "message", required = true} ] + output_params = [ { name = "message", type = "message" } ] + + # PluginParameterTypes are fully described in the TOML + [[plugin_param_types]] + name = "message" + + # Entrypoints point to a task graph yaml and include metadata for registering them in Dioptra + [[entrypoints]] + # path to the task graph yaml + path = "examples/hello-world.yaml" + # the name to register the entrypoint under (the task graph filename is use if not provided) + name = "Hello World" + # an optional description + description = "A simple example using the hello_world plugin." + # entrypoint parameters to register that should match the task graph + # here, type is an entrypoint parameter type, NOT a plugin parameter type + params = [ + { name = "name", type = "string", default_value = "World" } + ] + # plugins to register with the entrypoint (matched by name) + plugins = [ "hello_world" ] + + +JSON Schema +----------- + +The following JSON schema fully describes the Dioptra resource TOML. +It is used to validate Dioptra TOML files in the resource import workflow. + +.. code:: JSON + + { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://github.com/usnistgov/dioptra", + "title": "Dioptra Resource Schema", + "description": "A schema defining objects that can be imported into a Dioptra instance.", + "type": "object", + "properties": { + "plugins": { + "type": "array", + "description": "An array of Dioptra plugins", + "items": { + "type": "object", + "description": "A Dioptra plugin", + "properties": { + "path": { "type": "string" }, + "description": { "type": "string" }, + "tasks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "filename": { "type": "string" }, + "name": { "type": "string" }, + "input_params": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "type": { "type": "string" }, + "required": { "type": "boolean" } + }, + "required": [ "name", "type" ], + "additionalProperties": false + } + }, + "output_params": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "type": { "type": "string" } + }, + "required": [ "name", "type" ], + "additionalProperties": false + } + } + }, + "required": [ "filename", "name", "input_params", "output_params" ], + "additionalProperties": false + } + } + }, + "required": [ "path" ], + "additionalProperties": false + } + }, + "plugin_param_types": { + "type": "array", + "description": "An array of Dioptra plugin parameter types", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "structure": { "type": "object" } + }, + "required": [ "name" ], + "additionalProperties": false + } + }, + "entrypoints": { + "type": "array", + "description": "An array of Dioptra entrypoints", + "items": { + "type": "object", + "properties": { + "path": { "type": "string" }, + "name": { "type": "string" }, + "description": { "type": "string" }, + "params": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "type": { "type": "string" }, + "default_value": { "type": [ "string", "number", "boolean", "null" ] } + }, + "required": [ "name", "type" ], + "additionalProperties": false + } + }, + "plugins": { + "type": "array", + "items": { "type": "string" } + } + }, + "required": [ "path" ], + "additionalProperties": false + } + } + } + } diff --git a/examples/hello-world.yaml b/examples/hello-world.yaml new file mode 100644 index 000000000..0b7609dcb --- /dev/null +++ b/examples/hello-world.yaml @@ -0,0 +1,10 @@ +hello_step: + hello: + name: $name +goodbye_step: + greet: + greeting: Goodbye + name: $name +shout_step: + shout: + greeting: $goodbye_step.greeting diff --git a/plugins/hello_world/__init__.py b/plugins/hello_world/__init__.py new file mode 100644 index 000000000..ab0a41a34 --- /dev/null +++ b/plugins/hello_world/__init__.py @@ -0,0 +1,16 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/plugins/hello_world/tasks.py b/plugins/hello_world/tasks.py new file mode 100644 index 000000000..875fb2b82 --- /dev/null +++ b/plugins/hello_world/tasks.py @@ -0,0 +1,41 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import structlog +from dioptra import pyplugs + +LOGGER = structlog.get_logger() + + +@pyplugs.register() +def hello(name: str) -> str: + message = f"Hello, {name}" + LOGGER.info(message) + return message + + +@pyplugs.register() +def greet(greeting: str, name: str) -> str: + message = f"{greeting}, {name}" + LOGGER.info(message) + return message + + +@pyplugs.register() +def shout(greeting: str) -> str: + loud_greeting = greeting.upper() + LOGGER.info(loud_greeting) + return loud_greeting diff --git a/src/dioptra/client/workflows.py b/src/dioptra/client/workflows.py index 8dfa4f6c6..f31f3ce46 100644 --- a/src/dioptra/client/workflows.py +++ b/src/dioptra/client/workflows.py @@ -15,13 +15,19 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode from pathlib import Path -from typing import ClassVar, Final, TypeVar +from typing import Any, ClassVar, Final, Literal, TypeVar, overload -from .base import CollectionClient, IllegalArgumentError +from .base import ( + CollectionClient, + DioptraFile, + DioptraResponseProtocol, + IllegalArgumentError, +) T = TypeVar("T") JOB_FILES_DOWNLOAD: Final[str] = "jobFilesDownload" +RESOURCE_IMPORT: Final[str] = "resourceImport" class WorkflowsCollectionClient(CollectionClient[T]): @@ -86,3 +92,98 @@ def download_job_files( return self._session.download( self.url, JOB_FILES_DOWNLOAD, output_path=job_files_path, params=params ) + + @overload + def import_resources( + self, + git_url: str, + config_path: str | None = "dioptra.toml", + resolve_name_conflicts_strategy: Literal["fail", "overwrite"] | None = "fail", + ) -> DioptraResponseProtocol: + """Signature for using import_resource from git repo""" + ... # pragma: nocover + + @overload + def import_resources( + self, + archive_file: DioptraFile, + config_path: str | None = "dioptra.toml", + resolve_name_conflicts_strategy: Literal["fail", "overwrite"] | None = "fail", + ) -> DioptraResponseProtocol: + """Signature for using import_resource from archive file""" + ... # pragma: nocover + + @overload + def import_resources( + self, + files: list[DioptraFile], + config_path: str | None = "dioptra.toml", + resolve_name_conflicts_strategy: Literal["fail", "overwrite"] | None = "fail", + ) -> DioptraResponseProtocol: + """Signature for using import_resource from archive file""" + ... # pragma: nocover + + def import_resources( + self, + group_id, + git_url=None, + archive_file=None, + files=None, + config_path="dioptra.toml", + resolve_name_conflicts_strategy="fail", + ): + """ + Import resources from a archive file or git repository + + Args: + group_id: The group to import resources into + source_type: The source to import from + git_url: The url to the git repository if source_type is "git" + archive_file: The contents of the upload if source_type is "upload_archive" + files: The contents of the upload if source_type is "upload_files" + config_path: The path to the toml configuration file in the import source. + resolve_name_conflicts_strategy: The strategy for resolving name conflicts. + Either "fail" or "overwrite" + Raises: + IllegalArgumentError: If more than one import source is provided or if no + import source is provided. + """ + + import_source_args = [git_url, archive_file, files] + num_provided_import_source_args = sum( + arg is not None for arg in import_source_args + ) + + if num_provided_import_source_args == 0: + raise IllegalArgumentError( + "One of (git_url, archive_file, or files) must be provided" + ) + elif num_provided_import_source_args > 1: + raise IllegalArgumentError( + "Only one of (git_url, archive_file and files) can be provided" + ) + + data: dict[str, Any] = {"group": str(group_id)} + files_: dict[str, DioptraFile | list[DioptraFile]] = {} + + if git_url is not None: + data["sourceType"] = "git" + data["gitUrl"] = git_url + + if archive_file is not None: + data["sourceType"] = "upload_archive" + files_["archiveFile"] = archive_file + + if files is not None: + data["sourceType"] = "upload_files" + files_["files"] = files + + if config_path is not None: + data["configPath"] = config_path + + if resolve_name_conflicts_strategy is not None: + data["resolveNameConflictsStrategy"] = resolve_name_conflicts_strategy + + return self._session.post( + self.url, RESOURCE_IMPORT, data=data, files=files_ or None + ) diff --git a/src/dioptra/restapi/errors.py b/src/dioptra/restapi/errors.py index 1bb547ec9..1616532c2 100644 --- a/src/dioptra/restapi/errors.py +++ b/src/dioptra/restapi/errors.py @@ -339,6 +339,20 @@ def __init__(self, message: str): super().__init__(message) +class GitError(DioptraError): + """Git Error.""" + + def __init__(self, message: str): + super().__init__(message) + + +class ImportFailedError(DioptraError): + """Import failed Error.""" + + def __init__(self, message: str): + super().__init__(message) + + def error_result( error: DioptraError, status: http.HTTPStatus, detail: dict[str, typing.Any] ) -> tuple[dict[str, typing.Any], int]: @@ -443,6 +457,16 @@ def handle_mlflow_error(error: MLFlowError): log.debug(error.to_message()) return error_result(error, http.HTTPStatus.INTERNAL_SERVER_ERROR, {}) + @api.errorhandler(GitError) + def handle_git_error(error: GitError): + log.debug(error.to_message()) + return error_result(error, http.HTTPStatus.INTERNAL_SERVER_ERROR, {}) + + @api.errorhandler(GitError) + def handle_import_failed_error(error: ImportFailedError): + log.debug(error.to_message()) + return error_result(error, http.HTTPStatus.BAD_REQUEST, {}) + @api.errorhandler(DioptraError) def handle_base_error(error: DioptraError): log.debug(error.to_message()) diff --git a/src/dioptra/restapi/utils.py b/src/dioptra/restapi/utils.py index ed564151e..4da0a5f5d 100644 --- a/src/dioptra/restapi/utils.py +++ b/src/dioptra/restapi/utils.py @@ -151,7 +151,7 @@ def create_parameters_schema( location = "files" parameters_schema = ParametersSchema( - name=cast(str, field.name), + name=cast(str, field.data_key or field.name), type=parameter_type, location=location, required=field.required, @@ -310,6 +310,7 @@ def setup_injection(api: Api, injector: Injector) -> None: ma.Decimal: float, ma.Dict: dict, ma.Email: str, + ma.Enum: str, FileUpload: FileStorage, MultiFileUpload: FileStorage, ma.Float: float, diff --git a/src/dioptra/restapi/v1/plugin_parameter_types/service.py b/src/dioptra/restapi/v1/plugin_parameter_types/service.py index e4ce40fa8..9034f2669 100644 --- a/src/dioptra/restapi/v1/plugin_parameter_types/service.py +++ b/src/dioptra/restapi/v1/plugin_parameter_types/service.py @@ -566,6 +566,57 @@ def __init__( """ self._group_id_service = group_id_service + def get( + self, + group_id: int, + error_if_not_found: bool = False, + **kwargs, + ) -> list[models.PluginTaskParameterType]: + """Fetch a list of builtin plugin parameter types. + + Args: + group_id: The the group id of the plugin parameter type. + error_if_not_found: If True, raise an error if the plugin parameter + type is not found. Defaults to False. + + Returns: + The plugin parameter type object if found, otherwise None. + + Raises: + PluginParameterTypeDoesNotExistError: If the plugin parameter type + is not found and `error_if_not_found` is True. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug( + "Get builtin plugin parameter types", + group_id=group_id, + ) + + builtin_types = list(BUILTIN_TYPES.keys()) + + stmt = ( + select(models.PluginTaskParameterType) + .join(models.Resource) + .where( + models.PluginTaskParameterType.name.in_(builtin_types), + models.Resource.group_id == group_id, + models.Resource.is_deleted == False, # noqa: E712 + models.Resource.latest_snapshot_id + == models.PluginTaskParameterType.resource_snapshot_id, + ) + ) + plugin_parameter_types = list(db.session.scalars(stmt).all()) + + if len(plugin_parameter_types) != len(builtin_types): + retrieved_names = {param_type.name for param_type in plugin_parameter_types} + missing_names = set(builtin_types) - retrieved_names + if error_if_not_found: + raise EntityDoesNotExistError( + RESOURCE_TYPE, missing_names=missing_names + ) + + return plugin_parameter_types + def create_all( self, user: models.User, diff --git a/src/dioptra/restapi/v1/plugins/service.py b/src/dioptra/restapi/v1/plugins/service.py index fadb12ec4..5e6da90f2 100644 --- a/src/dioptra/restapi/v1/plugins/service.py +++ b/src/dioptra/restapi/v1/plugins/service.py @@ -711,7 +711,7 @@ def create( self, filename: str, contents: str, - description: str, + description: str | None, tasks: list[dict[str, Any]], plugin_id: int, commit: bool = True, diff --git a/src/dioptra/restapi/v1/workflows/controller.py b/src/dioptra/restapi/v1/workflows/controller.py index 428619cdc..e469b6e23 100644 --- a/src/dioptra/restapi/v1/workflows/controller.py +++ b/src/dioptra/restapi/v1/workflows/controller.py @@ -25,8 +25,14 @@ from injector import inject from structlog.stdlib import BoundLogger -from .schema import FileTypes, JobFilesDownloadQueryParametersSchema -from .service import JobFilesDownloadService +from dioptra.restapi.utils import as_api_parser, as_parameters_schema_list + +from .schema import ( + FileTypes, + JobFilesDownloadQueryParametersSchema, + ResourceImportSchema, +) +from .service import JobFilesDownloadService, ResourceImportService LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -78,3 +84,50 @@ def get(self): mimetype=mimetype[parsed_query_params["file_type"]], download_name=download_name[parsed_query_params["file_type"]], ) + + +@api.route("/resourceImport") +class ResourceImport(Resource): + @inject + def __init__( + self, resource_import_service: ResourceImportService, *args, **kwargs + ) -> None: + """Initialize the workflow resource. + + All arguments are provided via dependency injection. + + Args: + resource_import_service: A ResourceImportService object. + """ + self._resource_import_service = resource_import_service + super().__init__(*args, **kwargs) + + @login_required + @api.expect( + as_api_parser( + api, + as_parameters_schema_list( + ResourceImportSchema, operation="load", location="form" + ), + ) + ) + @accepts(form_schema=ResourceImportSchema, api=api) + def post(self): + """Import resources from an external source.""" # noqa: B950 + log = LOGGER.new( # noqa: F841 + request_id=str(uuid.uuid4()), resource="ResourceImport", request_type="POST" + ) + parsed_form = request.parsed_form + + return self._resource_import_service.import_resources( + group_id=parsed_form["group_id"], + source_type=parsed_form["source_type"], + git_url=parsed_form.get("git_url", None), + archive_file=request.files.get("archiveFile", None), + files=request.files.getlist("files", None), + config_path=parsed_form["config_path"], + resolve_name_conflicts_strategy=parsed_form[ + "resolve_name_conflicts_strategy" + ], + log=log, + ) diff --git a/src/dioptra/restapi/v1/workflows/dioptra-resources.schema.json b/src/dioptra/restapi/v1/workflows/dioptra-resources.schema.json new file mode 100644 index 000000000..ea7ef143f --- /dev/null +++ b/src/dioptra/restapi/v1/workflows/dioptra-resources.schema.json @@ -0,0 +1,104 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://github.com/usnistgov/dioptra", + "title": "Dioptra Resource Schema", + "description": "A schema defining objects that can be imported into a Dioptra instance.", + "type": "object", + "properties": { + "plugins": { + "type": "array", + "description": "An array of Dioptra plugins", + "items": { + "type": "object", + "description": "A Dioptra plugin", + "properties": { + "path": { "type": "string" }, + "description": { "type": "string" }, + "tasks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "filename": { "type": "string" }, + "name": { "type": "string" }, + "input_params": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "type": { "type": "string" }, + "required": { "type": "boolean" } + }, + "required": [ "name", "type" ], + "additionalProperties": false + } + }, + "output_params": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "type": { "type": "string" } + }, + "required": [ "name", "type" ], + "additionalProperties": false + } + } + }, + "required": [ "filename", "name", "input_params", "output_params" ], + "additionalProperties": false + } + } + }, + "required": [ "path" ], + "additionalProperties": false + } + }, + "plugin_param_types": { + "type": "array", + "description": "An array of Dioptra plugin parameter types", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "structure": { "type": "object" } + }, + "required": [ "name" ], + "additionalProperties": false + } + }, + "entrypoints": { + "type": "array", + "description": "An array of Dioptra entrypoints", + "items": { + "type": "object", + "properties": { + "path": { "type": "string" }, + "name": { "type": "string" }, + "description": { "type": "string" }, + "params": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "type": { "type": "string" }, + "default_value": { "type": [ "string", "number", "boolean", "null" ] } + }, + "required": [ "name", "type" ], + "additionalProperties": false + } + }, + "plugins": { + "type": "array", + "items": { "type": "string" } + } + }, + "required": [ "path" ], + "additionalProperties": false + } + } + } +} diff --git a/src/dioptra/restapi/v1/workflows/lib/__init__.py b/src/dioptra/restapi/v1/workflows/lib/__init__.py index db520e904..577ea2000 100644 --- a/src/dioptra/restapi/v1/workflows/lib/__init__.py +++ b/src/dioptra/restapi/v1/workflows/lib/__init__.py @@ -15,12 +15,14 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode from . import views +from .clone_git_repository import clone_git_repository from .export_plugin_files import export_plugin_files from .export_task_engine_yaml import export_task_engine_yaml from .package_job_files import package_job_files __all__ = [ "views", + "clone_git_repository", "export_plugin_files", "export_task_engine_yaml", "package_job_files", diff --git a/src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py b/src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py new file mode 100644 index 000000000..f4a1105c5 --- /dev/null +++ b/src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py @@ -0,0 +1,57 @@ +import subprocess +from pathlib import Path +from urllib.parse import urlparse + + +def clone_git_repository(url: str, dir: Path) -> str: + parsed_url = urlparse(url) + git_branch = parsed_url.fragment or None + git_paths = parsed_url.params or None + git_url = parsed_url._replace(fragment="", params="").geturl() + + git_env = ["GIT_TERMINAL_PROMPT=0"] + git_sparse_args = ["--filter=blob:none", "--no-checkout", "--depth=1"] + git_branch_args = ["-b", git_branch] if git_branch else [] + clone_cmd = ["git", "clone", *git_sparse_args, *git_branch_args, git_url, str(dir)] + clone_result = subprocess.run(git_env + clone_cmd, capture_output=True, text=True) + + if clone_result.returncode != 0: + raise subprocess.CalledProcessError( + clone_result.returncode, clone_result.stderr + ) + + if git_paths is not None: + paths = git_paths.split(",") + sparse_checkout_cmd = ["git", "sparse-checkout", "set", "--cone", *paths] + sparse_checkout_result = subprocess.run( + sparse_checkout_cmd, cwd=dir, capture_output=True, text=True + ) + + if sparse_checkout_result.returncode != 0: + raise subprocess.CalledProcessError( + sparse_checkout_result.returncode, sparse_checkout_result.stderr + ) + + checkout_cmd = ["git", "checkout"] + checkout_result = subprocess.run( + checkout_cmd, cwd=dir, capture_output=True, text=True + ) + + if checkout_result.returncode != 0: + raise subprocess.CalledProcessError( + checkout_result.returncode, checkout_result.stderr + ) + + hash_cmd = ["git", "rev-parse", "HEAD"] + hash_result = subprocess.run(hash_cmd, cwd=dir, capture_output=True, text=True) + + if hash_result.returncode != 0: + raise subprocess.CalledProcessError(hash_result.returncode, hash_result.stderr) + + return str(hash) + + +if __name__ == "__main__": + clone_git_repository( + "https://github.com/usnistgov/dioptra.git;plugins#dev", Path("dioptra-plugins") + ) diff --git a/src/dioptra/restapi/v1/workflows/schema.py b/src/dioptra/restapi/v1/workflows/schema.py index 92ea28ec7..4b7fcda5e 100644 --- a/src/dioptra/restapi/v1/workflows/schema.py +++ b/src/dioptra/restapi/v1/workflows/schema.py @@ -17,7 +17,9 @@ """The schemas for serializing/deserializing Workflow resources.""" from enum import Enum -from marshmallow import Schema, fields +from marshmallow import Schema, ValidationError, fields, validates_schema + +from dioptra.restapi.custom_schema_fields import FileUpload, MultiFileUpload class FileTypes(Enum): @@ -41,3 +43,86 @@ class JobFilesDownloadQueryParametersSchema(Schema): by_value=True, default=FileTypes.TAR_GZ.value, ) + + +class ResourceImportSourceTypes(Enum): + GIT = "git" + UPLOAD_ARCHIVE = "upload_archive" + UPLOAD_FILES = "upload_files" + + +class ResourceImportResolveNameConflictsStrategy(Enum): + FAIL = "fail" + OVERWRITE = "overwrite" + + +class ResourceImportSchema(Schema): + """The request schema for importing resources""" + + groupId = fields.Integer( + attribute="group_id", + data_key="group", + metadata=dict( + description="ID of the Group that will own the imported resources." + ), + required=True, + ) + sourceType = fields.Enum( + ResourceImportSourceTypes, + attribute="source_type", + metadata=dict( + description="The source of the resources to import" + "('upload_archive', 'upload_files', or 'git'." + ), + by_value=True, + required=True, + ) + gitUrl = fields.String( + attribute="git_url", + metadata=dict( + description="The URL of the git repository containing resources to import." + "A git branch can optionally be specified by appending #BRANCH_NAME. " + "Used when sourceType is 'git'." + ), + required=False, + ) + archiveFile = FileUpload( + attribute="archive_file", + metadata=dict( + type="file", + format="binary", + description="The archive file containing resources to import (.tar.gz)." + "Used when sourceType is 'upload_archive'.", + ), + required=False, + ) + files = MultiFileUpload( + attribute="files", + metadata=dict( + type="file", + format="binary", + description="The files containing the resources to import." + "Used when sourceType is 'upload_files'.", + ), + required=False, + ) + configPath = fields.String( + attribute="config_path", + metdata=dict(description="The path to the toml configuration file."), + load_default="dioptra.toml", + ) + resolveNameConflictsStrategy = fields.Enum( + ResourceImportResolveNameConflictsStrategy, + attribute="resolve_name_conflicts_strategy", + metadata=dict(description="Strategy for resolving resource name conflicts"), + by_value=True, + load_default=ResourceImportResolveNameConflictsStrategy.FAIL.value, + ) + + @validates_schema + def validate_source(self, data, **kwargs): + if ( + data["source_type"] == ResourceImportSourceTypes.GIT + and "git_url" not in data + ): + raise ValidationError({"gitUrl": "field required when sourceType is 'git'"}) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index d5769e274..8ea0182b7 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -15,19 +15,58 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """The server-side functions that perform workflows endpoint operations.""" -from typing import IO, Final +import json +import tarfile +from collections import defaultdict +from hashlib import sha256 +from io import BytesIO +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import IO, Any, Final, cast +import jsonschema import structlog +import toml +from injector import inject from structlog.stdlib import BoundLogger +from werkzeug.datastructures import FileStorage + +from dioptra.restapi.db import db, models +from dioptra.restapi.errors import GitError, ImportFailedError +from dioptra.restapi.v1.entrypoints.service import ( + EntrypointIdService, + EntrypointNameService, + EntrypointService, +) +from dioptra.restapi.v1.plugin_parameter_types.service import ( + BuiltinPluginParameterTypeService, + PluginParameterTypeIdService, + PluginParameterTypeNameService, + PluginParameterTypeService, +) +from dioptra.restapi.v1.plugins.service import ( + PluginIdFileService, + PluginIdService, + PluginNameService, + PluginService, +) +from dioptra.sdk.utilities.paths import set_cwd from .lib import views +from .lib.clone_git_repository import clone_git_repository from .lib.package_job_files import package_job_files -from .schema import FileTypes +from .schema import ( + FileTypes, + ResourceImportResolveNameConflictsStrategy, + ResourceImportSourceTypes, +) LOGGER: BoundLogger = structlog.stdlib.get_logger() RESOURCE_TYPE: Final[str] = "workflow" +DIOPTRA_RESOURCES_SCHEMA_PATH: Final[str] = "dioptra-resources.schema.json" + class JobFilesDownloadService(object): """The service methods for packaging job files for download.""" @@ -65,3 +104,386 @@ def get(self, job_id: int, file_type: FileTypes, **kwargs) -> IO[bytes]: file_type=file_type, logger=log, ) + + +class ResourceImportService(object): + """The service methods for packaging job files for download.""" + + @inject + def __init__( + self, + plugin_service: PluginService, + plugin_id_service: PluginIdService, + plugin_name_service: PluginNameService, + plugin_id_file_service: PluginIdFileService, + plugin_parameter_type_service: PluginParameterTypeService, + plugin_parameter_type_id_service: PluginParameterTypeIdService, + plugin_parameter_type_name_service: PluginParameterTypeNameService, + builtin_plugin_parameter_type_service: BuiltinPluginParameterTypeService, + entrypoint_service: EntrypointService, + entrypoint_id_service: EntrypointIdService, + entrypoint_name_service: EntrypointNameService, + ) -> None: + """Initialize the resource import service. + + All arguments are provided via dependency injection. + + Args: + plugin_service: A PluginService object, + plugin_name_service: A PluginNameService object. + plugin_id_service: A PluginIdService object. + plugin_id_file_service: A PluginIdFileService object. + plugin_parameter_type_service: A PluginParameterTypeService object. + plugin_parameter_type_id_service: A PluginParameterTypeIdService object. + plugin_parameter_type_name_service: A PluginParameterTypeNameService object. + builtin_plugin_parameter_type_service: A BuiltinPluginParameterTypeService + object. + entrypoint_service: An EntrypointService object. + entrypoint_id_service: An EntrypointIdService object. + entrypoint_name_service: An EntrypointNameService object. + """ + self._plugin_service = plugin_service + self._plugin_id_service = plugin_id_service + self._plugin_name_service = plugin_name_service + self._plugin_id_file_service = plugin_id_file_service + self._plugin_parameter_type_service = plugin_parameter_type_service + self._plugin_parameter_type_id_service = plugin_parameter_type_id_service + self._plugin_parameter_type_name_service = plugin_parameter_type_name_service + self._builtin_plugin_parameter_type_service = ( + builtin_plugin_parameter_type_service + ) + self._entrypoint_service = entrypoint_service + self._entrypoint_id_service = entrypoint_id_service + self._entrypoint_name_service = entrypoint_name_service + + def import_resources( + self, + group_id: int, + source_type: str, + git_url: str | None, + archive_file: FileStorage | None, + files: list[FileStorage] | None, + config_path: str, + resolve_name_conflicts_strategy: str, + **kwargs, + ) -> dict[str, Any]: + """Import resources from a archive file or git repository + + Args: + group_id: The group to import resources into + source_type: The source to import from (either "upload" or "git") + git_url: The url to the git repository if source_type is "git" + archive_file: The contents of the upload if source_type is "upload_archive" + files: The contents of the upload if source_type is "upload_files" + config_path: The path to the toml configuration file in the import source. + resolve_name_conflicts_strategy: The strategy for resolving name conflicts. + Either "fail" or "overwrite" + + Returns: + A message summarizing imported resources + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug("Import resources", group_id=group_id) + + overwrite = ( + resolve_name_conflicts_strategy + == ResourceImportResolveNameConflictsStrategy.OVERWRITE + ) + + with TemporaryDirectory() as tmp_dir, set_cwd(tmp_dir): + working_dir = Path(tmp_dir) + + if source_type == ResourceImportSourceTypes.UPLOAD_ARCHIVE: + bytes = archive_file.stream.read() + try: + with tarfile.open(fileobj=BytesIO(bytes), mode="r:*") as tar: + tar.extractall(path=working_dir, filter="data") + except Exception as e: + raise ImportFailedError("Failed to read uploaded tarfile") from e + hash = str(sha256(bytes).hexdigest()) + elif source_type == ResourceImportSourceTypes.UPLOAD_FILES: + hashes = b"" + for file in files: + Path(file.filename).parent.mkdir(parents=True, exist_ok=True) + bytes = file.stream.read() + with open(working_dir / file.filename, "wb") as f: + f.write(bytes) + hashes = hashes + sha256(bytes).digest() + hash = str(sha256(hashes).hexdigest()) + else: + try: + hash = clone_git_repository(cast(str, git_url), working_dir) + except Exception as e: + raise GitError(f"Failed to clone repository: {git_url}") from e + + try: + config = toml.load(working_dir / config_path) + except Exception as e: + raise ImportFailedError( + f"Failed to load resource import config from {config_path}." + ) from e + + # validate the config file + with open( + Path(__file__).resolve().parent / DIOPTRA_RESOURCES_SCHEMA_PATH, "rb" + ) as f: + schema = json.load(f) + jsonschema.validate(config, schema) + + param_types = self._register_plugin_param_types( + group_id, config.get("plugin_param_types", []), overwrite, log=log + ) + plugins = self._register_plugins( + group_id, config.get("plugins", []), param_types, overwrite, log=log + ) + entrypoints = self._register_entrypoints( + group_id, config.get("entrypoints", []), plugins, overwrite, log=log + ) + + db.session.commit() + + return { + "message": "successfully imported", + "hash": hash, + "resources": { + "plugins": list(plugins.keys()), + "plugin_param_types": list(param_types.keys()), + "entrypoints": list(entrypoints.keys()), + }, + } + + def _register_plugin_param_types( + self, + group_id: int, + param_types_config: list[dict[str, Any]], + overwrite: bool, + log: BoundLogger, + ) -> dict[str, models.PluginTaskParameterType]: + """ + Registers a list of PluginParameterTypes. + + Args: + group_id: The identifier of the group that will manage imported resources + param_types_config: A list of dictionaries describing a plugin param types + overwrite: Whether imported resources should replace existing resources with + a conflicting name + + Returns: + A dictionary mapping newly registered PluginParameterType names to the ORM + object + """ + + param_types = dict() + for param_type in param_types_config: + if overwrite: + existing = self._plugin_parameter_type_name_service.get( + param_type["name"], group_id=group_id, log=log + ) + if existing: + self._plugin_parameter_type_id_service.delete( + plugin_parameter_type_id=existing.resource_id, + log=log, + ) + + param_type_dict = self._plugin_parameter_type_service.create( + name=param_type["name"], + description=param_type.get("description", None), + structure=param_type.get("structure", None), + group_id=group_id, + commit=False, + log=log, + ) + param_types[param_type["name"]] = param_type_dict[ + "plugin_task_parameter_type" + ] + + db.session.flush() + + return param_types + + def _register_plugins( + self, + group_id: int, + plugins_config: list[dict[str, Any]], + param_types: dict[str, models.PluginTaskParameterType], + overwrite: bool, + log: BoundLogger, + ) -> dict[str, models.PluginTaskParameterType]: + """ + Registers a list of Plugins and their PluginFiles. + + Args: + group_id: The identifier of the group that will manage imported resources + plugins_config: A list of dictionaries describing a plugin and its tasks + param_types: A dictionary mapping param type name to the ORM object + overwrite: Whether imported resources should replace existing resources with + a conflicting name + + Returns: + A dictionary mapping newly registered Plugin names to the ORM objects + """ + + param_types = param_types.copy() + builtin_param_types = self._builtin_plugin_parameter_type_service.get( + group_id=group_id, error_if_not_found=False, log=log + ) + param_types.update( + {param_type.name: param_type for param_type in builtin_param_types} + ) + + plugins = {} + for plugin in plugins_config: + if overwrite: + existing = self._plugin_name_service.get( + Path(plugin["path"]).stem, group_id=group_id + ) + if existing: + self._plugin_id_service.delete( + plugin_id=existing.resource_id, + log=log, + ) + + plugin_dict = self._plugin_service.create( + name=Path(plugin["path"]).stem, + description=plugin.get("description", None), + group_id=group_id, + commit=False, + log=log, + ) + plugins[plugin_dict["plugin"].name] = plugin_dict["plugin"] + db.session.flush() + + tasks = self._build_tasks(plugin.get("tasks", []), param_types) + for plugin_file_path in Path(plugin["path"]).rglob("*.py"): + filename = str(plugin_file_path.relative_to(plugin["path"])) + try: + contents = plugin_file_path.read_text() + except FileNotFoundError as e: + raise ImportFailedError( + f"Failed to read plugin file from {plugin_file_path}" + ) from e + + self._plugin_id_file_service.create( + filename, + contents=contents, + description=None, + tasks=tasks[filename], + plugin_id=plugin_dict["plugin"].resource_id, + commit=False, + log=log, + ) + + db.session.flush() + + return plugins + + def _register_entrypoints( + self, + group_id: int, + entrypoints_config: list[dict[str, Any]], + plugins, + overwrite: bool, + log: BoundLogger, + ) -> dict[str, models.EntryPoint]: + """ + Registers a list of Entrypoints + + Args: + group_id: The identifier of the group that will manage imported resources + entrypoints_config: A list of dictionaries describing entrypoints + plugins: A dictionary mapping Plugin names to the ORM objects + overwrite: Whether imported resources should replace existing resources with + a conflicting name + + Returns: + A dictionary mapping newly registered Entrypoint names to ORM object + """ + + entrypoints = dict() + for entrypoint in entrypoints_config: + if overwrite: + existing = self._entrypoint_name_service.get( + entrypoint["name"], group_id=group_id, log=log + ) + if existing is not None: + self._entrypoint_id_service.delete( + entrypoint_id=existing.resource_id + ) + + try: + contents = Path(entrypoint["path"]).read_text() + except FileNotFoundError as e: + raise ImportFailedError( + f"Failed to read plugin file from {entrypoint['path']}" + ) from e + + params = [ + { + "name": param["name"], + "parameter_type": param["type"], + "default_value": param.get("default_value", None), + } + for param in entrypoint.get("params", []) + ] + plugin_ids = [ + plugins[plugin].resource_id for plugin in entrypoint.get("plugins", []) + ] + entrypoint_dict = self._entrypoint_service.create( + name=entrypoint.get("name", Path(entrypoint["path"]).stem), + description=entrypoint.get("description", None), + task_graph=contents, + parameters=params, + plugin_ids=plugin_ids, + queue_ids=[], + group_id=group_id, + commit=False, + log=log, + ) + entrypoints[entrypoint_dict["entry_point"].name] = entrypoint_dict[ + "entry_point" + ] + + db.session.flush() + + return entrypoints + + def _build_tasks( + self, + tasks_config: list[dict[str, Any]], + param_types: dict[str, models.PluginTaskParameterType], + ) -> dict[str, list]: + """ + Builds dictionaries describing plugin tasks from a configuration file + + Args: + tasks_config: A list of dictionaries describing plugin tasks + param_types: A dictionary mapping param type name to the ORM object + + Returns: + A dictionary mapping PluginFile name to a list of tasks + """ + + tasks = defaultdict(list) + for task in tasks_config: + tasks[task["filename"]].append( + { + "name": task["name"], + "description": task.get("description", None), + "input_params": [ + { + "name": param["name"], + "parameter_type_id": param_types[param["type"]].resource_id, + "required": param.get("required", False), + } + for param in task["input_params"] + ], + "output_params": [ + { + "name": param["name"], + "parameter_type_id": param_types[param["type"]].resource_id, + } + for param in task["output_params"] + ], + } + ) + return tasks diff --git a/tests/unit/restapi/v1/conftest.py b/tests/unit/restapi/v1/conftest.py index 676725d26..44290f0e6 100644 --- a/tests/unit/restapi/v1/conftest.py +++ b/tests/unit/restapi/v1/conftest.py @@ -15,12 +15,17 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """Fixtures representing resources needed for test suites""" +import os +import tarfile import textwrap from collections.abc import Iterator from http import HTTPStatus +from pathlib import Path +from tempfile import NamedTemporaryFile from typing import Any, cast import pytest +import toml import uuid from flask import Flask from flask.testing import FlaskClient @@ -28,6 +33,12 @@ from injector import Injector from pytest import MonkeyPatch +from dioptra.client import ( + DioptraFile, + select_files_in_directory, + select_one_or_more_files, +) + from ..lib import actions, mock_rq @@ -720,3 +731,31 @@ def registered_mlflowrun_incomplete( ) return responses + + +@pytest.fixture +def resources_tar_file() -> DioptraFile: + os.chdir(Path(__file__).absolute().parent / "resource_import_files") + + with NamedTemporaryFile(suffix=".tar.gz", delete=False) as f: + with tarfile.open(fileobj=f, mode="w:gz") as tar: + tar.add("dioptra.toml") + tar.add("plugins", recursive=True) + tar.add(Path("examples", "hello-world.yaml")) + + yield select_one_or_more_files([f.name])[0] + + os.unlink(f.name) + + +@pytest.fixture +def resources_files() -> DioptraFile: + os.chdir(Path(__file__).absolute().parent / "resource_import_files") + + return select_files_in_directory(".", recursive=True) + + +@pytest.fixture +def resources_import_config() -> dict[str, Any]: + root_dir = Path(__file__).absolute().parent / "resource_import_files" + return toml.load(root_dir / "dioptra.toml") diff --git a/tests/unit/restapi/v1/resource_import_files/dioptra.toml b/tests/unit/restapi/v1/resource_import_files/dioptra.toml new file mode 100644 index 000000000..ce3497057 --- /dev/null +++ b/tests/unit/restapi/v1/resource_import_files/dioptra.toml @@ -0,0 +1,37 @@ +[[plugins]] +path = "plugins/hello_world" +description = "A simple plugin used for testing and demonstration purposes." + + + [[plugins.tasks]] + filename = "tasks.py" + name = "hello" + input_params = [ { name = "name", type = "string", required = true} ] + output_params = [ { name = "greeting", type = "message" } ] + + [[plugins.tasks]] + filename = "tasks.py" + name = "greet" + input_params = [ + { name = "greeting", type = "string", required = true }, + { name = "name", type = "string", required = true }, + ] + output_params = [ { name = "greeting", type = "message" } ] + + [[plugins.tasks]] + filename = "tasks.py" + name = "shout" + input_params = [ { name = "greeting", type = "message", required = true} ] + output_params = [ { name = "loud_greeting", type = "message" } ] + +[[plugin_param_types]] +name = "message" + +[[entrypoints]] +path = "examples/hello-world.yaml" +name = "Hello World" +description = "A simple example using the hello_world plugin." +params = [ + { name = "name", type = "string", default_value = "World" } +] +plugins = [ "hello_world" ] diff --git a/tests/unit/restapi/v1/resource_import_files/examples/hello-world.yaml b/tests/unit/restapi/v1/resource_import_files/examples/hello-world.yaml new file mode 100644 index 000000000..0b7609dcb --- /dev/null +++ b/tests/unit/restapi/v1/resource_import_files/examples/hello-world.yaml @@ -0,0 +1,10 @@ +hello_step: + hello: + name: $name +goodbye_step: + greet: + greeting: Goodbye + name: $name +shout_step: + shout: + greeting: $goodbye_step.greeting diff --git a/tests/unit/restapi/v1/resource_import_files/plugins/hello_world/__init__.py b/tests/unit/restapi/v1/resource_import_files/plugins/hello_world/__init__.py new file mode 100644 index 000000000..ab0a41a34 --- /dev/null +++ b/tests/unit/restapi/v1/resource_import_files/plugins/hello_world/__init__.py @@ -0,0 +1,16 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/tests/unit/restapi/v1/resource_import_files/plugins/hello_world/tasks.py b/tests/unit/restapi/v1/resource_import_files/plugins/hello_world/tasks.py new file mode 100644 index 000000000..875fb2b82 --- /dev/null +++ b/tests/unit/restapi/v1/resource_import_files/plugins/hello_world/tasks.py @@ -0,0 +1,41 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import structlog +from dioptra import pyplugs + +LOGGER = structlog.get_logger() + + +@pyplugs.register() +def hello(name: str) -> str: + message = f"Hello, {name}" + LOGGER.info(message) + return message + + +@pyplugs.register() +def greet(greeting: str, name: str) -> str: + message = f"{greeting}, {name}" + LOGGER.info(message) + return message + + +@pyplugs.register() +def shout(greeting: str) -> str: + loud_greeting = greeting.upper() + LOGGER.info(loud_greeting) + return loud_greeting diff --git a/tests/unit/restapi/v1/test_workflow_resource_import.py b/tests/unit/restapi/v1/test_workflow_resource_import.py new file mode 100644 index 000000000..f997c3d2c --- /dev/null +++ b/tests/unit/restapi/v1/test_workflow_resource_import.py @@ -0,0 +1,142 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Test suite for the resource import workflow + +This module contains a set of tests that validate the CRUD operations and additional +functionalities for the queue entity. The tests ensure that the queues can be +registered, renamed, deleted, and locked/unlocked as expected through the REST API. +""" + +from pathlib import Path +from tempfile import NamedTemporaryFile +from typing import Any + +from flask_sqlalchemy import SQLAlchemy + +from dioptra.client import DioptraClient, DioptraFile +from dioptra.client.base import DioptraResponseProtocol + + +# -- Assertions ------------------------------------------------------------------------ + + +def assert_imported_resources_match_expected( + dioptra_client: DioptraClient[DioptraResponseProtocol], + expected: dict[str, Any], +): + response = dioptra_client.plugins.get() + response_plugins = set(plugin["name"] for plugin in response.json()["data"]) + expected_plugins = set(Path(plugin["path"]).stem for plugin in expected["plugins"]) + assert response.status_code == 200 and response_plugins == expected_plugins + + response = dioptra_client.plugin_parameter_types.get() + response_types = set(param["name"] for param in response.json()["data"]) + expected_types = set(param["name"] for param in expected["plugin_param_types"]) + assert ( + response.status_code == 200 + and response_types & expected_types == expected_types + ) + + response = dioptra_client.entrypoints.get() + response_entrypoints = set(ep["name"] for ep in response.json()["data"]) + expected_entrypoints = set(ep["name"] for ep in expected["entrypoints"]) + assert response.status_code == 200 and response_entrypoints == expected_entrypoints + + +def assert_resource_import_fails_due_to_name_clash( + dioptra_client: DioptraClient[DioptraResponseProtocol], + group_id: int, + archive_file: DioptraFile, +): + dioptra_client.plugins.create(group_id=group_id, name="hello_world") + response = dioptra_client.workflows.import_resources( + group_id=group_id, + archive_file=archive_file, + resolve_name_conflicts_strategy="fail", + ) + + assert response.status_code == 409 + + +def assert_resource_import_overwrite_works( + dioptra_client: DioptraClient[DioptraResponseProtocol], + group_id: int, + archive_file: DioptraFile, +): + dioptra_client.plugins.create(group_id=group_id, name="hello_world") + response = dioptra_client.workflows.import_resources( + group_id=group_id, + archive_file=archive_file, + resolve_name_conflicts_strategy="overwrite", + ) + + assert response.status_code == 200 + + +# -- Tests ----------------------------------------------------------------------------- + + +def test_resource_import_from_archive_file( + dioptra_client: DioptraClient[DioptraResponseProtocol], + db: SQLAlchemy, + auth_account: dict[str, Any], + resources_tar_file: DioptraFile, + resources_import_config: dict[str, Any], +): + group_id = auth_account["groups"][0]["id"] + dioptra_client.workflows.import_resources(group_id, archive_file=resources_tar_file) + + assert_imported_resources_match_expected(dioptra_client, resources_import_config) + + +def test_resource_import_from_files( + dioptra_client: DioptraClient[DioptraResponseProtocol], + db: SQLAlchemy, + auth_account: dict[str, Any], + resources_files: list[DioptraFile], + resources_import_config: dict[str, Any], +): + group_id = auth_account["groups"][0]["id"] + dioptra_client.workflows.import_resources(group_id, files=resources_files) + + assert_imported_resources_match_expected(dioptra_client, resources_import_config) + + +def test_resource_import_fails_from_name_clash( + dioptra_client: DioptraClient[DioptraResponseProtocol], + db: SQLAlchemy, + auth_account: dict[str, Any], + resources_tar_file: NamedTemporaryFile, +): + group_id = auth_account["groups"][0]["id"] + + assert_resource_import_fails_due_to_name_clash( + dioptra_client, group_id=group_id, archive_file=resources_tar_file + ) + + +def test_resource_import_overwrite( + dioptra_client: DioptraClient[DioptraResponseProtocol], + db: SQLAlchemy, + auth_account: dict[str, Any], + resources_tar_file: NamedTemporaryFile, +): + group_id = auth_account["groups"][0]["id"] + + assert_resource_import_overwrite_works( + dioptra_client, group_id=group_id, archive_file=resources_tar_file + ) diff --git a/tox.ini b/tox.ini index 6df5e06f7..bb0291a61 100644 --- a/tox.ini +++ b/tox.ini @@ -186,6 +186,7 @@ deps = types-PyYAML types-redis types-requests + types-toml typing-extensions>=3.7.4.3 skip_install = false commands = mypy {posargs:"{tox_root}{/}src{/}dioptra" "{tox_root}{/}task-plugins{/}dioptra_builtins"}