diff --git a/diracx-routers/src/diracx/routers/pilots/logging.py b/diracx-routers/src/diracx/routers/pilots/logging.py index b182890d..d4358ea9 100644 --- a/diracx-routers/src/diracx/routers/pilots/logging.py +++ b/diracx-routers/src/diracx/routers/pilots/logging.py @@ -112,7 +112,7 @@ async def get_logs( @router.delete("/logs") async def delete( - pilot_id: int, + pilot_ids: list[int], data: DateRange, db: PilotLogsDB, check_permissions: CheckPilotLogsPolicyCallable, @@ -125,17 +125,19 @@ async def delete( non_privil_params = {"parameter": "VO", "operator": "eq", "value": user_info.vo} # id pilot_id is provided we ignore data.min and data.max - if data.min and data.max and not pilot_id: + if not pilot_ids and data.min and data.max: raise InvalidQueryError( "This query requires a range operator definition in DiracX" ) - if pilot_id: - search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}] + if pilot_ids: + search_params = [ + {"parameter": "PilotID", "operator": "in", "values": pilot_ids} + ] if _non_privileged(user_info): search_params.append(non_privil_params) await db.delete(search_params) - message = f"Logs for pilot ID '{pilot_id}' successfully deleted" + message = f"Logs for pilot IDs '{pilot_ids}' successfully deleted" elif data.min: logger.warning(f"Deleting logs for pilots with submission data >='{data.min}'") diff --git a/diracx-routers/tests/pilots/test_pilot_logger.py b/diracx-routers/tests/pilots/test_pilot_logger.py index a538df29..86594d37 100644 --- a/diracx-routers/tests/pilots/test_pilot_logger.py +++ b/diracx-routers/tests/pilots/test_pilot_logger.py @@ -5,13 +5,14 @@ import pytest from sqlalchemy import inspect, update -from diracx.core.properties import PILOT +from diracx.core.properties import OPERATOR, PILOT from diracx.db.os import PilotLogsDB from diracx.db.sql import PilotAgentsDB from diracx.db.sql.pilot_agents.schema import PilotAgents from diracx.routers.pilots.logging import ( LogLine, LogMessage, + delete, get_logs, send_message, ) @@ -35,7 +36,7 @@ async def pilot_agents_db(tmp_path) -> PilotAgentsDB: @pytest.fixture async def pilot_logs_db(): # create a class that has sqlite backend replacing OpenSearch PilotLogsDB - m_pilot_logs_db = type("JobParametersDB", (MockOSDBMixin, PilotLogsDB), {}) + m_pilot_logs_db = type("PilotLogsDB", (MockOSDBMixin, PilotLogsDB), {}) db = m_pilot_logs_db( connection_kwargs={"sqlalchemy_dsn": "sqlite+aiosqlite:///:memory:"} @@ -76,25 +77,35 @@ async def test_logging( .values(SubmissionTime=sub_time) ) await db.conn.execute(stmt) - # 4 message records for the first pilot. - line = [{"Message": f"Message_no_{i}"} for i in range(1, 4)] - log_lines = [LogLine(line_no=i + 1, line=line[i]["Message"]) for i in range(3)] - message = LogMessage(pilot_stamp="stamp_1", lines=log_lines, vo="gridpp") check_permissions_mock = AsyncMock() - check_permissions_mock.return_value.vo = "gridpp" + check_permissions_mock.return_value.vo = "test_vo" # TODO add user properties dict return_value above mock_url.return_value = {"PilotAgentsDB": "sqlite+aiosqlite:///:memory:"} # use the existing context (we have a DB already): pilot_agents_db.engine_context = nullcontext mock_impl.return_value = [lambda x: pilot_agents_db] - # send logs for stamp_1, pilot id = 1 - pilot_id = await send_message(message, pilot_logs_db, check_permissions_mock) - assert pilot_id == 1 - # get logs for pilot_id=1 - log_records = await get_logs(pilot_id, pilot_logs_db, check_permissions_mock) - assert log_records == line - # delete logs for pilot_id = 1 - check_permissions_mock.return_value.properties = [PILOT] - # TODO: await mock_osdb delete implementation... - # res = await delete(pilot_id, DateRange(), pilot_logs_db, check_permissions_mock) + + # 4 message records for each pilot. + for pilot in range(1, upper_limit): + line = [{"Message": f"Message_no_{i}_pilot_no_{pilot}"} for i in range(1, 4)] + log_lines = [LogLine(line_no=i + 1, line=line[i]["Message"]) for i in range(3)] + message = LogMessage(pilot_stamp=f"stamp_{pilot}", lines=log_lines, vo="gridpp") + # send the message: + pilot_id = await send_message(message, pilot_logs_db, check_permissions_mock) + assert pilot_id == pilot + # get logs for pilot_id=1 + log_records = await get_logs(pilot_id, pilot_logs_db, check_permissions_mock) + assert log_records == line + + # bulk delete logs for pilot_id = 1 and 2 + check_permissions_mock.return_value.properties = [PILOT, OPERATOR] + del_ids = [1, 2] + res = await delete(del_ids, None, pilot_logs_db, check_permissions_mock) + assert res == f"Logs for pilot IDs '{del_ids}' successfully deleted" + # when it's gone, it's gone: + for pilot_id in [1, 2]: + log_records = await get_logs(pilot_id, pilot_logs_db, check_permissions_mock) + assert log_records == [{"Message": f"No logs for pilot ID = {pilot_id}"}] + # time restricted delete will come here: + # TODO: await mock_osdb date range delete implementation... diff --git a/diracx-testing/src/diracx/testing/mock_osdb.py b/diracx-testing/src/diracx/testing/mock_osdb.py index f92f3968..c2a08804 100644 --- a/diracx-testing/src/diracx/testing/mock_osdb.py +++ b/diracx-testing/src/diracx/testing/mock_osdb.py @@ -10,7 +10,7 @@ from functools import partial from typing import Any, AsyncIterator -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.dialects.sqlite import insert as sqlite_insert from diracx.core.models import SearchSpec, SortSpec @@ -99,8 +99,9 @@ async def upsert(self, doc_id, document) -> None: async def bulk_insert(self, index_name: str, docs: list[dict[str, Any]]) -> None: async with self: rows = [] - for item, doc in enumerate(docs): - values = {"doc_id": item + 1} + for doc in docs: + # don't use doc_id column explicitly. This ensures that doc_id is unique. + values = {} for key, value in doc.items(): if key in self.fields: values[key] = value @@ -163,6 +164,14 @@ async def search( results.append(result) return results + async def delete(self, query: list[dict[str, Any]]) -> None: + async with self: + stmt = delete(self._table) + stmt = sql_utils.apply_search_filters( + self._table.columns.__getitem__, stmt, query + ) + await self._sql_db.conn.execute(stmt) + async def ping(self): return await self._sql_db.ping()