Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable remote pilot logging system [MISSING AUTH] #269

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions diracx-db/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ TaskQueueDB = "diracx.db.sql:TaskQueueDB"

[project.entry-points."diracx.db.os"]
JobParametersDB = "diracx.db.os:JobParametersDB"
PilotLogsDB = "diracx.db.os:PilotLogsDB"

[tool.setuptools.packages.find]
where = ["src"]
Expand Down
6 changes: 5 additions & 1 deletion diracx-db/src/diracx/db/os/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from __future__ import annotations

__all__ = ("JobParametersDB",)
__all__ = (
"JobParametersDB",
"PilotLogsDB",
)

from .job_parameters import JobParametersDB
from .pilot_logs import PilotLogsDB
21 changes: 21 additions & 0 deletions diracx-db/src/diracx/db/os/pilot_logs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from __future__ import annotations

from diracx.db.os.utils import BaseOSDB


class PilotLogsDB(BaseOSDB):
fields = {
"PilotStamp": {"type": "keyword"},
"PilotID": {"type": "long"},
"SubmissionTime": {"type": "date"},
"LineNumber": {"type": "long"},
"Message": {"type": "text"},
"VO": {"type": "keyword"},
"timestamp": {"type": "date"},
}
index_prefix = "pilot_logs"

def index_name(self, doc_id: int) -> str:
# TODO decide how to define the index name
# use pilot ID
return f"{self.index_prefix}_{doc_id // 1e6:.0f}"
19 changes: 19 additions & 0 deletions diracx-db/src/diracx/db/os/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Any, Self

from opensearchpy import AsyncOpenSearch
from opensearchpy.helpers import async_bulk

from diracx.core.exceptions import InvalidQueryError
from diracx.core.extensions import select_from_extension
Expand Down Expand Up @@ -190,6 +191,13 @@
)
print(f"{response=}")

async def bulk_insert(self, index_name: str, docs: list[dict[str, Any]]) -> None:
"""Bulk inserting to database."""
n_inserted = await async_bulk(

Check warning on line 196 in diracx-db/src/diracx/db/os/utils.py

View check run for this annotation

Codecov / codecov/patch

diracx-db/src/diracx/db/os/utils.py#L196

Added line #L196 was not covered by tests
self.client, actions=[doc | {"_index": index_name} for doc in docs]
)
logger.info("Inserted %d documents to %r", n_inserted, index_name)

Check warning on line 199 in diracx-db/src/diracx/db/os/utils.py

View check run for this annotation

Codecov / codecov/patch

diracx-db/src/diracx/db/os/utils.py#L199

Added line #L199 was not covered by tests

async def search(
self, parameters, search, sorts, *, per_page: int = 100, page: int | None = None
) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -231,6 +239,17 @@

return hits

async def delete(self, query: list[dict[str, Any]]) -> dict:
"""Delete multiple documents by query."""
body = {}
res = {}
if query:
body["query"] = apply_search_filters(self.fields, query)
res = await self.client.delete_by_query(

Check warning on line 248 in diracx-db/src/diracx/db/os/utils.py

View check run for this annotation

Codecov / codecov/patch

diracx-db/src/diracx/db/os/utils.py#L244-L248

Added lines #L244 - L248 were not covered by tests
body=body, index=f"{self.index_prefix}*"
)
return res

Check warning on line 251 in diracx-db/src/diracx/db/os/utils.py

View check run for this annotation

Codecov / codecov/patch

diracx-db/src/diracx/db/os/utils.py#L251

Added line #L251 was not covered by tests


def require_type(operator, field_name, field_type, allowed_types):
if field_type not in allowed_types:
Expand Down
19 changes: 4 additions & 15 deletions diracx-db/src/diracx/db/sql/job/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
SortSpec,
)

from ..utils import BaseSQLDB, apply_search_filters, apply_sort_constraints
from ..utils import BaseSQLDB, apply_search_filters, apply_sort_constraints, get_columns
from .schema import (
InputData,
JobCommands,
Expand All @@ -26,17 +26,6 @@
)


def _get_columns(table, parameters):
columns = [x for x in table.columns]
if parameters:
if unrecognised_parameters := set(parameters) - set(table.columns.keys()):
raise InvalidQueryError(
f"Unrecognised parameters requested {unrecognised_parameters}"
)
columns = [c for c in columns if c.name in parameters]
return columns


class JobDB(BaseSQLDB):
metadata = JobDBBase.metadata

Expand All @@ -46,7 +35,7 @@
jdl_2_db_parameters = ["JobName", "JobType", "JobGroup"]

async def summary(self, group_by, search) -> list[dict[str, str | int]]:
columns = _get_columns(Jobs.__table__, group_by)
columns = get_columns(Jobs.__table__, group_by)

stmt = select(*columns, func.count(Jobs.job_id).label("count"))
stmt = apply_search_filters(Jobs.__table__.columns.__getitem__, stmt, search)
Expand All @@ -70,7 +59,7 @@
page: int | None = None,
) -> tuple[int, list[dict[Any, Any]]]:
# Find which columns to select
columns = _get_columns(Jobs.__table__, parameters)
columns = get_columns(Jobs.__table__, parameters)

stmt = select(*columns)

Expand Down Expand Up @@ -328,7 +317,7 @@
required_parameters = list(required_parameters_set)[0]
update_parameters = [{"job_id": k, **v} for k, v in properties.items()]

columns = _get_columns(Jobs.__table__, required_parameters)
columns = get_columns(Jobs.__table__, required_parameters)

Check warning on line 320 in diracx-db/src/diracx/db/sql/job/db.py

View check run for this annotation

Codecov / codecov/patch

diracx-db/src/diracx/db/sql/job/db.py#L320

Added line #L320 was not covered by tests
values: dict[str, BindParameter[Any] | datetime] = {
c.name: bindparam(c.name) for c in columns
}
Expand Down
54 changes: 52 additions & 2 deletions diracx-db/src/diracx/db/sql/pilot_agents/db.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from __future__ import annotations

from datetime import datetime, timezone
from typing import Any

from sqlalchemy import insert
from sqlalchemy import func, insert, select

from ..utils import BaseSQLDB
from diracx.core.exceptions import InvalidQueryError
from diracx.core.models import (
SearchSpec,
SortSpec,
)

from ..utils import BaseSQLDB, apply_search_filters, apply_sort_constraints, get_columns
from .schema import PilotAgents, PilotAgentsDBBase


Expand Down Expand Up @@ -44,3 +51,46 @@
stmt = insert(PilotAgents).values(values)
await self.conn.execute(stmt)
return

async def search(
self,
parameters: list[str] | None,
search: list[SearchSpec],
sorts: list[SortSpec],
*,
distinct: bool = False,
per_page: int = 100,
page: int | None = None,
) -> tuple[int, list[dict[Any, Any]]]:
# Find which columns to select
columns = get_columns(PilotAgents.__table__, parameters)

stmt = select(*columns)

stmt = apply_search_filters(
PilotAgents.__table__.columns.__getitem__, stmt, search
)
stmt = apply_sort_constraints(
PilotAgents.__table__.columns.__getitem__, stmt, sorts
)

if distinct:
stmt = stmt.distinct()

Check warning on line 78 in diracx-db/src/diracx/db/sql/pilot_agents/db.py

View check run for this annotation

Codecov / codecov/patch

diracx-db/src/diracx/db/sql/pilot_agents/db.py#L78

Added line #L78 was not covered by tests

# Calculate total count before applying pagination
total_count_subquery = stmt.alias()
total_count_stmt = select(func.count()).select_from(total_count_subquery)
total = (await self.conn.execute(total_count_stmt)).scalar_one()

# Apply pagination
if page is not None:
if page < 1:
raise InvalidQueryError("Page must be a positive integer")
if per_page < 1:
raise InvalidQueryError("Per page must be a positive integer")
stmt = stmt.offset((page - 1) * per_page).limit(per_page)

Check warning on line 91 in diracx-db/src/diracx/db/sql/pilot_agents/db.py

View check run for this annotation

Codecov / codecov/patch

diracx-db/src/diracx/db/sql/pilot_agents/db.py#L86-L91

Added lines #L86 - L91 were not covered by tests

# Execute the query
return total, [
dict(row._mapping) async for row in (await self.conn.stream(stmt))
]
2 changes: 2 additions & 0 deletions diracx-db/src/diracx/db/sql/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
SQLDBUnavailableError,
apply_search_filters,
apply_sort_constraints,
get_columns,
)
from .functions import substract_date, utcnow
from .types import Column, DateNowColumn, EnumBackedBool, EnumColumn, NullColumn
Expand All @@ -19,6 +20,7 @@
"EnumColumn",
"apply_search_filters",
"apply_sort_constraints",
"get_columns",
"substract_date",
"SQLDBUnavailableError",
)
11 changes: 11 additions & 0 deletions diracx-db/src/diracx/db/sql/utils/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,17 @@ def find_time_resolution(value):
raise InvalidQueryError(f"Cannot parse {value=}")


def get_columns(table, parameters):
columns = [x for x in table.columns]
if parameters:
if unrecognised_parameters := set(parameters) - set(table.columns.keys()):
raise InvalidQueryError(
f"Unrecognised parameters requested {unrecognised_parameters}"
)
columns = [c for c in columns if c.name in parameters]
return columns


def apply_search_filters(column_mapping, stmt, search):
for query in search:
try:
Expand Down
2 changes: 2 additions & 0 deletions diracx-routers/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ types = [
]

[project.entry-points."diracx.services"]
pilots = "diracx.routers.pilots:router"
jobs = "diracx.routers.jobs:router"
config = "diracx.routers.configuration:router"
auth = "diracx.routers.auth:router"
Expand All @@ -55,6 +56,7 @@ auth = "diracx.routers.auth:router"
[project.entry-points."diracx.access_policies"]
WMSAccessPolicy = "diracx.routers.jobs.access_policies:WMSAccessPolicy"
SandboxAccessPolicy = "diracx.routers.jobs.access_policies:SandboxAccessPolicy"
PilotLogsAccessPolicy = "diracx.routers.pilots.access_policies:PilotLogsAccessPolicy"

# Minimum version of the client supported
[project.entry-points."diracx.min_client_version"]
Expand Down
8 changes: 5 additions & 3 deletions diracx-routers/src/diracx/routers/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
"SandboxMetadataDB",
"TaskQueueDB",
"PilotAgentsDB",
"PilotLogsDB",
"add_settings_annotation",
"AvailableSecurityProperties",
)
Expand All @@ -21,6 +22,7 @@
from diracx.core.properties import SecurityProperty
from diracx.core.settings import DevelopmentSettings as _DevelopmentSettings
from diracx.db.os import JobParametersDB as _JobParametersDB
from diracx.db.os import PilotLogsDB as _PilotLogsDB
from diracx.db.sql import AuthDB as _AuthDB
from diracx.db.sql import JobDB as _JobDB
from diracx.db.sql import JobLoggingDB as _JobLoggingDB
Expand All @@ -36,7 +38,7 @@ def add_settings_annotation(cls: T) -> T:
return Annotated[cls, Depends(cls.create)] # type: ignore


# Databases
# SQL Databases
AuthDB = Annotated[_AuthDB, Depends(_AuthDB.transaction)]
JobDB = Annotated[_JobDB, Depends(_JobDB.transaction)]
JobLoggingDB = Annotated[_JobLoggingDB, Depends(_JobLoggingDB.transaction)]
Expand All @@ -46,9 +48,9 @@ def add_settings_annotation(cls: T) -> T:
]
TaskQueueDB = Annotated[_TaskQueueDB, Depends(_TaskQueueDB.transaction)]

# Opensearch databases
# OpenSearch Databases
JobParametersDB = Annotated[_JobParametersDB, Depends(_JobParametersDB.session)]

PilotLogsDB = Annotated[_PilotLogsDB, Depends(_PilotLogsDB.session)]

# Miscellaneous
Config = Annotated[_Config, Depends(ConfigSource.create)]
Expand Down
11 changes: 11 additions & 0 deletions diracx-routers/src/diracx/routers/pilots/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from __future__ import annotations

from logging import getLogger

from ..fastapi_classes import DiracxRouter
from .logging import router as logging_router

logger = getLogger(__name__)

router = DiracxRouter()
router.include_router(logging_router)
89 changes: 89 additions & 0 deletions diracx-routers/src/diracx/routers/pilots/access_policies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
from __future__ import annotations

import logging
from enum import StrEnum, auto
from typing import Annotated, Callable

from fastapi import Depends, HTTPException, status

from diracx.core.models import ScalarSearchOperator, ScalarSearchSpec
from diracx.core.properties import (
NORMAL_USER,
)
from diracx.routers.access_policies import BaseAccessPolicy

from ..dependencies import PilotAgentsDB
from ..utils.users import AuthorizedUserInfo

logger = logging.getLogger(__name__)


class ActionType(StrEnum):
#: Create/update pilot log records
CREATE = auto()
#: Search
QUERY = auto()


class PilotLogsAccessPolicy(BaseAccessPolicy):
"""Rules:
Only NORMAL_USER in a correct VO and a diracAdmin VO member can query log records.
All other actions and users are explicitly denied access.
"""

@staticmethod
async def policy(
policy_name: str,
user_info: AuthorizedUserInfo,
/,
*,
action: ActionType | None = None,
pilot_agents_db: PilotAgentsDB | None = None,
pilot_id: int | None = None,
):
assert pilot_agents_db
if action is None:
raise HTTPException(

Check warning on line 46 in diracx-routers/src/diracx/routers/pilots/access_policies.py

View check run for this annotation

Codecov / codecov/patch

diracx-routers/src/diracx/routers/pilots/access_policies.py#L44-L46

Added lines #L44 - L46 were not covered by tests
status.HTTP_400_BAD_REQUEST, detail="Action is a mandatory argument"
)
elif action == ActionType.QUERY:
if pilot_id is None:
logger.error("Pilot ID value is not provided (None)")
raise HTTPException(

Check warning on line 52 in diracx-routers/src/diracx/routers/pilots/access_policies.py

View check run for this annotation

Codecov / codecov/patch

diracx-routers/src/diracx/routers/pilots/access_policies.py#L49-L52

Added lines #L49 - L52 were not covered by tests
status.HTTP_400_BAD_REQUEST,
detail=f"PilotID not provided: {pilot_id}",
)
search_params = ScalarSearchSpec(

Check warning on line 56 in diracx-routers/src/diracx/routers/pilots/access_policies.py

View check run for this annotation

Codecov / codecov/patch

diracx-routers/src/diracx/routers/pilots/access_policies.py#L56

Added line #L56 was not covered by tests
parameter="PilotID",
operator=ScalarSearchOperator.EQUAL,
value=pilot_id,
)

total, result = await pilot_agents_db.search(["VO"], [search_params], [])

Check warning on line 62 in diracx-routers/src/diracx/routers/pilots/access_policies.py

View check run for this annotation

Codecov / codecov/patch

diracx-routers/src/diracx/routers/pilots/access_policies.py#L62

Added line #L62 was not covered by tests
# we expect exactly one row.
if total != 1:
logger.error(

Check warning on line 65 in diracx-routers/src/diracx/routers/pilots/access_policies.py

View check run for this annotation

Codecov / codecov/patch

diracx-routers/src/diracx/routers/pilots/access_policies.py#L64-L65

Added lines #L64 - L65 were not covered by tests
"Cannot determine VO for requested PilotID: %d, found %d candidates.",
pilot_id,
total,
)
raise HTTPException(

Check warning on line 70 in diracx-routers/src/diracx/routers/pilots/access_policies.py

View check run for this annotation

Codecov / codecov/patch

diracx-routers/src/diracx/routers/pilots/access_policies.py#L70

Added line #L70 was not covered by tests
status.HTTP_400_BAD_REQUEST, detail=f"PilotID not found: {pilot_id}"
)
vo = result[0]["VO"]

Check warning on line 73 in diracx-routers/src/diracx/routers/pilots/access_policies.py

View check run for this annotation

Codecov / codecov/patch

diracx-routers/src/diracx/routers/pilots/access_policies.py#L73

Added line #L73 was not covered by tests

if user_info.vo == "diracAdmin":
return

Check warning on line 76 in diracx-routers/src/diracx/routers/pilots/access_policies.py

View check run for this annotation

Codecov / codecov/patch

diracx-routers/src/diracx/routers/pilots/access_policies.py#L75-L76

Added lines #L75 - L76 were not covered by tests

if NORMAL_USER in user_info.properties and user_info.vo == vo:
return

Check warning on line 79 in diracx-routers/src/diracx/routers/pilots/access_policies.py

View check run for this annotation

Codecov / codecov/patch

diracx-routers/src/diracx/routers/pilots/access_policies.py#L78-L79

Added lines #L78 - L79 were not covered by tests

raise HTTPException(

Check warning on line 81 in diracx-routers/src/diracx/routers/pilots/access_policies.py

View check run for this annotation

Codecov / codecov/patch

diracx-routers/src/diracx/routers/pilots/access_policies.py#L81

Added line #L81 was not covered by tests
status.HTTP_403_FORBIDDEN,
detail="You don't have permission to access this pilot's log.",
)
else:
raise NotImplementedError(action)

Check warning on line 86 in diracx-routers/src/diracx/routers/pilots/access_policies.py

View check run for this annotation

Codecov / codecov/patch

diracx-routers/src/diracx/routers/pilots/access_policies.py#L86

Added line #L86 was not covered by tests


CheckPilotLogsPolicyCallable = Annotated[Callable, Depends(PilotLogsAccessPolicy.check)]
Loading
Loading