Skip to content

Commit

Permalink
refactor: Use make-unasync to generate /sync_ scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
wbenbihi committed Jul 30, 2024
1 parent 724e80b commit a9b4b29
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 22 deletions.
129 changes: 107 additions & 22 deletions neomodel/sync_/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ def _rel_merge_helper(
_SPECIAL_OPERATOR_ISNULL = "IS NULL"
_SPECIAL_OPERATOR_ISNOTNULL = "IS NOT NULL"
_SPECIAL_OPERATOR_REGEX = "=~"
_SPECIAL_OPERATOR_INCLUDES = "{val} IN {ident}.{prop}"
_SPECIAL_OPERATOR_INCLUDES_ALL = "all(x IN {val} WHERE x IN {ident}.{prop})"
_SPECIAL_OPERATOR_INCLUDES_ANY = "any(x IN {val} WHERE x IN {ident}.{prop})"

_UNARY_OPERATORS = (_SPECIAL_OPERATOR_ISNULL, _SPECIAL_OPERATOR_ISNOTNULL)

Expand Down Expand Up @@ -190,6 +193,9 @@ def _rel_merge_helper(
"isnull": _SPECIAL_OPERATOR_ISNULL,
"regex": _SPECIAL_OPERATOR_REGEX,
"exact": "=",
"includes": _SPECIAL_OPERATOR_INCLUDES,
"includes_all": _SPECIAL_OPERATOR_INCLUDES_ALL,
"includes_any": _SPECIAL_OPERATOR_INCLUDES_ANY,
}
# add all regex operators
OPERATOR_TABLE.update(_REGEX_OPERATOR_TABLE)
Expand Down Expand Up @@ -256,6 +262,61 @@ def process_filter_args(cls, kwargs):
return output


def transform_includes_operator_to_filter(
operator, filter_key, filter_value, property_obj
):
"""
Transform includes operator to a cypher filter
Args:
operator (str): operator to transform
filter_key (str): filter key
filter_value (str): filter value
property_obj (object): property object
Returns:
tuple: operator, deflated_value
"""
if not isinstance(filter_value, str):
raise ValueError(
f"Value must be a string for INCLUDES operation {filter_key}={filter_value}"
)
if not isinstance(property_obj, ArrayProperty):
raise ValueError(
f"Property {filter_key} must be an ArrayProperty to use INCLUDES operation"
)
deflated_value = filter_value
return operator, deflated_value


def transform_includes_all_any_operator_to_filter(
operator, filter_key, filter_value, property_obj
):
"""
Transform includes__all/any operator to a cypher filter
Args:
operator (str): operator to transform
filter_key (str): filter key
filter_value (str): filter value
property_obj (object): property object
Returns:
tuple: operator, deflated_value
"""
if not isinstance(filter_value, (tuple, list)):
raise ValueError(
f"Value must be an iterable for INCLUDES operation {filter_key}={filter_value}"
)
if not isinstance(property_obj, ArrayProperty):
raise ValueError(
f"Property {filter_key} must be an ArrayProperty to use INCLUDES operation"
)
deflated_value = property_obj.deflate(filter_value)
selected_operator = (
_SPECIAL_OPERATOR_INCLUDES_ANY
if operator == _SPECIAL_OPERATOR_INCLUDES_ANY
else _SPECIAL_OPERATOR_INCLUDES_ALL
)
return selected_operator, deflated_value


def transform_in_operator_to_filter(operator, filter_key, filter_value, property_obj):
"""
Transform in operator to a cypher filter
Expand All @@ -280,7 +341,7 @@ def transform_in_operator_to_filter(operator, filter_key, filter_value, property
return operator, deflated_value


def transform_null_operator_to_filter(filter_key, filter_value):
def transform_null_operator_to_filter(filter_key, filter_value, **kwargs):
"""
Transform null operator to a cypher filter
Args:
Expand Down Expand Up @@ -320,28 +381,29 @@ def transform_regex_operator_to_filter(
return operator, deflated_value


TRANSFORM_TABLE = {
(_SPECIAL_OPERATOR_IN,): transform_in_operator_to_filter,
(_SPECIAL_OPERATOR_ISNULL,): transform_null_operator_to_filter,
tuple(_REGEX_OPERATOR_TABLE.values()): transform_regex_operator_to_filter,
(_SPECIAL_OPERATOR_INCLUDES,): transform_includes_operator_to_filter,
(
_SPECIAL_OPERATOR_INCLUDES_ALL,
_SPECIAL_OPERATOR_INCLUDES_ANY,
): transform_includes_all_any_operator_to_filter,
}


def transform_operator_to_filter(operator, filter_key, filter_value, property_obj):
if operator == _SPECIAL_OPERATOR_IN:
operator, deflated_value = transform_in_operator_to_filter(
operator=operator,
filter_key=filter_key,
filter_value=filter_value,
property_obj=property_obj,
)
elif operator == _SPECIAL_OPERATOR_ISNULL:
operator, deflated_value = transform_null_operator_to_filter(
filter_key=filter_key, filter_value=filter_value
)
elif operator in _REGEX_OPERATOR_TABLE.values():
operator, deflated_value = transform_regex_operator_to_filter(
operator=operator,
filter_key=filter_key,
filter_value=filter_value,
property_obj=property_obj,
)
else:
deflated_value = property_obj.deflate(filter_value)
for ops_it, transform in TRANSFORM_TABLE.items():
if operator in ops_it:
return transform(
operator=operator,
filter_key=filter_key,
filter_value=filter_value,
property_obj=property_obj,
)

deflated_value = property_obj.deflate(filter_value)
return operator, deflated_value


Expand Down Expand Up @@ -634,7 +696,12 @@ def _parse_q_filters(self, ident, q, source_class):
statement = f"{ident}.{prop} {operator}"
else:
place_holder = self._register_place_holder(ident + "_" + prop)
if operator == _SPECIAL_OPERATOR_ARRAY_IN:
if operator in [
_SPECIAL_OPERATOR_ARRAY_IN,
_SPECIAL_OPERATOR_INCLUDES,
_SPECIAL_OPERATOR_INCLUDES_ALL,
_SPECIAL_OPERATOR_INCLUDES_ANY,
]:
statement = operator.format(
ident=ident,
prop=prop,
Expand Down Expand Up @@ -674,6 +741,21 @@ def build_where_stmt(self, ident, filters, q_filters=None, source_class=None):
statement = (
f"{'NOT' if negate else ''} {ident}.{prop} {operator}"
)
# Fix IN operator for Traversal Sets
# Potential bug: Must be investigated if it is really an issue
elif operator in [
_SPECIAL_OPERATOR_ARRAY_IN,
_SPECIAL_OPERATOR_INCLUDES,
_SPECIAL_OPERATOR_INCLUDES_ALL,
_SPECIAL_OPERATOR_INCLUDES_ANY,
]:
place_holder = self._register_place_holder(ident + "_" + prop)
self._query_params[place_holder] = val
statement = operator.format(
ident=ident,
prop=prop,
val=f"${place_holder}",
)
else:
place_holder = self._register_place_holder(ident + "_" + prop)
statement = f"{'NOT' if negate else ''} {ident}.{prop} {operator} ${place_holder}"
Expand Down Expand Up @@ -1000,6 +1082,9 @@ def filter(self, *args, **kwargs):
* 'istartswith': case insensitive string starts with
* 'endswith': string ends with
* 'iendswith': case insensitive string ends with
* 'includes': array contains value
* 'includes_all': array contains all values
* 'includes_any': array contains any of the values
:return: self
"""
Expand Down
106 changes: 106 additions & 0 deletions test/sync_/test_match_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
StructuredNode,
StructuredRel,
UniqueIdProperty,
ZeroOrOne,
)
from neomodel._async_compat.util import Util
from neomodel.exceptions import MultipleNodesReturned, RelationshipClassNotDefined
Expand Down Expand Up @@ -72,6 +73,23 @@ class PersonX(StructuredNode):
city = RelationshipTo(CityX, "LIVES_IN")


class MemberOfRelationship(StructuredRel):
tags = ArrayProperty(StringProperty(), required=True)


class Player(StructuredNode):
name = StringProperty(unique_index=True, required=True)
tags = ArrayProperty(StringProperty(), required=True)
club = RelationshipTo(
"Club", "MEMBER_OF", model=MemberOfRelationship, cardinality=ZeroOrOne
)


class Club(StructuredNode):
name = StringProperty(unique_index=True, required=True)
members = RelationshipFrom("Player", "MEMBER_OF", model=MemberOfRelationship)


@mark_sync_test
def test_filter_exclude_via_labels():
Coffee(name="Java", price=99).save()
Expand Down Expand Up @@ -654,3 +672,91 @@ def test_async_iterator():

# assert that generator runs loop above
assert counter == n


@mark_sync_test
def test_includes_filter_with_nodeset():
ronaldo = Player(name="Ronaldo", tags=["player", "striker", "portugal"]).save()
messi = Player(name="Messi", tags=["player", "striker", "argentina"]).save()

# Assert __includes works with Q object
assert ronaldo in Player.nodes.filter(Q(tags__includes="striker"))

# Assert that includes filter works with nodeset
assert ronaldo in Player.nodes.filter(tags__includes="striker")
assert messi in Player.nodes.filter(tags__includes="striker")

# Assert that includes filter works with nodeset and excluding values
assert ronaldo in Player.nodes.filter(tags__includes="portugal")
assert messi not in Player.nodes.filter(tags__includes="portugal")

# Assert __includes_any works with Q object
assert ronaldo in Player.nodes.filter(
Q(tags__includes_any=["portugal", "argentina"])
)

# Assert that includes_any filter works with nodeset and multiple values
assert ronaldo in Player.nodes.filter(tags__includes_any=["portugal", "argentina"])
assert messi in Player.nodes.filter(tags__includes_any=["portugal", "argentina"])
assert ronaldo in Player.nodes.filter(tags__includes_any=["portugal", "striker"])
assert messi in Player.nodes.filter(tags__includes_any=["portugal", "striker"])

# Assert __includes_all works with Q object
assert ronaldo in Player.nodes.filter(Q(tags__includes_all=["player", "striker"]))

# Assert that includes filter works with nodeset and all values
assert ronaldo in Player.nodes.filter(tags__includes_all=["player", "striker"])
assert messi in Player.nodes.filter(tags__includes_all=["player", "striker"])
assert ronaldo in Player.nodes.filter(
tags__includes_all=["player", "striker", "portugal"]
)
assert ronaldo not in Player.nodes.filter(
tags__includes_all=["player", "striker", "argentina"]
)
assert messi not in Player.nodes.filter(
tags__includes_all=["player", "striker", "portugal"]
)


@mark_sync_test
def test_includes_filter_with_traversal():
# Create the nodes
enrique = Player(name="Enrique", tags=["spain"]).save()
donnarumma = Player(name="Donnarumma", tags=["italia", "right-foot"]).save()
dembele = Player(name="Dembele", tags=["france", "both-feet"]).save()
marquinhos = Player(name="Marquinhos", tags=["brasil", "right-foot"]).save()
kolomuani = Player(name="Kolo Muani", tags=["congo", "right-foot"]).save()

psg = Club(name="PSG").save()
# Creates the edges
psg.members.connect(enrique, properties={"tags": ["coach"]})
psg.members.connect(donnarumma, properties={"tags": ["player", "goalkeeper"]})
psg.members.connect(marquinhos, properties={"tags": ["player", "defender"]})
psg.members.connect(dembele, properties={"tags": ["player", "forward"]})
psg.members.connect(kolomuani, properties={"tags": ["player", "forward"]})

# Assert __includes
players = psg.members.match(tags__includes="player").all()
assert donnarumma in players
assert dembele in players
assert marquinhos in players
assert kolomuani in players
assert enrique not in players
assert donnarumma in psg.members.match(tags__includes="goalkeeper").all()
assert dembele not in psg.members.match(tags__includes="goalkeeper").all()

# Assert __includes_any
players = psg.members.match(tags__includes_any=["defender", "forward"]).all()
assert donnarumma not in players
assert dembele in players
assert marquinhos in players
assert kolomuani in players
assert enrique not in players

# Assert __includes_all
players = psg.members.match(tags__includes_all=["player", "forward"]).all()
assert donnarumma not in players
assert marquinhos not in players
assert enrique not in players
assert dembele in players
assert kolomuani in players

0 comments on commit a9b4b29

Please sign in to comment.