Skip to content

Commit

Permalink
Implement GSSAPI authentication (#1122)
Browse files Browse the repository at this point in the history
Most commonly used with Kerberos.

Closes: #769
  • Loading branch information
eltoder authored Mar 4, 2024
1 parent c2c8d20 commit 1d4e568
Show file tree
Hide file tree
Showing 10 changed files with 230 additions and 53 deletions.
10 changes: 10 additions & 0 deletions .github/workflows/install-krb5.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash

set -Eexuo pipefail

if [ "$RUNNER_OS" == "Linux" ]; then
# Assume Ubuntu since this is the only Linux used in CI.
sudo apt-get update
sudo apt-get install -y --no-install-recommends \
libkrb5-dev krb5-user krb5-kdc krb5-admin-server
fi
2 changes: 2 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ jobs:
- name: Install Python Deps
if: steps.release.outputs.version == 0
run: |
.github/workflows/install-krb5.sh
python -m pip install -U pip setuptools wheel
python -m pip install -e .[test]
Expand Down Expand Up @@ -122,6 +123,7 @@ jobs:
- name: Install Python Deps
if: steps.release.outputs.version == 0
run: |
.github/workflows/install-krb5.sh
python -m pip install -U pip setuptools wheel
python -m pip install -e .[test]
Expand Down
19 changes: 15 additions & 4 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def parse(cls, sslmode):
'direct_tls',
'server_settings',
'target_session_attrs',
'krbsrvname',
])


Expand Down Expand Up @@ -261,7 +262,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
password, passfile, database, ssl,
direct_tls, server_settings,
target_session_attrs):
target_session_attrs, krbsrvname):
# `auth_hosts` is the version of host information for the purposes
# of reading the pgpass file.
auth_hosts = None
Expand Down Expand Up @@ -383,6 +384,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
if target_session_attrs is None:
target_session_attrs = dsn_target_session_attrs

if 'krbsrvname' in query:
val = query.pop('krbsrvname')
if krbsrvname is None:
krbsrvname = val

if query:
if server_settings is None:
server_settings = query
Expand Down Expand Up @@ -650,11 +656,15 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
)
) from None

if krbsrvname is None:
krbsrvname = os.getenv('PGKRBSRVNAME')

params = _ConnectionParameters(
user=user, password=password, database=database, ssl=ssl,
sslmode=sslmode, direct_tls=direct_tls,
server_settings=server_settings,
target_session_attrs=target_session_attrs)
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname)

return addrs, params

Expand All @@ -665,7 +675,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
max_cached_statement_lifetime,
max_cacheable_statement_size,
ssl, direct_tls, server_settings,
target_session_attrs):
target_session_attrs, krbsrvname):
local_vars = locals()
for var_name in {'max_cacheable_statement_size',
'max_cached_statement_lifetime',
Expand Down Expand Up @@ -694,7 +704,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
password=password, passfile=passfile, ssl=ssl,
direct_tls=direct_tls, database=database,
server_settings=server_settings,
target_session_attrs=target_session_attrs)
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname)

config = _ClientConfiguration(
command_timeout=command_timeout,
Expand Down
13 changes: 11 additions & 2 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2007,7 +2007,8 @@ async def connect(dsn=None, *,
connection_class=Connection,
record_class=protocol.Record,
server_settings=None,
target_session_attrs=None):
target_session_attrs=None,
krbsrvname=None):
r"""A coroutine to establish a connection to a PostgreSQL server.
The connection parameters may be specified either as a connection
Expand Down Expand Up @@ -2235,6 +2236,10 @@ async def connect(dsn=None, *,
or the value of the ``PGTARGETSESSIONATTRS`` environment variable,
or ``"any"`` if neither is specified.
:param str krbsrvname:
Kerberos service name to use when authenticating with GSSAPI. This
must match the server configuration. Defaults to 'postgres'.
:return: A :class:`~asyncpg.connection.Connection` instance.
Example:
Expand Down Expand Up @@ -2303,6 +2308,9 @@ async def connect(dsn=None, *,
.. versionchanged:: 0.28.0
Added the *target_session_attrs* parameter.
.. versionchanged:: 0.30.0
Added the *krbsrvname* parameter.
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
.. _create_default_context:
https://docs.python.org/3/library/ssl.html#ssl.create_default_context
Expand Down Expand Up @@ -2344,7 +2352,8 @@ async def connect(dsn=None, *,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size,
target_session_attrs=target_session_attrs
target_session_attrs=target_session_attrs,
krbsrvname=krbsrvname,
)


Expand Down
15 changes: 5 additions & 10 deletions asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,6 @@ cdef enum AuthenticationMessage:
AUTH_SASL_FINAL = 12


AUTH_METHOD_NAME = {
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
AUTH_REQUIRED_PASSWORD: 'password',
AUTH_REQUIRED_PASSWORDMD5: 'md5',
AUTH_REQUIRED_GSS: 'gss',
AUTH_REQUIRED_SASL: 'scram-sha-256',
AUTH_REQUIRED_SSPI: 'sspi',
}


cdef enum ResultType:
RESULT_OK = 1
RESULT_FAILED = 2
Expand Down Expand Up @@ -96,10 +86,13 @@ cdef class CoreProtocol:

object transport

object address
# Instance of _ConnectionParameters
object con_params
# Instance of SCRAMAuthentication
SCRAMAuthentication scram
# Instance of gssapi.SecurityContext
object gss_ctx

readonly int32_t backend_pid
readonly int32_t backend_secret
Expand Down Expand Up @@ -145,6 +138,8 @@ cdef class CoreProtocol:
cdef _auth_password_message_md5(self, bytes salt)
cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods)
cdef _auth_password_message_sasl_continue(self, bytes server_response)
cdef _auth_gss_init(self)
cdef _auth_gss_step(self, bytes server_response)

cdef _write(self, buf)
cdef _writelines(self, list buffers)
Expand Down
63 changes: 60 additions & 3 deletions asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,26 @@


import hashlib
import socket


include "scram.pyx"


cdef dict AUTH_METHOD_NAME = {
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
AUTH_REQUIRED_PASSWORD: 'password',
AUTH_REQUIRED_PASSWORDMD5: 'md5',
AUTH_REQUIRED_GSS: 'gss',
AUTH_REQUIRED_SASL: 'scram-sha-256',
AUTH_REQUIRED_SSPI: 'sspi',
}


cdef class CoreProtocol:

def __init__(self, con_params):
def __init__(self, addr, con_params):
self.address = addr
# type of `con_params` is `_ConnectionParameters`
self.buffer = ReadBuffer()
self.user = con_params.user
Expand All @@ -26,6 +38,8 @@ cdef class CoreProtocol:
self.encoding = 'utf-8'
# type of `scram` is `SCRAMAuthentcation`
self.scram = None
# type of `gss_ctx` is `gssapi.SecurityContext`
self.gss_ctx = None

self._reset_result()

Expand Down Expand Up @@ -619,9 +633,17 @@ cdef class CoreProtocol:
'could not verify server signature for '
'SCRAM authentciation: scram-sha-256',
)
self.scram = None

elif status == AUTH_REQUIRED_GSS:
self._auth_gss_init()
self.auth_msg = self._auth_gss_step(None)

elif status == AUTH_REQUIRED_GSS_CONTINUE:
server_response = self.buffer.consume_message()
self.auth_msg = self._auth_gss_step(server_response)

elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED,
AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE,
AUTH_REQUIRED_SSPI):
self.result_type = RESULT_FAILED
self.result = apg_exc.InterfaceError(
Expand All @@ -634,7 +656,8 @@ cdef class CoreProtocol:
'unsupported authentication method requested by the '
'server: {}'.format(status))

if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]:
if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL,
AUTH_REQUIRED_GSS_CONTINUE]:
self.buffer.discard_message()

cdef _auth_password_message_cleartext(self):
Expand Down Expand Up @@ -691,6 +714,40 @@ cdef class CoreProtocol:

return msg

cdef _auth_gss_init(self):
try:
import gssapi
except ModuleNotFoundError:
raise RuntimeError(
'gssapi module not found; please install asyncpg[gssapi] to '
'use asyncpg with Kerberos or GSSAPI authentication'
) from None

service_name = self.con_params.krbsrvname or 'postgres'
# find the canonical name of the server host
if isinstance(self.address, str):
raise RuntimeError('GSSAPI authentication is only supported for '
'TCP/IP connections')

host = self.address[0]
host_cname = socket.gethostbyname_ex(host)[0]
gss_name = gssapi.Name(f'{service_name}/{host_cname}')
self.gss_ctx = gssapi.SecurityContext(name=gss_name, usage='initiate')

cdef _auth_gss_step(self, bytes server_response):
cdef:
WriteBuffer msg

token = self.gss_ctx.step(server_response)
if not token:
self.gss_ctx = None
return None
msg = WriteBuffer.new_message(b'p')
msg.write_bytes(token)
msg.end_message()

return msg

cdef _parse_msg_ready_for_query(self):
cdef char status = self.buffer.read_byte()

Expand Down
1 change: 0 additions & 1 deletion asyncpg/protocol/protocol.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ cdef class BaseProtocol(CoreProtocol):

cdef:
object loop
object address
ConnectionSettings settings
object cancel_sent_waiter
object cancel_waiter
Expand Down
5 changes: 2 additions & 3 deletions asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,15 @@ NO_TIMEOUT = object()
cdef class BaseProtocol(CoreProtocol):
def __init__(self, addr, connected_fut, con_params, record_class: type, loop):
# type of `con_params` is `_ConnectionParameters`
CoreProtocol.__init__(self, con_params)
CoreProtocol.__init__(self, addr, con_params)

self.loop = loop
self.transport = None
self.waiter = connected_fut
self.cancel_waiter = None
self.cancel_sent_waiter = None

self.address = addr
self.settings = ConnectionSettings((self.address, con_params.database))
self.settings = ConnectionSettings((addr, con_params.database))
self.record_class = record_class

self.statement = None
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,14 @@ dependencies = [
github = "https://github.com/MagicStack/asyncpg"

[project.optional-dependencies]
gssapi = [
'gssapi',
]
test = [
'flake8~=6.1',
'uvloop>=0.15.3; platform_system != "Windows" and python_version < "3.12.0"',
'gssapi; platform_system == "Linux"',
'k5test; platform_system == "Linux"',
]
docs = [
'Sphinx~=5.3.0',
Expand Down
Loading

0 comments on commit 1d4e568

Please sign in to comment.