Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add typing to _ConnectionParameters and related functions #1199

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
97 changes: 60 additions & 37 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import asyncio
import collections
from collections.abc import Callable
from collections.abc import Callable, Sequence
import enum
import functools
import getpass
Expand Down Expand Up @@ -41,31 +41,28 @@ class SSLMode(enum.IntEnum):
verify_full = 5

@classmethod
def parse(cls, sslmode):
if isinstance(sslmode, cls):
return sslmode
return getattr(cls, sslmode.replace('-', '_'))
def parse(cls, sslmode: typing.Union[str, SSLMode]) -> SSLMode:
if isinstance(sslmode, str):
return getattr(cls, sslmode.replace('-', '_'))
return sslmode


class SSLNegotiation(compat.StrEnum):
postgres = "postgres"
direct = "direct"


_ConnectionParameters = collections.namedtuple(
'ConnectionParameters',
[
'user',
'password',
'database',
'ssl',
'sslmode',
'ssl_negotiation',
'server_settings',
'target_session_attrs',
'krbsrvname',
'gsslib',
])
class _ConnectionParameters(typing.NamedTuple):
user: str
password: typing.Optional[str]
database: str
ssl: typing.Union[ssl_module.SSLContext, bool, str, SSLMode, None]
sslmode: SSLMode
ssl_negotiation: SSLNegotiation
server_settings: typing.Optional[typing.Dict[str, str]]
target_session_attrs: SessionAttribute
krbsrvname: typing.Optional[str]
gsslib: str


_ClientConfiguration = collections.namedtuple(
Expand Down Expand Up @@ -131,11 +128,13 @@ def _read_password_file(passfile: pathlib.Path) \


def _read_password_from_pgpass(
*, passfile: typing.Optional[pathlib.Path],
hosts: typing.List[str],
ports: typing.List[int],
database: str,
user: str):
*,
passfile: pathlib.Path,
hosts: Sequence[str],
ports: typing.List[int],
database: str,
user: str
) -> typing.Optional[str]:
"""Parse the pgpass file and return the matching password.

:return:
Expand Down Expand Up @@ -167,7 +166,9 @@ def _read_password_from_pgpass(
return None


def _validate_port_spec(hosts, port):
def _validate_port_spec(
hosts: Sequence[object], port: typing.Union[int, typing.List[int]]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
hosts: Sequence[object], port: typing.Union[int, typing.List[int]]
hosts: typing.List[str], port: typing.Union[int, typing.List[int]]

Copy link
Contributor Author

@DanielNoord DanielNoord Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did this because of line 172. If your suggestion is correct we can remove the else branch. Since I didn't want to change the functionality of the code too much I just added what the code is able to handle instead of what it likely should be, if that makes sense.

Would you want me to change this? Or is keeping as is fine? If it is the latter, could you press the Merge button? :)

) -> typing.List[int]:
if isinstance(port, list):
# If there is a list of ports, its length must
# match that of the host list.
Expand All @@ -181,15 +182,20 @@ def _validate_port_spec(hosts, port):
return port


def _parse_hostlist(hostlist, port, *, unquote=False):
def _parse_hostlist(
hostlist: str,
port: typing.Union[int, typing.List[int]],
*,
unquote: bool = False,
) -> typing.Tuple[typing.List[str], typing.List[int]]:
if ',' in hostlist:
# A comma-separated list of host addresses.
hostspecs = hostlist.split(',')
else:
hostspecs = [hostlist]

hosts = []
hostlist_ports = []
hosts: typing.List[str] = []
hostlist_ports: typing.List[int] = []

if not port:
portspec = os.environ.get('PGPORT')
Expand Down Expand Up @@ -267,10 +273,25 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
return (homedir / '.postgresql' / filename).resolve()


def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
direct_tls, server_settings,
target_session_attrs, krbsrvname, gsslib):
def _parse_connect_dsn_and_args(
*,
dsn: str,
host: typing.Union[str, typing.List[str], typing.Tuple[str]],
port: typing.Union[int, typing.List[int]],
user: typing.Optional[str],
password: typing.Optional[str],
passfile: typing.Union[str, pathlib.Path, None],
database: typing.Optional[str],
ssl: typing.Union[bool, None, str, SSLMode],
direct_tls: typing.Optional[bool],
server_settings: typing.Optional[typing.Dict[str, str]],
target_session_attrs: typing.Optional[str],
krbsrvname: typing.Optional[str],
gsslib: typing.Optional[str],
) -> typing.Tuple[
typing.List[typing.Union[str, typing.Tuple[str, int]]],
_ConnectionParameters,
]:
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
Expand Down Expand Up @@ -316,10 +337,12 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password = urllib.parse.unquote(dsn_password)

if parsed.query:
query = urllib.parse.parse_qs(parsed.query, strict_parsing=True)
for key, val in query.items():
if isinstance(val, list):
query[key] = val[-1]
query = {
key: val[-1]
for key, val in urllib.parse.parse_qs(
parsed.query, strict_parsing=True
).items()
}

if 'port' in query:
val = query.pop('port')
Expand Down Expand Up @@ -491,7 +514,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
database=database, user=user,
passfile=passfile)

addrs = []
addrs: typing.List[typing.Union[str, typing.Tuple[str, int]]] = []
have_tcp_addrs = False
for h, p in zip(host, port):
if h.startswith('/'):
Expand Down
Loading