Skip to content

Commit

Permalink
feat: implement bulk delete
Browse files Browse the repository at this point in the history
  • Loading branch information
martynia committed Jan 8, 2025
1 parent fff5434 commit 2c59c0c
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 25 deletions.
12 changes: 7 additions & 5 deletions diracx-routers/src/diracx/routers/pilots/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ async def get_logs(

@router.delete("/logs")
async def delete(
pilot_id: int,
pilot_ids: list[int],
data: DateRange,
db: PilotLogsDB,
check_permissions: CheckPilotLogsPolicyCallable,
Expand All @@ -125,17 +125,19 @@ async def delete(
non_privil_params = {"parameter": "VO", "operator": "eq", "value": user_info.vo}

# id pilot_id is provided we ignore data.min and data.max
if data.min and data.max and not pilot_id:
if not pilot_ids and data.min and data.max:
raise InvalidQueryError(
"This query requires a range operator definition in DiracX"
)

if pilot_id:
search_params = [{"parameter": "PilotID", "operator": "eq", "value": pilot_id}]
if pilot_ids:
search_params = [
{"parameter": "PilotID", "operator": "in", "values": pilot_ids}
]
if _non_privileged(user_info):
search_params.append(non_privil_params)
await db.delete(search_params)
message = f"Logs for pilot ID '{pilot_id}' successfully deleted"
message = f"Logs for pilot IDs '{pilot_ids}' successfully deleted"

elif data.min:
logger.warning(f"Deleting logs for pilots with submission data >='{data.min}'")
Expand Down
45 changes: 28 additions & 17 deletions diracx-routers/tests/pilots/test_pilot_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import pytest
from sqlalchemy import inspect, update

from diracx.core.properties import PILOT
from diracx.core.properties import OPERATOR, PILOT
from diracx.db.os import PilotLogsDB
from diracx.db.sql import PilotAgentsDB
from diracx.db.sql.pilot_agents.schema import PilotAgents
from diracx.routers.pilots.logging import (
LogLine,
LogMessage,
delete,
get_logs,
send_message,
)
Expand All @@ -35,7 +36,7 @@ async def pilot_agents_db(tmp_path) -> PilotAgentsDB:
@pytest.fixture
async def pilot_logs_db():
# create a class that has sqlite backend replacing OpenSearch PilotLogsDB
m_pilot_logs_db = type("JobParametersDB", (MockOSDBMixin, PilotLogsDB), {})
m_pilot_logs_db = type("PilotLogsDB", (MockOSDBMixin, PilotLogsDB), {})

db = m_pilot_logs_db(
connection_kwargs={"sqlalchemy_dsn": "sqlite+aiosqlite:///:memory:"}
Expand Down Expand Up @@ -76,25 +77,35 @@ async def test_logging(
.values(SubmissionTime=sub_time)
)
await db.conn.execute(stmt)
# 4 message records for the first pilot.
line = [{"Message": f"Message_no_{i}"} for i in range(1, 4)]
log_lines = [LogLine(line_no=i + 1, line=line[i]["Message"]) for i in range(3)]
message = LogMessage(pilot_stamp="stamp_1", lines=log_lines, vo="gridpp")

check_permissions_mock = AsyncMock()
check_permissions_mock.return_value.vo = "gridpp"
check_permissions_mock.return_value.vo = "test_vo"
# TODO add user properties dict return_value above
mock_url.return_value = {"PilotAgentsDB": "sqlite+aiosqlite:///:memory:"}
# use the existing context (we have a DB already):
pilot_agents_db.engine_context = nullcontext
mock_impl.return_value = [lambda x: pilot_agents_db]
# send logs for stamp_1, pilot id = 1
pilot_id = await send_message(message, pilot_logs_db, check_permissions_mock)
assert pilot_id == 1
# get logs for pilot_id=1
log_records = await get_logs(pilot_id, pilot_logs_db, check_permissions_mock)
assert log_records == line
# delete logs for pilot_id = 1
check_permissions_mock.return_value.properties = [PILOT]
# TODO: await mock_osdb delete implementation...
# res = await delete(pilot_id, DateRange(), pilot_logs_db, check_permissions_mock)

# 4 message records for each pilot.
for pilot in range(1, upper_limit):
line = [{"Message": f"Message_no_{i}_pilot_no_{pilot}"} for i in range(1, 4)]
log_lines = [LogLine(line_no=i + 1, line=line[i]["Message"]) for i in range(3)]
message = LogMessage(pilot_stamp=f"stamp_{pilot}", lines=log_lines, vo="gridpp")
# send the message:
pilot_id = await send_message(message, pilot_logs_db, check_permissions_mock)
assert pilot_id == pilot
# get logs for pilot_id=1
log_records = await get_logs(pilot_id, pilot_logs_db, check_permissions_mock)
assert log_records == line

# bulk delete logs for pilot_id = 1 and 2
check_permissions_mock.return_value.properties = [PILOT, OPERATOR]
del_ids = [1, 2]
res = await delete(del_ids, None, pilot_logs_db, check_permissions_mock)
assert res == f"Logs for pilot IDs '{del_ids}' successfully deleted"
# when it's gone, it's gone:
for pilot_id in [1, 2]:
log_records = await get_logs(pilot_id, pilot_logs_db, check_permissions_mock)
assert log_records == [{"Message": f"No logs for pilot ID = {pilot_id}"}]
# time restricted delete will come here:
# TODO: await mock_osdb date range delete implementation...
15 changes: 12 additions & 3 deletions diracx-testing/src/diracx/testing/mock_osdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from functools import partial
from typing import Any, AsyncIterator

from sqlalchemy import select
from sqlalchemy import delete, select
from sqlalchemy.dialects.sqlite import insert as sqlite_insert

from diracx.core.models import SearchSpec, SortSpec
Expand Down Expand Up @@ -99,8 +99,9 @@ async def upsert(self, doc_id, document) -> None:
async def bulk_insert(self, index_name: str, docs: list[dict[str, Any]]) -> None:
async with self:
rows = []
for item, doc in enumerate(docs):
values = {"doc_id": item + 1}
for doc in docs:
# don't use doc_id column explicitly. This ensures that doc_id is unique.
values = {}
for key, value in doc.items():
if key in self.fields:
values[key] = value
Expand Down Expand Up @@ -163,6 +164,14 @@ async def search(
results.append(result)
return results

async def delete(self, query: list[dict[str, Any]]) -> None:
async with self:
stmt = delete(self._table)
stmt = sql_utils.apply_search_filters(
self._table.columns.__getitem__, stmt, query
)
await self._sql_db.conn.execute(stmt)

async def ping(self):
return await self._sql_db.ping()

Expand Down

0 comments on commit 2c59c0c

Please sign in to comment.