diff --git a/run_local.sh b/run_local.sh index 749fc66e..a602e074 100755 --- a/run_local.sh +++ b/run_local.sh @@ -23,6 +23,7 @@ export DIRACX_DB_URL_AUTHDB="sqlite+aiosqlite:///:memory:" export DIRACX_DB_URL_JOBDB="sqlite+aiosqlite:///:memory:" export DIRACX_DB_URL_JOBLOGGINGDB="sqlite+aiosqlite:///:memory:" export DIRACX_DB_URL_SANDBOXMETADATADB="sqlite+aiosqlite:///:memory:" +export DIRACX_DB_URL_TASKQUEUEDB="sqlite+aiosqlite:///:memory:" export DIRACX_SERVICE_AUTH_TOKEN_KEY="file://${signing_key}" export DIRACX_SERVICE_AUTH_ALLOWED_REDIRECTS='["http://'$(hostname| tr -s '[:upper:]' '[:lower:]')':8000/docs/oauth2-redirect"]' export DIRACX_SANDBOX_STORE_BUCKET_NAME=sandboxes diff --git a/setup.cfg b/setup.cfg index 01c95b0f..75823123 100644 --- a/setup.cfg +++ b/setup.cfg @@ -80,7 +80,8 @@ diracx.db.sql = JobDB = diracx.db.sql:JobDB JobLoggingDB = diracx.db.sql:JobLoggingDB SandboxMetadataDB = diracx.db.sql:SandboxMetadataDB - #DummyDB = diracx.db:DummyDB + TaskQueueDB = diracx.db.sql:TaskQueueDB + #DummyDB = diracx.db.sql:DummyDB diracx.db.os = JobParametersDB = diracx.db.os:JobParametersDB diracx.services = diff --git a/src/diracx/cli/internal/__init__.py b/src/diracx/cli/internal/__init__.py index e7c046f4..cb60035c 100644 --- a/src/diracx/cli/internal/__init__.py +++ b/src/diracx/cli/internal/__init__.py @@ -47,9 +47,7 @@ def generate_cs( DefaultGroup=user_group, Users={}, Groups={ - user_group: GroupConfig( - JobShare=None, Properties={"NormalUser"}, Quota=None, Users=set() - ) + user_group: GroupConfig(Properties={"NormalUser"}, Quota=None, Users=set()) }, ) config = Config( diff --git a/src/diracx/client/aio/operations/_operations.py b/src/diracx/client/aio/operations/_operations.py index 41506aea..49e113bc 100644 --- a/src/diracx/client/aio/operations/_operations.py +++ b/src/diracx/client/aio/operations/_operations.py @@ -35,6 +35,7 @@ build_auth_userinfo_request, build_config_serve_config_request, build_jobs_delete_bulk_jobs_request, + build_jobs_delete_single_job_request, build_jobs_get_job_status_bulk_request, build_jobs_get_job_status_history_bulk_request, build_jobs_get_sandbox_file_request, @@ -43,6 +44,9 @@ build_jobs_get_single_job_status_request, build_jobs_initiate_sandbox_upload_request, build_jobs_kill_bulk_jobs_request, + build_jobs_kill_single_job_request, + build_jobs_remove_bulk_jobs_request, + build_jobs_remove_single_job_request, build_jobs_reschedule_bulk_jobs_request, build_jobs_reschedule_single_job_request, build_jobs_search_request, @@ -1283,6 +1287,64 @@ async def kill_bulk_jobs(self, *, job_ids: List[int], **kwargs: Any) -> Any: return deserialized + @distributed_trace_async + async def remove_bulk_jobs(self, *, job_ids: List[int], **kwargs: Any) -> Any: + """Remove Bulk Jobs. + + Fully remove a list of jobs from the WMS databases. + + WARNING: This endpoint has been implemented for the compatibility with the legacy DIRAC WMS + and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should + be removed, and the delete endpoint should be used instead for any other purpose. + + :keyword job_ids: Required. + :paramtype job_ids: list[int] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[Any] = kwargs.pop("cls", None) + + request = build_jobs_remove_bulk_jobs_request( + job_ids=job_ids, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + @distributed_trace_async async def get_job_status_bulk( self, *, job_ids: List[int], **kwargs: Any @@ -1948,6 +2010,172 @@ async def get_single_job(self, job_id: int, **kwargs: Any) -> Any: return deserialized + @distributed_trace_async + async def delete_single_job(self, job_id: int, **kwargs: Any) -> Any: + """Delete Single Job. + + Delete a job by killing and setting the job status to DELETED. + + :param job_id: Required. + :type job_id: int + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[Any] = kwargs.pop("cls", None) + + request = build_jobs_delete_single_job_request( + job_id=job_id, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + + @distributed_trace_async + async def kill_single_job(self, job_id: int, **kwargs: Any) -> Any: + """Kill Single Job. + + Kill a job. + + :param job_id: Required. + :type job_id: int + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[Any] = kwargs.pop("cls", None) + + request = build_jobs_kill_single_job_request( + job_id=job_id, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + + @distributed_trace_async + async def remove_single_job(self, job_id: int, **kwargs: Any) -> Any: + """Remove Single Job. + + Fully remove a job from the WMS databases. + + WARNING: This endpoint has been implemented for the compatibility with the legacy DIRAC WMS + and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should + be removed, and the delete endpoint should be used instead. + + :param job_id: Required. + :type job_id: int + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[Any] = kwargs.pop("cls", None) + + request = build_jobs_remove_single_job_request( + job_id=job_id, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + await self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + @distributed_trace_async async def get_single_job_status( self, job_id: int, **kwargs: Any diff --git a/src/diracx/client/models/__init__.py b/src/diracx/client/models/__init__.py index 34678bc5..71619e80 100644 --- a/src/diracx/client/models/__init__.py +++ b/src/diracx/client/models/__init__.py @@ -22,6 +22,7 @@ from ._models import SandboxInfo from ._models import SandboxUploadResponse from ._models import ScalarSearchSpec +from ._models import ScalarSearchSpecValue from ._models import SetJobStatusReturn from ._models import SortSpec from ._models import SortSpecDirection @@ -32,6 +33,7 @@ from ._models import ValidationError from ._models import ValidationErrorLocItem from ._models import VectorSearchSpec +from ._models import VectorSearchSpecValues from ._enums import ChecksumAlgorithm from ._enums import Enum0 @@ -68,6 +70,7 @@ "SandboxInfo", "SandboxUploadResponse", "ScalarSearchSpec", + "ScalarSearchSpecValue", "SetJobStatusReturn", "SortSpec", "SortSpecDirection", @@ -78,6 +81,7 @@ "ValidationError", "ValidationErrorLocItem", "VectorSearchSpec", + "VectorSearchSpecValues", "ChecksumAlgorithm", "Enum0", "Enum1", diff --git a/src/diracx/client/models/_models.py b/src/diracx/client/models/_models.py index e3803575..7d997f8a 100644 --- a/src/diracx/client/models/_models.py +++ b/src/diracx/client/models/_models.py @@ -714,7 +714,7 @@ class ScalarSearchSpec(_serialization.Model): "like". :vartype operator: str or ~client.models.ScalarSearchOperator :ivar value: Value. Required. - :vartype value: str + :vartype value: ~client.models.ScalarSearchSpecValue """ _validation = { @@ -726,7 +726,7 @@ class ScalarSearchSpec(_serialization.Model): _attribute_map = { "parameter": {"key": "parameter", "type": "str"}, "operator": {"key": "operator", "type": "str"}, - "value": {"key": "value", "type": "str"}, + "value": {"key": "value", "type": "ScalarSearchSpecValue"}, } def __init__( @@ -734,7 +734,7 @@ def __init__( *, parameter: str, operator: Union[str, "_models.ScalarSearchOperator"], - value: str, + value: "_models.ScalarSearchSpecValue", **kwargs: Any ) -> None: """ @@ -744,7 +744,7 @@ def __init__( "like". :paramtype operator: str or ~client.models.ScalarSearchOperator :keyword value: Value. Required. - :paramtype value: str + :paramtype value: ~client.models.ScalarSearchSpecValue """ super().__init__(**kwargs) self.parameter = parameter @@ -752,6 +752,16 @@ def __init__( self.value = value +class ScalarSearchSpecValue(_serialization.Model): + """Value.""" + + _attribute_map = {} + + def __init__(self, **kwargs: Any) -> None: + """ """ + super().__init__(**kwargs) + + class SetJobStatusReturn(_serialization.Model): """SetJobStatusReturn. @@ -1093,7 +1103,7 @@ class VectorSearchSpec(_serialization.Model): :ivar operator: An enumeration. Required. Known values are: "in" and "not in". :vartype operator: str or ~client.models.VectorSearchOperator :ivar values: Values. Required. - :vartype values: list[str] + :vartype values: ~client.models.VectorSearchSpecValues """ _validation = { @@ -1105,7 +1115,7 @@ class VectorSearchSpec(_serialization.Model): _attribute_map = { "parameter": {"key": "parameter", "type": "str"}, "operator": {"key": "operator", "type": "str"}, - "values": {"key": "values", "type": "[str]"}, + "values": {"key": "values", "type": "VectorSearchSpecValues"}, } def __init__( @@ -1113,7 +1123,7 @@ def __init__( *, parameter: str, operator: Union[str, "_models.VectorSearchOperator"], - values: List[str], + values: "_models.VectorSearchSpecValues", **kwargs: Any ) -> None: """ @@ -1122,7 +1132,7 @@ def __init__( :keyword operator: An enumeration. Required. Known values are: "in" and "not in". :paramtype operator: str or ~client.models.VectorSearchOperator :keyword values: Values. Required. - :paramtype values: list[str] + :paramtype values: ~client.models.VectorSearchSpecValues """ super().__init__(**kwargs) self.parameter = parameter @@ -1130,6 +1140,16 @@ def __init__( self.values = values +class VectorSearchSpecValues(_serialization.Model): + """Values.""" + + _attribute_map = {} + + def __init__(self, **kwargs: Any) -> None: + """ """ + super().__init__(**kwargs) + + class VOInfo(_serialization.Model): """VOInfo. diff --git a/src/diracx/client/operations/_operations.py b/src/diracx/client/operations/_operations.py index 5cb6e97a..aa3101d6 100644 --- a/src/diracx/client/operations/_operations.py +++ b/src/diracx/client/operations/_operations.py @@ -410,6 +410,28 @@ def build_jobs_kill_bulk_jobs_request( ) +def build_jobs_remove_bulk_jobs_request( + *, job_ids: List[int], **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/jobs/remove" + + # Construct parameters + _params["job_ids"] = _SERIALIZER.query("job_ids", job_ids, "[int]") + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest( + method="POST", url=_url, params=_params, headers=_headers, **kwargs + ) + + def build_jobs_get_job_status_bulk_request( *, job_ids: List[int], **kwargs: Any ) -> HttpRequest: @@ -597,6 +619,63 @@ def build_jobs_get_single_job_request(job_id: int, **kwargs: Any) -> HttpRequest return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) +def build_jobs_delete_single_job_request(job_id: int, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/jobs/{job_id}" + path_format_arguments = { + "job_id": _SERIALIZER.url("job_id", job_id, "int"), + } + + _url: str = _format_url_section(_url, **path_format_arguments) # type: ignore + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="DELETE", url=_url, headers=_headers, **kwargs) + + +def build_jobs_kill_single_job_request(job_id: int, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/jobs/{job_id}/kill" + path_format_arguments = { + "job_id": _SERIALIZER.url("job_id", job_id, "int"), + } + + _url: str = _format_url_section(_url, **path_format_arguments) # type: ignore + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_jobs_remove_single_job_request(job_id: int, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/jobs/{job_id}/remove" + path_format_arguments = { + "job_id": _SERIALIZER.url("job_id", job_id, "int"), + } + + _url: str = _format_url_section(_url, **path_format_arguments) # type: ignore + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + def build_jobs_get_single_job_status_request(job_id: int, **kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) @@ -1889,6 +1968,64 @@ def kill_bulk_jobs(self, *, job_ids: List[int], **kwargs: Any) -> Any: return deserialized + @distributed_trace + def remove_bulk_jobs(self, *, job_ids: List[int], **kwargs: Any) -> Any: + """Remove Bulk Jobs. + + Fully remove a list of jobs from the WMS databases. + + WARNING: This endpoint has been implemented for the compatibility with the legacy DIRAC WMS + and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should + be removed, and the delete endpoint should be used instead for any other purpose. + + :keyword job_ids: Required. + :paramtype job_ids: list[int] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[Any] = kwargs.pop("cls", None) + + request = build_jobs_remove_bulk_jobs_request( + job_ids=job_ids, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + @distributed_trace def get_job_status_bulk( self, *, job_ids: List[int], **kwargs: Any @@ -2552,6 +2689,172 @@ def get_single_job(self, job_id: int, **kwargs: Any) -> Any: return deserialized + @distributed_trace + def delete_single_job(self, job_id: int, **kwargs: Any) -> Any: + """Delete Single Job. + + Delete a job by killing and setting the job status to DELETED. + + :param job_id: Required. + :type job_id: int + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[Any] = kwargs.pop("cls", None) + + request = build_jobs_delete_single_job_request( + job_id=job_id, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + + @distributed_trace + def kill_single_job(self, job_id: int, **kwargs: Any) -> Any: + """Kill Single Job. + + Kill a job. + + :param job_id: Required. + :type job_id: int + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[Any] = kwargs.pop("cls", None) + + request = build_jobs_kill_single_job_request( + job_id=job_id, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + + @distributed_trace + def remove_single_job(self, job_id: int, **kwargs: Any) -> Any: + """Remove Single Job. + + Fully remove a job from the WMS databases. + + WARNING: This endpoint has been implemented for the compatibility with the legacy DIRAC WMS + and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should + be removed, and the delete endpoint should be used instead. + + :param job_id: Required. + :type job_id: int + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[Any] = kwargs.pop("cls", None) + + request = build_jobs_remove_single_job_request( + job_id=job_id, + headers=_headers, + params=_params, + ) + request.url = self._client.format_url(request.url) + + _stream = False + pipeline_response: PipelineResponse = ( + self._client._pipeline.run( # pylint: disable=protected-access + request, stream=_stream, **kwargs + ) + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error( + status_code=response.status_code, response=response, error_map=error_map + ) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response) + + if cls: + return cls(pipeline_response, deserialized, {}) + + return deserialized + @distributed_trace def get_single_job_status( self, job_id: int, **kwargs: Any diff --git a/src/diracx/core/config/schema.py b/src/diracx/core/config/schema.py index 57d1f302..08b8dd8c 100644 --- a/src/diracx/core/config/schema.py +++ b/src/diracx/core/config/schema.py @@ -49,7 +49,7 @@ class GroupConfig(BaseModel): AutoAddVOMS: bool = False AutoUploadPilotProxy: bool = False AutoUploadProxy: bool = False - JobShare: Optional[int] + JobShare: int = 1000 Properties: set[SecurityProperty] Quota: Optional[int] Users: set[str] @@ -103,9 +103,14 @@ class JobMonitoringConfig(BaseModel): useESForJobParametersFlag: bool = False +class JobSchedulingConfig(BaseModel): + EnableSharesCorrection: bool = False + + class ServicesConfig(BaseModel): Catalogs: dict[str, Any] | None JobMonitoring: JobMonitoringConfig = JobMonitoringConfig() + JobScheduling: JobSchedulingConfig = JobSchedulingConfig() class OperationsConfig(BaseModel): diff --git a/src/diracx/core/exceptions.py b/src/diracx/core/exceptions.py index 4ca6eaa9..efb00f4c 100644 --- a/src/diracx/core/exceptions.py +++ b/src/diracx/core/exceptions.py @@ -36,3 +36,9 @@ class BadConfigurationVersion(ConfigurationError): class InvalidQueryError(DiracError): """It was not possible to build a valid database query from the given input""" + + +class JobNotFound(Exception): + def __init__(self, job_id: int): + self.job_id: int = job_id + super().__init__(f"Job {job_id} not found") diff --git a/src/diracx/core/models.py b/src/diracx/core/models.py index 20668652..0bc225d4 100644 --- a/src/diracx/core/models.py +++ b/src/diracx/core/models.py @@ -29,13 +29,13 @@ class SortSpec(TypedDict): class ScalarSearchSpec(TypedDict): parameter: str operator: ScalarSearchOperator - value: str + value: str | int class VectorSearchSpec(TypedDict): parameter: str operator: VectorSearchOperator - values: list[str] + values: list[str] | list[int] SearchSpec = ScalarSearchSpec | VectorSearchSpec diff --git a/src/diracx/db/__main__.py b/src/diracx/db/__main__.py index b79e0281..da36eace 100644 --- a/src/diracx/db/__main__.py +++ b/src/diracx/db/__main__.py @@ -35,6 +35,9 @@ async def init_sql(): db = BaseSQLDB.available_implementations(db_name)[0](db_url) async with db.engine_context(): async with db.engine.begin() as conn: + # set PRAGMA foreign_keys=ON if sqlite + if db._db_url.startswith("sqlite"): + await conn.exec_driver_sql("PRAGMA foreign_keys=ON") await conn.run_sync(db.metadata.create_all) diff --git a/src/diracx/db/sql/__init__.py b/src/diracx/db/sql/__init__.py index 17b542a5..582509b1 100644 --- a/src/diracx/db/sql/__init__.py +++ b/src/diracx/db/sql/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations -__all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB") +__all__ = ("AuthDB", "JobDB", "JobLoggingDB", "SandboxMetadataDB", "TaskQueueDB") from .auth.db import AuthDB -from .jobs.db import JobDB, JobLoggingDB +from .jobs.db import JobDB, JobLoggingDB, TaskQueueDB from .sandbox_metadata.db import SandboxMetadataDB diff --git a/src/diracx/db/sql/jobs/db.py b/src/diracx/db/sql/jobs/db.py index 329dd2bc..21635ea2 100644 --- a/src/diracx/db/sql/jobs/db.py +++ b/src/diracx/db/sql/jobs/db.py @@ -6,9 +6,9 @@ from typing import Any from sqlalchemy import delete, func, insert, select, update -from sqlalchemy.exc import NoResultFound +from sqlalchemy.exc import IntegrityError, NoResultFound -from diracx.core.exceptions import InvalidQueryError +from diracx.core.exceptions import InvalidQueryError, JobNotFound from diracx.core.models import ( JobMinorStatus, JobStatus, @@ -17,15 +17,26 @@ ScalarSearchOperator, ScalarSearchSpec, ) +from diracx.core.properties import JOB_SHARING, SecurityProperty from ..utils import BaseSQLDB, apply_search_filters from .schema import ( + BannedSitesQueue, + GridCEsQueue, InputData, + JobCommands, JobDBBase, JobJDLs, JobLoggingDBBase, Jobs, + JobsQueue, + JobTypesQueue, LoggingInfo, + PlatformsQueue, + SitesQueue, + TagsQueue, + TaskQueueDBBase, + TaskQueues, ) @@ -433,12 +444,35 @@ async def rescheduleJob(self, job_id) -> dict[str, Any]: return retVal async def get_job_status(self, job_id: int) -> LimitedJobStatusReturn: - stmt = select(Jobs.Status, Jobs.MinorStatus, Jobs.ApplicationStatus).where( - Jobs.JobID == job_id - ) - return LimitedJobStatusReturn( - **dict((await self.conn.execute(stmt)).one()._mapping) - ) + try: + stmt = select(Jobs.Status, Jobs.MinorStatus, Jobs.ApplicationStatus).where( + Jobs.JobID == job_id + ) + return LimitedJobStatusReturn( + **dict((await self.conn.execute(stmt)).one()._mapping) + ) + except NoResultFound as e: + raise JobNotFound(job_id) from e + + async def set_job_command(self, job_id: int, command: str, arguments: str = ""): + """Store a command to be passed to the job together with the next heart beat""" + try: + stmt = insert(JobCommands).values( + JobID=job_id, + Command=command, + Arguments=arguments, + ReceptionTime=datetime.now(tz=timezone.utc), + ) + await self.conn.execute(stmt) + except IntegrityError as e: + raise JobNotFound(job_id) from e + + async def delete_jobs(self, job_ids: list[int]): + """ + Delete jobs from the database + """ + stmt = delete(JobJDLs).where(JobJDLs.JobID.in_(job_ids)) + await self.conn.execute(stmt) MAGIC_EPOC_NUMBER = 1270000000 @@ -560,10 +594,9 @@ async def get_records(self, job_id: int) -> list[JobStatusReturn]: return res - async def delete_records(self, job_id: int): + async def delete_records(self, job_ids: list[int]): """Delete logging records for given jobs""" - - stmt = delete(LoggingInfo).where(LoggingInfo.JobID == job_id) + stmt = delete(LoggingInfo).where(LoggingInfo.JobID.in_(job_ids)) await self.conn.execute(stmt) async def get_wms_time_stamps(self, job_id): @@ -578,9 +611,259 @@ async def get_wms_time_stamps(self, job_id): ).where(LoggingInfo.JobID == job_id) rows = await self.conn.execute(stmt) if not rows.rowcount: - raise NoResultFound(f"No Logging Info for job {job_id}") + raise JobNotFound(job_id) from None for event, etime in rows: result[event] = str(etime + MAGIC_EPOC_NUMBER) return result + + +class TaskQueueDB(BaseSQLDB): + metadata = TaskQueueDBBase.metadata + + async def get_tq_infos_for_jobs( + self, job_ids: list[int] + ) -> set[tuple[int, str, str, str]]: + """ + Get the task queue info for given jobs + """ + stmt = select( + TaskQueues.TQId, TaskQueues.Owner, TaskQueues.OwnerGroup, TaskQueues.VO + ).where(JobsQueue.JobId.in_(job_ids)) + return set( + (int(row[0]), str(row[1]), str(row[2]), str(row[3])) + for row in (await self.conn.execute(stmt)).all() + ) + + async def get_owner_for_task_queue(self, tq_id: int) -> dict[str, str]: + """ + Get the owner and owner group for a task queue + """ + stmt = select(TaskQueues.Owner, TaskQueues.OwnerGroup, TaskQueues.VO).where( + TaskQueues.TQId == tq_id + ) + return dict((await self.conn.execute(stmt)).one()._mapping) + + async def remove_job(self, job_id: int): + """ + Remove a job from the task queues + """ + stmt = delete(JobsQueue).where(JobsQueue.JobId == job_id) + await self.conn.execute(stmt) + + async def remove_jobs(self, job_ids: list[int]): + """ + Remove jobs from the task queues + """ + stmt = delete(JobsQueue).where(JobsQueue.JobId.in_(job_ids)) + await self.conn.execute(stmt) + + async def delete_task_queue_if_empty( + self, + tq_id: int, + tq_owner: str, + tq_group: str, + job_share: int, + group_properties: set[SecurityProperty], + enable_shares_correction: bool, + allow_background_tqs: bool, + ): + """ + Try to delete a task queue if it's empty + """ + # Check if the task queue is empty + stmt = ( + select(TaskQueues.TQId) + .where(TaskQueues.Enabled >= 1) + .where(TaskQueues.TQId == tq_id) + .where(~TaskQueues.TQId.in_(select(JobsQueue.TQId))) + ) + rows = await self.conn.execute(stmt) + if not rows.rowcount: + return + + # Deleting the task queue (the other tables will be deleted in cascade) + stmt = delete(TaskQueues).where(TaskQueues.TQId == tq_id) + await self.conn.execute(stmt) + + await self.recalculate_tq_shares_for_entity( + tq_owner, + tq_group, + job_share, + group_properties, + enable_shares_correction, + allow_background_tqs, + ) + + async def recalculate_tq_shares_for_entity( + self, + owner: str, + group: str, + job_share: int, + group_properties: set[SecurityProperty], + enable_shares_correction: bool, + allow_background_tqs: bool, + ): + """ + Recalculate the shares for a user/userGroup combo + """ + if JOB_SHARING in group_properties: + # If group has JobSharing just set prio for that entry, user is irrelevant + return await self.__set_priorities_for_entity( + owner, group, job_share, group_properties, allow_background_tqs + ) + + stmt = ( + select(TaskQueues.Owner, func.count(TaskQueues.Owner)) + .where(TaskQueues.OwnerGroup == group) + .group_by(TaskQueues.Owner) + ) + rows = await self.conn.execute(stmt) + # make the rows a list of tuples + # Get owners in this group and the amount of times they appear + # TODO: I guess the rows are already a list of tupes + # maybe refactor + data = [(r[0], r[1]) for r in rows if r] + numOwners = len(data) + # If there are no owners do now + if numOwners == 0: + return + # Split the share amongst the number of owners + entities_shares = {row[0]: job_share / numOwners for row in data} + + # TODO: implement the following + # If corrector is enabled let it work it's magic + # if enable_shares_correction: + # entities_shares = await self.__shares_corrector.correct_shares( + # entitiesShares, group=group + # ) + + # Keep updating + owners = dict(data) + # IF the user is already known and has more than 1 tq, the rest of the users don't need to be modified + # (The number of owners didn't change) + if owner in owners and owners[owner] > 1: + await self.__set_priorities_for_entity( + owner, + group, + entities_shares[owner], + group_properties, + allow_background_tqs, + ) + return + # Oops the number of owners may have changed so we recalculate the prio for all owners in the group + for owner in owners: + await self.__set_priorities_for_entity( + owner, + group, + entities_shares[owner], + group_properties, + allow_background_tqs, + ) + + async def __set_priorities_for_entity( + self, + owner: str, + group: str, + share, + properties: set[SecurityProperty], + allow_background_tqs: bool, + ): + """ + Set the priority for a user/userGroup combo given a splitted share + """ + from DIRAC.WorkloadManagementSystem.DB.TaskQueueDB import calculate_priority + + stmt = ( + select( + TaskQueues.TQId, + func.sum(JobsQueue.RealPriority) / func.count(JobsQueue.RealPriority), + ) + .join(JobsQueue, TaskQueues.TQId == JobsQueue.TQId) + .where(TaskQueues.OwnerGroup == group) + .group_by(TaskQueues.TQId) + ) + if JOB_SHARING not in properties: + stmt = stmt.where(TaskQueues.Owner == owner) + rows = await self.conn.execute(stmt) + tq_dict: dict[int, float] = {tq_id: priority for tq_id, priority in rows} + + if not tq_dict: + return + + rows = await self.retrieve_task_queues(list(tq_dict)) + + prio_dict = calculate_priority(tq_dict, rows, share, allow_background_tqs) + + # Execute updates + for prio, tqs in prio_dict.items(): + update_stmt = ( + update(TaskQueues).where(TaskQueues.TQId.in_(tqs)).values(Priority=prio) + ) + await self.conn.execute(update_stmt) + + async def retrieve_task_queues(self, tq_id_list=None): + """ + Get all the task queues + """ + if tq_id_list is not None and not tq_id_list: + # Empty list => Fast-track no matches + return {} + + stmt = ( + select( + TaskQueues.TQId, + TaskQueues.Priority, + func.count(JobsQueue.TQId).label("Jobs"), + TaskQueues.Owner, + TaskQueues.OwnerGroup, + TaskQueues.VO, + TaskQueues.CPUTime, + ) + .join(JobsQueue, TaskQueues.TQId == JobsQueue.TQId) + .join(SitesQueue, TaskQueues.TQId == SitesQueue.TQId) + .join(GridCEsQueue, TaskQueues.TQId == GridCEsQueue.TQId) + .group_by( + TaskQueues.TQId, + TaskQueues.Priority, + TaskQueues.Owner, + TaskQueues.OwnerGroup, + TaskQueues.VO, + TaskQueues.CPUTime, + ) + ) + if tq_id_list is not None: + stmt = stmt.where(TaskQueues.TQId.in_(tq_id_list)) + + tq_data: dict[int, dict[str, list[str]]] = dict( + dict(row._mapping) for row in await self.conn.execute(stmt) + ) + # TODO: the line above should be equivalent to the following commented code, check this is the case + # for record in rows: + # tqId = record[0] + # tqData[tqId] = { + # "Priority": record[1], + # "Jobs": record[2], + # "Owner": record[3], + # "OwnerGroup": record[4], + # "VO": record[5], + # "CPUTime": record[6], + # } + + for tq_id in tq_data: + # TODO: maybe factorize this handy tuple list + for table, field in { + (SitesQueue, "Sites"), + (GridCEsQueue, "GridCEs"), + (BannedSitesQueue, "BannedSites"), + (PlatformsQueue, "Platforms"), + (JobTypesQueue, "JobTypes"), + (TagsQueue, "Tags"), + }: + stmt = select(table.Value).where(table.TQId == tq_id) + tq_data[tq_id][field] = list( + row[0] for row in await self.conn.execute(stmt) + ) + + return tq_data diff --git a/src/diracx/db/sql/jobs/schema.py b/src/diracx/db/sql/jobs/schema.py index 038513eb..111e3aca 100644 --- a/src/diracx/db/sql/jobs/schema.py +++ b/src/diracx/db/sql/jobs/schema.py @@ -1,9 +1,11 @@ import sqlalchemy.types as types from sqlalchemy import ( + BigInteger, + Boolean, DateTime, Enum, + Float, ForeignKey, - ForeignKeyConstraint, Index, Integer, Numeric, @@ -17,6 +19,7 @@ JobDBBase = declarative_base() JobLoggingDBBase = declarative_base() +TaskQueueDBBase = declarative_base() class EnumBackedBool(types.TypeDecorator): @@ -45,19 +48,16 @@ def process_result_value(self, value, dialect) -> bool: raise NotImplementedError(f"Unknown {value=}") -class JobJDLs(JobDBBase): - __tablename__ = "JobJDLs" - JobID = Column(Integer, autoincrement=True) - JDL = Column(Text) - JobRequirements = Column(Text) - OriginalJDL = Column(Text) - __table_args__ = (PrimaryKeyConstraint("JobID"),) - - class Jobs(JobDBBase): __tablename__ = "Jobs" - JobID = Column("JobID", Integer, primary_key=True, default=0) + JobID = Column( + "JobID", + Integer, + ForeignKey("JobJDLs.JobID", ondelete="CASCADE"), + primary_key=True, + default=0, + ) JobType = Column("JobType", String(32), default="user") JobGroup = Column("JobGroup", String(32), default="00000000") JobSplitType = Column( @@ -95,7 +95,6 @@ class Jobs(JobDBBase): ) __table_args__ = ( - ForeignKeyConstraint(["JobID"], ["JobJDLs.JobID"]), Index("JobType", "JobType"), Index("JobGroup", "JobGroup"), Index("JobSplitType", "JobSplitType"), @@ -110,33 +109,46 @@ class Jobs(JobDBBase): ) +class JobJDLs(JobDBBase): + __tablename__ = "JobJDLs" + JobID = Column(Integer, autoincrement=True, primary_key=True) + JDL = Column(Text) + JobRequirements = Column(Text) + OriginalJDL = Column(Text) + + class InputData(JobDBBase): __tablename__ = "InputData" - JobID = Column(Integer, primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) LFN = Column(String(255), default="", primary_key=True) Status = Column(String(32), default="AprioriGood") - __table_args__ = (ForeignKeyConstraint(["JobID"], ["Jobs.JobID"]),) class JobParameters(JobDBBase): __tablename__ = "JobParameters" - JobID = Column(Integer, primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) Name = Column(String(100), primary_key=True) Value = Column(Text) - __table_args__ = (ForeignKeyConstraint(["JobID"], ["Jobs.JobID"]),) class OptimizerParameters(JobDBBase): __tablename__ = "OptimizerParameters" - JobID = Column(Integer, primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) Name = Column(String(100), primary_key=True) Value = Column(Text) - __table_args__ = (ForeignKeyConstraint(["JobID"], ["Jobs.JobID"]),) class AtticJobParameters(JobDBBase): __tablename__ = "AtticJobParameters" - JobID = Column(Integer, ForeignKey("Jobs.JobID"), primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) Name = Column(String(100), primary_key=True) Value = Column(Text) RescheduleCycle = Column(Integer) @@ -162,25 +174,25 @@ class SiteMaskLogging(JobDBBase): class HeartBeatLoggingInfo(JobDBBase): __tablename__ = "HeartBeatLoggingInfo" - JobID = Column(Integer, primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) Name = Column(String(100), primary_key=True) Value = Column(Text) HeartBeatTime = Column(DateTime, primary_key=True) - __table_args__ = (ForeignKeyConstraint(["JobID"], ["Jobs.JobID"]),) - class JobCommands(JobDBBase): __tablename__ = "JobCommands" - JobID = Column(Integer, primary_key=True) + JobID = Column( + Integer, ForeignKey("Jobs.JobID", ondelete="CASCADE"), primary_key=True + ) Command = Column(String(100)) Arguments = Column(String(100)) Status = Column(String(64), default="Received") ReceptionTime = Column(DateTime, primary_key=True) ExecutionTime = NullColumn(DateTime) - __table_args__ = (ForeignKeyConstraint(["JobID"], ["Jobs.JobID"]),) - class LoggingInfo(JobLoggingDBBase): __tablename__ = "LoggingInfo" @@ -194,3 +206,99 @@ class LoggingInfo(JobLoggingDBBase): StatusTimeOrder = Column(Numeric(precision=12, scale=3), default=0) StatusSource = Column(String(32), default="Unknown") __table_args__ = (PrimaryKeyConstraint("JobID", "SeqNum"),) + + +class TaskQueues(TaskQueueDBBase): + __tablename__ = "tq_TaskQueues" + TQId = Column(Integer, primary_key=True) + Owner = Column(String(255), nullable=False) + OwnerDN = Column(String(255)) + OwnerGroup = Column(String(32), nullable=False) + VO = Column(String(32), nullable=False) + CPUTime = Column(BigInteger, nullable=False) + Priority = Column(Float, nullable=False) + Enabled = Column(Boolean, nullable=False, default=0) + __table_args__ = (Index("TQOwner", "Owner", "OwnerGroup", "CPUTime"),) + + +class JobsQueue(TaskQueueDBBase): + __tablename__ = "tq_Jobs" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + JobId = Column(Integer, primary_key=True) + Priority = Column(Integer, nullable=False) + RealPriority = Column(Float, nullable=False) + __table_args__ = (Index("TaskIndex", "TQId"),) + + +class SitesQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToSites" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("SitesTaskIndex", "TQId"), + Index("SitesIndex", "Value"), + ) + + +class GridCEsQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToGridCEs" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("GridCEsTaskIndex", "TQId"), + Index("GridCEsValueIndex", "Value"), + ) + + +class BannedSitesQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToBannedSites" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("BannedSitesTaskIndex", "TQId"), + Index("BannedSitesValueIndex", "Value"), + ) + + +class PlatformsQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToPlatforms" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("PlatformsTaskIndex", "TQId"), + Index("PlatformsValueIndex", "Value"), + ) + + +class JobTypesQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToJobTypes" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("JobTypesTaskIndex", "TQId"), + Index("JobTypesValueIndex", "Value"), + ) + + +class TagsQueue(TaskQueueDBBase): + __tablename__ = "tq_TQToTags" + TQId = Column( + Integer, ForeignKey("tq_TaskQueues.TQId", ondelete="CASCADE"), primary_key=True + ) + Value = Column(String(64), primary_key=True) + __table_args__ = ( + Index("TagsTaskIndex", "TQId"), + Index("TagsValueIndex", "Value"), + ) diff --git a/src/diracx/db/sql/jobs/status_utility.py b/src/diracx/db/sql/jobs/status_utility.py index f7f20997..f2ba43f4 100644 --- a/src/diracx/db/sql/jobs/status_utility.py +++ b/src/diracx/db/sql/jobs/status_utility.py @@ -1,15 +1,19 @@ +import asyncio from datetime import datetime, timezone from unittest.mock import MagicMock -from sqlalchemy.exc import NoResultFound +from fastapi import BackgroundTasks +from diracx.core.config.schema import Config +from diracx.core.exceptions import JobNotFound from diracx.core.models import ( JobStatus, JobStatusUpdate, ScalarSearchOperator, SetJobStatusReturn, ) -from diracx.db.sql.jobs.db import JobDB, JobLoggingDB +from diracx.db.sql.jobs.db import JobDB, JobLoggingDB, TaskQueueDB +from diracx.db.sql.sandbox_metadata.db import SandboxMetadataDB async def set_job_status( @@ -21,8 +25,10 @@ async def set_job_status( ) -> SetJobStatusReturn: """Set various status fields for job specified by its jobId. Set only the last status in the JobDB, updating all the status - logging information in the JobLoggingDB. The statusDict has datetime + logging information in the JobLoggingDB. The status dict has datetime as a key and status information dictionary as values + + :raises: JobNotFound if the job is not found in one of the DBs """ from DIRAC.Core.Utilities import TimeUtilities @@ -49,7 +55,7 @@ async def set_job_status( sorts=[], ) if not res: - raise NoResultFound(f"Job {job_id} not found") + raise JobNotFound(job_id) from None currentStatus = res[0]["Status"] startTime = res[0]["StartExecTime"] @@ -60,10 +66,7 @@ async def set_job_status( currentStatus = JobStatus.RUNNING # Get the latest time stamps of major status updates - try: - result = await job_logging_db.get_wms_time_stamps(job_id) - except NoResultFound as e: - raise e + result = await job_logging_db.get_wms_time_stamps(job_id) ##################################################################################################### @@ -146,3 +149,167 @@ async def set_job_status( ) return SetJobStatusReturn(**job_data) + + +class ForgivingTaskGroup(asyncio.TaskGroup): + # Hacky way, check https://stackoverflow.com/questions/75250788/how-to-prevent-python3-11-taskgroup-from-canceling-all-the-tasks + # Basically e're using this because we want to wait for all tasks to finish, even if one of them raises an exception + def _abort(self): + return None + + +async def delete_jobs( + job_ids: list[int], + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + "Delete" jobs by removing them from the task queues, set kill as a job command setting the job status to DELETED. + :raises: BaseExceptionGroup[JobNotFound] for every job that was not found + """ + + await _remove_jobs_from_task_queue(job_ids, config, task_queue_db, background_task) + # TODO: implement StorageManagerClient + # returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID(job_ids)) + + async with ForgivingTaskGroup() as task_group: + for job_id in job_ids: + task_group.create_task(job_db.set_job_command(job_id, "Kill")) + + task_group.create_task( + set_job_status( + job_id, + { + datetime.now(timezone.utc): JobStatusUpdate( + Status=JobStatus.DELETED, + MinorStatus="Checking accounting", + StatusSource="job_manager", + ) + }, + job_db, + job_logging_db, + force=True, + ) + ) + + +async def kill_jobs( + job_ids: list[int], + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Kill jobs by removing them from the task queues, set kill as a job command and setting the job status to KILLED. + :raises: BaseExceptionGroup[JobNotFound] for every job that was not found + """ + await _remove_jobs_from_task_queue(job_ids, config, task_queue_db, background_task) + # TODO: implement StorageManagerClient + # returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID(job_ids)) + + async with ForgivingTaskGroup() as task_group: + for job_id in job_ids: + task_group.create_task(job_db.set_job_command(job_id, "Kill")) + task_group.create_task( + set_job_status( + job_id, + { + datetime.now(timezone.utc): JobStatusUpdate( + Status=JobStatus.KILLED, + MinorStatus="Marked for termination", + StatusSource="job_manager", + ) + }, + job_db, + job_logging_db, + force=True, + ) + ) + + # TODO: Consider using the code below instead, probably more stable but less performant + # errors = [] + # for job_id in job_ids: + # try: + # await job_db.set_job_command(job_id, "Kill") + # await set_job_status( + # job_id, + # { + # datetime.now(timezone.utc): JobStatusUpdate( + # Status=JobStatus.KILLED, + # MinorStatus="Marked for termination", + # StatusSource="job_manager", + # ) + # }, + # job_db, + # job_logging_db, + # force=True, + # ) + # except JobNotFound as e: + # errors.append(e) + + # if errors: + # raise BaseExceptionGroup("Some job ids were not found", errors) + + +async def remove_jobs( + job_ids: list[int], + config: Config, + job_db: JobDB, + job_logging_db: JobLoggingDB, + sandbox_metadata_db: SandboxMetadataDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Fully remove a job from the WMS databases. + :raises: nothing + """ + + # Remove the staging task from the StorageManager + # TODO: this was not done in the JobManagerHandler, but it was done in the kill method + # I think it should be done here too + # TODO: implement StorageManagerClient + # returnValueOrRaise(StorageManagerClient().killTasksBySourceTaskID([job_id])) + + # TODO: this was also not done in the JobManagerHandler, but it was done in the JobCleaningAgent + # I think it should be done here as well + await sandbox_metadata_db.unassign_sandbox_from_jobs(job_ids) + + # Remove the job from TaskQueueDB + await _remove_jobs_from_task_queue(job_ids, config, task_queue_db, background_task) + + # Remove the job from JobLoggingDB + await job_logging_db.delete_records(job_ids) + + # Remove the job from JobDB + await job_db.delete_jobs(job_ids) + + +async def _remove_jobs_from_task_queue( + job_ids: list[int], + config: Config, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Remove the job from TaskQueueDB + """ + tq_infos = await task_queue_db.get_tq_infos_for_jobs(job_ids) + await task_queue_db.remove_jobs(job_ids) + for tq_id, owner, owner_group, vo in tq_infos: + # TODO: move to Celery + background_task.add_task( + task_queue_db.delete_task_queue_if_empty, + tq_id, + owner, + owner_group, + config.Registry[vo].Groups[owner_group].JobShare, + config.Registry[vo].Groups[owner_group].Properties, + config.Operations[vo].Services.JobScheduling.EnableSharesCorrection, + config.Registry[vo].Groups[owner_group].AllowBackgroundTQs, + ) diff --git a/src/diracx/db/sql/sandbox_metadata/db.py b/src/diracx/db/sql/sandbox_metadata/db.py index 8a5a4252..bfe39258 100644 --- a/src/diracx/db/sql/sandbox_metadata/db.py +++ b/src/diracx/db/sql/sandbox_metadata/db.py @@ -1,12 +1,13 @@ from __future__ import annotations import sqlalchemy +from sqlalchemy import delete from diracx.core.models import SandboxInfo, UserInfo from diracx.db.sql.utils import BaseSQLDB, utcnow from .schema import Base as SandboxMetadataDBBase -from .schema import sb_Owners, sb_SandBoxes +from .schema import sb_EntityMapping, sb_Owners, sb_SandBoxes class SandboxMetadataDB(BaseSQLDB): @@ -82,3 +83,13 @@ async def sandbox_is_assigned(self, se_name: str, pfn: str) -> bool: result = await self.conn.execute(stmt) is_assigned = result.scalar_one() return is_assigned + return True + + async def unassign_sandbox_from_jobs(self, job_ids: list[int]): + """ + Unassign sandbox from jobs + """ + stmt = delete(sb_EntityMapping).where( + sb_EntityMapping.EntityId.in_(f"Job:{job_id}" for job_id in job_ids) + ) + await self.conn.execute(stmt) diff --git a/src/diracx/routers/dependencies.py b/src/diracx/routers/dependencies.py index deb55167..7a67b94f 100644 --- a/src/diracx/routers/dependencies.py +++ b/src/diracx/routers/dependencies.py @@ -5,6 +5,8 @@ "AuthDB", "JobDB", "JobLoggingDB", + "SandboxMetadataDB", + "TaskQueueDB", "add_settings_annotation", "AvailableSecurityProperties", ) @@ -20,6 +22,7 @@ from diracx.db.sql import JobDB as _JobDB from diracx.db.sql import JobLoggingDB as _JobLoggingDB from diracx.db.sql import SandboxMetadataDB as _SandboxMetadataDB +from diracx.db.sql import TaskQueueDB as _TaskQueueDB T = TypeVar("T") @@ -36,6 +39,7 @@ def add_settings_annotation(cls: T) -> T: SandboxMetadataDB = Annotated[ _SandboxMetadataDB, Depends(_SandboxMetadataDB.transaction) ] +TaskQueueDB = Annotated[_TaskQueueDB, Depends(_TaskQueueDB.transaction)] # Miscellaneous Config = Annotated[_Config, Depends(ConfigSource.create)] diff --git a/src/diracx/routers/job_manager/__init__.py b/src/diracx/routers/job_manager/__init__.py index baa0c134..3766875d 100644 --- a/src/diracx/routers/job_manager/__init__.py +++ b/src/diracx/routers/job_manager/__init__.py @@ -6,11 +6,12 @@ from http import HTTPStatus from typing import Annotated, Any, TypedDict -from fastapi import Body, Depends, HTTPException, Query +from fastapi import BackgroundTasks, Body, Depends, HTTPException, Query from pydantic import BaseModel, root_validator from sqlalchemy.exc import NoResultFound from diracx.core.config import Config, ConfigSource +from diracx.core.exceptions import JobNotFound from diracx.core.models import ( JobStatus, JobStatusReturn, @@ -23,11 +24,14 @@ ) from diracx.core.properties import JOB_ADMINISTRATOR, NORMAL_USER from diracx.db.sql.jobs.status_utility import ( + delete_jobs, + kill_jobs, + remove_jobs, set_job_status, ) from ..auth import AuthorizedUserInfo, has_properties, verify_dirac_access_token -from ..dependencies import JobDB, JobLoggingDB +from ..dependencies import JobDB, JobLoggingDB, SandboxMetadataDB, TaskQueueDB from ..fastapi_classes import DiracxRouter from .sandboxes import router as sandboxes_router @@ -236,14 +240,108 @@ def __init__(self, user_info: AuthorizedUserInfo, allInfo: bool = True): @router.delete("/") -async def delete_bulk_jobs(job_ids: Annotated[list[int], Query()]): +async def delete_bulk_jobs( + job_ids: Annotated[list[int], Query()], + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + # TODO: implement job policy + + try: + await delete_jobs( + job_ids, + config, + job_db, + job_logging_db, + task_queue_db, + background_task, + ) + except* JobNotFound as group_exc: + failed_job_ids: list[int] = list({e.job_id for e in group_exc.exceptions}) # type: ignore + + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail={ + "message": f"Failed to delete {len(failed_job_ids)} jobs out of {len(job_ids)}", + "valid_job_ids": list(set(job_ids) - set(failed_job_ids)), + "failed_job_ids": failed_job_ids, + }, + ) from group_exc + return job_ids @router.post("/kill") async def kill_bulk_jobs( job_ids: Annotated[list[int], Query()], + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + # TODO: implement job policy + try: + await kill_jobs( + job_ids, + config, + job_db, + job_logging_db, + task_queue_db, + background_task, + ) + except* JobNotFound as group_exc: + failed_job_ids: list[int] = list({e.job_id for e in group_exc.exceptions}) # type: ignore + + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, + detail={ + "message": f"Failed to kill {len(failed_job_ids)} jobs out of {len(job_ids)}", + "valid_job_ids": list(set(job_ids) - set(failed_job_ids)), + "failed_job_ids": failed_job_ids, + }, + ) from group_exc + + return job_ids + + +@router.post("/remove") +async def remove_bulk_jobs( + job_ids: Annotated[list[int], Query()], + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + sandbox_metadata_db: SandboxMetadataDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, ): + """ + Fully remove a list of jobs from the WMS databases. + + + WARNING: This endpoint has been implemented for the compatibility with the legacy DIRAC WMS + and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should + be removed, and the delete endpoint should be used instead for any other purpose. + """ + # TODO: Remove once legacy DIRAC no longer needs this + + # TODO: implement job policy + # Some tests have already been written in the test_job_manager, + # but they need to be uncommented and are not complete + + await remove_jobs( + job_ids, + config, + job_db, + job_logging_db, + sandbox_metadata_db, + task_queue_db, + background_task, + ) + return job_ids @@ -256,7 +354,7 @@ async def get_job_status_bulk( *(job_db.get_job_status(job_id) for job_id in job_ids) ) return {job_id: status for job_id, status in zip(job_ids, result)} - except NoResultFound as e: + except JobNotFound as e: raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e @@ -478,13 +576,105 @@ async def get_single_job(job_id: int): return f"This job {job_id}" +@router.delete("/{job_id}") +async def delete_single_job( + job_id: int, + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Delete a job by killing and setting the job status to DELETED. + """ + + # TODO: implement job policy + try: + await delete_jobs( + [job_id], + config, + job_db, + job_logging_db, + task_queue_db, + background_task, + ) + except* JobNotFound as e: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND.value, detail=str(e.exceptions[0]) + ) from e + + return f"Job {job_id} has been successfully deleted" + + +@router.post("/{job_id}/kill") +async def kill_single_job( + job_id: int, + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Kill a job. + """ + + # TODO: implement job policy + + try: + await kill_jobs( + [job_id], config, job_db, job_logging_db, task_queue_db, background_task + ) + except* JobNotFound as e: + raise HTTPException( + status_code=HTTPStatus.NOT_FOUND, detail=str(e.exceptions[0]) + ) from e + + return f"Job {job_id} has been successfully killed" + + +@router.post("/{job_id}/remove") +async def remove_single_job( + job_id: int, + config: Annotated[Config, Depends(ConfigSource.create)], + job_db: JobDB, + job_logging_db: JobLoggingDB, + sandbox_metadata_db: SandboxMetadataDB, + task_queue_db: TaskQueueDB, + background_task: BackgroundTasks, +): + """ + Fully remove a job from the WMS databases. + + WARNING: This endpoint has been implemented for the compatibility with the legacy DIRAC WMS + and the JobCleaningAgent. However, once this agent is ported to diracx, this endpoint should + be removed, and the delete endpoint should be used instead. + """ + # TODO: Remove once legacy DIRAC no longer needs this + + # TODO: implement job policy + + await remove_jobs( + [job_id], + config, + job_db, + job_logging_db, + sandbox_metadata_db, + task_queue_db, + background_task, + ) + + return f"Job {job_id} has been successfully removed" + + @router.get("/{job_id}/status") async def get_single_job_status( job_id: int, job_db: JobDB ) -> dict[int, LimitedJobStatusReturn]: try: status = await job_db.get_job_status(job_id) - except NoResultFound as e: + except JobNotFound as e: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail=f"Job {job_id} not found" ) from e @@ -511,7 +701,7 @@ async def set_single_job_status( latest_status = await set_job_status( job_id, status, job_db, job_logging_db, force ) - except NoResultFound as e: + except JobNotFound as e: raise HTTPException(status_code=HTTPStatus.NOT_FOUND, detail=str(e)) from e return {job_id: latest_status} @@ -523,7 +713,7 @@ async def get_single_job_status_history( ) -> dict[int, list[JobStatusReturn]]: try: status = await job_logging_db.get_records(job_id) - except NoResultFound as e: + except JobNotFound as e: raise HTTPException( status_code=HTTPStatus.NOT_FOUND, detail="Job not found" ) from e diff --git a/tests/conftest.py b/tests/conftest.py index 6a067074..7bdea965 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -122,6 +122,7 @@ def with_app(test_auth_settings, test_sandbox_settings, with_config_repo): database_urls={ "JobDB": "sqlite+aiosqlite:///:memory:", "JobLoggingDB": "sqlite+aiosqlite:///:memory:", + "TaskQueueDB": "sqlite+aiosqlite:///:memory:", "AuthDB": "sqlite+aiosqlite:///:memory:", "SandboxMetadataDB": "sqlite+aiosqlite:///:memory:", }, @@ -147,6 +148,9 @@ async def create_db_schemas(app=app): assert isinstance(db, BaseSQLDB), (k, db) # Fill the DB schema async with db.engine.begin() as conn: + # set PRAGMA foreign_keys=ON if sqlite + if db._db_url.startswith("sqlite"): + await conn.exec_driver_sql("PRAGMA foreign_keys=ON") await conn.run_sync(db.metadata.create_all) yield diff --git a/tests/db/jobs/test_jobDB.py b/tests/db/jobs/test_jobDB.py index 24eec16e..5e46352b 100644 --- a/tests/db/jobs/test_jobDB.py +++ b/tests/db/jobs/test_jobDB.py @@ -4,6 +4,7 @@ import pytest +from diracx.core.exceptions import JobNotFound from diracx.db.sql.jobs.db import JobDB @@ -12,6 +13,9 @@ async def job_db(tmp_path): job_db = JobDB("sqlite+aiosqlite:///:memory:") async with job_db.engine_context(): async with job_db.engine.begin() as conn: + # set PRAGMA foreign_keys=ON if sqlite + if job_db._db_url.startswith("sqlite"): + await conn.exec_driver_sql("PRAGMA foreign_keys=ON") await conn.run_sync(job_db.metadata.create_all) yield job_db @@ -38,3 +42,9 @@ async def test_some_asyncio_code(job_db): async with job_db as job_db: result = await job_db.search(["JobID"], [], []) assert result + + +async def test_set_job_command_invalid_job_id(job_db: JobDB): + async with job_db as job_db: + with pytest.raises(JobNotFound): + await job_db.set_job_command(123456, "test_command") diff --git a/tests/db/jobs/test_jobLoggingDB.py b/tests/db/jobs/test_jobLoggingDB.py index 1949cde1..2a089d35 100644 --- a/tests/db/jobs/test_jobLoggingDB.py +++ b/tests/db/jobs/test_jobLoggingDB.py @@ -23,7 +23,7 @@ async def test_insert_record(job_logging_db: JobLoggingDB): # Act await job_logging_db.insert_record( 1, - status=JobStatus.RECEIVED.value, + status=JobStatus.RECEIVED, minor_status="minor_status", application_status="application_status", date=date, diff --git a/tests/routers/test_job_manager.py b/tests/routers/test_job_manager.py index 0b2a335d..03a86676 100644 --- a/tests/routers/test_job_manager.py +++ b/tests/routers/test_job_manager.py @@ -248,83 +248,89 @@ def test_user_without_the_normal_user_property_cannot_submit_job(admin_user_clie assert res.status_code == HTTPStatus.FORBIDDEN, res.json() -def test_get_job_status(normal_user_client: TestClient): - """Test that the job status is returned correctly.""" - # Arrange +@pytest.fixture +def valid_job_id(normal_user_client: TestClient): job_definitions = [TEST_JDL] r = normal_user_client.post("/api/jobs/", json=job_definitions) assert r.status_code == 200, r.json() - assert len(r.json()) == 1 # Parameters.JOB_ID is 3 - job_id = r.json()[0]["JobID"] + assert len(r.json()) == 1 + return r.json()[0]["JobID"] + + +@pytest.fixture +def valid_job_ids(normal_user_client: TestClient): + job_definitions = [TEST_PARAMETRIC_JDL] + r = normal_user_client.post("/api/jobs/", json=job_definitions) + assert r.status_code == 200, r.json() + assert len(r.json()) == 3 + return sorted([job_dict["JobID"] for job_dict in r.json()]) + + +@pytest.fixture +def invalid_job_id(): + return 999999996 + +@pytest.fixture +def invalid_job_ids(): + return [999999997, 999999998, 999999999] + + +def test_get_job_status(normal_user_client: TestClient, valid_job_id: int): + """Test that the job status is returned correctly.""" # Act - r = normal_user_client.get(f"/api/jobs/{job_id}/status") + r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") # Assert assert r.status_code == 200, r.json() # TODO: should we return camel case here (and everywhere else) ? - assert r.json()[str(job_id)]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" -def test_get_status_of_nonexistent_job(normal_user_client: TestClient): +def test_get_status_of_nonexistent_job( + normal_user_client: TestClient, invalid_job_id: int +): """Test that the job status is returned correctly.""" # Act - r = normal_user_client.get("/api/jobs/1/status") + r = normal_user_client.get(f"/api/jobs/{invalid_job_id}/status") # Assert assert r.status_code == 404, r.json() - assert r.json() == {"detail": "Job 1 not found"} + assert r.json() == {"detail": f"Job {invalid_job_id} not found"} -def test_get_job_status_in_bulk(normal_user_client: TestClient): +def test_get_job_status_in_bulk(normal_user_client: TestClient, valid_job_ids: list): """Test that we can get the status of multiple jobs in one request""" - # Arrange - job_definitions = [TEST_PARAMETRIC_JDL] - r = normal_user_client.post("/api/jobs/", json=job_definitions) - assert r.status_code == 200, r.json() - assert len(r.json()) == 3 # Parameters.JOB_ID is 3 - submitted_job_ids = sorted([job_dict["JobID"] for job_dict in r.json()]) - assert isinstance(submitted_job_ids, list) - assert (isinstance(submitted_job_id, int) for submitted_job_id in submitted_job_ids) - # Act - r = normal_user_client.get( - "/api/jobs/status", params={"job_ids": submitted_job_ids} - ) + r = normal_user_client.get("/api/jobs/status", params={"job_ids": valid_job_ids}) # Assert - print(r.json()) assert r.status_code == 200, r.json() assert len(r.json()) == 3 # Parameters.JOB_ID is 3 - for job_id in submitted_job_ids: + for job_id in valid_job_ids: assert str(job_id) in r.json() assert r.json()[str(job_id)]["Status"] == JobStatus.SUBMITTING.value assert r.json()[str(job_id)]["MinorStatus"] == "Bulk transaction confirmation" assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" -async def test_get_job_status_history(normal_user_client: TestClient): +async def test_get_job_status_history( + normal_user_client: TestClient, valid_job_id: int +): # Arrange - job_definitions = [TEST_JDL] - before = datetime.now(timezone.utc) - r = normal_user_client.post("/api/jobs/", json=job_definitions) - after = datetime.now(timezone.utc) + r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert len(r.json()) == 1 - job_id = r.json()[0]["JobID"] - r = normal_user_client.get(f"/api/jobs/{job_id}/status") - assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" NEW_STATUS = JobStatus.CHECKING.value NEW_MINOR_STATUS = "JobPath" - beforebis = datetime.now(timezone.utc) + before = datetime.now(timezone.utc) r = normal_user_client.put( - f"/api/jobs/{job_id}/status", + f"/api/jobs/{valid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { "Status": NEW_STATUS, @@ -332,83 +338,74 @@ async def test_get_job_status_history(normal_user_client: TestClient): } }, ) - afterbis = datetime.now(timezone.utc) + after = datetime.now(timezone.utc) assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == NEW_STATUS - assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["Status"] == NEW_STATUS + assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS # Act r = normal_user_client.get( - f"/api/jobs/{job_id}/status/history", + f"/api/jobs/{valid_job_id}/status/history", ) # Assert assert r.status_code == 200, r.json() assert len(r.json()) == 1 - assert len(r.json()[str(job_id)]) == 2 - assert r.json()[str(job_id)][0]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)][0]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)][0]["ApplicationStatus"] == "Unknown" - assert ( - before < datetime.fromisoformat(r.json()[str(job_id)][0]["StatusTime"]) < after - ) - assert r.json()[str(job_id)][0]["StatusSource"] == "JobManager" - - assert r.json()[str(job_id)][1]["Status"] == JobStatus.CHECKING.value - assert r.json()[str(job_id)][1]["MinorStatus"] == "JobPath" - assert r.json()[str(job_id)][1]["ApplicationStatus"] == "Unknown" + assert len(r.json()[str(valid_job_id)]) == 2 + assert r.json()[str(valid_job_id)][0]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)][0]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)][0]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)][0]["StatusSource"] == "JobManager" + + assert r.json()[str(valid_job_id)][1]["Status"] == JobStatus.CHECKING.value + assert r.json()[str(valid_job_id)][1]["MinorStatus"] == "JobPath" + assert r.json()[str(valid_job_id)][1]["ApplicationStatus"] == "Unknown" assert ( - beforebis - < datetime.fromisoformat(r.json()[str(job_id)][1]["StatusTime"]) - < afterbis + before + < datetime.fromisoformat(r.json()[str(valid_job_id)][1]["StatusTime"]) + < after ) - assert r.json()[str(job_id)][1]["StatusSource"] == "Unknown" + assert r.json()[str(valid_job_id)][1]["StatusSource"] == "Unknown" -def test_get_job_status_history_in_bulk(normal_user_client: TestClient): +def test_get_job_status_history_in_bulk( + normal_user_client: TestClient, valid_job_id: int +): # Arrange - job_definitions = [TEST_JDL] - r = normal_user_client.post("/api/jobs/", json=job_definitions) - assert r.status_code == 200, r.json() - assert len(r.json()) == 1 - job_id = r.json()[0]["JobID"] - r = normal_user_client.get(f"/api/jobs/{job_id}/status") + r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" # Act - r = normal_user_client.get("/api/jobs/status/history", params={"job_ids": [job_id]}) + r = normal_user_client.get( + "/api/jobs/status/history", params={"job_ids": [valid_job_id]} + ) # Assert assert r.status_code == 200, r.json() assert len(r.json()) == 1 - assert r.json()[str(job_id)][0]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)][0]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)][0]["ApplicationStatus"] == "Unknown" - assert datetime.fromisoformat(r.json()[str(job_id)][0]["StatusTime"]) - assert r.json()[str(job_id)][0]["StatusSource"] == "JobManager" + assert r.json()[str(valid_job_id)][0]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)][0]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)][0]["ApplicationStatus"] == "Unknown" + assert datetime.fromisoformat(r.json()[str(valid_job_id)][0]["StatusTime"]) + assert r.json()[str(valid_job_id)][0]["StatusSource"] == "JobManager" -def test_set_job_status(normal_user_client: TestClient): +def test_set_job_status(normal_user_client: TestClient, valid_job_id: int): # Arrange - job_definitions = [TEST_JDL] - r = normal_user_client.post("/api/jobs/", json=job_definitions) - assert r.status_code == 200, r.json() - assert len(r.json()) == 1 - job_id = r.json()[0]["JobID"] - r = normal_user_client.get(f"/api/jobs/{job_id}/status") + r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" # Act NEW_STATUS = JobStatus.CHECKING.value NEW_MINOR_STATUS = "JobPath" r = normal_user_client.put( - f"/api/jobs/{job_id}/status", + f"/api/jobs/{valid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { "Status": NEW_STATUS, @@ -419,20 +416,22 @@ def test_set_job_status(normal_user_client: TestClient): # Assert assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == NEW_STATUS - assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["Status"] == NEW_STATUS + assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS - r = normal_user_client.get(f"/api/jobs/{job_id}/status") + r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == NEW_STATUS - assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == NEW_STATUS + assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" -def test_set_job_status_invalid_job(normal_user_client: TestClient): +def test_set_job_status_invalid_job( + normal_user_client: TestClient, invalid_job_id: int +): # Act r = normal_user_client.put( - "/api/jobs/1/status", + f"/api/jobs/{invalid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { "Status": JobStatus.CHECKING.value, @@ -443,23 +442,17 @@ def test_set_job_status_invalid_job(normal_user_client: TestClient): # Assert assert r.status_code == 404, r.json() - assert r.json() == {"detail": "Job 1 not found"} + assert r.json() == {"detail": f"Job {invalid_job_id} not found"} def test_set_job_status_offset_naive_datetime_return_bad_request( normal_user_client: TestClient, + valid_job_id: int, ): - # Arrange - job_definitions = [TEST_JDL] - r = normal_user_client.post("/api/jobs/", json=job_definitions) - assert r.status_code == 200, r.json() - assert len(r.json()) == 1 - job_id = r.json()[0]["JobID"] - # Act date = datetime.utcnow().isoformat(sep=" ") r = normal_user_client.put( - f"/api/jobs/{job_id}/status", + f"/api/jobs/{valid_job_id}/status", json={ date: { "Status": JobStatus.CHECKING.value, @@ -474,25 +467,20 @@ def test_set_job_status_offset_naive_datetime_return_bad_request( def test_set_job_status_cannot_make_impossible_transitions( - normal_user_client: TestClient, + normal_user_client: TestClient, valid_job_id: int ): # Arrange - job_definitions = [TEST_JDL] - r = normal_user_client.post("/api/jobs/", json=job_definitions) + r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert len(r.json()) == 1 - job_id = r.json()[0]["JobID"] - r = normal_user_client.get(f"/api/jobs/{job_id}/status") - assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" # Act NEW_STATUS = JobStatus.RUNNING.value NEW_MINOR_STATUS = "JobPath" r = normal_user_client.put( - f"/api/jobs/{job_id}/status", + f"/api/jobs/{valid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { "Status": NEW_STATUS, @@ -503,34 +491,29 @@ def test_set_job_status_cannot_make_impossible_transitions( # Assert assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] != NEW_STATUS - assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["Status"] != NEW_STATUS + assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS - r = normal_user_client.get(f"/api/jobs/{job_id}/status") + r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] != NEW_STATUS - assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] != NEW_STATUS + assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" -def test_set_job_status_force(normal_user_client: TestClient): +def test_set_job_status_force(normal_user_client: TestClient, valid_job_id: int): # Arrange - job_definitions = [TEST_JDL] - r = normal_user_client.post("/api/jobs/", json=job_definitions) - assert r.status_code == 200, r.json() - assert len(r.json()) == 1 - job_id = r.json()[0]["JobID"] - r = normal_user_client.get(f"/api/jobs/{job_id}/status") + r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == JobStatus.RECEIVED.value - assert r.json()[str(job_id)]["MinorStatus"] == "Job accepted" - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.RECEIVED.value + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Job accepted" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" # Act NEW_STATUS = JobStatus.RUNNING.value NEW_MINOR_STATUS = "JobPath" r = normal_user_client.put( - f"/api/jobs/{job_id}/status", + f"/api/jobs/{valid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { "Status": NEW_STATUS, @@ -542,25 +525,19 @@ def test_set_job_status_force(normal_user_client: TestClient): # Assert assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == NEW_STATUS - assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["Status"] == NEW_STATUS + assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS - r = normal_user_client.get(f"/api/jobs/{job_id}/status") + r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") assert r.status_code == 200, r.json() - assert r.json()[str(job_id)]["Status"] == NEW_STATUS - assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS - assert r.json()[str(job_id)]["ApplicationStatus"] == "Unknown" + assert r.json()[str(valid_job_id)]["Status"] == NEW_STATUS + assert r.json()[str(valid_job_id)]["MinorStatus"] == NEW_MINOR_STATUS + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" -def test_set_job_status_bulk(normal_user_client: TestClient): +def test_set_job_status_bulk(normal_user_client: TestClient, valid_job_ids): # Arrange - job_definitions = [TEST_PARAMETRIC_JDL] - r = normal_user_client.post("/api/jobs/", json=job_definitions) - assert r.status_code == 200, r.json() - assert len(r.json()) == 3 - job_ids = sorted([job_dict["JobID"] for job_dict in r.json()]) - - for job_id in job_ids: + for job_id in valid_job_ids: r = normal_user_client.get(f"/api/jobs/{job_id}/status") assert r.status_code == 200, r.json() assert r.json()[str(job_id)]["Status"] == JobStatus.SUBMITTING.value @@ -578,13 +555,13 @@ def test_set_job_status_bulk(normal_user_client: TestClient): "MinorStatus": NEW_MINOR_STATUS, } } - for job_id in job_ids + for job_id in valid_job_ids }, ) # Assert assert r.status_code == 200, r.json() - for job_id in job_ids: + for job_id in valid_job_ids: assert r.json()[str(job_id)]["Status"] == NEW_STATUS assert r.json()[str(job_id)]["MinorStatus"] == NEW_MINOR_STATUS @@ -595,10 +572,12 @@ def test_set_job_status_bulk(normal_user_client: TestClient): assert r_get.json()[str(job_id)]["ApplicationStatus"] == "Unknown" -def test_set_job_status_with_invalid_job_id(normal_user_client: TestClient): +def test_set_job_status_with_invalid_job_id( + normal_user_client: TestClient, invalid_job_id: int +): # Act r = normal_user_client.put( - "/api/jobs/999999999/status", + f"/api/jobs/{invalid_job_id}/status", json={ datetime.now(tz=timezone.utc).isoformat(): { "Status": JobStatus.CHECKING.value, @@ -609,7 +588,7 @@ def test_set_job_status_with_invalid_job_id(normal_user_client: TestClient): # Assert assert r.status_code == 404, r.json() - assert r.json() == {"detail": "Job 999999999 not found"} + assert r.json() == {"detail": f"Job {invalid_job_id} not found"} def test_insert_and_reschedule(normal_user_client: TestClient): @@ -626,3 +605,247 @@ def test_insert_and_reschedule(normal_user_client: TestClient): params={"job_ids": submitted_job_ids}, ) assert r.status_code == 200, r.json() + + +# Test delete job + + +def test_delete_job_valid_job_id(normal_user_client: TestClient, valid_job_id: int): + # Act + r = normal_user_client.delete(f"/api/jobs/{valid_job_id}") + + # Assert + assert r.status_code == 200, r.json() + r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") + assert r.status_code == 200, r.json() + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.DELETED + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Checking accounting" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" + + +def test_delete_job_invalid_job_id(normal_user_client: TestClient, invalid_job_id: int): + # Act + r = normal_user_client.delete(f"/api/jobs/{invalid_job_id}") + + # Assert + assert r.status_code == 404, r.json() + assert r.json() == {"detail": f"Job {invalid_job_id} not found"} + + +def test_delete_bulk_jobs_valid_job_ids( + normal_user_client: TestClient, valid_job_ids: list[int] +): + # Act + r = normal_user_client.delete("/api/jobs/", params={"job_ids": valid_job_ids}) + + # Assert + assert r.status_code == 200, r.json() + for valid_job_id in valid_job_ids: + r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") + assert r.status_code == 200, r.json() + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.DELETED + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Checking accounting" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" + + +def test_delete_bulk_jobs_invalid_job_ids( + normal_user_client: TestClient, invalid_job_ids: list[int] +): + # Act + r = normal_user_client.delete("/api/jobs/", params={"job_ids": invalid_job_ids}) + + # Assert + assert r.status_code == 404, r.json() + assert r.json() == { + "detail": { + "message": f"Failed to delete {len(invalid_job_ids)} jobs out of {len(invalid_job_ids)}", + "valid_job_ids": [], + "failed_job_ids": invalid_job_ids, + } + } + + +def test_delete_bulk_jobs_mix_of_valid_and_invalid_job_ids( + normal_user_client: TestClient, valid_job_ids: list[int], invalid_job_ids: list[int] +): + # Arrange + job_ids = valid_job_ids + invalid_job_ids + + # Act + r = normal_user_client.delete("/api/jobs/", params={"job_ids": job_ids}) + + # Assert + assert r.status_code == 404, r.json() + assert r.json() == { + "detail": { + "message": f"Failed to delete {len(invalid_job_ids)} jobs out of {len(job_ids)}", + "valid_job_ids": valid_job_ids, + "failed_job_ids": invalid_job_ids, + } + } + for job_id in valid_job_ids: + r = normal_user_client.get(f"/api/jobs/{job_id}/status") + assert r.status_code == 200, r.json() + assert r.json()[str(job_id)]["Status"] != JobStatus.DELETED + + +# Test kill job + + +def test_kill_job_valid_job_id(normal_user_client: TestClient, valid_job_id: int): + # Act + r = normal_user_client.post(f"/api/jobs/{valid_job_id}/kill") + + # Assert + assert r.status_code == 200, r.json() + r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") + assert r.status_code == 200, r.json() + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.KILLED + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Marked for termination" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" + + +def test_kill_job_invalid_job_id(normal_user_client: TestClient, invalid_job_id: int): + # Act + r = normal_user_client.post(f"/api/jobs/{invalid_job_id}/kill") + + # Assert + assert r.status_code == 404, r.json() + assert r.json() == {"detail": f"Job {invalid_job_id} not found"} + + +def test_kill_bulk_jobs_valid_job_ids( + normal_user_client: TestClient, valid_job_ids: list[int] +): + # Act + r = normal_user_client.post("/api/jobs/kill", params={"job_ids": valid_job_ids}) + + # Assert + assert r.status_code == 200, r.json() + assert r.json() == valid_job_ids + for valid_job_id in valid_job_ids: + r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") + assert r.status_code == 200, r.json() + assert r.json()[str(valid_job_id)]["Status"] == JobStatus.KILLED + assert r.json()[str(valid_job_id)]["MinorStatus"] == "Marked for termination" + assert r.json()[str(valid_job_id)]["ApplicationStatus"] == "Unknown" + + +def test_kill_bulk_jobs_invalid_job_ids( + normal_user_client: TestClient, invalid_job_ids: list[int] +): + # Act + r = normal_user_client.post("/api/jobs/kill", params={"job_ids": invalid_job_ids}) + + # Assert + assert r.status_code == 404, r.json() + assert r.json() == { + "detail": { + "message": f"Failed to kill {len(invalid_job_ids)} jobs out of {len(invalid_job_ids)}", + "valid_job_ids": [], + "failed_job_ids": invalid_job_ids, + } + } + + +def test_kill_bulk_jobs_mix_of_valid_and_invalid_job_ids( + normal_user_client: TestClient, valid_job_ids: list[int], invalid_job_ids: list[int] +): + # Arrange + job_ids = valid_job_ids + invalid_job_ids + + # Act + r = normal_user_client.post("/api/jobs/kill", params={"job_ids": job_ids}) + + # Assert + assert r.status_code == 404, r.json() + assert r.json() == { + "detail": { + "message": f"Failed to kill {len(invalid_job_ids)} jobs out of {len(job_ids)}", + "valid_job_ids": valid_job_ids, + "failed_job_ids": invalid_job_ids, + } + } + for valid_job_id in valid_job_ids: + r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") + assert r.status_code == 200, r.json() + # assert the job is not killed + assert r.json()[str(valid_job_id)]["Status"] != JobStatus.KILLED + + +# Test remove job + + +def test_remove_job_valid_job_id(normal_user_client: TestClient, valid_job_id: int): + # Act + r = normal_user_client.post(f"/api/jobs/{valid_job_id}/remove") + + # Assert + assert r.status_code == 200, r.json() + r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status") + assert r.status_code == 404, r.json() + + +def test_remove_job_invalid_job_id(normal_user_client: TestClient, invalid_job_id: int): + # Act + r = normal_user_client.post(f"/api/jobs/{invalid_job_id}/remove") + + # Assert + assert r.status_code == 200, r.json() + + +def test_remove_bulk_jobs_valid_job_ids( + normal_user_client: TestClient, valid_job_ids: list[int] +): + # Act + r = normal_user_client.post("/api/jobs/remove", params={"job_ids": valid_job_ids}) + + # Assert + assert r.status_code == 200, r.json() + for job_id in valid_job_ids: + r = normal_user_client.get(f"/api/jobs/{job_id}/status") + assert r.status_code == 404, r.json() + + +# def test_remove_bulk_jobs_invalid_job_ids( +# normal_user_client: TestClient, invalid_job_ids: list[int] +# ): +# # Act +# r = normal_user_client.post("/api/jobs/remove", params={"job_ids": invalid_job_ids}) + +# # Assert +# assert r.status_code == 404, r.json() +# assert r.json() == { +# "detail": { +# "message": f"Failed to remove {len(invalid_job_ids)} jobs out of {len(invalid_job_ids)}", +# "failed_ids": { +# str(invalid_job_id): f"Job {invalid_job_id} not found" +# for invalid_job_id in invalid_job_ids +# }, +# } +# } + + +# def test_remove_bulk_jobs_mix_of_valid_and_invalid_job_ids( +# normal_user_client: TestClient, valid_job_ids: list[int], invalid_job_ids: list[int] +# ): +# # Arrange +# job_ids = valid_job_ids + invalid_job_ids + +# # Act +# r = normal_user_client.post("/api/jobs/remove", params={"job_ids": job_ids}) + +# # Assert +# assert r.status_code == 404, r.json() +# assert r.json() == { +# "detail": { +# "message": f"Failed to remove {len(invalid_job_ids)} jobs out of {len(job_ids)}", +# "failed_ids": { +# str(invalid_job_id): f"Job {invalid_job_id} not found" +# for invalid_job_id in invalid_job_ids +# }, +# } +# } +# for job_id in valid_job_ids: +# r = normal_user_client.get(f"/api/jobs/{job_id}/status") +# assert r.status_code == 404, r.json()