From 76105cc6de9ce6d074540c9b5c82c16c953fdbed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Wed, 23 Oct 2024 09:19:09 +0200 Subject: [PATCH] Add type annotations to `_create_ssl_connection` --- asyncpg/connect_utils.py | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 6cefc020..b1451f3f 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -7,6 +7,7 @@ import asyncio import collections +from collections.abc import Callable import enum import functools import getpass @@ -803,8 +804,23 @@ def connection_lost(self, exc: typing.Optional[Exception]) -> None: self.on_data.set_exception(exc) -async def _create_ssl_connection(protocol_factory, host, port, *, - loop, ssl_context, ssl_is_advisory=False): +_ProctolFactoryR = typing.TypeVar( + "_ProctolFactoryR", bound=asyncio.protocols.Protocol +) + + +async def _create_ssl_connection( + # TODO: The return type is a specific combination of subclasses of + # asyncio.protocols.Protocol that we can't express. For now, having the + # return type be dependent on signature of the factory is an improvement + protocol_factory: "Callable[[], _ProctolFactoryR]", + host: str, + port: int, + *, + loop: asyncio.AbstractEventLoop, + ssl_context: ssl_module.SSLContext, + ssl_is_advisory: bool = False, +) -> typing.Tuple[asyncio.Transport, _ProctolFactoryR]: tr, pr = await loop.create_connection( lambda: TLSUpgradeProto(loop, host, port, @@ -824,6 +840,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *, try: new_tr = await loop.start_tls( tr, pr, ssl_context, server_hostname=host) + assert new_tr is not None except (Exception, asyncio.CancelledError): tr.close() raise