diff --git a/diracx-core/src/diracx/core/models.py b/diracx-core/src/diracx/core/models.py index 3aae477ce..243068bca 100644 --- a/diracx-core/src/diracx/core/models.py +++ b/diracx-core/src/diracx/core/models.py @@ -120,3 +120,8 @@ class SandboxInfo(BaseModel): checksum: str = Field(pattern=r"^[0-f]{64}$") size: int = Field(ge=1) format: SandboxFormat + + +class SandboxType(StrEnum): + Input: str = "Input" + Output: str = "Output" diff --git a/diracx-db/src/diracx/db/sql/jobs/status_utility.py b/diracx-db/src/diracx/db/sql/jobs/status_utility.py index 09b00edfd..0451cc6bf 100644 --- a/diracx-db/src/diracx/db/sql/jobs/status_utility.py +++ b/diracx-db/src/diracx/db/sql/jobs/status_utility.py @@ -272,7 +272,7 @@ async def remove_jobs( # TODO: this was also not done in the JobManagerHandler, but it was done in the JobCleaningAgent # I think it should be done here as well - await sandbox_metadata_db.unassign_sandbox_from_jobs(job_ids) + await sandbox_metadata_db.unassign_sandboxes_to_jobs(job_ids) # Remove the job from TaskQueueDB await _remove_jobs_from_task_queue(job_ids, config, task_queue_db, background_task) diff --git a/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py b/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py index e7715d1ee..47a9e79ec 100644 --- a/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py +++ b/diracx-db/src/diracx/db/sql/sandbox_metadata/db.py @@ -1,9 +1,10 @@ from __future__ import annotations +from typing import Any + import sqlalchemy -from sqlalchemy import delete -from diracx.core.models import SandboxInfo, UserInfo +from diracx.core.models import SandboxInfo, SandboxType, UserInfo from diracx.db.sql.utils import BaseSQLDB, utcnow from .schema import Base as SandboxMetadataDBBase @@ -76,7 +77,7 @@ async def update_sandbox_last_access_time(self, se_name: str, pfn: str) -> None: result = await self.conn.execute(stmt) assert result.rowcount == 1 - async def sandbox_is_assigned(self, se_name: str, pfn: str) -> bool: + async def sandbox_is_assigned(self, pfn: str, se_name: str) -> bool: """Checks if a sandbox exists and has been assigned.""" stmt: sqlalchemy.Executable = sqlalchemy.select(sb_SandBoxes.Assigned).where( sb_SandBoxes.SEName == se_name, sb_SandBoxes.SEPFN == pfn @@ -84,13 +85,85 @@ async def sandbox_is_assigned(self, se_name: str, pfn: str) -> bool: result = await self.conn.execute(stmt) is_assigned = result.scalar_one() return is_assigned - return True - - async def unassign_sandbox_from_jobs(self, job_ids: list[int]): - """ - Unassign sandbox from jobs - """ - stmt = delete(sb_EntityMapping).where( - sb_EntityMapping.EntityId.in_(f"Job:{job_id}" for job_id in job_ids) + + @staticmethod + def jobid_to_entity_id(job_id: int) -> str: + """Define the entity id as 'Entity:entity_id' due to the DB definition""" + return f"Job:{job_id}" + + async def get_sandbox_assigned_to_job( + self, job_id: int, sb_type: SandboxType + ) -> list[Any]: + """Get the sandbox assign to job""" + entity_id = self.jobid_to_entity_id(job_id) + stmt = ( + sqlalchemy.select(sb_SandBoxes.SEPFN) + .where(sb_SandBoxes.SBId == sb_EntityMapping.SBId) + .where( + sb_EntityMapping.EntityId == entity_id, + sb_EntityMapping.Type == sb_type, + ) ) - await self.conn.execute(stmt) + result = await self.conn.execute(stmt) + return [result.scalar()] + + async def assign_sandbox_to_jobs( + self, + jobs_ids: list[int], + pfn: str, + sb_type: SandboxType, + se_name: str, + ) -> None: + """Mapp sandbox and jobs""" + for job_id in jobs_ids: + # Define the entity id as 'Entity:entity_id' due to the DB definition: + entity_id = self.jobid_to_entity_id(job_id) + select_sb_id = sqlalchemy.select( + sb_SandBoxes.SBId, + sqlalchemy.literal(entity_id).label("EntityId"), + sqlalchemy.literal(sb_type).label("Type"), + ).where( + sb_SandBoxes.SEName == se_name, + sb_SandBoxes.SEPFN == pfn, + ) + stmt = sqlalchemy.insert(sb_EntityMapping).from_select( + ["SBId", "EntityId", "Type"], select_sb_id + ) + await self.conn.execute(stmt) + + stmt = ( + sqlalchemy.update(sb_SandBoxes) + .where(sb_SandBoxes.SEPFN == pfn) + .values(Assigned=True) + ) + result = await self.conn.execute(stmt) + assert result.rowcount == 1 + + async def unassign_sandboxes_to_jobs(self, jobs_ids: list[int]) -> None: + """Delete mapping between jobs and sandboxes""" + for job_id in jobs_ids: + entity_id = self.jobid_to_entity_id(job_id) + sb_sel_stmt = sqlalchemy.select( + sb_SandBoxes.SBId, + ).where(sb_EntityMapping.EntityId == entity_id) + + result = await self.conn.execute(sb_sel_stmt) + sb_ids = [row.SBId for row in result] + + del_stmt = sqlalchemy.delete(sb_EntityMapping).where( + sb_EntityMapping.EntityId == entity_id + ) + await self.conn.execute(del_stmt) + + sb_entity_sel_stmt = sqlalchemy.select(sb_EntityMapping.SBId).where( + sb_EntityMapping.SBId.in_(sb_ids) + ) + result = await self.conn.execute(sb_entity_sel_stmt) + remaining_sb_ids = [row.SBId for row in result] + if not remaining_sb_ids: + unassign_stmt = ( + sqlalchemy.update(sb_SandBoxes) + .where(sb_SandBoxes.SBId.in_(sb_ids)) + .values(Assigned=False) + ) + await self.conn.execute(unassign_stmt) diff --git a/diracx-db/tests/test_sandbox_metadata.py b/diracx-db/tests/test_sandbox_metadata.py index 9d19f73cd..5a43d8b9b 100644 --- a/diracx-db/tests/test_sandbox_metadata.py +++ b/diracx-db/tests/test_sandbox_metadata.py @@ -9,7 +9,7 @@ from diracx.core.models import SandboxInfo, UserInfo from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB -from diracx.db.sql.sandbox_metadata.schema import sb_SandBoxes +from diracx.db.sql.sandbox_metadata.schema import sb_EntityMapping, sb_SandBoxes @pytest.fixture @@ -46,7 +46,7 @@ async def test_insert_sandbox(sandbox_metadata_db: SandboxMetadataDB): assert pfn1 not in db_contents async with sandbox_metadata_db: with pytest.raises(sqlalchemy.exc.NoResultFound): - await sandbox_metadata_db.sandbox_is_assigned("SandboxSE", pfn1) + await sandbox_metadata_db.sandbox_is_assigned(pfn1, "SandboxSE") # Insert the sandbox async with sandbox_metadata_db: @@ -65,7 +65,7 @@ async def test_insert_sandbox(sandbox_metadata_db: SandboxMetadataDB): # The sandbox still hasn't been assigned async with sandbox_metadata_db: - assert not await sandbox_metadata_db.sandbox_is_assigned("SandboxSE", pfn1) + assert not await sandbox_metadata_db.sandbox_is_assigned(pfn1, "SandboxSE") # Inserting again should update the last access time await asyncio.sleep(1) # The timestamp only has second precision @@ -90,3 +90,84 @@ async def _dump_db( ) res = await sandbox_metadata_db.conn.execute(stmt) return {row.SEPFN: (row.OwnerId, row.LastAccessTime) for row in res} + + +async def test_assign_and_unsassign_sandbox_to_jobs( + sandbox_metadata_db: SandboxMetadataDB, +): + pfn = secrets.token_hex() + user_info = UserInfo( + sub="vo:sub", preferred_username="user1", dirac_group="group1", vo="vo" + ) + dummy_jobid = 666 + sandbox_se = "SandboxSE" + # Insert the sandbox + async with sandbox_metadata_db: + await sandbox_metadata_db.insert_sandbox(sandbox_se, user_info, pfn, 100) + + async with sandbox_metadata_db: + stmt = sqlalchemy.select(sb_SandBoxes.SBId, sb_SandBoxes.SEPFN) + res = await sandbox_metadata_db.conn.execute(stmt) + db_contents = {row.SEPFN: row.SBId for row in res} + sb_id_1 = db_contents[pfn] + # The sandbox still hasn't been assigned + async with sandbox_metadata_db: + assert not await sandbox_metadata_db.sandbox_is_assigned(pfn, sandbox_se) + + # Check there is no mapping + async with sandbox_metadata_db: + stmt = sqlalchemy.select( + sb_EntityMapping.SBId, sb_EntityMapping.EntityId, sb_EntityMapping.Type + ) + res = await sandbox_metadata_db.conn.execute(stmt) + db_contents = {row.SBId: (row.EntityId, row.Type) for row in res} + assert db_contents == {} + + # Assign sandbox with dummy jobid + async with sandbox_metadata_db: + await sandbox_metadata_db.assign_sandbox_to_jobs( + jobs_ids=[dummy_jobid], pfn=pfn, sb_type="Output", se_name=sandbox_se + ) + # Check if sandbox and job are mapped + async with sandbox_metadata_db: + stmt = sqlalchemy.select( + sb_EntityMapping.SBId, sb_EntityMapping.EntityId, sb_EntityMapping.Type + ) + res = await sandbox_metadata_db.conn.execute(stmt) + db_contents = {row.SBId: (row.EntityId, row.Type) for row in res} + + entity_id_1, sb_type = db_contents[sb_id_1] + assert entity_id_1 == f"Job:{dummy_jobid}" + assert sb_type == "Output" + + async with sandbox_metadata_db: + stmt = sqlalchemy.select(sb_SandBoxes.SBId, sb_SandBoxes.SEPFN) + res = await sandbox_metadata_db.conn.execute(stmt) + db_contents = {row.SEPFN: row.SBId for row in res} + sb_id_1 = db_contents[pfn] + # The sandbox should be assigned + async with sandbox_metadata_db: + assert await sandbox_metadata_db.sandbox_is_assigned(pfn, sandbox_se) + + # Unassign the sandbox to job + async with sandbox_metadata_db: + await sandbox_metadata_db.unassign_sandboxes_to_jobs([dummy_jobid]) + + # Entity should not exists anymore + async with sandbox_metadata_db: + stmt = sqlalchemy.select(sb_EntityMapping.SBId).where( + sb_EntityMapping.EntityId == entity_id_1 + ) + res = await sandbox_metadata_db.conn.execute(stmt) + entity_sb_id = [row.SBId for row in res] + assert entity_sb_id == [] + + # Should not be assigned anymore + async with sandbox_metadata_db: + assert await sandbox_metadata_db.sandbox_is_assigned(pfn, sandbox_se) is False + # Check the mapping has been deleted + async with sandbox_metadata_db: + stmt = sqlalchemy.select(sb_EntityMapping.SBId) + res = await sandbox_metadata_db.conn.execute(stmt) + res_sb_id = [row.SBId for row in res] + assert sb_id_1 not in res_sb_id diff --git a/diracx-routers/src/diracx/routers/job_manager/sandboxes.py b/diracx-routers/src/diracx/routers/job_manager/sandboxes.py index a9ff1b148..45ed35d67 100644 --- a/diracx-routers/src/diracx/routers/job_manager/sandboxes.py +++ b/diracx-routers/src/diracx/routers/job_manager/sandboxes.py @@ -2,17 +2,19 @@ import contextlib from http import HTTPStatus -from typing import TYPE_CHECKING, Annotated, AsyncIterator +from typing import TYPE_CHECKING, Annotated, AsyncIterator, Literal from aiobotocore.session import get_session from botocore.config import Config from botocore.errorfactory import ClientError -from fastapi import Depends, HTTPException, Query +from fastapi import Body, Depends, HTTPException, Query from pydantic import BaseModel, PrivateAttr +from pyparsing import Any from sqlalchemy.exc import NoResultFound from diracx.core.models import ( SandboxInfo, + SandboxType, ) from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER from diracx.core.s3 import ( @@ -104,7 +106,7 @@ async def initiate_sandbox_upload( try: exists_and_assigned = await sandbox_metadata_db.sandbox_is_assigned( - settings.se_name, pfn + pfn, settings.se_name ) except NoResultFound: # The sandbox doesn't exist in the database @@ -194,3 +196,72 @@ async def get_sandbox_file( return SandboxDownloadResponse( url=presigned_url, expires_in=settings.url_validity_seconds ) + + +@router.get("/{job_id}/sandbox") +async def get_job_sandboxes( + job_id: int, + sandbox_metadata_db: SandboxMetadataDB, +) -> dict[str, list[Any]]: + """Get input and output sandboxes of given job""" + # TODO: check that user as created the job or is admin + input_sb = await sandbox_metadata_db.get_sandbox_assigned_to_job( + job_id, SandboxType.Input + ) + output_sb = await sandbox_metadata_db.get_sandbox_assigned_to_job( + job_id, SandboxType.Output + ) + return {SandboxType.Input: input_sb, SandboxType.Output: output_sb} + + +@router.get("/{job_id}/sandbox/{sandbox_type}") +async def get_job_sandbox( + job_id: int, + sandbox_metadata_db: SandboxMetadataDB, + sandbox_type: Literal["input", "output"], +) -> list[Any]: + """Get input or output sandbox of given job""" + # TODO: check that user has created the job or is admin + job_sb_pfns = await sandbox_metadata_db.get_sandbox_assigned_to_job( + job_id, SandboxType(sandbox_type.capitalize()) + ) + + return job_sb_pfns + + +@router.patch("/{job_id}/sandbox/output") +async def assign_sandbox_to_job( + job_id: int, + pfn: Annotated[str, Body(max_length=256, pattern=SANDBOX_PFN_REGEX)], + sandbox_metadata_db: SandboxMetadataDB, + settings: SandboxStoreSettings, +) -> None: + """Mapp the pfn as output sandbox to job""" + # TODO: check that user has created the job or is admin + short_pfn = pfn.split("|", 1)[-1] + await sandbox_metadata_db.assign_sandbox_to_jobs( + jobs_ids=[job_id], + pfn=short_pfn, + sb_type=SandboxType.Output, + se_name=settings.se_name, + ) + + +@router.delete("/{job_id}/sandbox") +async def unassign_job_sandboxes( + job_id: int, + sandbox_metadata_db: SandboxMetadataDB, +) -> None: + """Delete single job sandbox mapping""" + # TODO: check that user has created the job or is admin + await sandbox_metadata_db.unassign_sandboxes_to_jobs([job_id]) + + +@router.delete("/sandbox") +async def unassign_bulk_jobs_sandboxes( + jobs_ids: Annotated[list[int], Query()], + sandbox_metadata_db: SandboxMetadataDB, +) -> None: + """Delete bulk jobs sandbox mapping""" + # TODO: check that user has created the job or is admin + await sandbox_metadata_db.unassign_sandboxes_to_jobs(jobs_ids) diff --git a/diracx-routers/tests/jobs/test_sandboxes.py b/diracx-routers/tests/jobs/test_sandboxes.py index c012b87cf..ca8a74bda 100644 --- a/diracx-routers/tests/jobs/test_sandboxes.py +++ b/diracx-routers/tests/jobs/test_sandboxes.py @@ -13,7 +13,13 @@ from diracx.routers.auth.utils import AuthSettings pytestmark = pytest.mark.enabled_dependencies( - ["AuthSettings", "SandboxMetadataDB", "SandboxStoreSettings"] + [ + "AuthSettings", + "JobDB", + "JobLoggingDB", + "SandboxMetadataDB", + "SandboxStoreSettings", + ] ) @@ -92,3 +98,87 @@ def test_upload_oversized(normal_user_client: TestClient): ) assert r.status_code == 400, r.text assert "Sandbox too large" in r.json()["detail"], r.text + + +TEST_JDL = """ + Arguments = "jobDescription.xml -o LogLevel=INFO"; + Executable = "dirac-jobexec"; + JobGroup = jobGroup; + JobName = jobName; + JobType = User; + LogLevel = INFO; + OutputSandbox = + { + Script1_CodeOutput.log, + std.err, + std.out + }; + Priority = 1; + Site = ANY; + StdError = std.err; + StdOutput = std.out; +""" + + +def test_assign_then_unassign_sandboxes_to_jobs(normal_user_client: TestClient): + data = secrets.token_bytes(512) + checksum = hashlib.sha256(data).hexdigest() + + # Upload Sandbox: + r = normal_user_client.post( + "/api/jobs/sandbox", + json={ + "checksum_algorithm": "sha256", + "checksum": checksum, + "size": len(data), + "format": "tar.bz2", + }, + ) + assert r.status_code == 200, r.text + upload_info = r.json() + assert upload_info["url"] + sandbox_pfn = upload_info["pfn"] + assert sandbox_pfn.startswith("SB:SandboxSE|/S3/") + + # Submit a job: + job_definitions = [TEST_JDL] + r = normal_user_client.post("/api/jobs/", json=job_definitions) + assert r.status_code == 200, r.json() + assert len(r.json()) == len(job_definitions) + job_id = r.json()[0]["JobID"] + + # Getting job sb: + r = normal_user_client.get(f"/api/jobs/{job_id}/sandbox/output") + assert r.status_code == 200 + # Should be empty + assert r.json()[0] is None + + # Assign sb to job: + r = normal_user_client.patch( + f"/api/jobs/{job_id}/sandbox/output", + json=sandbox_pfn, + ) + assert r.status_code == 200 + + # Get the sb again: + short_pfn = sandbox_pfn.split("|", 1)[-1] + r = normal_user_client.get(f"/api/jobs/{job_id}/sandbox") + assert r.status_code == 200 + assert r.json()["Input"] == [None] + assert r.json()["Output"] == [short_pfn] + + r = normal_user_client.get(f"/api/jobs/{job_id}/sandbox/output") + assert r.status_code == 200 + assert r.json()[0] == short_pfn + + # Unassign sb to job: + job_ids = [job_id] + r = normal_user_client.delete("/api/jobs/sandbox", params={"jobs_ids": job_ids}) + assert r.status_code == 200 + + # Get the sb again, it should'nt be there anymore: + short_pfn = sandbox_pfn.split("|", 1)[-1] + r = normal_user_client.get(f"/api/jobs/{job_id}/sandbox") + assert r.status_code == 200 + assert r.json()["Input"] == [None] + assert r.json()["Output"] == [None]