From fc8c6d65a5863bd7dd44c4227c0495ecaef536f7 Mon Sep 17 00:00:00 2001 From: Marti Raudsepp Date: Fri, 10 Dec 2021 17:51:07 +0200 Subject: [PATCH] Support new X | Y union syntax of Python 3.10 (PEP 604) --- src/drf_yasg/inspectors/field.py | 8 +++++++- tests/test_get_basic_type_info_from_hint.py | 15 ++++++++++++++- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/src/drf_yasg/inspectors/field.py b/src/drf_yasg/inspectors/field.py index 09386363..d267f5a6 100644 --- a/src/drf_yasg/inspectors/field.py +++ b/src/drf_yasg/inspectors/field.py @@ -502,10 +502,16 @@ def inspect_collection_hint_class(hint_class): hinting_type_info.append(((typing.Sequence, typing.AbstractSet), inspect_collection_hint_class)) +# typing.UnionType was added in Python 3.10 for new PEP 604 pipe union syntax +try: + from types import UnionType +except ImportError: + UnionType = None + def _get_union_types(hint_class): origin_type = get_origin_type(hint_class) - if origin_type is typing.Union: + if origin_type is typing.Union or (UnionType is not None and origin_type is UnionType): return hint_class.__args__ diff --git a/tests/test_get_basic_type_info_from_hint.py b/tests/test_get_basic_type_info_from_hint.py index ecd87c08..eac39e93 100644 --- a/tests/test_get_basic_type_info_from_hint.py +++ b/tests/test_get_basic_type_info_from_hint.py @@ -15,6 +15,19 @@ ] +python310_union_tests = [] +if sys.version_info >= (3, 10): + # # New PEP 604 union syntax in Python 3.10+ + python310_union_tests = [ + (bool | None, {'type': openapi.TYPE_BOOLEAN, 'format': None, 'x-nullable': True}), + (list[int] | None, { + 'type': openapi.TYPE_ARRAY, 'items': openapi.Items(openapi.TYPE_INTEGER), 'x-nullable': True + }), + # Following cases are not 100% correct, but it should work somehow and not crash. + (int | float, None), + ] + + @pytest.mark.parametrize('hint_class, expected_swagger_type_info', [ (int, {'type': openapi.TYPE_INTEGER, 'format': None}), (str, {'type': openapi.TYPE_STRING, 'format': None}), @@ -41,7 +54,7 @@ (type('SomeType', (object,), {}), None), (None, None), (6, None), -] + python39_generics_tests) +] + python39_generics_tests + python310_union_tests) def test_get_basic_type_info_from_hint(hint_class, expected_swagger_type_info): type_info = get_basic_type_info_from_hint(hint_class) assert type_info == expected_swagger_type_info