From 820556f633bf6dadfa9bd5a554ffe51eadd3d4d3 Mon Sep 17 00:00:00 2001 From: Antoine Nguyen Date: Tue, 10 Dec 2024 15:33:55 +0100 Subject: [PATCH 1/3] Various improvements about subqueries. --- doc/source/advanced_query_operations.rst | 14 +++-- neomodel/async_/match.py | 75 +++++++++++++++++++----- neomodel/sync_/match.py | 75 +++++++++++++++++++----- neomodel/typing.py | 13 +++- test/async_/test_match_api.py | 29 +++++++++ test/sync_/test_match_api.py | 29 +++++++++ 6 files changed, 198 insertions(+), 37 deletions(-) diff --git a/doc/source/advanced_query_operations.rst b/doc/source/advanced_query_operations.rst index 73c5bbd6..de1c8c61 100644 --- a/doc/source/advanced_query_operations.rst +++ b/doc/source/advanced_query_operations.rst @@ -60,7 +60,7 @@ As discussed in the note above, this is for example useful when you need to orde Options for `intermediate_transform` *variables* are: -- `source`: `string`or `Resolver` - the variable to use as source for the transformation. Works with resolvers (see below). +- `source`: `string` or `Resolver` - the variable to use as source for the transformation. Works with resolvers (see below). - `source_prop`: `string` - optionally, a property of the source variable to use as source for the transformation. - `include_in_return`: `bool` - whether to include the variable in the return statement. Defaults to False. @@ -95,7 +95,7 @@ Subqueries The `subquery` method allows you to perform a `Cypher subquery `_ inside your query. This allows you to perform operations in isolation to the rest of your query:: from neomodel.sync_match import Collect, Last - + # This will create a CALL{} subquery # And return a variable named supps usable in the rest of your query Coffee.nodes.filter(name="Espresso") @@ -106,12 +106,18 @@ The `subquery` method allows you to perform a `Cypher subquery None: + def __init__(self, node_set, subquery_namespace: TOptional[str] = None) -> None: self.node_set = node_set self._ast = QueryAST() self._query_params: Dict = {} self._place_holder_registry: Dict = {} self._ident_count: int = 0 - self._subquery_context: bool = subquery_context + self._subquery_namespace: TOptional[str] = subquery_namespace async def build_ast(self) -> "AsyncQueryBuilder": if hasattr(self.node_set, "relations_to_fetch"): @@ -558,7 +557,7 @@ def build_traversal_from_path( # contains the primary node so _contains() works # as usual self._ast.return_clause = lhs_name - if self._subquery_context: + if self._subquery_namespace: # Don't include label in identifier if we are in a subquery lhs_ident = lhs_name elif relation["include_in_return"]: @@ -672,7 +671,10 @@ def _register_place_holder(self, key: str) -> str: self._place_holder_registry[key] += 1 else: self._place_holder_registry[key] = 1 - return key + "_" + str(self._place_holder_registry[key]) + place_holder = f"{key}_{self._place_holder_registry[key]}" + if self._subquery_namespace: + place_holder = f"{self._subquery_namespace}_{place_holder}" + return place_holder def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]: is_rel_filter = "|" in prop @@ -879,10 +881,21 @@ def build_query(self) -> str: query += ",".join(ordering) if hasattr(self.node_set, "_subqueries"): - for subquery, return_set in self.node_set._subqueries: - outer_primary_var = self._ast.return_clause - query += f" CALL {{ WITH {outer_primary_var} {subquery} }} " - for varname in return_set: + for subquery in self.node_set._subqueries: + query += " CALL {" + if subquery["initial_context"]: + query += " WITH " + context: List[str] = [] + for var in subquery["initial_context"]: + if isinstance(var, (NodeNameResolver, RelationNameResolver)): + context.append(var.resolve(self)) + else: + context.append(var) + query += ",".join(context) + + query += f"{subquery['query']} }} " + self._query_params.update(subquery["query_params"]) + for varname in subquery["return_set"]: # We declare the returned variables as "virtual" relations of the # root node class to make sure they will be translated by a call to # resolve_subgraph() (otherwise, they will be lost). @@ -893,10 +906,10 @@ def build_query(self) -> str: "variable_name": varname, "rel_variable_name": varname, } - returned_items += return_set + returned_items += subquery["return_set"] query += " RETURN " - if self._ast.return_clause and not self._subquery_context: + if self._ast.return_clause and not self._subquery_namespace: returned_items.append(self._ast.return_clause) if self._ast.additional_return: returned_items += self._ast.additional_return @@ -1120,6 +1133,8 @@ class NodeNameResolver: node: str def resolve(self, qbuilder: AsyncQueryBuilder) -> str: + if self.node == "self" and qbuilder._ast.return_clause: + return qbuilder._ast.return_clause result = qbuilder.lookup_query_variable(self.node) if result is None: raise ValueError(f"Unable to resolve variable name for node {self.node}") @@ -1238,7 +1253,7 @@ def __init__(self, source) -> None: self.relations_to_fetch: List = [] self._extra_results: List = [] - self._subqueries: list[Tuple[str, list[str]]] = [] + self._subqueries: list[Subquery] = [] self._intermediate_transforms: list = [] def __await__(self): @@ -1525,7 +1540,10 @@ async def resolve_subgraph(self) -> list: return results async def subquery( - self, nodeset: "AsyncNodeSet", return_set: List[str] + self, + nodeset: "AsyncNodeSet", + return_set: List[str], + initial_context: TOptional[List[str]] = None, ) -> "AsyncNodeSet": """Add a subquery to this node set. @@ -1534,16 +1552,41 @@ async def subquery( declared inside return_set variable in order to be included in the final RETURN statement. """ - qbuilder = await nodeset.query_cls(nodeset, subquery_context=True).build_ast() + namespace = f"sq{len(self._subqueries) + 1}" + qbuilder = await nodeset.query_cls( + nodeset, subquery_namespace=namespace + ).build_ast() for var in return_set: if ( var != qbuilder._ast.return_clause and var not in qbuilder._ast.additional_return and var not in [res["alias"] for res in nodeset._extra_results if res["alias"]] + and var + not in [ + varname + for tr in nodeset._intermediate_transforms + for varname, vardef in tr["vars"].items() + if vardef.get("include_in_return") + ] ): raise RuntimeError(f"Variable '{var}' is not returned by subquery.") - self._subqueries.append((qbuilder.build_query(), return_set)) + if initial_context: + for var in initial_context: + if type(var) is not str and not isinstance( + var, (NodeNameResolver, RelationNameResolver, RawCypher) + ): + raise ValueError( + f"Wrong variable specified in initial context, should be a string or an instance of NodeNameResolver or RelationNameResolver" + ) + self._subqueries.append( + { + "query": qbuilder.build_query(), + "query_params": qbuilder._query_params, + "return_set": return_set, + "initial_context": initial_context, + } + ) return self def intermediate_transform( diff --git a/neomodel/sync_/match.py b/neomodel/sync_/match.py index 15a49cfb..b26714ee 100644 --- a/neomodel/sync_/match.py +++ b/neomodel/sync_/match.py @@ -1,7 +1,6 @@ import inspect import re import string -import warnings from dataclasses import dataclass from typing import Any, Dict, List from typing import Optional as TOptional @@ -13,7 +12,7 @@ from neomodel.sync_ import relationship_manager from neomodel.sync_.core import StructuredNode, db from neomodel.sync_.relationship import StructuredRel -from neomodel.typing import Transformation +from neomodel.typing import Subquery, Transformation from neomodel.util import INCOMING, OUTGOING CYPHER_ACTIONS_WITH_SIDE_EFFECT_EXPR = re.compile(r"(?i:MERGE|CREATE|DELETE|DETACH)") @@ -414,13 +413,13 @@ def __init__( class QueryBuilder: - def __init__(self, node_set, subquery_context: bool = False) -> None: + def __init__(self, node_set, subquery_namespace: TOptional[str] = None) -> None: self.node_set = node_set self._ast = QueryAST() self._query_params: Dict = {} self._place_holder_registry: Dict = {} self._ident_count: int = 0 - self._subquery_context: bool = subquery_context + self._subquery_namespace: TOptional[str] = subquery_namespace def build_ast(self) -> "QueryBuilder": if hasattr(self.node_set, "relations_to_fetch"): @@ -558,7 +557,7 @@ def build_traversal_from_path( # contains the primary node so _contains() works # as usual self._ast.return_clause = lhs_name - if self._subquery_context: + if self._subquery_namespace: # Don't include label in identifier if we are in a subquery lhs_ident = lhs_name elif relation["include_in_return"]: @@ -672,7 +671,10 @@ def _register_place_holder(self, key: str) -> str: self._place_holder_registry[key] += 1 else: self._place_holder_registry[key] = 1 - return key + "_" + str(self._place_holder_registry[key]) + place_holder = f"{key}_{self._place_holder_registry[key]}" + if self._subquery_namespace: + place_holder = f"{self._subquery_namespace}_{place_holder}" + return place_holder def _parse_path(self, source_class, prop: str) -> Tuple[str, str, str, Any]: is_rel_filter = "|" in prop @@ -879,10 +881,21 @@ def build_query(self) -> str: query += ",".join(ordering) if hasattr(self.node_set, "_subqueries"): - for subquery, return_set in self.node_set._subqueries: - outer_primary_var = self._ast.return_clause - query += f" CALL {{ WITH {outer_primary_var} {subquery} }} " - for varname in return_set: + for subquery in self.node_set._subqueries: + query += " CALL {" + if subquery["initial_context"]: + query += " WITH " + context: List[str] = [] + for var in subquery["initial_context"]: + if isinstance(var, (NodeNameResolver, RelationNameResolver)): + context.append(var.resolve(self)) + else: + context.append(var) + query += ",".join(context) + + query += f"{subquery['query']} }} " + self._query_params.update(subquery["query_params"]) + for varname in subquery["return_set"]: # We declare the returned variables as "virtual" relations of the # root node class to make sure they will be translated by a call to # resolve_subgraph() (otherwise, they will be lost). @@ -893,10 +906,10 @@ def build_query(self) -> str: "variable_name": varname, "rel_variable_name": varname, } - returned_items += return_set + returned_items += subquery["return_set"] query += " RETURN " - if self._ast.return_clause and not self._subquery_context: + if self._ast.return_clause and not self._subquery_namespace: returned_items.append(self._ast.return_clause) if self._ast.additional_return: returned_items += self._ast.additional_return @@ -1118,6 +1131,8 @@ class NodeNameResolver: node: str def resolve(self, qbuilder: QueryBuilder) -> str: + if self.node == "self" and qbuilder._ast.return_clause: + return qbuilder._ast.return_clause result = qbuilder.lookup_query_variable(self.node) if result is None: raise ValueError(f"Unable to resolve variable name for node {self.node}") @@ -1236,7 +1251,7 @@ def __init__(self, source) -> None: self.relations_to_fetch: List = [] self._extra_results: List = [] - self._subqueries: list[Tuple[str, list[str]]] = [] + self._subqueries: list[Subquery] = [] self._intermediate_transforms: list = [] def __await__(self): @@ -1522,7 +1537,12 @@ def resolve_subgraph(self) -> list: ) return results - def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet": + def subquery( + self, + nodeset: "NodeSet", + return_set: List[str], + initial_context: TOptional[List[str]] = None, + ) -> "NodeSet": """Add a subquery to this node set. A subquery is a regular cypher query but executed within the context of a CALL @@ -1530,16 +1550,39 @@ def subquery(self, nodeset: "NodeSet", return_set: List[str]) -> "NodeSet": declared inside return_set variable in order to be included in the final RETURN statement. """ - qbuilder = nodeset.query_cls(nodeset, subquery_context=True).build_ast() + namespace = f"sq{len(self._subqueries) + 1}" + qbuilder = nodeset.query_cls(nodeset, subquery_namespace=namespace).build_ast() for var in return_set: if ( var != qbuilder._ast.return_clause and var not in qbuilder._ast.additional_return and var not in [res["alias"] for res in nodeset._extra_results if res["alias"]] + and var + not in [ + varname + for tr in nodeset._intermediate_transforms + for varname, vardef in tr["vars"].items() + if vardef.get("include_in_return") + ] ): raise RuntimeError(f"Variable '{var}' is not returned by subquery.") - self._subqueries.append((qbuilder.build_query(), return_set)) + if initial_context: + for var in initial_context: + if type(var) is not str and not isinstance( + var, (NodeNameResolver, RelationNameResolver, RawCypher) + ): + raise ValueError( + f"Wrong variable specified in initial context, should be a string or an instance of NodeNameResolver or RelationNameResolver" + ) + self._subqueries.append( + { + "query": qbuilder.build_query(), + "query_params": qbuilder._query_params, + "return_set": return_set, + "initial_context": initial_context, + } + ) return self def intermediate_transform( diff --git a/neomodel/typing.py b/neomodel/typing.py index 9438bd54..f0558096 100644 --- a/neomodel/typing.py +++ b/neomodel/typing.py @@ -1,6 +1,6 @@ """Custom types used for annotations.""" -from typing import Any, Optional, TypedDict +from typing import Any, Dict, List, Optional, TypedDict Transformation = TypedDict( "Transformation", @@ -10,3 +10,14 @@ "include_in_return": Optional[bool], }, ) + + +Subquery = TypedDict( + "Subquery", + { + "query": str, + "query_params": Dict, + "return_set": List[str], + "initial_context": Optional[List[Any]], + }, +) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index 2dff91c0..a494ae42 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -887,6 +887,7 @@ async def test_subquery(): ) .annotate(supps=Last(Collect("suppliers"))), ["supps"], + [NodeNameResolver("self")], ) result = await result.all() assert len(result) == 1 @@ -905,6 +906,34 @@ async def test_subquery(): ) +@mark_async_test +async def test_subquery_other_node(): + arabica = await Species(name="Arabica").save() + nescafe = await Coffee(name="Nescafe", price=99).save() + supplier1 = await Supplier(name="Supplier 1", delivery_cost=3).save() + supplier2 = await Supplier(name="Supplier 2", delivery_cost=20).save() + + await nescafe.suppliers.connect(supplier1) + await nescafe.suppliers.connect(supplier2) + await nescafe.species.connect(arabica) + + result = await Coffee.nodes.subquery( + Supplier.nodes.filter(name="Supplier 2").intermediate_transform( + { + "cost": { + "source": "supplier", + "source_prop": "delivery_cost", + "include_in_return": True, + } + } + ), + ["cost"], + ) + result = await result.all() + assert len(result) == 1 + assert result[0][0] == 20 + + @mark_async_test async def test_intermediate_transform(): arabica = await Species(name="Arabica").save() diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 4df51866..0bf69b7f 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -871,6 +871,7 @@ def test_subquery(): ) .annotate(supps=Last(Collect("suppliers"))), ["supps"], + [NodeNameResolver("self")], ) result = result.all() assert len(result) == 1 @@ -889,6 +890,34 @@ def test_subquery(): ) +@mark_sync_test +def test_subquery_other_node(): + arabica = Species(name="Arabica").save() + nescafe = Coffee(name="Nescafe", price=99).save() + supplier1 = Supplier(name="Supplier 1", delivery_cost=3).save() + supplier2 = Supplier(name="Supplier 2", delivery_cost=20).save() + + nescafe.suppliers.connect(supplier1) + nescafe.suppliers.connect(supplier2) + nescafe.species.connect(arabica) + + result = Coffee.nodes.subquery( + Supplier.nodes.filter(name="Supplier 2").intermediate_transform( + { + "cost": { + "source": "supplier", + "source_prop": "delivery_cost", + "include_in_return": True, + } + } + ), + ["cost"], + ) + result = result.all() + assert len(result) == 1 + assert result[0][0] == 20 + + @mark_sync_test def test_intermediate_transform(): arabica = Species(name="Arabica").save() From 058f5b8f9691af1b69b22cce28625f4210d800f5 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 10 Dec 2024 16:47:43 +0100 Subject: [PATCH 2/3] Increase test coverage --- doc/source/advanced_query_operations.rst | 2 +- neomodel/_version.py | 2 +- neomodel/typing.py | 8 +++---- test/async_/test_match_api.py | 29 ++++++++++++++++++++++-- test/sync_/test_match_api.py | 29 ++++++++++++++++++++++-- 5 files changed, 60 insertions(+), 10 deletions(-) diff --git a/doc/source/advanced_query_operations.rst b/doc/source/advanced_query_operations.rst index de1c8c61..74c15683 100644 --- a/doc/source/advanced_query_operations.rst +++ b/doc/source/advanced_query_operations.rst @@ -117,7 +117,7 @@ Options for `subquery` calls are: .. note:: In the example above, we reference `self` to be included in the initial context. It will actually inject the outer variable corresponding to `Coffee` node. - We know this is confusing to read, but have not found a better wat to do this yet. If you have any suggestions, please let us know. + We know this is confusing to read, but have not found a better way to do this yet. If you have any suggestions, please let us know. Helpers ------- diff --git a/neomodel/_version.py b/neomodel/_version.py index 1e41bf8f..cfda0f8e 100644 --- a/neomodel/_version.py +++ b/neomodel/_version.py @@ -1 +1 @@ -__version__ = "5.4.1" +__version__ = "5.4.2" diff --git a/neomodel/typing.py b/neomodel/typing.py index f0558096..a23f88eb 100644 --- a/neomodel/typing.py +++ b/neomodel/typing.py @@ -1,6 +1,6 @@ """Custom types used for annotations.""" -from typing import Any, Dict, List, Optional, TypedDict +from typing import Any, Optional, TypedDict Transformation = TypedDict( "Transformation", @@ -16,8 +16,8 @@ "Subquery", { "query": str, - "query_params": Dict, - "return_set": List[str], - "initial_context": Optional[List[Any]], + "query_params": dict, + "return_set": list[str], + "initial_context": Optional[list[Any]], }, ) diff --git a/test/async_/test_match_api.py b/test/async_/test_match_api.py index a494ae42..70c7f351 100644 --- a/test/async_/test_match_api.py +++ b/test/async_/test_match_api.py @@ -2,6 +2,7 @@ from datetime import datetime from test._async_compat import mark_async_test +import numpy as np from pytest import raises, skip, warns from neomodel import ( @@ -880,7 +881,7 @@ async def test_subquery(): await nescafe.suppliers.connect(supplier2) await nescafe.species.connect(arabica) - result = await Coffee.nodes.subquery( + subquery = await Coffee.nodes.subquery( Coffee.nodes.traverse_relations(suppliers="suppliers") .intermediate_transform( {"suppliers": {"source": "suppliers"}}, ordering=["suppliers.delivery_cost"] @@ -889,7 +890,7 @@ async def test_subquery(): ["supps"], [NodeNameResolver("self")], ) - result = await result.all() + result = await subquery.all() assert len(result) == 1 assert len(result[0]) == 2 assert result[0][0] == supplier2 @@ -905,6 +906,30 @@ async def test_subquery(): ["unknown"], ) + result_string_context = await subquery.subquery( + Coffee.nodes.traverse_relations(supps2="suppliers").annotate( + supps2=Collect("supps") + ), + ["supps2"], + ["supps"], + ) + result_string_context = await result_string_context.all() + assert len(result) == 1 + additional_elements = [ + item for item in result_string_context[0] if item not in result[0] + ] + assert len(additional_elements) == 1 + assert isinstance(additional_elements[0], list) + + with raises(ValueError, match=r"Wrong variable specified in initial context"): + result = await Coffee.nodes.subquery( + Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + supps=Collect("suppliers") + ), + ["supps"], + [2], + ) + @mark_async_test async def test_subquery_other_node(): diff --git a/test/sync_/test_match_api.py b/test/sync_/test_match_api.py index 0bf69b7f..94465db2 100644 --- a/test/sync_/test_match_api.py +++ b/test/sync_/test_match_api.py @@ -2,6 +2,7 @@ from datetime import datetime from test._async_compat import mark_sync_test +import numpy as np from pytest import raises, skip, warns from neomodel import ( @@ -864,7 +865,7 @@ def test_subquery(): nescafe.suppliers.connect(supplier2) nescafe.species.connect(arabica) - result = Coffee.nodes.subquery( + subquery = Coffee.nodes.subquery( Coffee.nodes.traverse_relations(suppliers="suppliers") .intermediate_transform( {"suppliers": {"source": "suppliers"}}, ordering=["suppliers.delivery_cost"] @@ -873,7 +874,7 @@ def test_subquery(): ["supps"], [NodeNameResolver("self")], ) - result = result.all() + result = subquery.all() assert len(result) == 1 assert len(result[0]) == 2 assert result[0][0] == supplier2 @@ -889,6 +890,30 @@ def test_subquery(): ["unknown"], ) + result_string_context = subquery.subquery( + Coffee.nodes.traverse_relations(supps2="suppliers").annotate( + supps2=Collect("supps") + ), + ["supps2"], + ["supps"], + ) + result_string_context = result_string_context.all() + assert len(result) == 1 + additional_elements = [ + item for item in result_string_context[0] if item not in result[0] + ] + assert len(additional_elements) == 1 + assert isinstance(additional_elements[0], list) + + with raises(ValueError, match=r"Wrong variable specified in initial context"): + result = Coffee.nodes.subquery( + Coffee.nodes.traverse_relations(suppliers="suppliers").annotate( + supps=Collect("suppliers") + ), + ["supps"], + [2], + ) + @mark_sync_test def test_subquery_other_node(): From 6c8a1f16b7a2be9b0d5aab55134dda4cb8d0d689 Mon Sep 17 00:00:00 2001 From: Marius Conjeaud Date: Tue, 10 Dec 2024 16:58:53 +0100 Subject: [PATCH 3/3] Update changelog --- Changelog | 2 ++ 1 file changed, 2 insertions(+) diff --git a/Changelog b/Changelog index 4f78f667..a93e806e 100644 --- a/Changelog +++ b/Changelog @@ -1,5 +1,7 @@ Vesion 5.4.2 2024-12 * Add support for Neo4j Rust driver extension : pip install neomodel[rust-driver-ext] +* Add initial_context parameter to subqueries +* NodeNameResolver can call self to reference top-level node Version 5.4.1 2024-11 * Add support for Cypher parallel runtime