From 7870b7e0c048f80d776cbda2e9e5386a9e9dc879 Mon Sep 17 00:00:00 2001 From: Tobias Alex-Petersen Date: Tue, 20 Dec 2022 20:20:30 +0100 Subject: [PATCH] Allow with_variant to take TypeEngine instance like Column. (#245) Fixes Fixes: #244 --- sqlalchemy-stubs/sql/type_api.pyi | 4 ++-- test/files/column_arguments2.py | 8 ++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/sqlalchemy-stubs/sql/type_api.pyi b/sqlalchemy-stubs/sql/type_api.pyi index b1a5d04..46b0141 100644 --- a/sqlalchemy-stubs/sql/type_api.pyi +++ b/sqlalchemy-stubs/sql/type_api.pyi @@ -96,7 +96,7 @@ class TypeEngine(Traversible, Generic[_T]): @property def python_type(self) -> Type[_T]: ... def with_variant( - self, type_: Type[TypeEngine[_U]], dialect_name: str + self, type_: Union[Type[TypeEngine[_U]], TypeEngine[_U]], dialect_name: str ) -> Variant[_U]: ... def as_generic(self, allow_nulltype: bool = ...) -> TypeEngine[Any]: ... def dialect_impl(self, dialect: Dialect) -> Type[Any]: ... @@ -177,7 +177,7 @@ class Variant(TypeDecorator[_T]): ) -> Union[_VT, TypeEngine[Any]]: ... def load_dialect_impl(self, dialect: Dialect) -> TypeEngine[Any]: ... def with_variant( - self, type_: Type[TypeEngine[_U]], dialect_name: str + self, type_: Union[Type[TypeEngine[_U]], TypeEngine[_U]], dialect_name: str ) -> Variant[_U]: ... @property def comparator_factory(self) -> Type[Any]: ... # type: ignore[override] diff --git a/test/files/column_arguments2.py b/test/files/column_arguments2.py index fd40967..1861f3f 100644 --- a/test/files/column_arguments2.py +++ b/test/files/column_arguments2.py @@ -80,3 +80,11 @@ # These seems supported now Column(Integer, ForeignKey("a.id"), type_=String) Column("name", ForeignKey("a.id"), name="String") + + +# TypeEngine.with_variant should accept both a TypeEngine instance and the Concrete Type +Integer().with_variant(Integer, "mysql") +Integer().with_variant(Integer(), "mysql") +# Also test Variant.with_variant +Integer().with_variant(Integer, "mysql").with_variant(Integer, "mysql") +Integer().with_variant(Integer, "mysql").with_variant(Integer(), "mysql")