Skip to content

Commit

Permalink
IT WORKS BUT....
Browse files Browse the repository at this point in the history
  • Loading branch information
chaen committed Jun 14, 2024
1 parent b56117e commit 24e10ac
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 2 deletions.
19 changes: 17 additions & 2 deletions diracx-routers/src/diracx/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,12 @@ def create_app_inner(
fail_startup = True
# Add the SQL DBs to the application
available_sql_db_classes: set[type[BaseSQLDB]] = set()

# SQL Classes that are overridden.
overriden_sql_db_classes: set[type[BaseSQLDB]] = set()

for db_name, db_url in database_urls.items():

try:
sql_db_classes = BaseSQLDB.available_implementations(db_name)

Expand All @@ -192,9 +197,16 @@ def create_app_inner(
app.lifetime_functions.append(sql_db.engine_context)
# Add overrides for all the DB classes, including those from extensions
# This means vanilla DiracX routers get an instance of the extension's DB
for sql_db_class in sql_db_classes:
for db_id, sql_db_class in enumerate(sql_db_classes):
assert sql_db_class.transaction not in app.dependency_overrides
available_sql_db_classes.add(sql_db_class)

# If this is not the highest priority one,

Check warning on line 204 in diracx-routers/src/diracx/routers/__init__.py

View check run for this annotation

Codecov / codecov/patch

diracx-routers/src/diracx/routers/__init__.py#L204

Added line #L204 was not covered by tests
# add it to the list of overriden DBs

if db_id != 0:
overriden_sql_db_classes.add(sql_db_class)

app.dependency_overrides[sql_db_class.transaction] = partial(
db_transaction, sql_db
)
Expand Down Expand Up @@ -248,8 +260,11 @@ def create_app_inner(

# Ensure required DBs are available
missing_sql_dbs = (
set(find_dependents(router, BaseSQLDB)) - available_sql_db_classes
set(find_dependents(router, BaseSQLDB))
- available_sql_db_classes
- overriden_sql_db_classes

Check warning on line 265 in diracx-routers/src/diracx/routers/__init__.py

View check run for this annotation

Codecov / codecov/patch

diracx-routers/src/diracx/routers/__init__.py#L265

Added line #L265 was not covered by tests
)

if missing_sql_dbs:
raise NotImplementedError(
f"Cannot enable {system_name=} as it requires {missing_sql_dbs=}"
Expand Down
1 change: 1 addition & 0 deletions diracx-testing/src/diracx/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def enrich_tokens(access_payload: dict, refresh_payload: dict):

@contextlib.contextmanager
def configure(self, enabled_dependencies):

assert (
self.app.dependency_overrides == {} and self.app.lifetime_functions == []
), "configure cannot be nested"
Expand Down
4 changes: 4 additions & 0 deletions extensions/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ Do not forget to add the database in your `[project.entry-points."diracx.db.sql"

``GubbinsJobDB`` illustrate how to extend an existing ``diracx`` DB, add new methods, modify methods, add a table.

A router test exists, even though no router is redefined.

WARNING: in the dependency override, you need to specify both the original DiracX JobDB as well as the extended one GubbinsJobDB



## Routers
Expand Down
1 change: 1 addition & 0 deletions extensions/gubbins/gubbins-db/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ testing = ["diracx-testing"]

[project.entry-points."diracx.db.sql"]
LollygagDB = "gubbins.db.sql:LollygagDB"
JobDB = "gubbins.db.sql:GubbinsJobDB"

[tool.setuptools.packages.find]
where = ["src"]
Expand Down
12 changes: 12 additions & 0 deletions extensions/gubbins/gubbins-db/src/gubbins/db/sql/jobs/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,15 @@ async def getJobJDL( # type: ignore[override]

info = (await self.conn.execute(stmt)).scalar_one()
return {"JDL": jdl, "Info": info}

async def setJobAttributes(self, job_id, jobData):
"""
This method modified the one in the parent class,
without changing the argument nor the return type
Also, this method is called by the router via the status_utility
so we can test in test_gubbins_job_router that the behavior
is altered without even redefining a gubbins specific router
"""
# We do nothing
...
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import pytest
from diracx.core.models import JobStatus
from fastapi.testclient import TestClient

pytestmark = pytest.mark.enabled_dependencies(
[
"AuthSettings",
# CAUTION !!!
# You need to put both the original AND your extended one
"JobDB",
"GubbinsJobDB",
#
"JobLoggingDB",
"WMSAccessPolicy",
"ConfigSource",
"TaskQueueDB",
]
)


TEST_JDL = """
Arguments = "jobDescription.xml -o LogLevel=INFO";
Executable = "dirac-jobexec";
JobGroup = jobGroup;
JobName = jobName;
JobType = User;
LogLevel = INFO;
OutputSandbox =
{
Script1_CodeOutput.log,
std.err,
std.out
};
Priority = 1;
Site = ANY;
StdError = std.err;
StdOutput = std.out;
"""


@pytest.fixture
def normal_user_client(client_factory):
with client_factory.normal_user() as client:
yield client


@pytest.fixture
def valid_job_id(normal_user_client: TestClient):
"""
Copied from the vanila tests
This ensures that the submission route works
"""
job_definitions = [TEST_JDL]
r = normal_user_client.post("/api/jobs/", json=job_definitions)
assert r.status_code == 200, r.json()
assert len(r.json()) == 1
return r.json()[0]["JobID"]


def test_gubbins_job_router(normal_user_client, valid_job_id):
"""
Basically like diracx test_delete_job_valid_job_id
except that the job does not go into DELETED status,
as the method is intercepted by the DB
"""

# We search for the job
r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status")
assert r.status_code == 200, r.json()
assert r.json()[str(valid_job_id)]["Status"] == JobStatus.RECEIVED

# We delete the job, and here we expect that nothing
# actually happened
r = normal_user_client.delete(f"/api/jobs/{valid_job_id}")
assert r.status_code == 200, r.json()

r = normal_user_client.get(f"/api/jobs/{valid_job_id}/status")
assert r.status_code == 200, r.json()
# The job would normally be deleted
assert r.json()[str(valid_job_id)]["Status"] == JobStatus.RECEIVED

0 comments on commit 24e10ac

Please sign in to comment.