Skip to content

Commit

Permalink
Merge pull request #845 from neo4j-contrib/fix/subquery_filter_values
Browse files Browse the repository at this point in the history
Various improvements about subqueries.
  • Loading branch information
mariusconjeaud authored Dec 10, 2024
2 parents 60f84d1 + 6c8a1f1 commit a0c0c00
Show file tree
Hide file tree
Showing 8 changed files with 254 additions and 41 deletions.
2 changes: 2 additions & 0 deletions Changelog
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
Vesion 5.4.2 2024-12
* Add support for Neo4j Rust driver extension : pip install neomodel[rust-driver-ext]
* Add initial_context parameter to subqueries
* NodeNameResolver can call self to reference top-level node

Version 5.4.1 2024-11
* Add support for Cypher parallel runtime
Expand Down
14 changes: 10 additions & 4 deletions doc/source/advanced_query_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ As discussed in the note above, this is for example useful when you need to orde

Options for `intermediate_transform` *variables* are:

- `source`: `string`or `Resolver` - the variable to use as source for the transformation. Works with resolvers (see below).
- `source`: `string` or `Resolver` - the variable to use as source for the transformation. Works with resolvers (see below).
- `source_prop`: `string` - optionally, a property of the source variable to use as source for the transformation.
- `include_in_return`: `bool` - whether to include the variable in the return statement. Defaults to False.

Expand Down Expand Up @@ -95,7 +95,7 @@ Subqueries
The `subquery` method allows you to perform a `Cypher subquery <https://neo4j.com/docs/cypher-manual/current/subqueries/call-subquery/>`_ inside your query. This allows you to perform operations in isolation to the rest of your query::

from neomodel.sync_match import Collect, Last

# This will create a CALL{} subquery
# And return a variable named supps usable in the rest of your query
Coffee.nodes.filter(name="Espresso")
Expand All @@ -106,12 +106,18 @@ The `subquery` method allows you to perform a `Cypher subquery <https://neo4j.co
)
.annotate(supps=Last(Collect("suppliers"))),
["supps"],
[NodeNameResolver("self")]
)

Options for `subquery` calls are:

- `return_set`: list of `string` - the subquery variables that should be included in the outer query result
- `initial_context`: optional list of `string` or `Resolver` - the outer query variables that will be injected at the begining of the subquery

.. note::
Notice the subquery starts with Coffee.nodes ; neomodel will use this to know it needs to inject the source "coffee" variable generated by the outer query into the subquery. This means only Espresso coffee nodes will be considered in the subquery.
In the example above, we reference `self` to be included in the initial context. It will actually inject the outer variable corresponding to `Coffee` node.

We know this is confusing to read, but have not found a better wat to do this yet. If you have any suggestions, please let us know.
We know this is confusing to read, but have not found a better way to do this yet. If you have any suggestions, please let us know.

Helpers
-------
Expand Down
2 changes: 1 addition & 1 deletion neomodel/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "5.4.1"
__version__ = "5.4.2"
75 changes: 59 additions & 16 deletions neomodel/async_/match.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import inspect
import re
import string
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List
from typing import Optional as TOptional
Expand All @@ -13,7 +12,7 @@
from neomodel.exceptions import MultipleNodesReturned
from neomodel.match_q import Q, QBase
from neomodel.properties import AliasProperty, ArrayProperty, Property
from neomodel.typing import Transformation
from neomodel.typing import Subquery, Transformation
from neomodel.util import INCOMING, OUTGOING

CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)")
Expand Down Expand Up @@ -414,13 +413,13 @@ def __init__(


class AsyncQueryBuilder:
def __init__(self, node_set, subquery_context: bool = False) -> None:
def __init__(self, node_set, subquery_namespace: TOptional[str] = None) -> None:
self.node_set = node_set
self._ast = QueryAST()
self._query_params: Dict = {}
self._place_holder_registry: Dict = {}
self._ident_count: int = 0
self._subquery_context: bool = subquery_context
self._subquery_namespace: TOptional[str] = subquery_namespace

async def build_ast(self) -> "AsyncQueryBuilder":
if hasattr(self.node_set, "relations_to_fetch"):
Expand Down Expand Up @@ -558,7 +557,7 @@ def build_traversal_from_path(
# contains the primary node so _contains() works
# as usual
self._ast.return_clause = lhs_name
if self._subquery_context:
if self._subquery_namespace:
# Don't include label in identifier if we are in a subquery
lhs_ident = lhs_name
elif relation["include_in_return"]:
Expand Down Expand Up @@ -672,7 +671,10 @@ def _register_place_holder(self, key: str) -> str:
self._place_holder_registry[key] += 1
else:
self._place_holder_registry[key] = 1
return key + "_" + str(self._place_holder_registry[key])
place_holder = f"{key}_{self._place_holder_registry[key]}"
if self._subquery_namespace:
place_holder = f"{self._subquery_namespace}_{place_holder}"
return place_holder

def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]:
is_rel_filter = "|" in prop
Expand Down Expand Up @@ -879,10 +881,21 @@ def build_query(self) -> str:
query += ",".join(ordering)

if hasattr(self.node_set, "_subqueries"):
for subquery, return_set in self.node_set._subqueries:
outer_primary_var = self._ast.return_clause
query += f" CALL {{ WITH {outer_primary_var} {subquery} }} "
for varname in return_set:
for subquery in self.node_set._subqueries:
query += " CALL {"
if subquery["initial_context"]:
query += " WITH "
context: List[str] = []
for var in subquery["initial_context"]:
if isinstance(var, (NodeNameResolver, RelationNameResolver)):
context.append(var.resolve(self))
else:
context.append(var)
query += ",".join(context)

query += f"{subquery['query']} }} "
self._query_params.update(subquery["query_params"])
for varname in subquery["return_set"]:
# We declare the returned variables as "virtual" relations of the
# root node class to make sure they will be translated by a call to
# resolve_subgraph() (otherwise, they will be lost).
Expand All @@ -893,10 +906,10 @@ def build_query(self) -> str:
"variable_name": varname,
"rel_variable_name": varname,
}
returned_items += return_set
returned_items += subquery["return_set"]

query += " RETURN "
if self._ast.return_clause and not self._subquery_context:
if self._ast.return_clause and not self._subquery_namespace:
returned_items.append(self._ast.return_clause)
if self._ast.additional_return:
returned_items += self._ast.additional_return
Expand Down Expand Up @@ -1120,6 +1133,8 @@ class NodeNameResolver:
node: str

def resolve(self, qbuilder: AsyncQueryBuilder) -> str:
if self.node == "self" and qbuilder._ast.return_clause:
return qbuilder._ast.return_clause
result = qbuilder.lookup_query_variable(self.node)
if result is None:
raise ValueError(f"Unable to resolve variable name for node {self.node}")
Expand Down Expand Up @@ -1238,7 +1253,7 @@ def __init__(self, source) -> None:

self.relations_to_fetch: List = []
self._extra_results: List = []
self._subqueries: list[Tuple[str, list[str]]] = []
self._subqueries: list[Subquery] = []
self._intermediate_transforms: list = []

def __await__(self):
Expand Down Expand Up @@ -1525,7 +1540,10 @@ async def resolve_subgraph(self) -> list:
return results

async def subquery(
self, nodeset: "AsyncNodeSet", return_set: List[str]
self,
nodeset: "AsyncNodeSet",
return_set: List[str],
initial_context: TOptional[List[str]] = None,
) -> "AsyncNodeSet":
"""Add a subquery to this node set.
Expand All @@ -1534,16 +1552,41 @@ async def subquery(
declared inside return_set variable in order to be included in the final RETURN
statement.
"""
qbuilder = await nodeset.query_cls(nodeset, subquery_context=True).build_ast()
namespace = f"sq{len(self._subqueries) + 1}"
qbuilder = await nodeset.query_cls(
nodeset, subquery_namespace=namespace
).build_ast()
for var in return_set:
if (
var != qbuilder._ast.return_clause
and var not in qbuilder._ast.additional_return
and var
not in [res["alias"] for res in nodeset._extra_results if res["alias"]]
and var
not in [
varname
for tr in nodeset._intermediate_transforms
for varname, vardef in tr["vars"].items()
if vardef.get("include_in_return")
]
):
raise RuntimeError(f"Variable '{var}' is not returned by subquery.")
self._subqueries.append((qbuilder.build_query(), return_set))
if initial_context:
for var in initial_context:
if type(var) is not str and not isinstance(
var, (NodeNameResolver, RelationNameResolver, RawCypher)
):
raise ValueError(
f"Wrong variable specified in initial context, should be a string or an instance of NodeNameResolver or RelationNameResolver"
)
self._subqueries.append(
{
"query": qbuilder.build_query(),
"query_params": qbuilder._query_params,
"return_set": return_set,
"initial_context": initial_context,
}
)
return self

def intermediate_transform(
Expand Down
75 changes: 59 additions & 16 deletions neomodel/sync_/match.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import inspect
import re
import string
import warnings
from dataclasses import dataclass
from typing import Any, Dict, List
from typing import Optional as TOptional
Expand All @@ -13,7 +12,7 @@
from neomodel.sync_ import relationship_manager
from neomodel.sync_.core import StructuredNode, db
from neomodel.sync_.relationship import StructuredRel
from neomodel.typing import Transformation
from neomodel.typing import Subquery, Transformation
from neomodel.util import INCOMING, OUTGOING

CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)")
Expand Down Expand Up @@ -414,13 +413,13 @@ def __init__(


class QueryBuilder:
def __init__(self, node_set, subquery_context: bool = False) -> None:
def __init__(self, node_set, subquery_namespace: TOptional[str] = None) -> None:
self.node_set = node_set
self._ast = QueryAST()
self._query_params: Dict = {}
self._place_holder_registry: Dict = {}
self._ident_count: int = 0
self._subquery_context: bool = subquery_context
self._subquery_namespace: TOptional[str] = subquery_namespace

def build_ast(self) -> "QueryBuilder":
if hasattr(self.node_set, "relations_to_fetch"):
Expand Down Expand Up @@ -558,7 +557,7 @@ def build_traversal_from_path(
# contains the primary node so _contains() works
# as usual
self._ast.return_clause = lhs_name
if self._subquery_context:
if self._subquery_namespace:
# Don't include label in identifier if we are in a subquery
lhs_ident = lhs_name
elif relation["include_in_return"]:
Expand Down Expand Up @@ -672,7 +671,10 @@ def _register_place_holder(self, key: str) -> str:
self._place_holder_registry[key] += 1
else:
self._place_holder_registry[key] = 1
return key + "_" + str(self._place_holder_registry[key])
place_holder = f"{key}_{self._place_holder_registry[key]}"
if self._subquery_namespace:
place_holder = f"{self._subquery_namespace}_{place_holder}"
return place_holder

def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]:
is_rel_filter = "|" in prop
Expand Down Expand Up @@ -879,10 +881,21 @@ def build_query(self) -> str:
query += ",".join(ordering)

if hasattr(self.node_set, "_subqueries"):
for subquery, return_set in self.node_set._subqueries:
outer_primary_var = self._ast.return_clause
query += f" CALL {{ WITH {outer_primary_var} {subquery} }} "
for varname in return_set:
for subquery in self.node_set._subqueries:
query += " CALL {"
if subquery["initial_context"]:
query += " WITH "
context: List[str] = []
for var in subquery["initial_context"]:
if isinstance(var, (NodeNameResolver, RelationNameResolver)):
context.append(var.resolve(self))
else:
context.append(var)
query += ",".join(context)

query += f"{subquery['query']} }} "
self._query_params.update(subquery["query_params"])
for varname in subquery["return_set"]:
# We declare the returned variables as "virtual" relations of the
# root node class to make sure they will be translated by a call to
# resolve_subgraph() (otherwise, they will be lost).
Expand All @@ -893,10 +906,10 @@ def build_query(self) -> str:
"variable_name": varname,
"rel_variable_name": varname,
}
returned_items += return_set
returned_items += subquery["return_set"]

query += " RETURN "
if self._ast.return_clause and not self._subquery_context:
if self._ast.return_clause and not self._subquery_namespace:
returned_items.append(self._ast.return_clause)
if self._ast.additional_return:
returned_items += self._ast.additional_return
Expand Down Expand Up @@ -1118,6 +1131,8 @@ class NodeNameResolver:
node: str

def resolve(self, qbuilder: QueryBuilder) -> str:
if self.node == "self" and qbuilder._ast.return_clause:
return qbuilder._ast.return_clause
result = qbuilder.lookup_query_variable(self.node)
if result is None:
raise ValueError(f"Unable to resolve variable name for node {self.node}")
Expand Down Expand Up @@ -1236,7 +1251,7 @@ def __init__(self, source) -> None:

self.relations_to_fetch: List = []
self._extra_results: List = []
self._subqueries: list[Tuple[str, list[str]]] = []
self._subqueries: list[Subquery] = []
self._intermediate_transforms: list = []

def __await__(self):
Expand Down Expand Up @@ -1522,24 +1537,52 @@ def resolve_subgraph(self) -> list:
)
return results

def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet":
def subquery(
self,
nodeset: "NodeSet",
return_set: List[str],
initial_context: TOptional[List[str]] = None,
) -> "NodeSet":
"""Add a subquery to this node set.
A subquery is a regular cypher query but executed within the context of a CALL
statement. Such query will generally fetch additional variables which must be
declared inside return_set variable in order to be included in the final RETURN
statement.
"""
qbuilder = nodeset.query_cls(nodeset, subquery_context=True).build_ast()
namespace = f"sq{len(self._subqueries) + 1}"
qbuilder = nodeset.query_cls(nodeset, subquery_namespace=namespace).build_ast()
for var in return_set:
if (
var != qbuilder._ast.return_clause
and var not in qbuilder._ast.additional_return
and var
not in [res["alias"] for res in nodeset._extra_results if res["alias"]]
and var
not in [
varname
for tr in nodeset._intermediate_transforms
for varname, vardef in tr["vars"].items()
if vardef.get("include_in_return")
]
):
raise RuntimeError(f"Variable '{var}' is not returned by subquery.")
self._subqueries.append((qbuilder.build_query(), return_set))
if initial_context:
for var in initial_context:
if type(var) is not str and not isinstance(
var, (NodeNameResolver, RelationNameResolver, RawCypher)
):
raise ValueError(
f"Wrong variable specified in initial context, should be a string or an instance of NodeNameResolver or RelationNameResolver"
)
self._subqueries.append(
{
"query": qbuilder.build_query(),
"query_params": qbuilder._query_params,
"return_set": return_set,
"initial_context": initial_context,
}
)
return self

def intermediate_transform(
Expand Down
11 changes: 11 additions & 0 deletions neomodel/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,14 @@
"include_in_return": Optional[bool],
},
)


Subquery = TypedDict(
"Subquery",
{
"query": str,
"query_params": dict,
"return_set": list[str],
"initial_context": Optional[list[Any]],
},
)
Loading

0 comments on commit a0c0c00

Please sign in to comment.