Skip to content

Commit

Permalink
feat: add support for multi-file uploads to resource import workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
keithmanville committed Jan 16, 2025
1 parent 51c022b commit fc97c56
Show file tree
Hide file tree
Showing 6 changed files with 113 additions and 30 deletions.
47 changes: 34 additions & 13 deletions src/dioptra/client/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,11 +113,22 @@ def import_resources(
"""Signature for using import_resource from archive file"""
... # pragma: nocover

@overload
def import_resources(
self,
files: list[DioptraFile],
config_path: str | None = "dioptra.toml",
resolve_name_conflicts_strategy: Literal["fail", "overwrite"] | None = "fail",
) -> DioptraResponseProtocol:
"""Signature for using import_resource from archive file"""
... # pragma: nocover

def import_resources(
self,
group_id,
git_url=None,
archive_file=None,
files=None,
config_path="dioptra.toml",
resolve_name_conflicts_strategy="fail",
):
Expand All @@ -126,36 +137,46 @@ def import_resources(
Args:
group_id: The group to import resources into
source_type: The source to import from (either "upload" or "git")
source_type: The source to import from
git_url: The url to the git repository if source_type is "git"
archive_file: The contents of the upload if source_type is "upload"
archive_file: The contents of the upload if source_type is "upload_archive"
files: The contents of the upload if source_type is "upload_files"
config_path: The path to the toml configuration file in the import source.
resolve_name_conflicts_strategy: The strategy for resolving name conflicts.
Either "fail" or "overwrite"
Raises:
IllegalArgumentError: If only one of archive_file
IllegalArgumentError: If more than one import source is provided or if no
import source is provided.
"""

if archive_file is None and git_url is None:
import_source_args = [git_url, archive_file, files]
num_provided_import_source_args = sum(
arg is not None for arg in import_source_args
)

if num_provided_import_source_args == 0:
raise IllegalArgumentError(
"One of 'archive_file' and 'git_url' must be provided"
"One of (git_url, archive_file, or files) must be provided"
)

if archive_file is not None and git_url is not None:
elif num_provided_import_source_args > 1:
raise IllegalArgumentError(
"Only one of 'archive_file' and 'git_url' can be provided"
"Only one of (git_url, archive_file and files) can be provided"
)

data: dict[str, Any] = {"group": group_id}
files: dict[str, DioptraFile | list[DioptraFile]] = {}
data: dict[str, Any] = {"group": str(group_id)}
files_: dict[str, DioptraFile | list[DioptraFile]] = {}

if git_url is not None:
data["sourceType"] = "git"
data["gitUrl"] = git_url

if archive_file is not None:
data["sourceType"] = "upload"
files["archiveFile"] = archive_file
data["sourceType"] = "upload_archive"
files_["archiveFile"] = archive_file

if files is not None:
data["sourceType"] = "upload_files"
files_["files"] = files

if config_path is not None:
data["configPath"] = config_path
Expand All @@ -164,5 +185,5 @@ def import_resources(
data["resolveNameConflictsStrategy"] = resolve_name_conflicts_strategy

return self._session.post(
self.url, RESOURCE_IMPORT, data=data, files=files or None
self.url, RESOURCE_IMPORT, data=data, files=files_ or None
)
1 change: 1 addition & 0 deletions src/dioptra/restapi/v1/workflows/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def post(self):
source_type=parsed_form["source_type"],
git_url=parsed_form.get("git_url", None),
archive_file=request.files.get("archiveFile", None),
files=request.files.getlist("files", None),
config_path=parsed_form["config_path"],
resolve_name_conflicts_strategy=parsed_form[
"resolve_name_conflicts_strategy"
Expand Down
26 changes: 20 additions & 6 deletions src/dioptra/restapi/v1/workflows/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from marshmallow import Schema, ValidationError, fields, validates_schema

from dioptra.restapi.custom_schema_fields import FileUpload
from dioptra.restapi.custom_schema_fields import FileUpload, MultiFileUpload


class FileTypes(Enum):
Expand Down Expand Up @@ -47,7 +47,8 @@ class JobFilesDownloadQueryParametersSchema(Schema):

class ResourceImportSourceTypes(Enum):
GIT = "git"
UPLOAD = "upload"
UPLOAD_ARCHIVE = "upload_archive"
UPLOAD_FILES = "upload_files"


class ResourceImportResolveNameConflictsStrategy(Enum):
Expand All @@ -69,14 +70,17 @@ class ResourceImportSchema(Schema):
sourceType = fields.Enum(
ResourceImportSourceTypes,
attribute="source_type",
metadata=dict(description="The source of the resources to import."),
metadata=dict(
description="The source of the resources to import"
"('upload_archive', 'upload_files', or 'git'."
),
by_value=True,
required=True,
)
gitUrl = fields.String(
attribute="git_url",
metadata=dict(
description="The URL of the git repository containing resources to import. "
description="The URL of the git repository containing resources to import."
"A git branch can optionally be specified by appending #BRANCH_NAME. "
"Used when sourceType is 'git'."
),
Expand All @@ -87,8 +91,18 @@ class ResourceImportSchema(Schema):
metadata=dict(
type="file",
format="binary",
description="The archive file containing resources to import (.tar.gz). "
"Used when sourceType is 'upload'.",
description="The archive file containing resources to import (.tar.gz)."
"Used when sourceType is 'upload_archive'.",
),
required=False,
)
files = MultiFileUpload(
attribute="files",
metadata=dict(
type="file",
format="binary",
description="The files containing the resources to import."
"Used when sourceType is 'upload_files'.",
),
required=False,
)
Expand Down
32 changes: 27 additions & 5 deletions src/dioptra/restapi/v1/workflows/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def import_resources(
source_type: str,
git_url: str | None,
archive_file: FileStorage | None,
files: list[FileStorage] | None,
config_path: str,
resolve_name_conflicts_strategy: str,
**kwargs,
Expand All @@ -172,7 +173,8 @@ def import_resources(
group_id: The group to import resources into
source_type: The source to import from (either "upload" or "git")
git_url: The url to the git repository if source_type is "git"
archive_file: The contents of the upload if source_type is "upload"
archive_file: The contents of the upload if source_type is "upload_archive"
files: The contents of the upload if source_type is "upload_files"
config_path: The path to the toml configuration file in the import source.
resolve_name_conflicts_strategy: The strategy for resolving name conflicts.
Either "fail" or "overwrite"
Expand All @@ -191,19 +193,28 @@ def import_resources(
with TemporaryDirectory() as tmp_dir, set_cwd(tmp_dir):
working_dir = Path(tmp_dir)

if source_type == ResourceImportSourceTypes.UPLOAD:
if source_type == ResourceImportSourceTypes.UPLOAD_ARCHIVE:
bytes = archive_file.stream.read()
try:
with tarfile.open(fileobj=BytesIO(bytes), mode="r:*") as tar:
tar.extractall(path=working_dir, filter="data")
except Exception as e:
raise ImportFailedError("Failed to read uploaded tarfile") from e
hash = str(sha256(bytes).hexdigest())
elif source_type == ResourceImportSourceTypes.UPLOAD_FILES:
hashes = b""
for file in files:
Path(file.filename).parent.mkdir(parents=True, exist_ok=True)
bytes = file.stream.read()
with open(working_dir / file.filename, "wb") as f:
f.write(bytes)
hashes = hashes + sha256(bytes).digest()
hash = str(sha256(hashes).hexdigest())
else:
try:
hash = clone_git_repository(cast(str, git_url), working_dir)
except Exception as e:
raise GitError("Failed to clone repository: {git_url}") from e
raise GitError(f"Failed to clone repository: {git_url}") from e

try:
config = toml.load(working_dir / config_path)
Expand Down Expand Up @@ -345,7 +356,12 @@ def _register_plugins(
tasks = self._build_tasks(plugin.get("tasks", []), param_types)
for plugin_file_path in Path(plugin["path"]).rglob("*.py"):
filename = str(plugin_file_path.relative_to(plugin["path"]))
contents = plugin_file_path.read_text()
try:
contents = plugin_file_path.read_text()
except FileNotFoundError as e:
raise ImportFailedError(
f"Failed to read plugin file from {plugin_file_path}"
) from e

self._plugin_id_file_service.create(
filename,
Expand Down Expand Up @@ -394,7 +410,13 @@ def _register_entrypoints(
entrypoint_id=existing.resource_id
)

contents = Path(entrypoint["path"]).read_text()
try:
contents = Path(entrypoint["path"]).read_text()
except FileNotFoundError as e:
raise ImportFailedError(
f"Failed to read plugin file from {entrypoint['path']}"
) from e

params = [
{
"name": param["name"],
Expand Down
22 changes: 17 additions & 5 deletions tests/unit/restapi/v1/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
# https://creativecommons.org/licenses/by/4.0/legalcode
"""Fixtures representing resources needed for test suites"""
import os
import tarfile
import textwrap
from collections.abc import Iterator
Expand All @@ -32,7 +33,11 @@
from injector import Injector
from pytest import MonkeyPatch

from dioptra.client import DioptraFile, select_one_or_more_files
from dioptra.client import (
DioptraFile,
select_files_in_directory,
select_one_or_more_files,
)

from ..lib import actions, mock_rq

Expand Down Expand Up @@ -730,20 +735,27 @@ def registered_mlflowrun_incomplete(

@pytest.fixture
def resources_tar_file() -> DioptraFile:
root_dir = Path(__file__).absolute().parent / "resource_import_files"
os.chdir(Path(__file__).absolute().parent / "resource_import_files")

f = NamedTemporaryFile(suffix=".tar.gz")
with tarfile.open(fileobj=f, mode="w:gz") as tar:
tar.add(root_dir / "dioptra.toml", arcname="dioptra.toml")
tar.add(root_dir / "hello_world", arcname="plugins/hello_world", recursive=True)
tar.add(root_dir / "hello-world.yaml", arcname="examples/hello-world.yaml")
tar.add("dioptra.toml")
tar.add("plugins", recursive=True)
tar.add("examples/hello-world.yaml")
f.seek(0)

yield select_one_or_more_files([f.name])[0]

f.close()


@pytest.fixture
def resources_files() -> DioptraFile:
os.chdir(Path(__file__).absolute().parent / "resource_import_files")

return select_files_in_directory(".", recursive=True)


@pytest.fixture
def resources_import_config() -> dict[str, Any]:
root_dir = Path(__file__).absolute().parent / "resource_import_files"
Expand Down
15 changes: 14 additions & 1 deletion tests/unit/restapi/v1/test_workflow_resource_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def assert_resource_import_overwrite_works(
# -- Tests -----------------------------------------------------------------------------


def test_resource_import(
def test_resource_import_from_archive_file(
dioptra_client: DioptraClient[DioptraResponseProtocol],
db: SQLAlchemy,
auth_account: dict[str, Any],
Expand All @@ -103,6 +103,19 @@ def test_resource_import(
assert_imported_resources_match_expected(dioptra_client, resources_import_config)


def test_resource_import_from_files(
dioptra_client: DioptraClient[DioptraResponseProtocol],
db: SQLAlchemy,
auth_account: dict[str, Any],
resources_files: list[DioptraFile],
resources_import_config: dict[str, Any],
):
group_id = auth_account["groups"][0]["id"]
dioptra_client.workflows.import_resources(group_id, files=resources_files)

assert_imported_resources_match_expected(dioptra_client, resources_import_config)


def test_resource_import_fails_from_name_clash(
dioptra_client: DioptraClient[DioptraResponseProtocol],
db: SQLAlchemy,
Expand Down

0 comments on commit fc97c56

Please sign in to comment.