From 92f3ae68df7cdb18141412c8ae462a74e032f4ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Wed, 23 Oct 2024 18:45:14 +0200 Subject: [PATCH 1/8] Add type annotations to `_validate_port_spec` --- asyncpg/connect_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index c65f68a6..b84e0624 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -8,7 +8,7 @@ import asyncio import collections -from collections.abc import Callable +from collections.abc import Callable, Sequence import enum import functools import getpass @@ -167,7 +167,9 @@ def _read_password_from_pgpass( return None -def _validate_port_spec(hosts, port): +def _validate_port_spec( + hosts: "Sequence[object]", port: typing.Union[int, typing.List[int]] +) -> typing.List[int]: if isinstance(port, list): # If there is a list of ports, its length must # match that of the host list. From 098330658bf73ae763ced4e6b15fd93036de206b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Wed, 23 Oct 2024 18:45:45 +0200 Subject: [PATCH 2/8] Add type annotations to `_parse_hostlist` --- asyncpg/connect_utils.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index b84e0624..679de06e 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -183,15 +183,20 @@ def _validate_port_spec( return port -def _parse_hostlist(hostlist, port, *, unquote=False): +def _parse_hostlist( + hostlist: str, + port: typing.Union[int, typing.List[int]], + *, + unquote: bool = False, +) -> typing.Tuple[typing.List[str], typing.List[int]]: if ',' in hostlist: # A comma-separated list of host addresses. hostspecs = hostlist.split(',') else: hostspecs = [hostlist] - hosts = [] - hostlist_ports = [] + hosts: typing.List[str] = [] + hostlist_ports: typing.List[int] = [] if not port: portspec = os.environ.get('PGPORT') From ffd80874ac80194aaaf091d4ade51242e6c961e1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Wed, 23 Oct 2024 20:17:00 +0200 Subject: [PATCH 3/8] Fix and improve typing of `_read_password_from_pgpass` --- asyncpg/connect_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 679de06e..43aa5c0b 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -131,11 +131,13 @@ def _read_password_file(passfile: pathlib.Path) \ def _read_password_from_pgpass( - *, passfile: typing.Optional[pathlib.Path], - hosts: typing.List[str], - ports: typing.List[int], - database: str, - user: str): + *, + passfile: pathlib.Path, + hosts: "Sequence[str]", + ports: typing.List[int], + database: str, + user: str +) -> typing.Optional[str]: """Parse the pgpass file and return the matching password. :return: From ea9cb59d5a8b79e7e6f92d5e09042da70165b1aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Wed, 23 Oct 2024 20:17:13 +0200 Subject: [PATCH 4/8] Ensure `query` is always `dict[str, str]` --- asyncpg/connect_utils.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 43aa5c0b..1ba44507 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -325,10 +325,12 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, password = urllib.parse.unquote(dsn_password) if parsed.query: - query = urllib.parse.parse_qs(parsed.query, strict_parsing=True) - for key, val in query.items(): - if isinstance(val, list): - query[key] = val[-1] + query = { + key: val[-1] + for key, val in urllib.parse.parse_qs( + parsed.query, strict_parsing=True + ).items() + } if 'port' in query: val = query.pop('port') From d27cbcdb5fb34eff6ca025509eb262d5900c047b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Wed, 23 Oct 2024 20:17:30 +0200 Subject: [PATCH 5/8] Refactor `SSLMode.parse` so type checkers undertstand it --- asyncpg/connect_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 1ba44507..7a8a61f9 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -41,10 +41,10 @@ class SSLMode(enum.IntEnum): verify_full = 5 @classmethod - def parse(cls, sslmode): - if isinstance(sslmode, cls): - return sslmode - return getattr(cls, sslmode.replace('-', '_')) + def parse(cls, sslmode: typing.Union[str, "SSLMode"]) -> "SSLMode": + if isinstance(sslmode, str): + return getattr(cls, sslmode.replace('-', '_')) + return sslmode class SSLNegotiation(compat.StrEnum): From 4ba71c86c41381d414c94094e5cc81c829fff25d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Wed, 23 Oct 2024 20:17:38 +0200 Subject: [PATCH 6/8] Add type annotations to `_ConnectionParameters` and its constructing function --- asyncpg/connect_utils.py | 50 +++++++++++++++++++++++++--------------- 1 file changed, 31 insertions(+), 19 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 7a8a61f9..fc68a24d 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -52,20 +52,17 @@ class SSLNegotiation(compat.StrEnum): direct = "direct" -_ConnectionParameters = collections.namedtuple( - 'ConnectionParameters', - [ - 'user', - 'password', - 'database', - 'ssl', - 'sslmode', - 'ssl_negotiation', - 'server_settings', - 'target_session_attrs', - 'krbsrvname', - 'gsslib', - ]) +class _ConnectionParameters(typing.NamedTuple): + user: str + password: typing.Optional[str] + database: str + ssl: typing.Union[ssl_module.SSLContext, bool, str, SSLMode, None] + sslmode: SSLMode + ssl_negotiation: SSLNegotiation + server_settings: typing.Optional[typing.Dict[str, str]] + target_session_attrs: "SessionAttribute" + krbsrvname: typing.Optional[str] + gsslib: str _ClientConfiguration = collections.namedtuple( @@ -276,10 +273,25 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: return (homedir / '.postgresql' / filename).resolve() -def _parse_connect_dsn_and_args(*, dsn, host, port, user, - password, passfile, database, ssl, - direct_tls, server_settings, - target_session_attrs, krbsrvname, gsslib): +def _parse_connect_dsn_and_args( + *, + dsn: str, + host: typing.Union[str, typing.List[str], typing.Tuple[str]], + port: typing.Union[int, typing.List[int]], + user: typing.Optional[str], + password: typing.Optional[str], + passfile: typing.Union[str, pathlib.Path, None], + database: typing.Optional[str], + ssl: typing.Union[bool, None, str, SSLMode], + direct_tls: typing.Optional[bool], + server_settings: typing.Optional[typing.Dict[str, str]], + target_session_attrs: typing.Optional[str], + krbsrvname: typing.Optional[str], + gsslib: typing.Optional[str], +) -> typing.Tuple[ + typing.List[typing.Union[str, typing.Tuple[str, int]]], + _ConnectionParameters, +]: # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. auth_hosts = None @@ -502,7 +514,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, database=database, user=user, passfile=passfile) - addrs = [] + addrs: typing.List[typing.Union[str, typing.Tuple[str, int]]] = [] have_tcp_addrs = False for h, p in zip(host, port): if h.startswith('/'): From 294e6bce8e6ec21a8654a88fd316016b8e7f5794 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Tue, 29 Oct 2024 21:18:01 +0100 Subject: [PATCH 7/8] Remove stringified annotations after future import --- asyncpg/connect_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index fc68a24d..a536e9b8 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -60,7 +60,7 @@ class _ConnectionParameters(typing.NamedTuple): sslmode: SSLMode ssl_negotiation: SSLNegotiation server_settings: typing.Optional[typing.Dict[str, str]] - target_session_attrs: "SessionAttribute" + target_session_attrs: SessionAttribute krbsrvname: typing.Optional[str] gsslib: str @@ -130,7 +130,7 @@ def _read_password_file(passfile: pathlib.Path) \ def _read_password_from_pgpass( *, passfile: pathlib.Path, - hosts: "Sequence[str]", + hosts: Sequence[str], ports: typing.List[int], database: str, user: str @@ -167,7 +167,7 @@ def _read_password_from_pgpass( def _validate_port_spec( - hosts: "Sequence[object]", port: typing.Union[int, typing.List[int]] + hosts: Sequence[object], port: typing.Union[int, typing.List[int]] ) -> typing.List[int]: if isinstance(port, list): # If there is a list of ports, its length must From ccf181af3f2c11bccf7edfff4155234a515159eb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Tue, 29 Oct 2024 21:26:32 +0100 Subject: [PATCH 8/8] Remove more stringified annotations --- asyncpg/connect_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index a536e9b8..5dd486d9 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -41,7 +41,7 @@ class SSLMode(enum.IntEnum): verify_full = 5 @classmethod - def parse(cls, sslmode: typing.Union[str, "SSLMode"]) -> "SSLMode": + def parse(cls, sslmode: typing.Union[str, SSLMode]) -> SSLMode: if isinstance(sslmode, str): return getattr(cls, sslmode.replace('-', '_')) return sslmode