From ed01570c72c95bf689f4efec48f4e4e2a907ac3c Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Fri, 18 Oct 2024 09:17:14 -0400 Subject: [PATCH 01/28] feat: import resources from git repo or archive file --- dioptra.toml | 38 ++++ entrypoints/hello-world.yaml | 3 + examples/hello-world.yaml | 10 + plugins/hello_world/__init__.py | 16 ++ plugins/hello_world/tasks.py | 41 ++++ src/dioptra/restapi/utils.py | 1 + .../v1/plugin_parameter_types/service.py | 50 +++++ .../restapi/v1/workflows/controller.py | 58 ++++- .../restapi/v1/workflows/lib/__init__.py | 2 + .../v1/workflows/lib/clone_git_repository.py | 56 +++++ src/dioptra/restapi/v1/workflows/schema.py | 87 +++++++- src/dioptra/restapi/v1/workflows/service.py | 206 +++++++++++++++++- 12 files changed, 561 insertions(+), 7 deletions(-) create mode 100644 dioptra.toml create mode 100644 entrypoints/hello-world.yaml create mode 100644 examples/hello-world.yaml create mode 100644 plugins/hello_world/__init__.py create mode 100644 plugins/hello_world/tasks.py create mode 100644 src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py diff --git a/dioptra.toml b/dioptra.toml new file mode 100644 index 000000000..295bc3f3e --- /dev/null +++ b/dioptra.toml @@ -0,0 +1,38 @@ +[[plugins]] +path = "plugins/hello_world" +description = "A simple plugin used for testing and demonstration purposes." + + + [[plugins.tasks]] + filename = "tasks.py" + name = "hello" + input_params = [ { name = "name", type = "string", required = true} ] + output_params = [ { name = "message", type = "string" } ] + + [[plugins.tasks]] + filename = "tasks.py" + name = "greet" + input_params = [ + { name = "greeting", type = "string", required = true }, + { name = "name", type = "string", required = true }, + ] + output_params = [ { name = "message", type = "string" } ] + + [[plugins.tasks]] + filename = "tasks.py" + name = "shout" + input_params = [ { name = "message", type = "string", required = true} ] + output_params = [ { name = "message", type = "string" } ] + +[[plugin_param_types]] +name = "image" +structure = { list = ["int", "int", "int"] } + +[[entrypoints]] +path = "examples/hello-world.yaml" +name = "Hello World" +description = "A simple example using the hello_world plugin." +params = [ + { name = "name", type = "string", default_value = "World" } +] +plugins = [ "hello_world" ] diff --git a/entrypoints/hello-world.yaml b/entrypoints/hello-world.yaml new file mode 100644 index 000000000..49be5fe26 --- /dev/null +++ b/entrypoints/hello-world.yaml @@ -0,0 +1,3 @@ +message: + greet: + name: $name diff --git a/examples/hello-world.yaml b/examples/hello-world.yaml new file mode 100644 index 000000000..3ab57da83 --- /dev/null +++ b/examples/hello-world.yaml @@ -0,0 +1,10 @@ +hello_step: + hello: + name: $name +goodbye_step: + greet: + greeting: Goodbye + name: $name +shout_step: + shout: + message: $goodbye_step.message diff --git a/plugins/hello_world/__init__.py b/plugins/hello_world/__init__.py new file mode 100644 index 000000000..ab0a41a34 --- /dev/null +++ b/plugins/hello_world/__init__.py @@ -0,0 +1,16 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/plugins/hello_world/tasks.py b/plugins/hello_world/tasks.py new file mode 100644 index 000000000..45b9ccda2 --- /dev/null +++ b/plugins/hello_world/tasks.py @@ -0,0 +1,41 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import structlog +from dioptra import pyplugs + +LOGGER = structlog.get_logger() + + +@pyplugs.register() +def hello(name: str) -> str: + message = f"Hello, {name}" + LOGGER.info(message) + return message + + +@pyplugs.register() +def greet(greeting: str, name: str) -> str: + message = f"{greeting}, {name}" + LOGGER.info(message) + return message + + +@pyplugs.register() +def shout(message: str) -> str: + message = message.upper() + LOGGER.info(message) + return message diff --git a/src/dioptra/restapi/utils.py b/src/dioptra/restapi/utils.py index ed564151e..7236d875f 100644 --- a/src/dioptra/restapi/utils.py +++ b/src/dioptra/restapi/utils.py @@ -310,6 +310,7 @@ def setup_injection(api: Api, injector: Injector) -> None: ma.Decimal: float, ma.Dict: dict, ma.Email: str, + ma.Enum: str, FileUpload: FileStorage, MultiFileUpload: FileStorage, ma.Float: float, diff --git a/src/dioptra/restapi/v1/plugin_parameter_types/service.py b/src/dioptra/restapi/v1/plugin_parameter_types/service.py index e4ce40fa8..728157f05 100644 --- a/src/dioptra/restapi/v1/plugin_parameter_types/service.py +++ b/src/dioptra/restapi/v1/plugin_parameter_types/service.py @@ -566,6 +566,56 @@ def __init__( """ self._group_id_service = group_id_service + def get( + self, + group_id: int, + error_if_not_found: bool = False, + **kwargs, + ) -> models.PluginTaskParameterType | None: + """Fetch a list of plugin parameter types by their names. + + Args: + group_id: The the group id of the plugin parameter type. + error_if_not_found: If True, raise an error if the plugin parameter + type is not found. Defaults to False. + + Returns: + The plugin parameter type object if found, otherwise None. + + Raises: + PluginParameterTypeDoesNotExistError: If the plugin parameter type + is not found and `error_if_not_found` is True. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug( + "Get builtin plugin parameter types", + group_id=group_id, + ) + + builtin_types = list(BUILTIN_TYPES.keys()) + + stmt = ( + select(models.PluginTaskParameterType) + .join(models.Resource) + .where( + models.PluginTaskParameterType.name.in_(builtin_types), + models.Resource.group_id == group_id, + models.Resource.is_deleted == False, # noqa: E712 + models.Resource.latest_snapshot_id + == models.PluginTaskParameterType.resource_snapshot_id, + ) + ) + plugin_parameter_types = list(db.session.scalars(stmt).all()) + + if len(plugin_parameter_types) != len(builtin_types): + retrieved_names = {param_type.name for param_type in plugin_parameter_types} + missing_names = set(builtin_types) - retrieved_names + if error_if_not_found: + log.debug("Plugin Parameter Type(s) not found", names=missing_names) + raise PluginParameterTypeDoesNotExistError + + return plugin_parameter_types + def create_all( self, user: models.User, diff --git a/src/dioptra/restapi/v1/workflows/controller.py b/src/dioptra/restapi/v1/workflows/controller.py index 428619cdc..5265d8709 100644 --- a/src/dioptra/restapi/v1/workflows/controller.py +++ b/src/dioptra/restapi/v1/workflows/controller.py @@ -25,8 +25,14 @@ from injector import inject from structlog.stdlib import BoundLogger -from .schema import FileTypes, JobFilesDownloadQueryParametersSchema -from .service import JobFilesDownloadService +from dioptra.restapi.utils import as_api_parser, as_parameters_schema_list + +from .schema import ( + FileTypes, + JobFilesDownloadQueryParametersSchema, + ResourceImportSchema, +) +from .service import JobFilesDownloadService, ResourceImportService LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -78,3 +84,51 @@ def get(self): mimetype=mimetype[parsed_query_params["file_type"]], download_name=download_name[parsed_query_params["file_type"]], ) + + +@api.route("/resourceImport") +class ResourceImport(Resource): + @inject + def __init__( + self, resource_import_service: ResourceImportService, *args, **kwargs + ) -> None: + """Initialize the workflow resource. + + All arguments are provided via dependency injection. + + Args: + resource_import_service: A ResourceImportService object. + """ + self._resource_import_service = resource_import_service + super().__init__(*args, **kwargs) + + @login_required + @api.expect( + as_api_parser( + api, + as_parameters_schema_list( + ResourceImportSchema, operation="load", location="form" + ), + ) + ) + @accepts(form_schema=ResourceImportSchema, api=api) + def post(self): + """Import resources from an external source.""" # noqa: B950 + log = LOGGER.new( # noqa: F841 + request_id=str(uuid.uuid4()), resource="ResourceImport", request_type="POST" + ) + parsed_form = request.parsed_form + + log.info("HERE") + return self._resource_import_service.import_resources( + group_id=parsed_form["group_id"], + source_type=parsed_form["source_type"], + git_url=parsed_form.get("git_url", None), + archive_file=request.files.get("archiveFile", None), + config_path=parsed_form["config_path"], + read_only=parsed_form["read_only"], + resolve_name_conflicts_strategy=parsed_form[ + "resolve_name_conflicts_strategy" + ], + log=log, + ) diff --git a/src/dioptra/restapi/v1/workflows/lib/__init__.py b/src/dioptra/restapi/v1/workflows/lib/__init__.py index db520e904..577ea2000 100644 --- a/src/dioptra/restapi/v1/workflows/lib/__init__.py +++ b/src/dioptra/restapi/v1/workflows/lib/__init__.py @@ -15,12 +15,14 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode from . import views +from .clone_git_repository import clone_git_repository from .export_plugin_files import export_plugin_files from .export_task_engine_yaml import export_task_engine_yaml from .package_job_files import package_job_files __all__ = [ "views", + "clone_git_repository", "export_plugin_files", "export_task_engine_yaml", "package_job_files", diff --git a/src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py b/src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py new file mode 100644 index 000000000..3fa682bf1 --- /dev/null +++ b/src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py @@ -0,0 +1,56 @@ +import subprocess +from pathlib import Path +from urllib.parse import urlparse + + +def clone_git_repository(url: str, dir: Path): + parsed_url = urlparse(url) + git_branch = parsed_url.fragment or None + git_paths = parsed_url.params or None + git_url = parsed_url._replace(fragment="", params="").geturl() + + git_sparse_args = ["--filter=blob:none", "--no-checkout", "--depth=1"] + git_branch_args = ["-b", git_branch] if git_branch else [] + clone_cmd = ["git", "clone", *git_sparse_args, *git_branch_args, git_url, dir] + clone_result = subprocess.run(clone_cmd, capture_output=True, text=True) + + if clone_result.returncode != 0: + raise subprocess.CalledProcessError( + clone_result.returncode, clone_result.stderr + ) + + if git_paths is not None: + paths = git_paths.split(",") + sparse_checkout_cmd = ["git", "sparse-checkout", "set", "--cone", *paths] + sparse_checkout_result = subprocess.run( + sparse_checkout_cmd, cwd=dir, capture_output=True, text=True + ) + + if sparse_checkout_result.returncode != 0: + raise subprocess.CalledProcessError( + sparse_checkout_result.returncode, sparse_checkout_result.stderr + ) + + checkout_cmd = ["git", "checkout"] + checkout_result = subprocess.run( + checkout_cmd, cwd=dir, capture_output=True, text=True + ) + + if checkout_result.returncode != 0: + raise subprocess.CalledProcessError( + checkout_result.returncode, checkout_result.stderr + ) + + hash_cmd = ["git", "rev-parse", "HEAD"] + hash_result = subprocess.run(hash_cmd, cwd=dir, capture_output=True, text=True) + + if hash_result.returncode != 0: + raise subprocess.CalledProcessError + + return hash + + +if __name__ == "__main__": + clone_git_repository( + "https://github.com/usnistgov/dioptra.git;plugins#dev", Path("dioptra-plugins") + ) diff --git a/src/dioptra/restapi/v1/workflows/schema.py b/src/dioptra/restapi/v1/workflows/schema.py index 92ea28ec7..60fad7d28 100644 --- a/src/dioptra/restapi/v1/workflows/schema.py +++ b/src/dioptra/restapi/v1/workflows/schema.py @@ -17,7 +17,9 @@ """The schemas for serializing/deserializing Workflow resources.""" from enum import Enum -from marshmallow import Schema, fields +from marshmallow import Schema, ValidationError, fields, validates_schema + +from dioptra.restapi.custom_schema_fields import FileUpload class FileTypes(Enum): @@ -41,3 +43,86 @@ class JobFilesDownloadQueryParametersSchema(Schema): by_value=True, default=FileTypes.TAR_GZ.value, ) + + +class ResourceImportSourceTypes(Enum): + GIT = "git" + UPLOAD = "upload" + + +class ResourceImportResolveNameConflictsStrategy(Enum): + FAIL = "fail" + OVERWRITE = "overwrite" + + +class ResourceImportSchema(Schema): + """The request schema for importing resources""" + + groupId = fields.Integer( + attribute="group_id", + # data_key="group", + metadata=dict( + description="ID of the Group that will own the imported resources." + ), + required=True, + ) + sourceType = fields.Enum( + ResourceImportSourceTypes, + attribute="source_type", + metadata=dict(description="The source of the resources to import."), + by_value=True, + required=True, + ) + gitUrl = fields.String( + attribute="git_url", + metadata=dict( + description="The URL of the git repository containing resources to import. " + "A git branch can optionally be specified by appending #BRANCH_NAME. " + "Used when sourceType is 'git'." + ), + required=False, + ) + archiveFile = FileUpload( + attribute="archive_file", + metadata=dict( + type="file", + format="binary", + description="The archive file containing resources to import (.tar.gz). " + "Used when sourceType is 'upload'.", + ), + required=False, + ) + configPath = fields.String( + attribute="config_path", + metdata=dict(description="The path to the toml configuration file."), + load_default="dioptra.toml", + ) + readOnly = fields.Bool( + attribute="read_only", + metadata=dict(description="Whether imported resources should be readonly."), + load_default=False, + ) + resolveNameConflictsStrategy = fields.Enum( + ResourceImportResolveNameConflictsStrategy, + attribute="resolve_name_conflicts_strategy", + metadata=dict(description="Strategy for resolving resource name conflicts"), + by_value=True, + load_default=ResourceImportResolveNameConflictsStrategy.FAIL.value, + ) + + @validates_schema + def validate_source(self, data, **kwargs): + if ( + data["source_type"] == ResourceImportSourceTypes.GIT + and "git_url" not in data + ): + raise ValidationError({"gitUrl": "field required when sourceType is 'git'"}) + + # 'upload' is not in data + # if ( + # data["source_type"] == ResourceImportSourceTypes.UPLOAD + # and "data" not in data + # ): + # raise ValidationError( + # {"data": "field required when sourceType is 'upload'"} + # ) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index d5769e274..dc7c714e7 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -15,14 +15,31 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """The server-side functions that perform workflows endpoint operations.""" -from typing import IO, Final +import tarfile +from collections import defaultdict +from hashlib import sha256 +from io import BytesIO +from pathlib import Path +from tempfile import TemporaryDirectory +from typing import IO, Any, Final import structlog +import toml +from injector import inject from structlog.stdlib import BoundLogger +from werkzeug.datastructures import FileStorage -from .lib import views -from .lib.package_job_files import package_job_files -from .schema import FileTypes +from dioptra.restapi.db import db +from dioptra.restapi.v1.entrypoints.service import EntrypointService +from dioptra.restapi.v1.plugin_parameter_types.service import ( + BuiltinPluginParameterTypeService, + PluginParameterTypeService, +) +from dioptra.restapi.v1.plugins.service import PluginIdFileService, PluginService +from dioptra.sdk.utilities.paths import set_cwd + +from .lib import clone_git_repository, package_job_files, views +from .schema import FileTypes, ResourceImportSourceTypes LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -65,3 +82,184 @@ def get(self, job_id: int, file_type: FileTypes, **kwargs) -> IO[bytes]: file_type=file_type, logger=log, ) + + +class ResourceImportService(object): + """The service methods for packaging job files for download.""" + + @inject + def __init__( + self, + plugin_service: PluginService, + plugin_id_file_service: PluginIdFileService, + plugin_parameter_type_service: PluginParameterTypeService, + builtin_plugin_parameter_type_service: BuiltinPluginParameterTypeService, + entrypoint_service: EntrypointService, + ) -> None: + """Initialize the resource import service. + + All arguments are provided via dependency injection. + + Args: + plugin_name_service: A PluginNameService object. + plugin_id_file_service: A PluginIdFileService object. + plugin_parameter_type_service: A PluginParameterTypeService object. + builtin_plugin_parameter_type_service: A BuiltinPluginParameterTypeService object. + entrypoint_service: A EntrypointService object. + """ + self._plugin_service = plugin_service + self._plugin_id_file_service = plugin_id_file_service + self._plugin_parameter_type_service = plugin_parameter_type_service + self._builtin_plugin_parameter_type_service = ( + builtin_plugin_parameter_type_service + ) + self._entrypoint_service = entrypoint_service + + def import_resources( + self, + group_id: int, + source_type: str, + git_url: str | None, + archive_file: FileStorage | None, + config_path: str, + read_only: bool, + resolve_name_conflicts_strategy: str, + **kwargs, + ) -> dict[str, Any]: + """Import resources from a archive file or git repository + + Args: + group_id: The group to import resources into + + Returns: + A message summarizing imported resources + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug("Import resources", group_id=group_id) + + with TemporaryDirectory() as tmp_dir, set_cwd(tmp_dir): + working_dir = Path(tmp_dir) + + if source_type == ResourceImportSourceTypes.UPLOAD: + bytes = archive_file.stream.read() + with tarfile.open(fileobj=BytesIO(bytes), mode="r:*") as tar: + tar.extractall(path=working_dir, filter="data") + hash = sha256(bytes).hexdigest() + elif source_type == ResourceImportSourceTypes.GIT: + hash = clone_git_repository(git_url, working_dir) + + log.info(hash=hash, paths=list(working_dir.glob("*"))) + + config_path = working_dir / config_path + if not config_path.exists() or not config_path.is_file(): + raise Exception + + config = toml.load(config_path) + + # register new plugin param types + param_types = { + param_type["name"]: self._plugin_parameter_type_service.create( + name=param_type["name"], + description=param_type.get("description", None), + structure=param_type.get("structure", None), + group_id=group_id, + commit=False, + log=log, + ) + for param_type in config.get("plugin_param_types", []) + } + # retrieve built-ins + param_types.update( + { + param_type.name: param_type + for param_type in self._builtin_plugin_parameter_type_service.get( + group_id=group_id, error_if_not_found=False, log=log + ) + } + ) + db.session.flush() + + # register new plugins + plugin_ids = {} + for plugin in config.get("plugins", []): + plugin_dict = self._plugin_service.create( + name=Path(plugin["path"]).stem, + description=plugin.get("description", None), + group_id=group_id, + commit=False, + log=log, + ) + db.session.flush() + plugin_ids[plugin_dict["plugin"].name] = plugin_dict[ + "plugin" + ].resource_id + + tasks = _build_tasks(plugin.get("tasks", []), param_types) + for plugin_file_path in Path(plugin["path"]).rglob("*.py"): + filename = str(plugin_file_path.relative_to(plugin["path"])) + contents = plugin_file_path.read_text() + + plugin_file_dict = self._plugin_id_file_service.create( + filename, + contents=contents, + description=None, + tasks=tasks[filename], + plugin_id=plugin_dict["plugin"].resource_id, + commit=False, + ) + + # register new entrypoints + for entrypoint in config.get("entrypoints", []): + contents = Path(entrypoint["path"]).read_text() + params = [ + { + "name": param["name"], + "parameter_type": param["type"], + "default_value": param.get("default_value", None), + } + for param in entrypoint.get("params", []) + ] + self._entrypoint_service.create( + name=entrypoint.get("name", Path(entrypoint["path"]).stem), + description=entrypoint.get("description", None), + task_graph=contents, + parameters=params, + plugin_ids=[ + plugin_ids[plugin] for plugin in entrypoint.get("plugins", []) + ], + queue_ids=[], + group_id=group_id, + commit=False, + log=log, + ) + + db.session.commit() + + return {"message": "successfully imported"} + + +def _build_tasks(tasks_config, param_types): + tasks = defaultdict(list) + for task in tasks_config: + tasks[task["filename"]].append( + { + "name": task["name"], + "description": task.get("description", None), + "input_params": [ + { + "name": param["name"], + "parameter_type_id": param_types[param["type"]].resource_id, + "required": param.get("required", False), + } + for param in task["input_params"] + ], + "output_params": [ + { + "name": param["name"], + "parameter_type_id": param_types[param["type"]].resource_id, + } + for param in task["output_params"] + ], + } + ) + return tasks From fa23b1cd7272dd64b883c34c315343fdad0c0c6a Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Tue, 24 Sep 2024 09:22:46 -0400 Subject: [PATCH 02/28] feat: add json schema to validate dioptra toml --- entrypoints/hello-world.yaml | 3 - .../workflows/dioptra-resources.schema.json | 104 ++++++++++++++++++ src/dioptra/restapi/v1/workflows/service.py | 21 +++- 3 files changed, 121 insertions(+), 7 deletions(-) delete mode 100644 entrypoints/hello-world.yaml create mode 100644 src/dioptra/restapi/v1/workflows/dioptra-resources.schema.json diff --git a/entrypoints/hello-world.yaml b/entrypoints/hello-world.yaml deleted file mode 100644 index 49be5fe26..000000000 --- a/entrypoints/hello-world.yaml +++ /dev/null @@ -1,3 +0,0 @@ -message: - greet: - name: $name diff --git a/src/dioptra/restapi/v1/workflows/dioptra-resources.schema.json b/src/dioptra/restapi/v1/workflows/dioptra-resources.schema.json new file mode 100644 index 000000000..ea7ef143f --- /dev/null +++ b/src/dioptra/restapi/v1/workflows/dioptra-resources.schema.json @@ -0,0 +1,104 @@ +{ + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://github.com/usnistgov/dioptra", + "title": "Dioptra Resource Schema", + "description": "A schema defining objects that can be imported into a Dioptra instance.", + "type": "object", + "properties": { + "plugins": { + "type": "array", + "description": "An array of Dioptra plugins", + "items": { + "type": "object", + "description": "A Dioptra plugin", + "properties": { + "path": { "type": "string" }, + "description": { "type": "string" }, + "tasks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "filename": { "type": "string" }, + "name": { "type": "string" }, + "input_params": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "type": { "type": "string" }, + "required": { "type": "boolean" } + }, + "required": [ "name", "type" ], + "additionalProperties": false + } + }, + "output_params": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "type": { "type": "string" } + }, + "required": [ "name", "type" ], + "additionalProperties": false + } + } + }, + "required": [ "filename", "name", "input_params", "output_params" ], + "additionalProperties": false + } + } + }, + "required": [ "path" ], + "additionalProperties": false + } + }, + "plugin_param_types": { + "type": "array", + "description": "An array of Dioptra plugin parameter types", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "structure": { "type": "object" } + }, + "required": [ "name" ], + "additionalProperties": false + } + }, + "entrypoints": { + "type": "array", + "description": "An array of Dioptra entrypoints", + "items": { + "type": "object", + "properties": { + "path": { "type": "string" }, + "name": { "type": "string" }, + "description": { "type": "string" }, + "params": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "type": { "type": "string" }, + "default_value": { "type": [ "string", "number", "boolean", "null" ] } + }, + "required": [ "name", "type" ], + "additionalProperties": false + } + }, + "plugins": { + "type": "array", + "items": { "type": "string" } + } + }, + "required": [ "path" ], + "additionalProperties": false + } + } + } +} diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index dc7c714e7..ef03ce5d2 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -15,6 +15,8 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """The server-side functions that perform workflows endpoint operations.""" +import json +import jsonschema import tarfile from collections import defaultdict from hashlib import sha256 @@ -45,6 +47,8 @@ RESOURCE_TYPE: Final[str] = "workflow" +DIOPTRA_RESOURCES_SCHEMA_PATH: Final[str] = "dioptra-resources.schema.json" + class JobFilesDownloadService(object): """The service methods for packaging job files for download.""" @@ -130,6 +134,12 @@ def import_resources( Args: group_id: The group to import resources into + source_type: The source to import from (either "upload" or "git") + git_url: The url to the git repository if source_type is "git" + archive_file: The contents of the upload if source_type is "upload" + read_only: Whether to apply a readonly lock to all imported resources + resolve_name_conflicts_strategy: The strategy for resolving name conflicts. + Either "fail" or "overwrite" Returns: A message summarizing imported resources @@ -150,11 +160,14 @@ def import_resources( log.info(hash=hash, paths=list(working_dir.glob("*"))) - config_path = working_dir / config_path - if not config_path.exists() or not config_path.is_file(): - raise Exception + config = toml.load(working_dir / config_path) - config = toml.load(config_path) + # validate the config file + with open( + Path(__file__).resolve().parent / DIOPTRA_RESOURCES_SCHEMA_PATH, "rb" + ) as f: + schema = json.load(f) + jsonschema.validate(config, schema) # register new plugin param types param_types = { From f0a49e908470d058e49fed95b64e39f36dd77ddc Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Fri, 18 Oct 2024 09:21:11 -0400 Subject: [PATCH 03/28] feat: implement overwrite name conflict resolution strategy --- src/dioptra/restapi/v1/entrypoints/service.py | 20 ++++++++--- .../v1/plugin_parameter_types/service.py | 20 ++++++++--- src/dioptra/restapi/v1/plugins/service.py | 33 ++++++++++++++----- .../restapi/v1/workflows/controller.py | 1 - src/dioptra/restapi/v1/workflows/service.py | 15 ++++++++- 5 files changed, 68 insertions(+), 21 deletions(-) diff --git a/src/dioptra/restapi/v1/entrypoints/service.py b/src/dioptra/restapi/v1/entrypoints/service.py index a8f3cf711..a1681eb18 100644 --- a/src/dioptra/restapi/v1/entrypoints/service.py +++ b/src/dioptra/restapi/v1/entrypoints/service.py @@ -94,6 +94,7 @@ def create( plugin_ids: list[int], queue_ids: list[int], group_id: int, + replace_existing: bool = False, commit: bool = True, **kwargs, ) -> utils.EntrypointDict: @@ -104,6 +105,8 @@ def create( be unique. description: The description of the entrypoint. group_id: The group that will own the entrypoint. + replace_existing: If True and a resource already exists with this + name, delete it instead of raising an exception commit: If True, commit the transaction. Defaults to True. Returns: @@ -114,11 +117,18 @@ def create( """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - duplicate = self._entrypoint_name_service.get(name, group_id=group_id, log=log) - if duplicate is not None: - raise EntityExistsError( - RESOURCE_TYPE, duplicate.resource_id, name=name, group_id=group_id - ) + existing = self._entrypoint_name_service.get(name, group_id=group_id, log=log) + if existing is not None: + if replace_existing: + deleted_resource_lock = models.ResourceLock( + resource_lock_type=resource_lock_types.DELETE, + resource=existing.resource, + ) + db.session.add(deleted_resource_lock) + else: + raise EntityExistsError( + RESOURCE_TYPE, duplicate.resource_id, name=name, group_id=group_id + ) group = self._group_id_service.get(group_id, error_if_not_found=True) queues = self._queue_ids_service.get(queue_ids, error_if_not_found=True) diff --git a/src/dioptra/restapi/v1/plugin_parameter_types/service.py b/src/dioptra/restapi/v1/plugin_parameter_types/service.py index 728157f05..fb7613111 100644 --- a/src/dioptra/restapi/v1/plugin_parameter_types/service.py +++ b/src/dioptra/restapi/v1/plugin_parameter_types/service.py @@ -85,6 +85,7 @@ def create( structure: dict[str, Any], description: str, group_id: int, + replace_existing: bool = False, commit: bool = True, **kwargs, ) -> utils.PluginParameterTypeDict: @@ -98,6 +99,8 @@ def create( type's structure. description: The description of the plugin parameter type. group_id: The group that will own the plugin parameter type. + replace_existing: If True and a resource already exists with this + name, delete it instead of raising an exception commit: If True, commit the transaction. Defaults to True. Returns: @@ -119,13 +122,20 @@ def create( ) raise PluginParameterTypeMatchesBuiltinTypeError - duplicate = self._plugin_parameter_type_name_service.get( + existing = self._plugin_parameter_type_name_service.get( name, group_id=group_id, log=log ) - if duplicate is not None: - raise EntityExistsError( - RESOURCE_TYPE, duplicate.resource_id, name=name, group_id=group_id - ) + if existing is not None: + if replace_existing: + deleted_resource_lock = models.ResourceLock( + resource_lock_type=resource_lock_types.DELETE, + resource=existing.resource, + ) + db.session.add(deleted_resource_lock) + else: + raise EntityExistsError( + RESOURCE_TYPE, duplicate.resource_id, name=name, group_id=group_id + ) group = self._group_id_service.get(group_id, error_if_not_found=True) diff --git a/src/dioptra/restapi/v1/plugins/service.py b/src/dioptra/restapi/v1/plugins/service.py index fadb12ec4..43944edba 100644 --- a/src/dioptra/restapi/v1/plugins/service.py +++ b/src/dioptra/restapi/v1/plugins/service.py @@ -85,7 +85,13 @@ def __init__( self._group_id_service = group_id_service def create( - self, name: str, description: str, group_id: int, commit: bool = True, **kwargs + self, + name: str, + description: str, + group_id: int, + replace_existing: bool = False, + commit: bool = True, + **kwargs, ) -> utils.PluginWithFilesDict: """Create a new plugin. @@ -94,6 +100,8 @@ def create( unique. description: The description of the plugin. group_id: The group that will own the plugin. + replace_existing: If True and a resource already exists with this + name, delete it instead of raising an exception commit: If True, commit the transaction. Defaults to True. Returns: @@ -104,14 +112,21 @@ def create( """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - duplicate = self._plugin_name_service.get(name, group_id=group_id, log=log) - if duplicate is not None: - raise EntityExistsError( - PLUGIN_RESOURCE_TYPE, - duplicate.resource_id, - name=name, - group_id=group_id, - ) + existing = self._plugin_name_service.get(name, group_id=group_id, log=log) + if existing is not None: + if replace_existing: + deleted_resource_lock = models.ResourceLock( + resource_lock_type=resource_lock_types.DELETE, + resource=existing.resource, + ) + db.session.add(deleted_resource_lock) + else: + raise EntityExistsError( + PLUGIN_RESOURCE_TYPE, + duplicate.resource_id, + name=name, + group_id=group_id, + ) group = self._group_id_service.get(group_id, error_if_not_found=True) diff --git a/src/dioptra/restapi/v1/workflows/controller.py b/src/dioptra/restapi/v1/workflows/controller.py index 5265d8709..140074823 100644 --- a/src/dioptra/restapi/v1/workflows/controller.py +++ b/src/dioptra/restapi/v1/workflows/controller.py @@ -119,7 +119,6 @@ def post(self): ) parsed_form = request.parsed_form - log.info("HERE") return self._resource_import_service.import_resources( group_id=parsed_form["group_id"], source_type=parsed_form["source_type"], diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index ef03ce5d2..81fa68cf8 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -41,7 +41,11 @@ from dioptra.sdk.utilities.paths import set_cwd from .lib import clone_git_repository, package_job_files, views -from .schema import FileTypes, ResourceImportSourceTypes +from .schema import ( + FileTypes, + ResourceImportSourceTypes, + ResourceImportResolveNameConflictsStrategy, +) LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -147,6 +151,11 @@ def import_resources( log: BoundLogger = kwargs.get("log", LOGGER.new()) log.debug("Import resources", group_id=group_id) + replace_existing = ( + resolve_name_conflicts_strategy + == ResourceImportResolveNameConflictsStrategy.OVERWRITE + ) + with TemporaryDirectory() as tmp_dir, set_cwd(tmp_dir): working_dir = Path(tmp_dir) @@ -176,6 +185,7 @@ def import_resources( description=param_type.get("description", None), structure=param_type.get("structure", None), group_id=group_id, + replace_existing=replace_existing, commit=False, log=log, ) @@ -199,6 +209,7 @@ def import_resources( name=Path(plugin["path"]).stem, description=plugin.get("description", None), group_id=group_id, + replace_existing=replace_existing, commit=False, log=log, ) @@ -219,6 +230,7 @@ def import_resources( tasks=tasks[filename], plugin_id=plugin_dict["plugin"].resource_id, commit=False, + log=log, ) # register new entrypoints @@ -242,6 +254,7 @@ def import_resources( ], queue_ids=[], group_id=group_id, + replace_existing=replace_existing, commit=False, log=log, ) From 919295e483a1edd75ead0f22dd737ff32c88c636 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Fri, 18 Oct 2024 09:25:55 -0400 Subject: [PATCH 04/28] feat: implement readOnly flag for import workflow --- src/dioptra/restapi/v1/entrypoints/service.py | 29 +++++++++-- .../v1/plugin_parameter_types/service.py | 9 ++++ src/dioptra/restapi/v1/plugins/service.py | 49 +++++++++++++++++++ src/dioptra/restapi/v1/workflows/service.py | 8 ++- 4 files changed, 90 insertions(+), 5 deletions(-) diff --git a/src/dioptra/restapi/v1/entrypoints/service.py b/src/dioptra/restapi/v1/entrypoints/service.py index a1681eb18..cf4d5ead7 100644 --- a/src/dioptra/restapi/v1/entrypoints/service.py +++ b/src/dioptra/restapi/v1/entrypoints/service.py @@ -94,6 +94,7 @@ def create( plugin_ids: list[int], queue_ids: list[int], group_id: int, + read_only: bool = False, replace_existing: bool = False, commit: bool = True, **kwargs, @@ -105,6 +106,7 @@ def create( be unique. description: The description of the entrypoint. group_id: The group that will own the entrypoint. + read_only: If True, apply a read only lock to the resource replace_existing: If True and a resource already exists with this name, delete it instead of raising an exception commit: If True, commit the transaction. Defaults to True. @@ -170,6 +172,13 @@ def create( queue_resources = [queue.resource for queue in queues] new_entrypoint.children.extend(plugin_resources + queue_resources) + if read_only: + db.session.add( + models.ResourceLock( + resource_lock_type=resource_lock_types.READONLY, + resource=resource, + ) + ) db.session.add(new_entrypoint) if commit: @@ -448,10 +457,17 @@ def modify( entrypoint = entrypoint_dict["entry_point"] group_id = entrypoint.resource.group_id - if name != entrypoint.name: - duplicate = self._entrypoint_name_service.get( - name, group_id=group_id, log=log + + if entrypoint.resource.is_readonly: + log.debug( + "The Entrypoint is read-only and cannot be modified", + entrypoint_id=entrypoint.resource_id, + name=entrypoint.name, ) + raise EntrypointReadOnlyLockError + + if name != entrypoint.name: + duplicate = self._entrypoint_name_service.get(name, group_id=group_id, log=log) if duplicate is not None: raise EntityExistsError( RESOURCE_TYPE, duplicate.resource_id, name=name, group_id=group_id @@ -529,6 +545,13 @@ def delete(self, entrypoint_id: int, **kwargs) -> dict[str, Any]: if entrypoint_resource is None: raise EntityDoesNotExistError(RESOURCE_TYPE, entrypoint_id=entrypoint_id) + if entrypoint_resource.is_readonly: + log.debug( + "The Entrypoint is read-only and cannot be deleted", + entrypoint_id=entrypoint_resource.resource_id, + ) + raise EntrypointReadOnlyLockError + deleted_resource_lock = models.ResourceLock( resource_lock_type=resource_lock_types.DELETE, resource=entrypoint_resource, diff --git a/src/dioptra/restapi/v1/plugin_parameter_types/service.py b/src/dioptra/restapi/v1/plugin_parameter_types/service.py index fb7613111..e73f5199e 100644 --- a/src/dioptra/restapi/v1/plugin_parameter_types/service.py +++ b/src/dioptra/restapi/v1/plugin_parameter_types/service.py @@ -85,6 +85,7 @@ def create( structure: dict[str, Any], description: str, group_id: int, + read_only: bool = False, replace_existing: bool = False, commit: bool = True, **kwargs, @@ -99,6 +100,7 @@ def create( type's structure. description: The description of the plugin parameter type. group_id: The group that will own the plugin parameter type. + read_only: If True, apply a read only lock to the resource replace_existing: If True and a resource already exists with this name, delete it instead of raising an exception commit: If True, commit the transaction. Defaults to True. @@ -147,6 +149,13 @@ def create( resource=resource, creator=current_user, ) + if read_only: + db.session.add( + models.ResourceLock( + resource_lock_type=resource_lock_types.READONLY, + resource=resource, + ) + ) db.session.add(new_plugin_parameter_type) if commit: diff --git a/src/dioptra/restapi/v1/plugins/service.py b/src/dioptra/restapi/v1/plugins/service.py index 43944edba..64919e5ec 100644 --- a/src/dioptra/restapi/v1/plugins/service.py +++ b/src/dioptra/restapi/v1/plugins/service.py @@ -89,6 +89,7 @@ def create( name: str, description: str, group_id: int, + read_only: bool = False, replace_existing: bool = False, commit: bool = True, **kwargs, @@ -100,6 +101,7 @@ def create( unique. description: The description of the plugin. group_id: The group that will own the plugin. + read_only: If True, apply a read only lock to the resource replace_existing: If True and a resource already exists with this name, delete it instead of raising an exception commit: If True, commit the transaction. Defaults to True. @@ -134,6 +136,13 @@ def create( new_plugin = models.Plugin( name=name, description=description, resource=resource, creator=current_user ) + if read_only: + db.session.add( + models.ResourceLock( + resource_lock_type=resource_lock_types.READONLY, + resource=resource, + ) + ) db.session.add(new_plugin) if commit: @@ -431,6 +440,14 @@ def modify( plugin_files = plugin_dict["plugin_files"] group_id = plugin.resource.group_id + if plugin.resource.is_readonly: + log.debug( + "The Plugin is read-only and cannot be modified", + plugin=plugin.resource_id, + name=plugin.name, + ) + raise PluginReadOnlyLockError + if name != plugin.name: duplicate = self._plugin_name_service.get(name, group_id=group_id, log=log) if duplicate is not None: @@ -481,6 +498,13 @@ def delete(self, plugin_id: int, **kwargs) -> dict[str, Any]: if plugin_resource is None: raise EntityDoesNotExistError(PLUGIN_RESOURCE_TYPE, plugin_id=plugin_id) + if plugin_resource.is_readonly: + log.debug( + "The Plugin is read-only and cannot be deleted", + plugin_id=plugin_resource.resource_id, + ) + raise PluginReadOnlyLockError + deleted_resource_lock = models.ResourceLock( resource_lock_type=resource_lock_types.DELETE, resource=plugin_resource, @@ -729,6 +753,7 @@ def create( description: str, tasks: list[dict[str, Any]], plugin_id: int, + read_only: bool = False, commit: bool = True, **kwargs, ) -> utils.PluginFileDict: @@ -743,6 +768,7 @@ def create( description: The description of the plugin file. tasks: The tasks associated with the plugin file. plugin_id: The unique id of the plugin containing the plugin file. + read_only: If True, apply a read only lock to the resource commit: If True, commit the transaction. Defaults to True. Returns: @@ -789,6 +815,14 @@ def create( ) new_plugin_file.parents.append(plugin.resource) + + if read_only: + db.session.add( + models.ResourceLock( + resource_lock_type=resource_lock_types.READONLY, + resource=resource, + ) + ) db.session.add(new_plugin_file) _add_plugin_tasks(tasks, plugin_file=new_plugin_file, log=log) @@ -1135,6 +1169,14 @@ def modify( plugin = plugin_file_dict["plugin"] plugin_file = plugin_file_dict["plugin_file"] + if plugin_file.resource.is_readonly: + log.debug( + "The Plugin is read-only and cannot be modified", + plugin=plugin_file.resource_id, + filename=plugin_file.filename, + ) + raise PluginFileReadOnlyLockError + if filename != plugin_file.filename: duplicate = self._plugin_file_name_service.get( filename, plugin_id=plugin_id, log=log @@ -1210,6 +1252,13 @@ def delete(self, plugin_id: int, plugin_file_id: int, **kwargs) -> dict[str, Any plugin_file_id=plugin_file_id, ) + if plugin_file.resource.is_readonly: + log.debug( + "The PluginFile is read-only and cannot be deleted", + plugin_id=plugin_file.resource_id, + ) + raise PluginFileReadOnlyLockError + plugin_file_id_to_return = plugin_file.resource_id # to return to user db.session.add( models.ResourceLock( diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index 81fa68cf8..5dc1cb23b 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -16,7 +16,6 @@ # https://creativecommons.org/licenses/by/4.0/legalcode """The server-side functions that perform workflows endpoint operations.""" import json -import jsonschema import tarfile from collections import defaultdict from hashlib import sha256 @@ -25,6 +24,7 @@ from tempfile import TemporaryDirectory from typing import IO, Any, Final +import jsonschema import structlog import toml from injector import inject @@ -43,8 +43,8 @@ from .lib import clone_git_repository, package_job_files, views from .schema import ( FileTypes, - ResourceImportSourceTypes, ResourceImportResolveNameConflictsStrategy, + ResourceImportSourceTypes, ) LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -185,6 +185,7 @@ def import_resources( description=param_type.get("description", None), structure=param_type.get("structure", None), group_id=group_id, + read_only=read_only, replace_existing=replace_existing, commit=False, log=log, @@ -209,6 +210,7 @@ def import_resources( name=Path(plugin["path"]).stem, description=plugin.get("description", None), group_id=group_id, + read_only=read_only, replace_existing=replace_existing, commit=False, log=log, @@ -229,6 +231,7 @@ def import_resources( description=None, tasks=tasks[filename], plugin_id=plugin_dict["plugin"].resource_id, + read_only=read_only, commit=False, log=log, ) @@ -254,6 +257,7 @@ def import_resources( ], queue_ids=[], group_id=group_id, + read_only=read_only, replace_existing=replace_existing, commit=False, log=log, From 30d33cd5a076f4e1293d2ad7184ceeeb14d42005 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Tue, 24 Sep 2024 12:49:27 -0400 Subject: [PATCH 05/28] docs: added resource import toml reference --- docs/source/index.rst | 1 + .../reference/resource-import-reference.rst | 209 ++++++++++++++++++ 2 files changed, 210 insertions(+) create mode 100644 docs/source/reference/resource-import-reference.rst diff --git a/docs/source/index.rst b/docs/source/index.rst index 35e692e60..3b2f55d96 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -65,6 +65,7 @@ Email us: dioptra@nist.gov :caption: Reference reference/task-engine-reference + reference/resource-import-reference reference/api-reference-restapi .. reference/api-reference-sdk reference/api-reference-client diff --git a/docs/source/reference/resource-import-reference.rst b/docs/source/reference/resource-import-reference.rst new file mode 100644 index 000000000..d2d59ddfc --- /dev/null +++ b/docs/source/reference/resource-import-reference.rst @@ -0,0 +1,209 @@ +=========================== + Resource Import Reference +=========================== + +This document describes the contract for importing resources into a Dioptra instance. + +.. contents:: + +Overview +======== + +Dioptra provides functionality for importing Plugins, Entrypoints, and +PluginParameterTypes. This allows users to easily publish and share +resources across Dioptra instances. It is also the mechanism used to +distribute core plugins and examples developed by the Dioptra maintainers. + +Resources are described via a combination of a TOML configuration file, +Python source code (for plugins), and YAML task graphs (for entrypoints). +The TOML file is the central configuration that references required sources +and includes metadata for fully registering the resources in Dioptra (such +as plugin task specifications, and entrypoint parameters). + +Collections of resources can be distributed via git repositories or by sharing +files manually. The resourceImport workflow can import from a repository or +archive file. See the API Reference for details. + +TOML Configuration Format +========================= + +The TOML format consists of three arrays of tables, one for each of the +importable resource types: Plugins, PluginParameterTypes, and Entrypoints. +This allows for importing of zero or more of each of these resources. + +Example +------- + +The following example illustrates how to configure a collection of resources +including a Plugin, PluginParameterType, and Entrypoint. + +.. code:: TOML + + # Plugins point to a python package and include metadata for registering them in Dioptra + [[plugins]] + # the path to the Python plugin relative to the root directory + path = "plugins/hello_world" + # an optional description + description = "A simple plugin used for testing and demonstration purposes." + + # an array of plugin task definitions + [[plugins.tasks]] + # the name of the file relative to the root plugin directory + filename = "tasks.py" + # the name must match the name of the function + name = "hello" + # input parameter names must match the Python function definition + # types are plugin parameter types and are matched by name + input_params = [ { name = "name", type = "string", required = true} ] + output_params = [ { name = "message", type = "string" } ] + + [[plugins.tasks]] + filename = "tasks.py" + name = "greet" + input_params = [ + { name = "greeting", type = "string", required = true }, + { name = "name", type = "string", required = true }, + ] + output_params = [ { name = "message", type = "string" } ] + + [[plugins.tasks]] + filename = "tasks.py" + name = "shout" + input_params = [ { name = "message", type = "string", required = true} ] + output_params = [ { name = "message", type = "string" } ] + + # PluginParameterTypes are fully described in the TOML + [[plugin_param_types]] + name = "image" + # an optional structure for the type + structure = { list = ["int", "int", "int"] } + + # Entrypoints point to a task graph yaml and include metadata for registering them in Dioptra + [[entrypoints]] + # path to the task graph yaml + path = "examples/hello-world.yaml" + # the name to register the entrypoint under (the task graph filename is use if not provided) + name = "Hello World" + # an optional description + description = "A simple example using the hello_world plugin." + # entrypoint parameters to register that should match the task graph + # here, type is an entrypoint parameter type, NOT a plugin parameter type + params = [ + { name = "name", type = "string", default_value = "World" } + ] + # plugins to register with the entrypoint (matched by name) + plugins = [ "hello_world" ] + + +JSON Schema +----------- + +The following JSON schema fully describes the Dioptra resource TOML. +It is used to validate Dioptra TOML files in the resource import workflow. + +.. code:: JSON + + { + "$schema": "https://json-schema.org/draft/2020-12/schema", + "$id": "https://github.com/usnistgov/dioptra", + "title": "Dioptra Resource Schema", + "description": "A schema defining objects that can be imported into a Dioptra instance.", + "type": "object", + "properties": { + "plugins": { + "type": "array", + "description": "An array of Dioptra plugins", + "items": { + "type": "object", + "description": "A Dioptra plugin", + "properties": { + "path": { "type": "string" }, + "description": { "type": "string" }, + "tasks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "filename": { "type": "string" }, + "name": { "type": "string" }, + "input_params": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "type": { "type": "string" }, + "required": { "type": "boolean" } + }, + "required": [ "name", "type" ], + "additionalProperties": false + } + }, + "output_params": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "type": { "type": "string" } + }, + "required": [ "name", "type" ], + "additionalProperties": false + } + } + }, + "required": [ "filename", "name", "input_params", "output_params" ], + "additionalProperties": false + } + } + }, + "required": [ "path" ], + "additionalProperties": false + } + }, + "plugin_param_types": { + "type": "array", + "description": "An array of Dioptra plugin parameter types", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "structure": { "type": "object" } + }, + "required": [ "name" ], + "additionalProperties": false + } + }, + "entrypoints": { + "type": "array", + "description": "An array of Dioptra entrypoints", + "items": { + "type": "object", + "properties": { + "path": { "type": "string" }, + "name": { "type": "string" }, + "description": { "type": "string" }, + "params": { + "type": "array", + "items": { + "type": "object", + "properties": { + "name": { "type": "string" }, + "type": { "type": "string" }, + "default_value": { "type": [ "string", "number", "boolean", "null" ] } + }, + "required": [ "name", "type" ], + "additionalProperties": false + } + }, + "plugins": { + "type": "array", + "items": { "type": "string" } + } + }, + "required": [ "path" ], + "additionalProperties": false + } + } + } + } From 42f081784f092b8fc662b86bcfbb13aa3b7fa026 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Wed, 25 Sep 2024 15:32:40 -0400 Subject: [PATCH 06/28] fix: bug related to using newly created plugin param type --- src/dioptra/restapi/v1/workflows/service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index 5dc1cb23b..ad2fa78fc 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -189,7 +189,7 @@ def import_resources( replace_existing=replace_existing, commit=False, log=log, - ) + )["plugin_task_parameter_type"] for param_type in config.get("plugin_param_types", []) } # retrieve built-ins @@ -225,7 +225,7 @@ def import_resources( filename = str(plugin_file_path.relative_to(plugin["path"])) contents = plugin_file_path.read_text() - plugin_file_dict = self._plugin_id_file_service.create( + self._plugin_id_file_service.create( filename, contents=contents, description=None, From a87bda57f0b2673e02346435c93456eb0820eb1f Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Wed, 25 Sep 2024 15:32:54 -0400 Subject: [PATCH 07/28] feat: updated example to use new plugin param type --- dioptra.toml | 11 +++++------ examples/hello-world.yaml | 2 +- plugins/hello_world/tasks.py | 8 ++++---- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/dioptra.toml b/dioptra.toml index 295bc3f3e..ce3497057 100644 --- a/dioptra.toml +++ b/dioptra.toml @@ -7,7 +7,7 @@ description = "A simple plugin used for testing and demonstration purposes." filename = "tasks.py" name = "hello" input_params = [ { name = "name", type = "string", required = true} ] - output_params = [ { name = "message", type = "string" } ] + output_params = [ { name = "greeting", type = "message" } ] [[plugins.tasks]] filename = "tasks.py" @@ -16,17 +16,16 @@ description = "A simple plugin used for testing and demonstration purposes." { name = "greeting", type = "string", required = true }, { name = "name", type = "string", required = true }, ] - output_params = [ { name = "message", type = "string" } ] + output_params = [ { name = "greeting", type = "message" } ] [[plugins.tasks]] filename = "tasks.py" name = "shout" - input_params = [ { name = "message", type = "string", required = true} ] - output_params = [ { name = "message", type = "string" } ] + input_params = [ { name = "greeting", type = "message", required = true} ] + output_params = [ { name = "loud_greeting", type = "message" } ] [[plugin_param_types]] -name = "image" -structure = { list = ["int", "int", "int"] } +name = "message" [[entrypoints]] path = "examples/hello-world.yaml" diff --git a/examples/hello-world.yaml b/examples/hello-world.yaml index 3ab57da83..0b7609dcb 100644 --- a/examples/hello-world.yaml +++ b/examples/hello-world.yaml @@ -7,4 +7,4 @@ goodbye_step: name: $name shout_step: shout: - message: $goodbye_step.message + greeting: $goodbye_step.greeting diff --git a/plugins/hello_world/tasks.py b/plugins/hello_world/tasks.py index 45b9ccda2..875fb2b82 100644 --- a/plugins/hello_world/tasks.py +++ b/plugins/hello_world/tasks.py @@ -35,7 +35,7 @@ def greet(greeting: str, name: str) -> str: @pyplugs.register() -def shout(message: str) -> str: - message = message.upper() - LOGGER.info(message) - return message +def shout(greeting: str) -> str: + loud_greeting = greeting.upper() + LOGGER.info(loud_greeting) + return loud_greeting From 4d8364d69c2b8a39dc93e8cd47a8cc9dcd4a74c7 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Tue, 24 Dec 2024 09:19:36 -0500 Subject: [PATCH 08/28] tests: test resource import workflow --- src/dioptra/restapi/v1/plugins/service.py | 2 +- src/dioptra/restapi/v1/workflows/service.py | 280 +++++++++++------- tests/unit/restapi/v1/conftest.py | 29 ++ .../v1/test_workflow_resource_import.py | 172 +++++++++++ 4 files changed, 375 insertions(+), 108 deletions(-) create mode 100644 tests/unit/restapi/v1/test_workflow_resource_import.py diff --git a/src/dioptra/restapi/v1/plugins/service.py b/src/dioptra/restapi/v1/plugins/service.py index 64919e5ec..dbad48ee5 100644 --- a/src/dioptra/restapi/v1/plugins/service.py +++ b/src/dioptra/restapi/v1/plugins/service.py @@ -125,7 +125,7 @@ def create( else: raise EntityExistsError( PLUGIN_RESOURCE_TYPE, - duplicate.resource_id, + existing.resource_id, name=name, group_id=group_id, ) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index ad2fa78fc..3d8e1249a 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -151,7 +151,7 @@ def import_resources( log: BoundLogger = kwargs.get("log", LOGGER.new()) log.debug("Import resources", group_id=group_id) - replace_existing = ( + overwrite = ( resolve_name_conflicts_strategy == ResourceImportResolveNameConflictsStrategy.OVERWRITE ) @@ -167,8 +167,6 @@ def import_resources( elif source_type == ResourceImportSourceTypes.GIT: hash = clone_git_repository(git_url, working_dir) - log.info(hash=hash, paths=list(working_dir.glob("*"))) - config = toml.load(working_dir / config_path) # validate the config file @@ -178,118 +176,186 @@ def import_resources( schema = json.load(f) jsonschema.validate(config, schema) - # register new plugin param types - param_types = { - param_type["name"]: self._plugin_parameter_type_service.create( - name=param_type["name"], - description=param_type.get("description", None), - structure=param_type.get("structure", None), - group_id=group_id, - read_only=read_only, - replace_existing=replace_existing, - commit=False, - log=log, - )["plugin_task_parameter_type"] - for param_type in config.get("plugin_param_types", []) - } - # retrieve built-ins - param_types.update( - { - param_type.name: param_type - for param_type in self._builtin_plugin_parameter_type_service.get( - group_id=group_id, error_if_not_found=False, log=log - ) - } + param_types = self._register_plugin_param_types( + group_id, config.get("plugin_param_types", []), overwrite, log=log + ) + plugins = self._register_plugins( + group_id, config.get("plugins", []), param_types, overwrite, log=log + ) + entrypoints = self._register_entrypoints( + group_id, config.get("entrypoints", []), plugins, overwrite, log=log ) - db.session.flush() - # register new plugins - plugin_ids = {} - for plugin in config.get("plugins", []): - plugin_dict = self._plugin_service.create( - name=Path(plugin["path"]).stem, - description=plugin.get("description", None), - group_id=group_id, - read_only=read_only, - replace_existing=replace_existing, - commit=False, - log=log, + db.session.commit() + + return { + "message": "successfully imported", + "hash": hash, + "resources": { + "plugins": [plugin.name for plugin in plugins], + "plugin_param_types": [param_type.name for param_type in param_types], + "entrypoints": [entrypoint.name for entrypoint in entrypoints], + }, + } + + def _register_plugin_param_types( + self, + group_id: int, + param_types_config: dict[str, Any], + overwrite: bool, + log: BoundLogger, + ) -> dict[str, Any]: + param_types = dict() + for param_type in param_types_config: + if overwrite: + existing = self._plugin_parameter_type_name_service.get( + param_type["name"], group_id=group_id, log=log ) - db.session.flush() - plugin_ids[plugin_dict["plugin"].name] = plugin_dict[ - "plugin" - ].resource_id - - tasks = _build_tasks(plugin.get("tasks", []), param_types) - for plugin_file_path in Path(plugin["path"]).rglob("*.py"): - filename = str(plugin_file_path.relative_to(plugin["path"])) - contents = plugin_file_path.read_text() - - self._plugin_id_file_service.create( - filename, - contents=contents, - description=None, - tasks=tasks[filename], - plugin_id=plugin_dict["plugin"].resource_id, - read_only=read_only, - commit=False, + if existing: + self._plugin_parameter_type_service.delete( + plugin_parameter_type_id=existing["id"], log=log, ) - # register new entrypoints - for entrypoint in config.get("entrypoints", []): - contents = Path(entrypoint["path"]).read_text() - params = [ - { - "name": param["name"], - "parameter_type": param["type"], - "default_value": param.get("default_value", None), - } - for param in entrypoint.get("params", []) - ] - self._entrypoint_service.create( - name=entrypoint.get("name", Path(entrypoint["path"]).stem), - description=entrypoint.get("description", None), - task_graph=contents, - parameters=params, - plugin_ids=[ - plugin_ids[plugin] for plugin in entrypoint.get("plugins", []) - ], - queue_ids=[], - group_id=group_id, - read_only=read_only, - replace_existing=replace_existing, + param_type["name"] = self._plugin_parameter_type_service.create( + name=param_type["name"], + description=param_type.get("description", None), + structure=param_type.get("structure", None), + group_id=group_id, + commit=False, + log=log, + )["plugin_task_parameter_type"] + + db.session.flush() + + return param_types + + def _register_plugins( + self, + group_id: int, + plugins_config: dict[str, Any], + param_types: Any, + overwrite: bool, + log: BoundLogger, + ): + builtin_param_types = self._builtin_plugin_parameter_type_service.get( + group_id=group_id, error_if_not_found=False, log=log + ) + param_types.update( + {param_type.name: param_type for param_type in builtin_param_types} + ) + + plugins = {} + for plugin in plugins_config: + if overwrite: + existing = self._plugin_name_service.get( + Path(plugin["path"]).stem, group_id=group_id + ) + if existing: + self._plugin_id_service.delete( + plugin_id=existing["id"], + log=log, + ) + + plugin_dict = self._plugin_service.create( + name=Path(plugin["path"]).stem, + description=plugin.get("description", None), + group_id=group_id, + commit=False, + log=log, + ) + plugins[plugin_dict["plugin"].name] = plugin_dict["plugin"] + db.session.flush() + + tasks = self._build_tasks(plugin.get("tasks", []), param_types) + for plugin_file_path in Path(plugin["path"]).rglob("*.py"): + filename = str(plugin_file_path.relative_to(plugin["path"])) + contents = plugin_file_path.read_text() + + self._plugin_id_file_service.create( + filename, + contents=contents, + description=None, + tasks=tasks[filename], + plugin_id=plugin_dict["plugin"].resource_id, commit=False, log=log, ) - db.session.commit() + db.session.flush() - return {"message": "successfully imported"} - - -def _build_tasks(tasks_config, param_types): - tasks = defaultdict(list) - for task in tasks_config: - tasks[task["filename"]].append( - { - "name": task["name"], - "description": task.get("description", None), - "input_params": [ - { - "name": param["name"], - "parameter_type_id": param_types[param["type"]].resource_id, - "required": param.get("required", False), - } - for param in task["input_params"] - ], - "output_params": [ - { - "name": param["name"], - "parameter_type_id": param_types[param["type"]].resource_id, - } - for param in task["output_params"] - ], - } - ) - return tasks + return plugins + + def _register_entrypoints( + self, + group_id: int, + entrypoints_config: dict[str, Any], + plugins, + overwrite: bool, + log: BoundLogger, + ): + entrypoints = dict() + for entrypoint in entrypoints_config: + if overwrite: + existing = self._entrypoint_name_service.get( + entrypoint["name"], group_id=group_id, log=log + ) + if existing is not None: + self.entrypoint_id_service.delete( + entrypoint_id=existing.resource_id + ) + + contents = Path(entrypoint["path"]).read_text() + params = [ + { + "name": param["name"], + "parameter_type": param["type"], + "default_value": param.get("default_value", None), + } + for param in entrypoint.get("params", []) + ] + plugin_ids = [ + plugins[plugin].resource_id for plugin in entrypoint.get("plugins", []) + ] + entrypoint_dict = self._entrypoint_service.create( + name=entrypoint.get("name", Path(entrypoint["path"]).stem), + description=entrypoint.get("description", None), + task_graph=contents, + parameters=params, + plugin_ids=plugin_ids, + queue_ids=[], + group_id=group_id, + commit=False, + log=log, + ) + entrypoints[entrypoint_dict["name"]] = entrypoint_dict["entrypoint"] + + db.session.flush() + + return entrypoints + + def _build_tasks(self, tasks_config, param_types): + tasks = defaultdict(list) + for task in tasks_config: + tasks[task["filename"]].append( + { + "name": task["name"], + "description": task.get("description", None), + "input_params": [ + { + "name": param["name"], + "parameter_type_id": param_types[param["type"]].resource_id, + "required": param.get("required", False), + } + for param in task["input_params"] + ], + "output_params": [ + { + "name": param["name"], + "parameter_type_id": param_types[param["type"]].resource_id, + } + for param in task["output_params"] + ], + } + ) + return tasks diff --git a/tests/unit/restapi/v1/conftest.py b/tests/unit/restapi/v1/conftest.py index 676725d26..df18ac330 100644 --- a/tests/unit/restapi/v1/conftest.py +++ b/tests/unit/restapi/v1/conftest.py @@ -21,15 +21,21 @@ from typing import Any, cast import pytest +import tarfile +import toml import uuid from flask import Flask from flask.testing import FlaskClient from flask_sqlalchemy import SQLAlchemy from injector import Injector from pytest import MonkeyPatch +from tempfile import NamedTemporaryFile from ..lib import actions, mock_rq +# TODO: figure out if thesre is a better way to do this +DIOPTRA_ROOT = Path(__file__).parent.parent.parent.parent.parent + @pytest.fixture def app(dependency_modules: list[Any]) -> Iterator[Flask]: @@ -720,3 +726,26 @@ def registered_mlflowrun_incomplete( ) return responses + + +def resources_tar_file(): + f = NamedTemporaryFile(suffix=".tar.gz") + with tarfile.open(fileobj=f, mode="w:gz") as tar: + tar.add(DIOPTRA_ROOT / "dioptra.toml", arcname="dioptra.toml") + tar.add( + DIOPTRA_ROOT / "plugins/hello_world", + arcname="plugins/hello_world", + recursive=True, + ) + tar.add( + DIOPTRA_ROOT / "examples/hello-world.yaml", + arcname="examples/hello-world.yaml", + ) + f.seek(0) + + return f + + +@pytest.fixture +def resources_import_config(): + return toml.load(DIOPTRA_ROOT / "dioptra.toml") diff --git a/tests/unit/restapi/v1/test_workflow_resource_import.py b/tests/unit/restapi/v1/test_workflow_resource_import.py new file mode 100644 index 000000000..3cb94f32b --- /dev/null +++ b/tests/unit/restapi/v1/test_workflow_resource_import.py @@ -0,0 +1,172 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Test suite for the resource import workflow + +This module contains a set of tests that validate the CRUD operations and additional +functionalities for the queue entity. The tests ensure that the queues can be +registered, renamed, deleted, and locked/unlocked as expected through the REST API. +""" + +from typing import Any +from pathlib import Path +from tempfile import NamedTemporaryFile +from flask.testing import FlaskClient +from flask_sqlalchemy import SQLAlchemy +from werkzeug.test import TestResponse + +from dioptra.restapi.routes import V1_ROOT, V1_WORKFLOWS_ROUTE + + +# -- Actions --------------------------------------------------------------------------- + + +def resource_import( + client: FlaskClient, + resources_tar_file: NamedTemporaryFile, + group_id: int, + read_only: bool, + resolve_name_conflict_strategy: str, +) -> TestResponse: + """Import resources into Dioptra + + Args: + client: The Flask test client. + group_id: The id of the group to import resources into. + + Returns: + The response from the API. + """ + + payload = { + "groupId": group_id, + "sourceType": "upload", + "archiveFile": (resources_tar_file, "upload.tar.gz"), + "configPath": "dioptra.toml", + "readOnly": read_only, + "resolveNameConflictsStrategy": resolve_name_conflict_strategy, + } + + return client.post( + f"/{V1_ROOT}/{V1_WORKFLOWS_ROUTE}/resourceImport", + data=payload, + content_type="multipart/form-data", + follow_redirects=True, + ) + + +# -- Assertions ------------------------------------------------------------------------ + + +def assert_imported_resources_match_expected( + client: FlaskClient, + expected: dict[str, Any], +): + response = client.get(f"/{V1_ROOT}/plugins", follow_redirects=True) + response_plugins = set(plugin["name"] for plugin in response.get_json()["data"]) + expected_plugins = set(Path(plugin["path"]).stem for plugin in expected["plugins"]) + assert response.status_code == 200 and response_plugins == expected_plugins + + response = client.get(f"/{V1_ROOT}/pluginParameterTypes", follow_redirects=True) + response_types = set(param["name"] for param in response.get_json()["data"]) + expected_types = set(param["name"] for param in expected["plugin_param_types"]) + assert ( + response.status_code == 200 + and response_types & expected_types == expected_types + ) + + response = client.get(f"/{V1_ROOT}/entrypoints", follow_redirects=True) + response_entrypoints = set(ep["name"] for ep in response.get_json()["data"]) + expected_entrypoints = set(ep["name"] for ep in expected["entrypoints"]) + assert response.status_code == 200 and response_entrypoints == expected_entrypoints + + +def assert_resource_import_fails_due_to_name_clash( + client: FlaskClient, + group_id: int, + resources_tar_file: NamedTemporaryFile, +): + payload = {"name": "hello_world", "group": group_id} + client.post(f"/{V1_ROOT}/plugins", json=payload, follow_redirects=True) + + response = resource_import( + client, + resources_tar_file, + group_id, + read_only=True, + resolve_name_conflict_strategy="fail", + ) + assert response.status_code == 409 + + +def assert_resource_import_overwrite_works( + client: FlaskClient, + group_id: int, + resources_tar_file: NamedTemporaryFile, +): + payload = {"name": "hello_world", "group": group_id} + client.post(f"/{V1_ROOT}/plugins", json=payload, follow_redirects=True) + + response = resource_import( + client, + resources_tar_file, + group_id, + read_only=True, + resolve_name_conflict_strategy="overwrite", + ) + assert response.status_code == 200 + + +# -- Tests ----------------------------------------------------------------------------- + + +def test_resource_import( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + resources_tar_file: NamedTemporaryFile, + resources_import_config: dict[str, Any], +): + group_id = auth_account["groups"][0]["id"] + resource_import( + client, + resources_tar_file, + group_id, + read_only=True, + resolve_name_conflict_strategy="fail", + ) + + assert_imported_resources_match_expected(client, resources_import_config) + + +def test_resource_import_fails_from_name_clash( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + resources_tar_file: NamedTemporaryFile, +): + group_id = auth_account["groups"][0]["id"] + assert_resource_import_fails_due_to_name_clash(client, group_id, resources_tar_file) + + +def test_resource_import_overwrite( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], + resources_tar_file: NamedTemporaryFile, +): + group_id = auth_account["groups"][0]["id"] + assert_resource_import_overwrite_works(client, group_id, resources_tar_file) From e02f9430b490fff058a2e62d6d9fcf1e43807583 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Fri, 25 Oct 2024 15:42:07 -0400 Subject: [PATCH 09/28] fix: revert overwrite and readonly --- src/dioptra/restapi/v1/entrypoints/service.py | 49 ++--------- .../v1/plugin_parameter_types/service.py | 34 ++------ src/dioptra/restapi/v1/plugins/service.py | 82 ++----------------- 3 files changed, 25 insertions(+), 140 deletions(-) diff --git a/src/dioptra/restapi/v1/entrypoints/service.py b/src/dioptra/restapi/v1/entrypoints/service.py index cf4d5ead7..a8f3cf711 100644 --- a/src/dioptra/restapi/v1/entrypoints/service.py +++ b/src/dioptra/restapi/v1/entrypoints/service.py @@ -94,8 +94,6 @@ def create( plugin_ids: list[int], queue_ids: list[int], group_id: int, - read_only: bool = False, - replace_existing: bool = False, commit: bool = True, **kwargs, ) -> utils.EntrypointDict: @@ -106,9 +104,6 @@ def create( be unique. description: The description of the entrypoint. group_id: The group that will own the entrypoint. - read_only: If True, apply a read only lock to the resource - replace_existing: If True and a resource already exists with this - name, delete it instead of raising an exception commit: If True, commit the transaction. Defaults to True. Returns: @@ -119,18 +114,11 @@ def create( """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - existing = self._entrypoint_name_service.get(name, group_id=group_id, log=log) - if existing is not None: - if replace_existing: - deleted_resource_lock = models.ResourceLock( - resource_lock_type=resource_lock_types.DELETE, - resource=existing.resource, - ) - db.session.add(deleted_resource_lock) - else: - raise EntityExistsError( - RESOURCE_TYPE, duplicate.resource_id, name=name, group_id=group_id - ) + duplicate = self._entrypoint_name_service.get(name, group_id=group_id, log=log) + if duplicate is not None: + raise EntityExistsError( + RESOURCE_TYPE, duplicate.resource_id, name=name, group_id=group_id + ) group = self._group_id_service.get(group_id, error_if_not_found=True) queues = self._queue_ids_service.get(queue_ids, error_if_not_found=True) @@ -172,13 +160,6 @@ def create( queue_resources = [queue.resource for queue in queues] new_entrypoint.children.extend(plugin_resources + queue_resources) - if read_only: - db.session.add( - models.ResourceLock( - resource_lock_type=resource_lock_types.READONLY, - resource=resource, - ) - ) db.session.add(new_entrypoint) if commit: @@ -457,17 +438,10 @@ def modify( entrypoint = entrypoint_dict["entry_point"] group_id = entrypoint.resource.group_id - - if entrypoint.resource.is_readonly: - log.debug( - "The Entrypoint is read-only and cannot be modified", - entrypoint_id=entrypoint.resource_id, - name=entrypoint.name, - ) - raise EntrypointReadOnlyLockError - if name != entrypoint.name: - duplicate = self._entrypoint_name_service.get(name, group_id=group_id, log=log) + duplicate = self._entrypoint_name_service.get( + name, group_id=group_id, log=log + ) if duplicate is not None: raise EntityExistsError( RESOURCE_TYPE, duplicate.resource_id, name=name, group_id=group_id @@ -545,13 +519,6 @@ def delete(self, entrypoint_id: int, **kwargs) -> dict[str, Any]: if entrypoint_resource is None: raise EntityDoesNotExistError(RESOURCE_TYPE, entrypoint_id=entrypoint_id) - if entrypoint_resource.is_readonly: - log.debug( - "The Entrypoint is read-only and cannot be deleted", - entrypoint_id=entrypoint_resource.resource_id, - ) - raise EntrypointReadOnlyLockError - deleted_resource_lock = models.ResourceLock( resource_lock_type=resource_lock_types.DELETE, resource=entrypoint_resource, diff --git a/src/dioptra/restapi/v1/plugin_parameter_types/service.py b/src/dioptra/restapi/v1/plugin_parameter_types/service.py index e73f5199e..1bbe4e13c 100644 --- a/src/dioptra/restapi/v1/plugin_parameter_types/service.py +++ b/src/dioptra/restapi/v1/plugin_parameter_types/service.py @@ -85,8 +85,6 @@ def create( structure: dict[str, Any], description: str, group_id: int, - read_only: bool = False, - replace_existing: bool = False, commit: bool = True, **kwargs, ) -> utils.PluginParameterTypeDict: @@ -100,9 +98,6 @@ def create( type's structure. description: The description of the plugin parameter type. group_id: The group that will own the plugin parameter type. - read_only: If True, apply a read only lock to the resource - replace_existing: If True and a resource already exists with this - name, delete it instead of raising an exception commit: If True, commit the transaction. Defaults to True. Returns: @@ -124,20 +119,13 @@ def create( ) raise PluginParameterTypeMatchesBuiltinTypeError - existing = self._plugin_parameter_type_name_service.get( + duplicate = self._plugin_parameter_type_name_service.get( name, group_id=group_id, log=log ) - if existing is not None: - if replace_existing: - deleted_resource_lock = models.ResourceLock( - resource_lock_type=resource_lock_types.DELETE, - resource=existing.resource, - ) - db.session.add(deleted_resource_lock) - else: - raise EntityExistsError( - RESOURCE_TYPE, duplicate.resource_id, name=name, group_id=group_id - ) + if duplicate is not None: + raise EntityExistsError( + RESOURCE_TYPE, duplicate.resource_id, name=name, group_id=group_id + ) group = self._group_id_service.get(group_id, error_if_not_found=True) @@ -149,13 +137,6 @@ def create( resource=resource, creator=current_user, ) - if read_only: - db.session.add( - models.ResourceLock( - resource_lock_type=resource_lock_types.READONLY, - resource=resource, - ) - ) db.session.add(new_plugin_parameter_type) if commit: @@ -630,8 +611,9 @@ def get( retrieved_names = {param_type.name for param_type in plugin_parameter_types} missing_names = set(builtin_types) - retrieved_names if error_if_not_found: - log.debug("Plugin Parameter Type(s) not found", names=missing_names) - raise PluginParameterTypeDoesNotExistError + raise EntityDoesNotExistError( + RESOURCE_TYPE, missing_names=missing_names + ) return plugin_parameter_types diff --git a/src/dioptra/restapi/v1/plugins/service.py b/src/dioptra/restapi/v1/plugins/service.py index dbad48ee5..fadb12ec4 100644 --- a/src/dioptra/restapi/v1/plugins/service.py +++ b/src/dioptra/restapi/v1/plugins/service.py @@ -85,14 +85,7 @@ def __init__( self._group_id_service = group_id_service def create( - self, - name: str, - description: str, - group_id: int, - read_only: bool = False, - replace_existing: bool = False, - commit: bool = True, - **kwargs, + self, name: str, description: str, group_id: int, commit: bool = True, **kwargs ) -> utils.PluginWithFilesDict: """Create a new plugin. @@ -101,9 +94,6 @@ def create( unique. description: The description of the plugin. group_id: The group that will own the plugin. - read_only: If True, apply a read only lock to the resource - replace_existing: If True and a resource already exists with this - name, delete it instead of raising an exception commit: If True, commit the transaction. Defaults to True. Returns: @@ -114,21 +104,14 @@ def create( """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - existing = self._plugin_name_service.get(name, group_id=group_id, log=log) - if existing is not None: - if replace_existing: - deleted_resource_lock = models.ResourceLock( - resource_lock_type=resource_lock_types.DELETE, - resource=existing.resource, - ) - db.session.add(deleted_resource_lock) - else: - raise EntityExistsError( - PLUGIN_RESOURCE_TYPE, - existing.resource_id, - name=name, - group_id=group_id, - ) + duplicate = self._plugin_name_service.get(name, group_id=group_id, log=log) + if duplicate is not None: + raise EntityExistsError( + PLUGIN_RESOURCE_TYPE, + duplicate.resource_id, + name=name, + group_id=group_id, + ) group = self._group_id_service.get(group_id, error_if_not_found=True) @@ -136,13 +119,6 @@ def create( new_plugin = models.Plugin( name=name, description=description, resource=resource, creator=current_user ) - if read_only: - db.session.add( - models.ResourceLock( - resource_lock_type=resource_lock_types.READONLY, - resource=resource, - ) - ) db.session.add(new_plugin) if commit: @@ -440,14 +416,6 @@ def modify( plugin_files = plugin_dict["plugin_files"] group_id = plugin.resource.group_id - if plugin.resource.is_readonly: - log.debug( - "The Plugin is read-only and cannot be modified", - plugin=plugin.resource_id, - name=plugin.name, - ) - raise PluginReadOnlyLockError - if name != plugin.name: duplicate = self._plugin_name_service.get(name, group_id=group_id, log=log) if duplicate is not None: @@ -498,13 +466,6 @@ def delete(self, plugin_id: int, **kwargs) -> dict[str, Any]: if plugin_resource is None: raise EntityDoesNotExistError(PLUGIN_RESOURCE_TYPE, plugin_id=plugin_id) - if plugin_resource.is_readonly: - log.debug( - "The Plugin is read-only and cannot be deleted", - plugin_id=plugin_resource.resource_id, - ) - raise PluginReadOnlyLockError - deleted_resource_lock = models.ResourceLock( resource_lock_type=resource_lock_types.DELETE, resource=plugin_resource, @@ -753,7 +714,6 @@ def create( description: str, tasks: list[dict[str, Any]], plugin_id: int, - read_only: bool = False, commit: bool = True, **kwargs, ) -> utils.PluginFileDict: @@ -768,7 +728,6 @@ def create( description: The description of the plugin file. tasks: The tasks associated with the plugin file. plugin_id: The unique id of the plugin containing the plugin file. - read_only: If True, apply a read only lock to the resource commit: If True, commit the transaction. Defaults to True. Returns: @@ -815,14 +774,6 @@ def create( ) new_plugin_file.parents.append(plugin.resource) - - if read_only: - db.session.add( - models.ResourceLock( - resource_lock_type=resource_lock_types.READONLY, - resource=resource, - ) - ) db.session.add(new_plugin_file) _add_plugin_tasks(tasks, plugin_file=new_plugin_file, log=log) @@ -1169,14 +1120,6 @@ def modify( plugin = plugin_file_dict["plugin"] plugin_file = plugin_file_dict["plugin_file"] - if plugin_file.resource.is_readonly: - log.debug( - "The Plugin is read-only and cannot be modified", - plugin=plugin_file.resource_id, - filename=plugin_file.filename, - ) - raise PluginFileReadOnlyLockError - if filename != plugin_file.filename: duplicate = self._plugin_file_name_service.get( filename, plugin_id=plugin_id, log=log @@ -1252,13 +1195,6 @@ def delete(self, plugin_id: int, plugin_file_id: int, **kwargs) -> dict[str, Any plugin_file_id=plugin_file_id, ) - if plugin_file.resource.is_readonly: - log.debug( - "The PluginFile is read-only and cannot be deleted", - plugin_id=plugin_file.resource_id, - ) - raise PluginFileReadOnlyLockError - plugin_file_id_to_return = plugin_file.resource_id # to return to user db.session.add( models.ResourceLock( From c97fe207896b70d698081fc1340f984cf423b04e Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Tue, 24 Dec 2024 09:20:32 -0500 Subject: [PATCH 10/28] tests: added tests --- tests/unit/restapi/v1/conftest.py | 5 +++-- tests/unit/restapi/v1/test_workflow_resource_import.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/unit/restapi/v1/conftest.py b/tests/unit/restapi/v1/conftest.py index df18ac330..abf4cec01 100644 --- a/tests/unit/restapi/v1/conftest.py +++ b/tests/unit/restapi/v1/conftest.py @@ -15,13 +15,15 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """Fixtures representing resources needed for test suites""" +import tarfile import textwrap from collections.abc import Iterator from http import HTTPStatus +from pathlib import Path +from tempfile import NamedTemporaryFile from typing import Any, cast import pytest -import tarfile import toml import uuid from flask import Flask @@ -29,7 +31,6 @@ from flask_sqlalchemy import SQLAlchemy from injector import Injector from pytest import MonkeyPatch -from tempfile import NamedTemporaryFile from ..lib import actions, mock_rq diff --git a/tests/unit/restapi/v1/test_workflow_resource_import.py b/tests/unit/restapi/v1/test_workflow_resource_import.py index 3cb94f32b..bdd7c4620 100644 --- a/tests/unit/restapi/v1/test_workflow_resource_import.py +++ b/tests/unit/restapi/v1/test_workflow_resource_import.py @@ -21,16 +21,16 @@ registered, renamed, deleted, and locked/unlocked as expected through the REST API. """ -from typing import Any from pathlib import Path from tempfile import NamedTemporaryFile +from typing import Any + from flask.testing import FlaskClient from flask_sqlalchemy import SQLAlchemy from werkzeug.test import TestResponse from dioptra.restapi.routes import V1_ROOT, V1_WORKFLOWS_ROUTE - # -- Actions --------------------------------------------------------------------------- From f5ad53abd6aa580e83835cf0c00232ba58107e60 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Wed, 30 Oct 2024 14:28:29 -0400 Subject: [PATCH 11/28] refactor: refactor resource import workflow --- .../v1/plugin_parameter_types/service.py | 4 +- src/dioptra/restapi/v1/plugins/service.py | 2 +- .../restapi/v1/workflows/controller.py | 1 - .../v1/workflows/lib/clone_git_repository.py | 8 +- src/dioptra/restapi/v1/workflows/schema.py | 7 +- src/dioptra/restapi/v1/workflows/service.py | 94 ++++++++++++++----- .../v1/test_workflow_resource_import.py | 3 +- 7 files changed, 81 insertions(+), 38 deletions(-) diff --git a/src/dioptra/restapi/v1/plugin_parameter_types/service.py b/src/dioptra/restapi/v1/plugin_parameter_types/service.py index 1bbe4e13c..9034f2669 100644 --- a/src/dioptra/restapi/v1/plugin_parameter_types/service.py +++ b/src/dioptra/restapi/v1/plugin_parameter_types/service.py @@ -571,8 +571,8 @@ def get( group_id: int, error_if_not_found: bool = False, **kwargs, - ) -> models.PluginTaskParameterType | None: - """Fetch a list of plugin parameter types by their names. + ) -> list[models.PluginTaskParameterType]: + """Fetch a list of builtin plugin parameter types. Args: group_id: The the group id of the plugin parameter type. diff --git a/src/dioptra/restapi/v1/plugins/service.py b/src/dioptra/restapi/v1/plugins/service.py index fadb12ec4..5e6da90f2 100644 --- a/src/dioptra/restapi/v1/plugins/service.py +++ b/src/dioptra/restapi/v1/plugins/service.py @@ -711,7 +711,7 @@ def create( self, filename: str, contents: str, - description: str, + description: str | None, tasks: list[dict[str, Any]], plugin_id: int, commit: bool = True, diff --git a/src/dioptra/restapi/v1/workflows/controller.py b/src/dioptra/restapi/v1/workflows/controller.py index 140074823..fd5cd9b6f 100644 --- a/src/dioptra/restapi/v1/workflows/controller.py +++ b/src/dioptra/restapi/v1/workflows/controller.py @@ -125,7 +125,6 @@ def post(self): git_url=parsed_form.get("git_url", None), archive_file=request.files.get("archiveFile", None), config_path=parsed_form["config_path"], - read_only=parsed_form["read_only"], resolve_name_conflicts_strategy=parsed_form[ "resolve_name_conflicts_strategy" ], diff --git a/src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py b/src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py index 3fa682bf1..4ebb1068a 100644 --- a/src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py +++ b/src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py @@ -3,7 +3,7 @@ from urllib.parse import urlparse -def clone_git_repository(url: str, dir: Path): +def clone_git_repository(url: str, dir: Path) -> str: parsed_url = urlparse(url) git_branch = parsed_url.fragment or None git_paths = parsed_url.params or None @@ -11,7 +11,7 @@ def clone_git_repository(url: str, dir: Path): git_sparse_args = ["--filter=blob:none", "--no-checkout", "--depth=1"] git_branch_args = ["-b", git_branch] if git_branch else [] - clone_cmd = ["git", "clone", *git_sparse_args, *git_branch_args, git_url, dir] + clone_cmd = ["git", "clone", *git_sparse_args, *git_branch_args, git_url, str(dir)] clone_result = subprocess.run(clone_cmd, capture_output=True, text=True) if clone_result.returncode != 0: @@ -45,9 +45,9 @@ def clone_git_repository(url: str, dir: Path): hash_result = subprocess.run(hash_cmd, cwd=dir, capture_output=True, text=True) if hash_result.returncode != 0: - raise subprocess.CalledProcessError + raise subprocess.CalledProcessError(hash_result.returncode, hash_result.stderr) - return hash + return str(hash) if __name__ == "__main__": diff --git a/src/dioptra/restapi/v1/workflows/schema.py b/src/dioptra/restapi/v1/workflows/schema.py index 60fad7d28..f975c32eb 100644 --- a/src/dioptra/restapi/v1/workflows/schema.py +++ b/src/dioptra/restapi/v1/workflows/schema.py @@ -60,7 +60,7 @@ class ResourceImportSchema(Schema): groupId = fields.Integer( attribute="group_id", - # data_key="group", + data_key="group", metadata=dict( description="ID of the Group that will own the imported resources." ), @@ -97,11 +97,6 @@ class ResourceImportSchema(Schema): metdata=dict(description="The path to the toml configuration file."), load_default="dioptra.toml", ) - readOnly = fields.Bool( - attribute="read_only", - metadata=dict(description="Whether imported resources should be readonly."), - load_default=False, - ) resolveNameConflictsStrategy = fields.Enum( ResourceImportResolveNameConflictsStrategy, attribute="resolve_name_conflicts_strategy", diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index 3d8e1249a..dce4f8514 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -32,12 +32,24 @@ from werkzeug.datastructures import FileStorage from dioptra.restapi.db import db -from dioptra.restapi.v1.entrypoints.service import EntrypointService +from dioptra.restapi.errors import DioptraError +from dioptra.restapi.v1.entrypoints.service import ( + EntrypointIdService, + EntrypointNameService, + EntrypointService, +) from dioptra.restapi.v1.plugin_parameter_types.service import ( BuiltinPluginParameterTypeService, + PluginParameterTypeIdService, + PluginParameterTypeNameService, PluginParameterTypeService, ) -from dioptra.restapi.v1.plugins.service import PluginIdFileService, PluginService +from dioptra.restapi.v1.plugins.service import ( + PluginIdFileService, + PluginIdService, + PluginNameService, + PluginService, +) from dioptra.sdk.utilities.paths import set_cwd from .lib import clone_git_repository, package_job_files, views @@ -99,29 +111,48 @@ class ResourceImportService(object): def __init__( self, plugin_service: PluginService, + plugin_id_service: PluginIdService, + plugin_name_service: PluginNameService, plugin_id_file_service: PluginIdFileService, plugin_parameter_type_service: PluginParameterTypeService, + plugin_parameter_type_id_service: PluginParameterTypeIdService, + plugin_parameter_type_name_service: PluginParameterTypeNameService, builtin_plugin_parameter_type_service: BuiltinPluginParameterTypeService, entrypoint_service: EntrypointService, + entrypoint_id_service: EntrypointIdService, + entrypoint_name_service: EntrypointNameService, ) -> None: """Initialize the resource import service. All arguments are provided via dependency injection. Args: + plugin_service: A PluginService object, plugin_name_service: A PluginNameService object. + plugin_id_service: A PluginIdService object. plugin_id_file_service: A PluginIdFileService object. plugin_parameter_type_service: A PluginParameterTypeService object. - builtin_plugin_parameter_type_service: A BuiltinPluginParameterTypeService object. - entrypoint_service: A EntrypointService object. + plugin_parameter_type_id_service: A PluginParameterTypeIdService object. + plugin_parameter_type_name_service: A PluginParameterTypeNameService object. + builtin_plugin_parameter_type_service: A BuiltinPluginParameterTypeService + object. + entrypoint_service: An EntrypointService object. + entrypoint_id_service: An EntrypointIdService object. + entrypoint_name_service: An EntrypointNameService object. """ self._plugin_service = plugin_service + self._plugin_id_service = plugin_id_service + self._plugin_name_service = plugin_name_service self._plugin_id_file_service = plugin_id_file_service self._plugin_parameter_type_service = plugin_parameter_type_service + self._plugin_parameter_type_id_service = plugin_parameter_type_id_service + self._plugin_parameter_type_name_service = plugin_parameter_type_name_service self._builtin_plugin_parameter_type_service = ( builtin_plugin_parameter_type_service ) self._entrypoint_service = entrypoint_service + self._entrypoint_id_service = entrypoint_id_service + self._entrypoint_name_service = entrypoint_name_service def import_resources( self, @@ -130,7 +161,6 @@ def import_resources( git_url: str | None, archive_file: FileStorage | None, config_path: str, - read_only: bool, resolve_name_conflicts_strategy: str, **kwargs, ) -> dict[str, Any]: @@ -141,7 +171,6 @@ def import_resources( source_type: The source to import from (either "upload" or "git") git_url: The url to the git repository if source_type is "git" archive_file: The contents of the upload if source_type is "upload" - read_only: Whether to apply a readonly lock to all imported resources resolve_name_conflicts_strategy: The strategy for resolving name conflicts. Either "fail" or "overwrite" @@ -163,11 +192,16 @@ def import_resources( bytes = archive_file.stream.read() with tarfile.open(fileobj=BytesIO(bytes), mode="r:*") as tar: tar.extractall(path=working_dir, filter="data") - hash = sha256(bytes).hexdigest() + hash = str(sha256(bytes).hexdigest()) elif source_type == ResourceImportSourceTypes.GIT: hash = clone_git_repository(git_url, working_dir) - config = toml.load(working_dir / config_path) + try: + config = toml.load(working_dir / config_path) + except Exception as e: + raise DioptraError( + f"Failed to load resource import config from {config_path}." + ) from e # validate the config file with open( @@ -192,19 +226,27 @@ def import_resources( "message": "successfully imported", "hash": hash, "resources": { - "plugins": [plugin.name for plugin in plugins], - "plugin_param_types": [param_type.name for param_type in param_types], - "entrypoints": [entrypoint.name for entrypoint in entrypoints], + "plugins": list(plugins.keys()), + "plugin_param_types": list(param_types.keys()), + "entrypoints": list(entrypoints.keys()), }, } def _register_plugin_param_types( self, group_id: int, - param_types_config: dict[str, Any], + param_types_config: list[dict[str, Any]], overwrite: bool, log: BoundLogger, ) -> dict[str, Any]: + """ + Registers a list PluginParameterTypes. + + Args: + group_id: The identifier + Returns: + A dictionary mapping PluginParameterType name to the ORM object + """ param_types = dict() for param_type in param_types_config: if overwrite: @@ -212,19 +254,22 @@ def _register_plugin_param_types( param_type["name"], group_id=group_id, log=log ) if existing: - self._plugin_parameter_type_service.delete( - plugin_parameter_type_id=existing["id"], + self._plugin_parameter_type_id_service.delete( + plugin_parameter_type_id=existing.resource_id, log=log, ) - param_type["name"] = self._plugin_parameter_type_service.create( + param_type_dict = self._plugin_parameter_type_service.create( name=param_type["name"], description=param_type.get("description", None), structure=param_type.get("structure", None), group_id=group_id, commit=False, log=log, - )["plugin_task_parameter_type"] + ) + param_types[param_type["name"]] = param_type_dict[ + "plugin_task_parameter_type" + ] db.session.flush() @@ -233,11 +278,12 @@ def _register_plugin_param_types( def _register_plugins( self, group_id: int, - plugins_config: dict[str, Any], + plugins_config: list[dict[str, Any]], param_types: Any, overwrite: bool, log: BoundLogger, ): + param_types = param_types.copy() builtin_param_types = self._builtin_plugin_parameter_type_service.get( group_id=group_id, error_if_not_found=False, log=log ) @@ -253,7 +299,7 @@ def _register_plugins( ) if existing: self._plugin_id_service.delete( - plugin_id=existing["id"], + plugin_id=existing.resource_id, log=log, ) @@ -289,7 +335,7 @@ def _register_plugins( def _register_entrypoints( self, group_id: int, - entrypoints_config: dict[str, Any], + entrypoints_config: list[dict[str, Any]], plugins, overwrite: bool, log: BoundLogger, @@ -301,7 +347,7 @@ def _register_entrypoints( entrypoint["name"], group_id=group_id, log=log ) if existing is not None: - self.entrypoint_id_service.delete( + self._entrypoint_id_service.delete( entrypoint_id=existing.resource_id ) @@ -328,13 +374,17 @@ def _register_entrypoints( commit=False, log=log, ) - entrypoints[entrypoint_dict["name"]] = entrypoint_dict["entrypoint"] + entrypoints[entrypoint_dict["entry_point"].name] = entrypoint_dict[ + "entry_point" + ] db.session.flush() return entrypoints - def _build_tasks(self, tasks_config, param_types): + def _build_tasks( + self, tasks_config: list[dict[str, Any]], param_types: list[dict[str, str]] + ) -> dict[str, list]: tasks = defaultdict(list) for task in tasks_config: tasks[task["filename"]].append( diff --git a/tests/unit/restapi/v1/test_workflow_resource_import.py b/tests/unit/restapi/v1/test_workflow_resource_import.py index bdd7c4620..af930042f 100644 --- a/tests/unit/restapi/v1/test_workflow_resource_import.py +++ b/tests/unit/restapi/v1/test_workflow_resource_import.py @@ -52,11 +52,10 @@ def resource_import( """ payload = { - "groupId": group_id, + "group": group_id, "sourceType": "upload", "archiveFile": (resources_tar_file, "upload.tar.gz"), "configPath": "dioptra.toml", - "readOnly": read_only, "resolveNameConflictsStrategy": resolve_name_conflict_strategy, } From 0210432792280bb4a20533df4ed30a05da87b520 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Thu, 31 Oct 2024 15:32:11 -0400 Subject: [PATCH 12/28] docs: added docstrings and type annotations --- src/dioptra/restapi/v1/workflows/service.py | 65 ++++++++++++++++++--- 1 file changed, 56 insertions(+), 9 deletions(-) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index dce4f8514..55f913e76 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -31,7 +31,7 @@ from structlog.stdlib import BoundLogger from werkzeug.datastructures import FileStorage -from dioptra.restapi.db import db +from dioptra.restapi.db import db, models from dioptra.restapi.errors import DioptraError from dioptra.restapi.v1.entrypoints.service import ( EntrypointIdService, @@ -238,15 +238,21 @@ def _register_plugin_param_types( param_types_config: list[dict[str, Any]], overwrite: bool, log: BoundLogger, - ) -> dict[str, Any]: + ) -> dict[str, models.PluginTaskParameterType]: """ - Registers a list PluginParameterTypes. + Registers a list of PluginParameterTypes. Args: - group_id: The identifier + group_id: The identifier of the group that will manage imported resources + param_types_config: A list of dictionaries describing a plugin param types + overwrite: Whether imported resources should replace existing resources with + a conflicting name + Returns: - A dictionary mapping PluginParameterType name to the ORM object + A dictionary mapping newly registered PluginParameterType names to the ORM + object """ + param_types = dict() for param_type in param_types_config: if overwrite: @@ -279,10 +285,24 @@ def _register_plugins( self, group_id: int, plugins_config: list[dict[str, Any]], - param_types: Any, + param_types: dict[str, models.PluginTaskParameterType], overwrite: bool, log: BoundLogger, - ): + ) -> dict[str, models.PluginTaskParameterType]: + """ + Registers a list of Plugins and their PluginFiles. + + Args: + group_id: The identifier of the group that will manage imported resources + plugins_config: A list of dictionaries describing a plugin and its tasks + param_types: A dictionary mapping param type name to the ORM object + overwrite: Whether imported resources should replace existing resources with + a conflicting name + + Returns: + A dictionary mapping newly registered Plugin names to the ORM objects + """ + param_types = param_types.copy() builtin_param_types = self._builtin_plugin_parameter_type_service.get( group_id=group_id, error_if_not_found=False, log=log @@ -339,7 +359,21 @@ def _register_entrypoints( plugins, overwrite: bool, log: BoundLogger, - ): + ) -> dict[str, models.EntryPoint]: + """ + Registers a list of Entrypoints + + Args: + group_id: The identifier of the group that will manage imported resources + entrypoints_config: A list of dictionaries describing entrypoints + plugins: A dictionary mapping Plugin names to the ORM objects + overwrite: Whether imported resources should replace existing resources with + a conflicting name + + Returns: + A dictionary mapping newly registered Entrypoint names to ORM object + """ + entrypoints = dict() for entrypoint in entrypoints_config: if overwrite: @@ -383,8 +417,21 @@ def _register_entrypoints( return entrypoints def _build_tasks( - self, tasks_config: list[dict[str, Any]], param_types: list[dict[str, str]] + self, + tasks_config: list[dict[str, Any]], + param_types: dict[str, models.PluginTaskParameterType], ) -> dict[str, list]: + """ + Builds dictionaries describing plugin tasks from a configuration file + + Args: + tasks_config: A list of dictionaries describing plugin tasks + param_types: A dictionary mapping param type name to the ORM object + + Returns: + A dictionary mapping PluginFile name to a list of tasks + """ + tasks = defaultdict(list) for task in tasks_config: tasks[task["filename"]].append( From c5a6a36ffe5a000fcaca367be18c82bad1930e47 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Thu, 31 Oct 2024 15:33:19 -0400 Subject: [PATCH 13/28] refactor: placed files needed for tests under tests dir --- tests/unit/restapi/v1/conftest.py | 21 ++++------ .../v1/resource_import_files/dioptra.toml | 37 +++++++++++++++++ .../v1/resource_import_files/hello-world.yaml | 10 +++++ .../hello_world/__init__.py | 16 ++++++++ .../hello_world/tasks.py | 41 +++++++++++++++++++ 5 files changed, 111 insertions(+), 14 deletions(-) create mode 100644 tests/unit/restapi/v1/resource_import_files/dioptra.toml create mode 100644 tests/unit/restapi/v1/resource_import_files/hello-world.yaml create mode 100644 tests/unit/restapi/v1/resource_import_files/hello_world/__init__.py create mode 100644 tests/unit/restapi/v1/resource_import_files/hello_world/tasks.py diff --git a/tests/unit/restapi/v1/conftest.py b/tests/unit/restapi/v1/conftest.py index abf4cec01..6f85f4428 100644 --- a/tests/unit/restapi/v1/conftest.py +++ b/tests/unit/restapi/v1/conftest.py @@ -34,9 +34,6 @@ from ..lib import actions, mock_rq -# TODO: figure out if thesre is a better way to do this -DIOPTRA_ROOT = Path(__file__).parent.parent.parent.parent.parent - @pytest.fixture def app(dependency_modules: list[Any]) -> Iterator[Flask]: @@ -730,18 +727,13 @@ def registered_mlflowrun_incomplete( def resources_tar_file(): + root_dir = Path(__file__).absolute().parent / "resource_import_files" + f = NamedTemporaryFile(suffix=".tar.gz") with tarfile.open(fileobj=f, mode="w:gz") as tar: - tar.add(DIOPTRA_ROOT / "dioptra.toml", arcname="dioptra.toml") - tar.add( - DIOPTRA_ROOT / "plugins/hello_world", - arcname="plugins/hello_world", - recursive=True, - ) - tar.add( - DIOPTRA_ROOT / "examples/hello-world.yaml", - arcname="examples/hello-world.yaml", - ) + tar.add(root_dir / "dioptra.toml", arcname="dioptra.toml") + tar.add(root_dir / "hello_world", arcname="plugins/hello_world", recursive=True) + tar.add(root_dir / "hello-world.yaml", arcname="examples/hello-world.yaml") f.seek(0) return f @@ -749,4 +741,5 @@ def resources_tar_file(): @pytest.fixture def resources_import_config(): - return toml.load(DIOPTRA_ROOT / "dioptra.toml") + root_dir = Path(__file__).absolute().parent / "resource_import_files" + return toml.load(root_dir / "dioptra.toml") diff --git a/tests/unit/restapi/v1/resource_import_files/dioptra.toml b/tests/unit/restapi/v1/resource_import_files/dioptra.toml new file mode 100644 index 000000000..ce3497057 --- /dev/null +++ b/tests/unit/restapi/v1/resource_import_files/dioptra.toml @@ -0,0 +1,37 @@ +[[plugins]] +path = "plugins/hello_world" +description = "A simple plugin used for testing and demonstration purposes." + + + [[plugins.tasks]] + filename = "tasks.py" + name = "hello" + input_params = [ { name = "name", type = "string", required = true} ] + output_params = [ { name = "greeting", type = "message" } ] + + [[plugins.tasks]] + filename = "tasks.py" + name = "greet" + input_params = [ + { name = "greeting", type = "string", required = true }, + { name = "name", type = "string", required = true }, + ] + output_params = [ { name = "greeting", type = "message" } ] + + [[plugins.tasks]] + filename = "tasks.py" + name = "shout" + input_params = [ { name = "greeting", type = "message", required = true} ] + output_params = [ { name = "loud_greeting", type = "message" } ] + +[[plugin_param_types]] +name = "message" + +[[entrypoints]] +path = "examples/hello-world.yaml" +name = "Hello World" +description = "A simple example using the hello_world plugin." +params = [ + { name = "name", type = "string", default_value = "World" } +] +plugins = [ "hello_world" ] diff --git a/tests/unit/restapi/v1/resource_import_files/hello-world.yaml b/tests/unit/restapi/v1/resource_import_files/hello-world.yaml new file mode 100644 index 000000000..0b7609dcb --- /dev/null +++ b/tests/unit/restapi/v1/resource_import_files/hello-world.yaml @@ -0,0 +1,10 @@ +hello_step: + hello: + name: $name +goodbye_step: + greet: + greeting: Goodbye + name: $name +shout_step: + shout: + greeting: $goodbye_step.greeting diff --git a/tests/unit/restapi/v1/resource_import_files/hello_world/__init__.py b/tests/unit/restapi/v1/resource_import_files/hello_world/__init__.py new file mode 100644 index 000000000..ab0a41a34 --- /dev/null +++ b/tests/unit/restapi/v1/resource_import_files/hello_world/__init__.py @@ -0,0 +1,16 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/tests/unit/restapi/v1/resource_import_files/hello_world/tasks.py b/tests/unit/restapi/v1/resource_import_files/hello_world/tasks.py new file mode 100644 index 000000000..875fb2b82 --- /dev/null +++ b/tests/unit/restapi/v1/resource_import_files/hello_world/tasks.py @@ -0,0 +1,41 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import structlog +from dioptra import pyplugs + +LOGGER = structlog.get_logger() + + +@pyplugs.register() +def hello(name: str) -> str: + message = f"Hello, {name}" + LOGGER.info(message) + return message + + +@pyplugs.register() +def greet(greeting: str, name: str) -> str: + message = f"{greeting}, {name}" + LOGGER.info(message) + return message + + +@pyplugs.register() +def shout(greeting: str) -> str: + loud_greeting = greeting.upper() + LOGGER.info(loud_greeting) + return loud_greeting From f641e2aff842cb6c412a95b12b877dcdc5770e63 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Thu, 31 Oct 2024 15:33:44 -0400 Subject: [PATCH 14/28] bug: properly handle temporary file and file cleanup --- tests/unit/restapi/v1/conftest.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/unit/restapi/v1/conftest.py b/tests/unit/restapi/v1/conftest.py index 6f85f4428..a5f5173a2 100644 --- a/tests/unit/restapi/v1/conftest.py +++ b/tests/unit/restapi/v1/conftest.py @@ -736,7 +736,9 @@ def resources_tar_file(): tar.add(root_dir / "hello-world.yaml", arcname="examples/hello-world.yaml") f.seek(0) - return f + yield f + + f.close() @pytest.fixture From e665d8a92f03a8f74f3fdc3143d5f76208adc811 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Wed, 13 Nov 2024 09:32:42 -0500 Subject: [PATCH 15/28] fix: bug where data_key was not used for form data --- src/dioptra/restapi/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dioptra/restapi/utils.py b/src/dioptra/restapi/utils.py index 7236d875f..4da0a5f5d 100644 --- a/src/dioptra/restapi/utils.py +++ b/src/dioptra/restapi/utils.py @@ -151,7 +151,7 @@ def create_parameters_schema( location = "files" parameters_schema = ParametersSchema( - name=cast(str, field.name), + name=cast(str, field.data_key or field.name), type=parameter_type, location=location, required=field.required, From 832bedb97ad1f1721e0f499136e892d4db188ee4 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Wed, 13 Nov 2024 09:33:07 -0500 Subject: [PATCH 16/28] docs: updated import workflow reference to match example --- docs/source/reference/resource-import-reference.rst | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/docs/source/reference/resource-import-reference.rst b/docs/source/reference/resource-import-reference.rst index d2d59ddfc..ece031dc9 100644 --- a/docs/source/reference/resource-import-reference.rst +++ b/docs/source/reference/resource-import-reference.rst @@ -55,7 +55,7 @@ including a Plugin, PluginParameterType, and Entrypoint. # input parameter names must match the Python function definition # types are plugin parameter types and are matched by name input_params = [ { name = "name", type = "string", required = true} ] - output_params = [ { name = "message", type = "string" } ] + output_params = [ { name = "message", type = "message" } ] [[plugins.tasks]] filename = "tasks.py" @@ -64,19 +64,17 @@ including a Plugin, PluginParameterType, and Entrypoint. { name = "greeting", type = "string", required = true }, { name = "name", type = "string", required = true }, ] - output_params = [ { name = "message", type = "string" } ] + output_params = [ { name = "message", type = "message" } ] [[plugins.tasks]] filename = "tasks.py" name = "shout" - input_params = [ { name = "message", type = "string", required = true} ] - output_params = [ { name = "message", type = "string" } ] + input_params = [ { name = "message", type = "message", required = true} ] + output_params = [ { name = "message", type = "message" } ] # PluginParameterTypes are fully described in the TOML [[plugin_param_types]] - name = "image" - # an optional structure for the type - structure = { list = ["int", "int", "int"] } + name = "message" # Entrypoints point to a task graph yaml and include metadata for registering them in Dioptra [[entrypoints]] From 6001ea3e1e504d9475436dfd121f44bb9ae0218a Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Wed, 13 Nov 2024 12:13:31 -0500 Subject: [PATCH 17/28] chore: removed commented code --- src/dioptra/restapi/v1/workflows/schema.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/dioptra/restapi/v1/workflows/schema.py b/src/dioptra/restapi/v1/workflows/schema.py index f975c32eb..6103e6123 100644 --- a/src/dioptra/restapi/v1/workflows/schema.py +++ b/src/dioptra/restapi/v1/workflows/schema.py @@ -112,12 +112,3 @@ def validate_source(self, data, **kwargs): and "git_url" not in data ): raise ValidationError({"gitUrl": "field required when sourceType is 'git'"}) - - # 'upload' is not in data - # if ( - # data["source_type"] == ResourceImportSourceTypes.UPLOAD - # and "data" not in data - # ): - # raise ValidationError( - # {"data": "field required when sourceType is 'upload'"} - # ) From 27366444fe0e9d101703462bbc08a9fb2e55f8de Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Wed, 13 Nov 2024 12:44:28 -0500 Subject: [PATCH 18/28] fix: prevent git clone from prompting for credentials --- src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py b/src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py index 4ebb1068a..f4a1105c5 100644 --- a/src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py +++ b/src/dioptra/restapi/v1/workflows/lib/clone_git_repository.py @@ -9,10 +9,11 @@ def clone_git_repository(url: str, dir: Path) -> str: git_paths = parsed_url.params or None git_url = parsed_url._replace(fragment="", params="").geturl() + git_env = ["GIT_TERMINAL_PROMPT=0"] git_sparse_args = ["--filter=blob:none", "--no-checkout", "--depth=1"] git_branch_args = ["-b", git_branch] if git_branch else [] clone_cmd = ["git", "clone", *git_sparse_args, *git_branch_args, git_url, str(dir)] - clone_result = subprocess.run(clone_cmd, capture_output=True, text=True) + clone_result = subprocess.run(git_env + clone_cmd, capture_output=True, text=True) if clone_result.returncode != 0: raise subprocess.CalledProcessError( From 318e7abfae7f4495c6be3f30943af772880a1ba1 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Wed, 13 Nov 2024 12:45:53 -0500 Subject: [PATCH 19/28] fix: handle file upload and git clone errors --- src/dioptra/restapi/v1/workflows/service.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index 55f913e76..24e758e13 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -190,11 +190,17 @@ def import_resources( if source_type == ResourceImportSourceTypes.UPLOAD: bytes = archive_file.stream.read() - with tarfile.open(fileobj=BytesIO(bytes), mode="r:*") as tar: - tar.extractall(path=working_dir, filter="data") + try: + with tarfile.open(fileobj=BytesIO(bytes), mode="r:*") as tar: + tar.extractall(path=working_dir, filter="data") + except Exception as e: + raise DioptraError("Failed to read uploaded tarfile") from e hash = str(sha256(bytes).hexdigest()) - elif source_type == ResourceImportSourceTypes.GIT: - hash = clone_git_repository(git_url, working_dir) + else: + try: + hash = clone_git_repository(git_url, working_dir) + except Exception as e: + raise DioptraError("Failed to clone repository") from e try: config = toml.load(working_dir / config_path) From 3d75674c0efa55cf42feaffb7b5df5be02e4368a Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Tue, 24 Dec 2024 11:27:34 -0500 Subject: [PATCH 20/28] tests: fix missing decorator after rebase --- tests/unit/restapi/v1/conftest.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/restapi/v1/conftest.py b/tests/unit/restapi/v1/conftest.py index a5f5173a2..89261e274 100644 --- a/tests/unit/restapi/v1/conftest.py +++ b/tests/unit/restapi/v1/conftest.py @@ -726,6 +726,7 @@ def registered_mlflowrun_incomplete( return responses +@pytest.fixture def resources_tar_file(): root_dir = Path(__file__).absolute().parent / "resource_import_files" From 96fb3f2d8155380920dbd3642529e5f89770e840 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Tue, 24 Dec 2024 11:28:07 -0500 Subject: [PATCH 21/28] feat(restapi): add custom errors for import workflow --- src/dioptra/restapi/errors.py | 24 +++++++++++++++++++++ src/dioptra/restapi/v1/workflows/service.py | 8 +++---- 2 files changed, 28 insertions(+), 4 deletions(-) diff --git a/src/dioptra/restapi/errors.py b/src/dioptra/restapi/errors.py index 1bb547ec9..1616532c2 100644 --- a/src/dioptra/restapi/errors.py +++ b/src/dioptra/restapi/errors.py @@ -339,6 +339,20 @@ def __init__(self, message: str): super().__init__(message) +class GitError(DioptraError): + """Git Error.""" + + def __init__(self, message: str): + super().__init__(message) + + +class ImportFailedError(DioptraError): + """Import failed Error.""" + + def __init__(self, message: str): + super().__init__(message) + + def error_result( error: DioptraError, status: http.HTTPStatus, detail: dict[str, typing.Any] ) -> tuple[dict[str, typing.Any], int]: @@ -443,6 +457,16 @@ def handle_mlflow_error(error: MLFlowError): log.debug(error.to_message()) return error_result(error, http.HTTPStatus.INTERNAL_SERVER_ERROR, {}) + @api.errorhandler(GitError) + def handle_git_error(error: GitError): + log.debug(error.to_message()) + return error_result(error, http.HTTPStatus.INTERNAL_SERVER_ERROR, {}) + + @api.errorhandler(GitError) + def handle_import_failed_error(error: ImportFailedError): + log.debug(error.to_message()) + return error_result(error, http.HTTPStatus.BAD_REQUEST, {}) + @api.errorhandler(DioptraError) def handle_base_error(error: DioptraError): log.debug(error.to_message()) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index 24e758e13..1640b14e9 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -32,7 +32,7 @@ from werkzeug.datastructures import FileStorage from dioptra.restapi.db import db, models -from dioptra.restapi.errors import DioptraError +from dioptra.restapi.errors import GitError, ImportFailedError from dioptra.restapi.v1.entrypoints.service import ( EntrypointIdService, EntrypointNameService, @@ -194,18 +194,18 @@ def import_resources( with tarfile.open(fileobj=BytesIO(bytes), mode="r:*") as tar: tar.extractall(path=working_dir, filter="data") except Exception as e: - raise DioptraError("Failed to read uploaded tarfile") from e + raise ImportFailedError("Failed to read uploaded tarfile") from e hash = str(sha256(bytes).hexdigest()) else: try: hash = clone_git_repository(git_url, working_dir) except Exception as e: - raise DioptraError("Failed to clone repository") from e + raise GitError("Failed to clone repository: {git_url}") from e try: config = toml.load(working_dir / config_path) except Exception as e: - raise DioptraError( + raise ImportFailedError( f"Failed to load resource import config from {config_path}." ) from e From 32e06d9125483158d9fe9eb4f6f3e0f4b0a83620 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Tue, 24 Dec 2024 11:28:41 -0500 Subject: [PATCH 22/28] feat(client): wip on updating client for import workflow --- src/dioptra/client/workflows.py | 50 ++++++++++++++++++++++++++++++++- 1 file changed, 49 insertions(+), 1 deletion(-) diff --git a/src/dioptra/client/workflows.py b/src/dioptra/client/workflows.py index 8dfa4f6c6..f836764ca 100644 --- a/src/dioptra/client/workflows.py +++ b/src/dioptra/client/workflows.py @@ -15,13 +15,14 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode from pathlib import Path -from typing import ClassVar, Final, TypeVar +from typing import ClassVar, Final, Literal, TypeVar from .base import CollectionClient, IllegalArgumentError T = TypeVar("T") JOB_FILES_DOWNLOAD: Final[str] = "jobFilesDownload" +RESOURCE_IMPORT: Final[str] = "resourceImport" class WorkflowsCollectionClient(CollectionClient[T]): @@ -86,3 +87,50 @@ def download_job_files( return self._session.download( self.url, JOB_FILES_DOWNLOAD, output_path=job_files_path, params=params ) + + def import_resources_from_git( + self, + group_id: int, + git_url: str, + config_path: str | None = "dioptra.toml", + resolve_name_conflict_strategy: Literal["fail", "overwrite"] | None = "fail", + ): + """ """ + + json_ = { + "group": group_id, + "sourceType": "git", + "gitUrl": git_url, + } + + if config_path is not None: + json_["configPath"] = config_path + + if resolve_name_conflict_strategy is not None: + json_["resolveNameConflictStrategy"] = resolve_name_conflict_strategy + + # need to update post to specify that json should be sent as form data + self._session.post(self.url, RESOURCE_IMPORT, json_=json_) + + def import_resources_from_archive( + self, + group_id: int, + archive_file_path: Path, + config_path: str = "dioptra.toml", + resolve_name_conflict_strategy: Literal["fail", "overwrite"] = "fail", + ): + """ """ + + json_ = { + "group": group_id, + "sourceType": "upload", + } + + if config_path is not None: + json_["configPath"] = config_path + + if resolve_name_conflict_strategy is not None: + json_["resolveNameConflictStrategy"] = resolve_name_conflict_strategy + + # need to update post to specify that json should be sent as form data + return self._session.post(self.url, RESOURCE_IMPORT, json_=json_) From 9fdc6d61b6f051bbcd5f16d3a22a22da032956de Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Wed, 15 Jan 2025 14:20:24 -0500 Subject: [PATCH 23/28] build: add types-toml to mypy tox environment --- tox.ini | 1 + 1 file changed, 1 insertion(+) diff --git a/tox.ini b/tox.ini index 6df5e06f7..bb0291a61 100644 --- a/tox.ini +++ b/tox.ini @@ -186,6 +186,7 @@ deps = types-PyYAML types-redis types-requests + types-toml typing-extensions>=3.7.4.3 skip_install = false commands = mypy {posargs:"{tox_root}{/}src{/}dioptra" "{tox_root}{/}task-plugins{/}dioptra_builtins"} From 51c022b40b81edc65671564b1c8060ba9e66e245 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Wed, 15 Jan 2025 14:22:04 -0500 Subject: [PATCH 24/28] feat(client): add resource import workflow to client --- src/dioptra/client/workflows.py | 100 ++++++++++----- src/dioptra/restapi/v1/workflows/service.py | 9 +- tests/unit/restapi/v1/conftest.py | 8 +- .../v1/test_workflow_resource_import.py | 120 ++++++------------ 4 files changed, 116 insertions(+), 121 deletions(-) diff --git a/src/dioptra/client/workflows.py b/src/dioptra/client/workflows.py index f836764ca..86a869145 100644 --- a/src/dioptra/client/workflows.py +++ b/src/dioptra/client/workflows.py @@ -15,9 +15,14 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode from pathlib import Path -from typing import ClassVar, Final, Literal, TypeVar +from typing import Any, ClassVar, Final, Literal, TypeVar, overload -from .base import CollectionClient, IllegalArgumentError +from .base import ( + CollectionClient, + DioptraFile, + DioptraResponseProtocol, + IllegalArgumentError, +) T = TypeVar("T") @@ -88,49 +93,76 @@ def download_job_files( self.url, JOB_FILES_DOWNLOAD, output_path=job_files_path, params=params ) - def import_resources_from_git( + @overload + def import_resources( self, - group_id: int, git_url: str, config_path: str | None = "dioptra.toml", - resolve_name_conflict_strategy: Literal["fail", "overwrite"] | None = "fail", + resolve_name_conflicts_strategy: Literal["fail", "overwrite"] | None = "fail", + ) -> DioptraResponseProtocol: + """Signature for using import_resource from git repo""" + ... # pragma: nocover + + @overload + def import_resources( + self, + archive_file: DioptraFile, + config_path: str | None = "dioptra.toml", + resolve_name_conflicts_strategy: Literal["fail", "overwrite"] | None = "fail", + ) -> DioptraResponseProtocol: + """Signature for using import_resource from archive file""" + ... # pragma: nocover + + def import_resources( + self, + group_id, + git_url=None, + archive_file=None, + config_path="dioptra.toml", + resolve_name_conflicts_strategy="fail", ): - """ """ + """ + Import resources from a archive file or git repository - json_ = { - "group": group_id, - "sourceType": "git", - "gitUrl": git_url, - } + Args: + group_id: The group to import resources into + source_type: The source to import from (either "upload" or "git") + git_url: The url to the git repository if source_type is "git" + archive_file: The contents of the upload if source_type is "upload" + config_path: The path to the toml configuration file in the import source. + resolve_name_conflicts_strategy: The strategy for resolving name conflicts. + Either "fail" or "overwrite" + Raises: + IllegalArgumentError: If only one of archive_file + """ - if config_path is not None: - json_["configPath"] = config_path + if archive_file is None and git_url is None: + raise IllegalArgumentError( + "One of 'archive_file' and 'git_url' must be provided" + ) - if resolve_name_conflict_strategy is not None: - json_["resolveNameConflictStrategy"] = resolve_name_conflict_strategy + if archive_file is not None and git_url is not None: + raise IllegalArgumentError( + "Only one of 'archive_file' and 'git_url' can be provided" + ) - # need to update post to specify that json should be sent as form data - self._session.post(self.url, RESOURCE_IMPORT, json_=json_) + data: dict[str, Any] = {"group": group_id} + files: dict[str, DioptraFile | list[DioptraFile]] = {} - def import_resources_from_archive( - self, - group_id: int, - archive_file_path: Path, - config_path: str = "dioptra.toml", - resolve_name_conflict_strategy: Literal["fail", "overwrite"] = "fail", - ): - """ """ + if git_url is not None: + data["sourceType"] = "git" + data["gitUrl"] = git_url - json_ = { - "group": group_id, - "sourceType": "upload", - } + if archive_file is not None: + data["sourceType"] = "upload" + files["archiveFile"] = archive_file if config_path is not None: - json_["configPath"] = config_path + data["configPath"] = config_path - if resolve_name_conflict_strategy is not None: - json_["resolveNameConflictStrategy"] = resolve_name_conflict_strategy + if resolve_name_conflicts_strategy is not None: + data["resolveNameConflictsStrategy"] = resolve_name_conflicts_strategy - # need to update post to specify that json should be sent as form data - return self._session.post(self.url, RESOURCE_IMPORT, json_=json_) + return self._session.post( + self.url, RESOURCE_IMPORT, data=data, files=files or None + ) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index 1640b14e9..ab9012260 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -22,7 +22,7 @@ from io import BytesIO from pathlib import Path from tempfile import TemporaryDirectory -from typing import IO, Any, Final +from typing import IO, Any, Final, cast import jsonschema import structlog @@ -52,7 +52,9 @@ ) from dioptra.sdk.utilities.paths import set_cwd -from .lib import clone_git_repository, package_job_files, views +from .lib import views +from .lib.clone_git_repository import clone_git_repository +from .lib.package_job_files import package_job_files from .schema import ( FileTypes, ResourceImportResolveNameConflictsStrategy, @@ -171,6 +173,7 @@ def import_resources( source_type: The source to import from (either "upload" or "git") git_url: The url to the git repository if source_type is "git" archive_file: The contents of the upload if source_type is "upload" + config_path: The path to the toml configuration file in the import source. resolve_name_conflicts_strategy: The strategy for resolving name conflicts. Either "fail" or "overwrite" @@ -198,7 +201,7 @@ def import_resources( hash = str(sha256(bytes).hexdigest()) else: try: - hash = clone_git_repository(git_url, working_dir) + hash = clone_git_repository(cast(str, git_url), working_dir) except Exception as e: raise GitError("Failed to clone repository: {git_url}") from e diff --git a/tests/unit/restapi/v1/conftest.py b/tests/unit/restapi/v1/conftest.py index 89261e274..2e7f24f94 100644 --- a/tests/unit/restapi/v1/conftest.py +++ b/tests/unit/restapi/v1/conftest.py @@ -32,6 +32,8 @@ from injector import Injector from pytest import MonkeyPatch +from dioptra.client import DioptraFile, select_one_or_more_files + from ..lib import actions, mock_rq @@ -727,7 +729,7 @@ def registered_mlflowrun_incomplete( @pytest.fixture -def resources_tar_file(): +def resources_tar_file() -> DioptraFile: root_dir = Path(__file__).absolute().parent / "resource_import_files" f = NamedTemporaryFile(suffix=".tar.gz") @@ -737,12 +739,12 @@ def resources_tar_file(): tar.add(root_dir / "hello-world.yaml", arcname="examples/hello-world.yaml") f.seek(0) - yield f + yield select_one_or_more_files([f.name])[0] f.close() @pytest.fixture -def resources_import_config(): +def resources_import_config() -> dict[str, Any]: root_dir = Path(__file__).absolute().parent / "resource_import_files" return toml.load(root_dir / "dioptra.toml") diff --git a/tests/unit/restapi/v1/test_workflow_resource_import.py b/tests/unit/restapi/v1/test_workflow_resource_import.py index af930042f..82437c4fd 100644 --- a/tests/unit/restapi/v1/test_workflow_resource_import.py +++ b/tests/unit/restapi/v1/test_workflow_resource_import.py @@ -25,107 +25,65 @@ from tempfile import NamedTemporaryFile from typing import Any -from flask.testing import FlaskClient from flask_sqlalchemy import SQLAlchemy -from werkzeug.test import TestResponse -from dioptra.restapi.routes import V1_ROOT, V1_WORKFLOWS_ROUTE - -# -- Actions --------------------------------------------------------------------------- - - -def resource_import( - client: FlaskClient, - resources_tar_file: NamedTemporaryFile, - group_id: int, - read_only: bool, - resolve_name_conflict_strategy: str, -) -> TestResponse: - """Import resources into Dioptra - - Args: - client: The Flask test client. - group_id: The id of the group to import resources into. - - Returns: - The response from the API. - """ - - payload = { - "group": group_id, - "sourceType": "upload", - "archiveFile": (resources_tar_file, "upload.tar.gz"), - "configPath": "dioptra.toml", - "resolveNameConflictsStrategy": resolve_name_conflict_strategy, - } - - return client.post( - f"/{V1_ROOT}/{V1_WORKFLOWS_ROUTE}/resourceImport", - data=payload, - content_type="multipart/form-data", - follow_redirects=True, - ) +from dioptra.client import DioptraClient, DioptraFile +from dioptra.client.base import DioptraResponseProtocol # -- Assertions ------------------------------------------------------------------------ def assert_imported_resources_match_expected( - client: FlaskClient, + dioptra_client: DioptraClient[DioptraResponseProtocol], expected: dict[str, Any], ): - response = client.get(f"/{V1_ROOT}/plugins", follow_redirects=True) - response_plugins = set(plugin["name"] for plugin in response.get_json()["data"]) + response = dioptra_client.plugins.get() + response_plugins = set(plugin["name"] for plugin in response.json()["data"]) expected_plugins = set(Path(plugin["path"]).stem for plugin in expected["plugins"]) assert response.status_code == 200 and response_plugins == expected_plugins - response = client.get(f"/{V1_ROOT}/pluginParameterTypes", follow_redirects=True) - response_types = set(param["name"] for param in response.get_json()["data"]) + response = dioptra_client.plugin_parameter_types.get() + response_types = set(param["name"] for param in response.json()["data"]) expected_types = set(param["name"] for param in expected["plugin_param_types"]) assert ( response.status_code == 200 and response_types & expected_types == expected_types ) - response = client.get(f"/{V1_ROOT}/entrypoints", follow_redirects=True) - response_entrypoints = set(ep["name"] for ep in response.get_json()["data"]) + response = dioptra_client.entrypoints.get() + response_entrypoints = set(ep["name"] for ep in response.json()["data"]) expected_entrypoints = set(ep["name"] for ep in expected["entrypoints"]) assert response.status_code == 200 and response_entrypoints == expected_entrypoints def assert_resource_import_fails_due_to_name_clash( - client: FlaskClient, + dioptra_client: DioptraClient[DioptraResponseProtocol], group_id: int, - resources_tar_file: NamedTemporaryFile, + archive_file: DioptraFile, ): - payload = {"name": "hello_world", "group": group_id} - client.post(f"/{V1_ROOT}/plugins", json=payload, follow_redirects=True) - - response = resource_import( - client, - resources_tar_file, - group_id, - read_only=True, - resolve_name_conflict_strategy="fail", + dioptra_client.plugins.create(group_id=group_id, name="hello_world") + response = dioptra_client.workflows.import_resources( + group_id=group_id, + archive_file=archive_file, + resolve_name_conflicts_strategy="fail", ) + assert response.status_code == 409 def assert_resource_import_overwrite_works( - client: FlaskClient, + dioptra_client: DioptraClient[DioptraResponseProtocol], group_id: int, - resources_tar_file: NamedTemporaryFile, + archive_file: DioptraFile, ): - payload = {"name": "hello_world", "group": group_id} - client.post(f"/{V1_ROOT}/plugins", json=payload, follow_redirects=True) - - response = resource_import( - client, - resources_tar_file, - group_id, - read_only=True, - resolve_name_conflict_strategy="overwrite", + dioptra_client.plugins.create(group_id=group_id, name="hello_world") + response = dioptra_client.workflows.import_resources( + group_id=group_id, + archive_file=archive_file, + resolve_name_conflicts_strategy="overwrite", ) + assert response.status_code == 200 @@ -133,39 +91,39 @@ def assert_resource_import_overwrite_works( def test_resource_import( - client: FlaskClient, + dioptra_client: DioptraClient[DioptraResponseProtocol], db: SQLAlchemy, auth_account: dict[str, Any], - resources_tar_file: NamedTemporaryFile, + resources_tar_file: DioptraFile, resources_import_config: dict[str, Any], ): group_id = auth_account["groups"][0]["id"] - resource_import( - client, - resources_tar_file, - group_id, - read_only=True, - resolve_name_conflict_strategy="fail", - ) + dioptra_client.workflows.import_resources(group_id, archive_file=resources_tar_file) - assert_imported_resources_match_expected(client, resources_import_config) + assert_imported_resources_match_expected(dioptra_client, resources_import_config) def test_resource_import_fails_from_name_clash( - client: FlaskClient, + dioptra_client: DioptraClient[DioptraResponseProtocol], db: SQLAlchemy, auth_account: dict[str, Any], resources_tar_file: NamedTemporaryFile, ): group_id = auth_account["groups"][0]["id"] - assert_resource_import_fails_due_to_name_clash(client, group_id, resources_tar_file) + + assert_resource_import_fails_due_to_name_clash( + dioptra_client, group_id=group_id, archive_file=resources_tar_file + ) def test_resource_import_overwrite( - client: FlaskClient, + dioptra_client: DioptraClient[DioptraResponseProtocol], db: SQLAlchemy, auth_account: dict[str, Any], resources_tar_file: NamedTemporaryFile, ): group_id = auth_account["groups"][0]["id"] - assert_resource_import_overwrite_works(client, group_id, resources_tar_file) + + assert_resource_import_overwrite_works( + dioptra_client, group_id=group_id, archive_file=resources_tar_file + ) From fc97c56defb6057915a4649135aa7b07b57c81d5 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Wed, 15 Jan 2025 16:41:33 -0500 Subject: [PATCH 25/28] feat: add support for multi-file uploads to resource import workflow --- src/dioptra/client/workflows.py | 47 ++++++++++++++----- .../restapi/v1/workflows/controller.py | 1 + src/dioptra/restapi/v1/workflows/schema.py | 26 +++++++--- src/dioptra/restapi/v1/workflows/service.py | 32 +++++++++++-- tests/unit/restapi/v1/conftest.py | 22 +++++++-- .../v1/test_workflow_resource_import.py | 15 +++++- 6 files changed, 113 insertions(+), 30 deletions(-) diff --git a/src/dioptra/client/workflows.py b/src/dioptra/client/workflows.py index 86a869145..f31f3ce46 100644 --- a/src/dioptra/client/workflows.py +++ b/src/dioptra/client/workflows.py @@ -113,11 +113,22 @@ def import_resources( """Signature for using import_resource from archive file""" ... # pragma: nocover + @overload + def import_resources( + self, + files: list[DioptraFile], + config_path: str | None = "dioptra.toml", + resolve_name_conflicts_strategy: Literal["fail", "overwrite"] | None = "fail", + ) -> DioptraResponseProtocol: + """Signature for using import_resource from archive file""" + ... # pragma: nocover + def import_resources( self, group_id, git_url=None, archive_file=None, + files=None, config_path="dioptra.toml", resolve_name_conflicts_strategy="fail", ): @@ -126,36 +137,46 @@ def import_resources( Args: group_id: The group to import resources into - source_type: The source to import from (either "upload" or "git") + source_type: The source to import from git_url: The url to the git repository if source_type is "git" - archive_file: The contents of the upload if source_type is "upload" + archive_file: The contents of the upload if source_type is "upload_archive" + files: The contents of the upload if source_type is "upload_files" config_path: The path to the toml configuration file in the import source. resolve_name_conflicts_strategy: The strategy for resolving name conflicts. Either "fail" or "overwrite" Raises: - IllegalArgumentError: If only one of archive_file + IllegalArgumentError: If more than one import source is provided or if no + import source is provided. """ - if archive_file is None and git_url is None: + import_source_args = [git_url, archive_file, files] + num_provided_import_source_args = sum( + arg is not None for arg in import_source_args + ) + + if num_provided_import_source_args == 0: raise IllegalArgumentError( - "One of 'archive_file' and 'git_url' must be provided" + "One of (git_url, archive_file, or files) must be provided" ) - - if archive_file is not None and git_url is not None: + elif num_provided_import_source_args > 1: raise IllegalArgumentError( - "Only one of 'archive_file' and 'git_url' can be provided" + "Only one of (git_url, archive_file and files) can be provided" ) - data: dict[str, Any] = {"group": group_id} - files: dict[str, DioptraFile | list[DioptraFile]] = {} + data: dict[str, Any] = {"group": str(group_id)} + files_: dict[str, DioptraFile | list[DioptraFile]] = {} if git_url is not None: data["sourceType"] = "git" data["gitUrl"] = git_url if archive_file is not None: - data["sourceType"] = "upload" - files["archiveFile"] = archive_file + data["sourceType"] = "upload_archive" + files_["archiveFile"] = archive_file + + if files is not None: + data["sourceType"] = "upload_files" + files_["files"] = files if config_path is not None: data["configPath"] = config_path @@ -164,5 +185,5 @@ def import_resources( data["resolveNameConflictsStrategy"] = resolve_name_conflicts_strategy return self._session.post( - self.url, RESOURCE_IMPORT, data=data, files=files or None + self.url, RESOURCE_IMPORT, data=data, files=files_ or None ) diff --git a/src/dioptra/restapi/v1/workflows/controller.py b/src/dioptra/restapi/v1/workflows/controller.py index fd5cd9b6f..e469b6e23 100644 --- a/src/dioptra/restapi/v1/workflows/controller.py +++ b/src/dioptra/restapi/v1/workflows/controller.py @@ -124,6 +124,7 @@ def post(self): source_type=parsed_form["source_type"], git_url=parsed_form.get("git_url", None), archive_file=request.files.get("archiveFile", None), + files=request.files.getlist("files", None), config_path=parsed_form["config_path"], resolve_name_conflicts_strategy=parsed_form[ "resolve_name_conflicts_strategy" diff --git a/src/dioptra/restapi/v1/workflows/schema.py b/src/dioptra/restapi/v1/workflows/schema.py index 6103e6123..4b7fcda5e 100644 --- a/src/dioptra/restapi/v1/workflows/schema.py +++ b/src/dioptra/restapi/v1/workflows/schema.py @@ -19,7 +19,7 @@ from marshmallow import Schema, ValidationError, fields, validates_schema -from dioptra.restapi.custom_schema_fields import FileUpload +from dioptra.restapi.custom_schema_fields import FileUpload, MultiFileUpload class FileTypes(Enum): @@ -47,7 +47,8 @@ class JobFilesDownloadQueryParametersSchema(Schema): class ResourceImportSourceTypes(Enum): GIT = "git" - UPLOAD = "upload" + UPLOAD_ARCHIVE = "upload_archive" + UPLOAD_FILES = "upload_files" class ResourceImportResolveNameConflictsStrategy(Enum): @@ -69,14 +70,17 @@ class ResourceImportSchema(Schema): sourceType = fields.Enum( ResourceImportSourceTypes, attribute="source_type", - metadata=dict(description="The source of the resources to import."), + metadata=dict( + description="The source of the resources to import" + "('upload_archive', 'upload_files', or 'git'." + ), by_value=True, required=True, ) gitUrl = fields.String( attribute="git_url", metadata=dict( - description="The URL of the git repository containing resources to import. " + description="The URL of the git repository containing resources to import." "A git branch can optionally be specified by appending #BRANCH_NAME. " "Used when sourceType is 'git'." ), @@ -87,8 +91,18 @@ class ResourceImportSchema(Schema): metadata=dict( type="file", format="binary", - description="The archive file containing resources to import (.tar.gz). " - "Used when sourceType is 'upload'.", + description="The archive file containing resources to import (.tar.gz)." + "Used when sourceType is 'upload_archive'.", + ), + required=False, + ) + files = MultiFileUpload( + attribute="files", + metadata=dict( + type="file", + format="binary", + description="The files containing the resources to import." + "Used when sourceType is 'upload_files'.", ), required=False, ) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index ab9012260..8ea0182b7 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -162,6 +162,7 @@ def import_resources( source_type: str, git_url: str | None, archive_file: FileStorage | None, + files: list[FileStorage] | None, config_path: str, resolve_name_conflicts_strategy: str, **kwargs, @@ -172,7 +173,8 @@ def import_resources( group_id: The group to import resources into source_type: The source to import from (either "upload" or "git") git_url: The url to the git repository if source_type is "git" - archive_file: The contents of the upload if source_type is "upload" + archive_file: The contents of the upload if source_type is "upload_archive" + files: The contents of the upload if source_type is "upload_files" config_path: The path to the toml configuration file in the import source. resolve_name_conflicts_strategy: The strategy for resolving name conflicts. Either "fail" or "overwrite" @@ -191,7 +193,7 @@ def import_resources( with TemporaryDirectory() as tmp_dir, set_cwd(tmp_dir): working_dir = Path(tmp_dir) - if source_type == ResourceImportSourceTypes.UPLOAD: + if source_type == ResourceImportSourceTypes.UPLOAD_ARCHIVE: bytes = archive_file.stream.read() try: with tarfile.open(fileobj=BytesIO(bytes), mode="r:*") as tar: @@ -199,11 +201,20 @@ def import_resources( except Exception as e: raise ImportFailedError("Failed to read uploaded tarfile") from e hash = str(sha256(bytes).hexdigest()) + elif source_type == ResourceImportSourceTypes.UPLOAD_FILES: + hashes = b"" + for file in files: + Path(file.filename).parent.mkdir(parents=True, exist_ok=True) + bytes = file.stream.read() + with open(working_dir / file.filename, "wb") as f: + f.write(bytes) + hashes = hashes + sha256(bytes).digest() + hash = str(sha256(hashes).hexdigest()) else: try: hash = clone_git_repository(cast(str, git_url), working_dir) except Exception as e: - raise GitError("Failed to clone repository: {git_url}") from e + raise GitError(f"Failed to clone repository: {git_url}") from e try: config = toml.load(working_dir / config_path) @@ -345,7 +356,12 @@ def _register_plugins( tasks = self._build_tasks(plugin.get("tasks", []), param_types) for plugin_file_path in Path(plugin["path"]).rglob("*.py"): filename = str(plugin_file_path.relative_to(plugin["path"])) - contents = plugin_file_path.read_text() + try: + contents = plugin_file_path.read_text() + except FileNotFoundError as e: + raise ImportFailedError( + f"Failed to read plugin file from {plugin_file_path}" + ) from e self._plugin_id_file_service.create( filename, @@ -394,7 +410,13 @@ def _register_entrypoints( entrypoint_id=existing.resource_id ) - contents = Path(entrypoint["path"]).read_text() + try: + contents = Path(entrypoint["path"]).read_text() + except FileNotFoundError as e: + raise ImportFailedError( + f"Failed to read plugin file from {entrypoint['path']}" + ) from e + params = [ { "name": param["name"], diff --git a/tests/unit/restapi/v1/conftest.py b/tests/unit/restapi/v1/conftest.py index 2e7f24f94..f1b09d941 100644 --- a/tests/unit/restapi/v1/conftest.py +++ b/tests/unit/restapi/v1/conftest.py @@ -15,6 +15,7 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """Fixtures representing resources needed for test suites""" +import os import tarfile import textwrap from collections.abc import Iterator @@ -32,7 +33,11 @@ from injector import Injector from pytest import MonkeyPatch -from dioptra.client import DioptraFile, select_one_or_more_files +from dioptra.client import ( + DioptraFile, + select_files_in_directory, + select_one_or_more_files, +) from ..lib import actions, mock_rq @@ -730,13 +735,13 @@ def registered_mlflowrun_incomplete( @pytest.fixture def resources_tar_file() -> DioptraFile: - root_dir = Path(__file__).absolute().parent / "resource_import_files" + os.chdir(Path(__file__).absolute().parent / "resource_import_files") f = NamedTemporaryFile(suffix=".tar.gz") with tarfile.open(fileobj=f, mode="w:gz") as tar: - tar.add(root_dir / "dioptra.toml", arcname="dioptra.toml") - tar.add(root_dir / "hello_world", arcname="plugins/hello_world", recursive=True) - tar.add(root_dir / "hello-world.yaml", arcname="examples/hello-world.yaml") + tar.add("dioptra.toml") + tar.add("plugins", recursive=True) + tar.add("examples/hello-world.yaml") f.seek(0) yield select_one_or_more_files([f.name])[0] @@ -744,6 +749,13 @@ def resources_tar_file() -> DioptraFile: f.close() +@pytest.fixture +def resources_files() -> DioptraFile: + os.chdir(Path(__file__).absolute().parent / "resource_import_files") + + return select_files_in_directory(".", recursive=True) + + @pytest.fixture def resources_import_config() -> dict[str, Any]: root_dir = Path(__file__).absolute().parent / "resource_import_files" diff --git a/tests/unit/restapi/v1/test_workflow_resource_import.py b/tests/unit/restapi/v1/test_workflow_resource_import.py index 82437c4fd..f997c3d2c 100644 --- a/tests/unit/restapi/v1/test_workflow_resource_import.py +++ b/tests/unit/restapi/v1/test_workflow_resource_import.py @@ -90,7 +90,7 @@ def assert_resource_import_overwrite_works( # -- Tests ----------------------------------------------------------------------------- -def test_resource_import( +def test_resource_import_from_archive_file( dioptra_client: DioptraClient[DioptraResponseProtocol], db: SQLAlchemy, auth_account: dict[str, Any], @@ -103,6 +103,19 @@ def test_resource_import( assert_imported_resources_match_expected(dioptra_client, resources_import_config) +def test_resource_import_from_files( + dioptra_client: DioptraClient[DioptraResponseProtocol], + db: SQLAlchemy, + auth_account: dict[str, Any], + resources_files: list[DioptraFile], + resources_import_config: dict[str, Any], +): + group_id = auth_account["groups"][0]["id"] + dioptra_client.workflows.import_resources(group_id, files=resources_files) + + assert_imported_resources_match_expected(dioptra_client, resources_import_config) + + def test_resource_import_fails_from_name_clash( dioptra_client: DioptraClient[DioptraResponseProtocol], db: SQLAlchemy, From 81b398a71fd6668d4f68365bdb3739faf04341f7 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Thu, 16 Jan 2025 14:58:50 -0500 Subject: [PATCH 26/28] fix: updated paths to resource files used in tests --- .../v1/resource_import_files/{ => examples}/hello-world.yaml | 0 .../resource_import_files/{ => plugins}/hello_world/__init__.py | 0 .../v1/resource_import_files/{ => plugins}/hello_world/tasks.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename tests/unit/restapi/v1/resource_import_files/{ => examples}/hello-world.yaml (100%) rename tests/unit/restapi/v1/resource_import_files/{ => plugins}/hello_world/__init__.py (100%) rename tests/unit/restapi/v1/resource_import_files/{ => plugins}/hello_world/tasks.py (100%) diff --git a/tests/unit/restapi/v1/resource_import_files/hello-world.yaml b/tests/unit/restapi/v1/resource_import_files/examples/hello-world.yaml similarity index 100% rename from tests/unit/restapi/v1/resource_import_files/hello-world.yaml rename to tests/unit/restapi/v1/resource_import_files/examples/hello-world.yaml diff --git a/tests/unit/restapi/v1/resource_import_files/hello_world/__init__.py b/tests/unit/restapi/v1/resource_import_files/plugins/hello_world/__init__.py similarity index 100% rename from tests/unit/restapi/v1/resource_import_files/hello_world/__init__.py rename to tests/unit/restapi/v1/resource_import_files/plugins/hello_world/__init__.py diff --git a/tests/unit/restapi/v1/resource_import_files/hello_world/tasks.py b/tests/unit/restapi/v1/resource_import_files/plugins/hello_world/tasks.py similarity index 100% rename from tests/unit/restapi/v1/resource_import_files/hello_world/tasks.py rename to tests/unit/restapi/v1/resource_import_files/plugins/hello_world/tasks.py From b53f9f79fbfb3a43fdd2d66ed315ed122a9b3368 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Thu, 16 Jan 2025 15:11:41 -0500 Subject: [PATCH 27/28] fix: fix path for windows --- tests/unit/restapi/v1/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/restapi/v1/conftest.py b/tests/unit/restapi/v1/conftest.py index f1b09d941..9a86fa7fe 100644 --- a/tests/unit/restapi/v1/conftest.py +++ b/tests/unit/restapi/v1/conftest.py @@ -741,7 +741,7 @@ def resources_tar_file() -> DioptraFile: with tarfile.open(fileobj=f, mode="w:gz") as tar: tar.add("dioptra.toml") tar.add("plugins", recursive=True) - tar.add("examples/hello-world.yaml") + tar.add(Path("examples", "hello-world.yaml")) f.seek(0) yield select_one_or_more_files([f.name])[0] From c9e02a34cd3c46e5dd5b15ffdbf2a0029d4ab458 Mon Sep 17 00:00:00 2001 From: Keith Manville Date: Thu, 16 Jan 2025 15:47:04 -0500 Subject: [PATCH 28/28] fix: handle tmp file deletion manually --- tests/unit/restapi/v1/conftest.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/unit/restapi/v1/conftest.py b/tests/unit/restapi/v1/conftest.py index 9a86fa7fe..44290f0e6 100644 --- a/tests/unit/restapi/v1/conftest.py +++ b/tests/unit/restapi/v1/conftest.py @@ -737,16 +737,15 @@ def registered_mlflowrun_incomplete( def resources_tar_file() -> DioptraFile: os.chdir(Path(__file__).absolute().parent / "resource_import_files") - f = NamedTemporaryFile(suffix=".tar.gz") - with tarfile.open(fileobj=f, mode="w:gz") as tar: - tar.add("dioptra.toml") - tar.add("plugins", recursive=True) - tar.add(Path("examples", "hello-world.yaml")) - f.seek(0) + with NamedTemporaryFile(suffix=".tar.gz", delete=False) as f: + with tarfile.open(fileobj=f, mode="w:gz") as tar: + tar.add("dioptra.toml") + tar.add("plugins", recursive=True) + tar.add(Path("examples", "hello-world.yaml")) yield select_one_or_more_files([f.name])[0] - f.close() + os.unlink(f.name) @pytest.fixture