Skip to content

Commit

Permalink
Add support for parallel runtime
Browse files Browse the repository at this point in the history
  • Loading branch information
mariusconjeaud committed Nov 25, 2024
1 parent f866109 commit 1355fa1
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 7 deletions.
12 changes: 12 additions & 0 deletions neomodel/async_/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -598,6 +598,18 @@ async def edition_is_enterprise(self) -> bool:
edition = await self.database_edition
return edition == "enterprise"

@ensure_connection
async def parallel_runtime_available(self) -> bool:
"""Returns true if the database supports parallel runtime
Returns:
bool: True if the database supports parallel runtime
"""
return (
await self.version_is_higher_than("5.13")
and await self.edition_is_enterprise()
)

async def change_neo4j_password(self, user, new_password):
await self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'")

Expand Down
36 changes: 32 additions & 4 deletions neomodel/async_/match.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import re
import string
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List
from typing import Optional as TOptional
Expand Down Expand Up @@ -396,6 +397,7 @@ def __init__(
lookup: TOptional[str] = None,
additional_return: TOptional[List[str]] = None,
is_count: TOptional[bool] = False,
use_parallel_runtime: TOptional[bool] = False,
) -> None:
self.match = match if match else []
self.optional_match = optional_match if optional_match else []
Expand All @@ -409,6 +411,7 @@ def __init__(
self.lookup = lookup
self.additional_return = additional_return if additional_return else []
self.is_count = is_count
self.use_parallel_runtime = use_parallel_runtime
self.subgraph: Dict = {}


Expand All @@ -432,6 +435,19 @@ async def build_ast(self) -> "AsyncQueryBuilder":
self._ast.skip = self.node_set.skip
if hasattr(self.node_set, "limit"):
self._ast.limit = self.node_set.limit
if hasattr(self.node_set, "use_parallel_runtime"):
if (
self.node_set.use_parallel_runtime
and not await adb.parallel_runtime_available()
):
warnings.warn(
"Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. "
"Reverting to default runtime.",
UserWarning,
)
self.node_set.use_parallel_runtime = False
else:
self._ast.use_parallel_runtime = self.node_set.use_parallel_runtime

return self

Expand Down Expand Up @@ -589,9 +605,11 @@ def build_traversal_from_path(
}
else:
existing_rhs_name = subgraph[part][
"rel_variable_name"
if relation.get("relation_filtering")
else "variable_name"
(
"rel_variable_name"
if relation.get("relation_filtering")
else "variable_name"
)
]
if relation["include_in_return"] and not already_present:
self._additional_return(rel_ident)
Expand Down Expand Up @@ -812,6 +830,8 @@ def lookup_query_variable(
def build_query(self) -> str:
query: str = ""

if self._ast.use_parallel_runtime:
query += "CYPHER runtime=parallel "

Check warning on line 834 in neomodel/async_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/async_/match.py#L834

Added line #L834 was not covered by tests
if self._ast.lookup:
query += self._ast.lookup

Expand Down Expand Up @@ -973,7 +993,9 @@ async def _execute(self, lazy: bool = False, dict_output: bool = False):
]
query = self.build_query()
results, prop_names = await adb.cypher_query(
query, self._query_params, resolve_objects=True
query,
self._query_params,
resolve_objects=True,
)
if dict_output:
for item in results:
Expand Down Expand Up @@ -1236,6 +1258,8 @@ def __init__(self, source) -> None:
self._subqueries: list[Tuple[str, list[str]]] = []
self._intermediate_transforms: list = []

self.use_parallel_runtime = False

def __await__(self):
return self.all().__await__()

Expand Down Expand Up @@ -1564,6 +1588,10 @@ def intermediate_transform(
)
return self

def parallel_runtime(self) -> "AsyncNodeSet":
self.use_parallel_runtime = True
return self


class AsyncTraversal(AsyncBaseSet):
"""
Expand Down
9 changes: 9 additions & 0 deletions neomodel/sync_/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,6 +596,15 @@ def edition_is_enterprise(self) -> bool:
edition = self.database_edition
return edition == "enterprise"

@ensure_connection
def parallel_runtime_available(self) -> bool:
"""Returns true if the database supports parallel runtime
Returns:
bool: True if the database supports parallel runtime
"""
return self.version_is_higher_than("5.13") and self.edition_is_enterprise()

def change_neo4j_password(self, user, new_password):
self.cypher_query(f"ALTER USER {user} SET PASSWORD '{new_password}'")

Expand Down
28 changes: 27 additions & 1 deletion neomodel/sync_/match.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import inspect
import re
import string
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List
from typing import Optional as TOptional
Expand Down Expand Up @@ -396,6 +397,7 @@ def __init__(
lookup: TOptional[str] = None,
additional_return: TOptional[List[str]] = None,
is_count: TOptional[bool] = False,
use_parallel_runtime: TOptional[bool] = False,
) -> None:
self.match = match if match else []
self.optional_match = optional_match if optional_match else []
Expand All @@ -409,6 +411,7 @@ def __init__(
self.lookup = lookup
self.additional_return = additional_return if additional_return else []
self.is_count = is_count
self.use_parallel_runtime = use_parallel_runtime
self.subgraph: Dict = {}


Expand All @@ -432,6 +435,19 @@ def build_ast(self) -> "QueryBuilder":
self._ast.skip = self.node_set.skip
if hasattr(self.node_set, "limit"):
self._ast.limit = self.node_set.limit
if hasattr(self.node_set, "use_parallel_runtime"):
if (
self.node_set.use_parallel_runtime
and not db.parallel_runtime_available()
):
warnings.warn(
"Parallel runtime is only available in Neo4j Enterprise Edition 5.13 and above. "
"Reverting to default runtime.",
UserWarning,
)
self.node_set.use_parallel_runtime = False
else:
self._ast.use_parallel_runtime = self.node_set.use_parallel_runtime

return self

Expand Down Expand Up @@ -814,6 +830,8 @@ def lookup_query_variable(
def build_query(self) -> str:
query: str = ""

if self._ast.use_parallel_runtime:
query += "CYPHER runtime=parallel "

Check warning on line 834 in neomodel/sync_/match.py

View check run for this annotation

Codecov / codecov/patch

neomodel/sync_/match.py#L834

Added line #L834 was not covered by tests
if self._ast.lookup:
query += self._ast.lookup

Expand Down Expand Up @@ -973,7 +991,9 @@ def _execute(self, lazy: bool = False, dict_output: bool = False):
]
query = self.build_query()
results, prop_names = db.cypher_query(
query, self._query_params, resolve_objects=True
query,
self._query_params,
resolve_objects=True,
)
if dict_output:
for item in results:
Expand Down Expand Up @@ -1236,6 +1256,8 @@ def __init__(self, source) -> None:
self._subqueries: list[Tuple[str, list[str]]] = []
self._intermediate_transforms: list = []

self.use_parallel_runtime = False

def __await__(self):
return self.all().__await__()

Expand Down Expand Up @@ -1562,6 +1584,10 @@ def intermediate_transform(
)
return self

def parallel_runtime(self) -> "NodeSet":
self.use_parallel_runtime = True
return self


class Traversal(BaseSet):
"""
Expand Down
33 changes: 32 additions & 1 deletion test/async_/test_match_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import re
import warnings
from datetime import datetime
from test._async_compat import mark_async_test

from pytest import raises
from pytest import raises, warns

from neomodel import (
INCOMING,
Expand Down Expand Up @@ -1113,3 +1114,33 @@ async def test_async_iterator():

# assert that generator runs loop above
assert counter == n


@mark_async_test
async def test_parallel_runtime():
await Coffee(name="Java", price=99).save()

node_set = AsyncNodeSet(Coffee).parallel_runtime()

assert node_set.use_parallel_runtime

if (
not await adb.version_is_higher_than("5.13")
or not await adb.edition_is_enterprise()
):
assert not await adb.parallel_runtime_available()
with warns(
UserWarning,
match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13",
):
qb = await AsyncQueryBuilder(node_set).build_ast()
assert not qb._ast.use_parallel_runtime
assert not qb.build_query().startswith("CYPHER runtime=parallel")
else:
assert await adb.parallel_runtime_available()
qb = await AsyncQueryBuilder(node_set).build_ast()
assert qb._ast.use_parallel_runtime
assert qb.build_query().startswith("CYPHER runtime=parallel")

results = [node async for node in qb._execute()]
assert len(results) == 1
30 changes: 29 additions & 1 deletion test/sync_/test_match_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import re
import warnings
from datetime import datetime
from test._async_compat import mark_sync_test

from pytest import raises
from pytest import raises, warns

from neomodel import (
INCOMING,
Expand Down Expand Up @@ -1097,3 +1098,30 @@ def test_async_iterator():

# assert that generator runs loop above
assert counter == n


@mark_sync_test
def test_parallel_runtime():
Coffee(name="Java", price=99).save()

node_set = NodeSet(Coffee).parallel_runtime()

assert node_set.use_parallel_runtime

if not db.version_is_higher_than("5.13") or not db.edition_is_enterprise():
assert not db.parallel_runtime_available()
with warns(
UserWarning,
match="Parallel runtime is only available in Neo4j Enterprise Edition 5.13",
):
qb = QueryBuilder(node_set).build_ast()
assert not qb._ast.use_parallel_runtime
assert not qb.build_query().startswith("CYPHER runtime=parallel")
else:
assert db.parallel_runtime_available()
qb = QueryBuilder(node_set).build_ast()
assert qb._ast.use_parallel_runtime
assert qb.build_query().startswith("CYPHER runtime=parallel")

results = [node for node in qb._execute()]
assert len(results) == 1

0 comments on commit 1355fa1

Please sign in to comment.