diff --git a/flask_restless/search/operators.py b/flask_restless/search/operators.py index 481ea38c..fbea587f 100644 --- a/flask_restless/search/operators.py +++ b/flask_restless/search/operators.py @@ -18,6 +18,8 @@ creating the expression. """ +from sqlalchemy import func + #: Special symbol that represents the absence of a `val` element in a #: dictionary representing a filter object. NO_ARGUMENT = object() @@ -118,6 +120,14 @@ def any_(arg1, arg2): return arg1.any(arg2) +def to_tsquery(arg1, arg2): + return arg1.match(arg2) + + +def plainto_tsquery(arg1, arg2): + return arg1.op('@@')(func.plainto_tsquery(arg2)) + + #: Operator functions keyed by name. #: #: Each of these functions accepts either one or two arguments. The @@ -167,6 +177,8 @@ def any_(arg1, arg2): # (Binary) relationship operators. 'has': has, 'any': any_, + 'to_tsquery': to_tsquery, + 'plainto_tsquery': plainto_tsquery, } diff --git a/tests/test_filtering_postgresql.py b/tests/test_filtering_postgresql.py index 1fe135ec..670f1071 100644 --- a/tests/test_filtering_postgresql.py +++ b/tests/test_filtering_postgresql.py @@ -18,9 +18,8 @@ except ImportError: from psycopg2cffi import compat compat.register() -from sqlalchemy import Column -from sqlalchemy import Integer -from sqlalchemy.dialects.postgresql import INET +from sqlalchemy import Column, func, Integer, String +from sqlalchemy.dialects.postgresql import INET, TSVECTOR from sqlalchemy.exc import OperationalError from .helpers import loads @@ -197,3 +196,99 @@ def test_contains_or_is_contained_by(self): document = loads(response.data) networks = document['data'] assert ['1', '2'] == sorted(network['id'] for network in networks) + + +class TestTSVectorOperators(SearchTestBase): + """Unit tests for the TSQuery operators in PostgreSQL. + + For more information, see `Text Search Types`_ + in the PostgreSQL documentation. + + .. _Text Search Types: + https://www.postgresql.org/docs/current/datatype-textsearch.html + + """ + + def setUp(self): + super(TestTSVectorOperators, self).setUp() + + class Product(self.Base): + __tablename__ = 'product' + id = Column(Integer, primary_key=True) + name = Column(String) + document = Column(TSVECTOR) + + self.Product = Product + # This try/except skips the tests if we are unable to create the + # tables in the PostgreSQL database. + try: + self.Base.metadata.create_all() + except OperationalError: + self.skipTest('error creating tables in PostgreSQL database') + self.manager.create_api(Product) + + # Create common records + self.product1 = self.Product( + id=1, name='Porsche 911', document=func.to_tsvector('Porsche 911')) + self.product2 = self.Product( + id=2, name='Porsche 918', document=func.to_tsvector('Porsche 918')) + self.session.add_all([self.product1, self.product2]) + self.session.commit() + + def database_uri(self): + """Return a PostgreSQL connection URI. + + Since this test case is for operators specific to PostgreSQL, we + return a PostgreSQL connection URI. The particular + Python-to-PostgreSQL adapter we are using is currently + `Psycopg`_. + + .. _Psycopg: http://initd.org/psycopg/ + + """ + return 'postgresql+psycopg2://postgres@localhost:5432/testdb' + + def test_to_tsquery(self): + """Test for the ``to_tsquery`` operator. + + ..warning:: + + This operation is only available on TSVECTOR fields. + + For example: + + .. sourcecode:: postgresql + + document @@ to_tsquery('Hello !world') + """ + # Search with the &. + filters = [ + dict(name='document', op='to_tsquery', val='911 & Porsche')] + response = self.search('/api/product', filters) + print(response.data) + document = loads(response.data) + products = document['data'] + assert [self.product1.id] == sorted( + int(product['id']) for product in products) + + def test_plainto_tsquery(self): + """Test for the ``plainto_tsquery`` operator. + + ..warning:: + + This operation is only available on TSVECTOR fields. + + For example: + + .. sourcecode:: postgresql + + document @@ plainto_tsquery('Hello !world') + """ + # Search without the &. + filters = [ + dict(name='document', op='plainto_tsquery', val='911 Porsche')] + response = self.search('/api/product', filters) + document = loads(response.data) + products = document['data'] + assert [self.product1.id] == sorted( + int(product['id']) for product in products)