Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 1d4e568

Browse files
authoredMar 4, 2024
Implement GSSAPI authentication (#1122)
Most commonly used with Kerberos. Closes: #769
1 parent c2c8d20 commit 1d4e568

File tree

10 files changed

+230
-53
lines changed

10 files changed

+230
-53
lines changed
 

‎.github/workflows/install-krb5.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/bin/bash
2+
3+
set -Eexuo pipefail
4+
5+
if [ "$RUNNER_OS" == "Linux" ]; then
6+
# Assume Ubuntu since this is the only Linux used in CI.
7+
sudo apt-get update
8+
sudo apt-get install -y --no-install-recommends \
9+
libkrb5-dev krb5-user krb5-kdc krb5-admin-server
10+
fi

‎.github/workflows/tests.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ jobs:
6262
- name: Install Python Deps
6363
if: steps.release.outputs.version == 0
6464
run: |
65+
.github/workflows/install-krb5.sh
6566
python -m pip install -U pip setuptools wheel
6667
python -m pip install -e .[test]
6768
@@ -122,6 +123,7 @@ jobs:
122123
- name: Install Python Deps
123124
if: steps.release.outputs.version == 0
124125
run: |
126+
.github/workflows/install-krb5.sh
125127
python -m pip install -U pip setuptools wheel
126128
python -m pip install -e .[test]
127129

‎asyncpg/connect_utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def parse(cls, sslmode):
5656
'direct_tls',
5757
'server_settings',
5858
'target_session_attrs',
59+
'krbsrvname',
5960
])
6061

6162

@@ -261,7 +262,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]:
261262
def _parse_connect_dsn_and_args(*, dsn, host, port, user,
262263
password, passfile, database, ssl,
263264
direct_tls, server_settings,
264-
target_session_attrs):
265+
target_session_attrs, krbsrvname):
265266
# `auth_hosts` is the version of host information for the purposes
266267
# of reading the pgpass file.
267268
auth_hosts = None
@@ -383,6 +384,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
383384
if target_session_attrs is None:
384385
target_session_attrs = dsn_target_session_attrs
385386

387+
if 'krbsrvname' in query:
388+
val = query.pop('krbsrvname')
389+
if krbsrvname is None:
390+
krbsrvname = val
391+
386392
if query:
387393
if server_settings is None:
388394
server_settings = query
@@ -650,11 +656,15 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user,
650656
)
651657
) from None
652658

659+
if krbsrvname is None:
660+
krbsrvname = os.getenv('PGKRBSRVNAME')
661+
653662
params = _ConnectionParameters(
654663
user=user, password=password, database=database, ssl=ssl,
655664
sslmode=sslmode, direct_tls=direct_tls,
656665
server_settings=server_settings,
657-
target_session_attrs=target_session_attrs)
666+
target_session_attrs=target_session_attrs,
667+
krbsrvname=krbsrvname)
658668

659669
return addrs, params
660670

@@ -665,7 +675,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
665675
max_cached_statement_lifetime,
666676
max_cacheable_statement_size,
667677
ssl, direct_tls, server_settings,
668-
target_session_attrs):
678+
target_session_attrs, krbsrvname):
669679
local_vars = locals()
670680
for var_name in {'max_cacheable_statement_size',
671681
'max_cached_statement_lifetime',
@@ -694,7 +704,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile,
694704
password=password, passfile=passfile, ssl=ssl,
695705
direct_tls=direct_tls, database=database,
696706
server_settings=server_settings,
697-
target_session_attrs=target_session_attrs)
707+
target_session_attrs=target_session_attrs,
708+
krbsrvname=krbsrvname)
698709

699710
config = _ClientConfiguration(
700711
command_timeout=command_timeout,

‎asyncpg/connection.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2007,7 +2007,8 @@ async def connect(dsn=None, *,
20072007
connection_class=Connection,
20082008
record_class=protocol.Record,
20092009
server_settings=None,
2010-
target_session_attrs=None):
2010+
target_session_attrs=None,
2011+
krbsrvname=None):
20112012
r"""A coroutine to establish a connection to a PostgreSQL server.
20122013
20132014
The connection parameters may be specified either as a connection
@@ -2235,6 +2236,10 @@ async def connect(dsn=None, *,
22352236
or the value of the ``PGTARGETSESSIONATTRS`` environment variable,
22362237
or ``"any"`` if neither is specified.
22372238
2239+
:param str krbsrvname:
2240+
Kerberos service name to use when authenticating with GSSAPI. This
2241+
must match the server configuration. Defaults to 'postgres'.
2242+
22382243
:return: A :class:`~asyncpg.connection.Connection` instance.
22392244
22402245
Example:
@@ -2303,6 +2308,9 @@ async def connect(dsn=None, *,
23032308
.. versionchanged:: 0.28.0
23042309
Added the *target_session_attrs* parameter.
23052310
2311+
.. versionchanged:: 0.30.0
2312+
Added the *krbsrvname* parameter.
2313+
23062314
.. _SSLContext: https://docs.python.org/3/library/ssl.html#ssl.SSLContext
23072315
.. _create_default_context:
23082316
https://docs.python.org/3/library/ssl.html#ssl.create_default_context
@@ -2344,7 +2352,8 @@ async def connect(dsn=None, *,
23442352
statement_cache_size=statement_cache_size,
23452353
max_cached_statement_lifetime=max_cached_statement_lifetime,
23462354
max_cacheable_statement_size=max_cacheable_statement_size,
2347-
target_session_attrs=target_session_attrs
2355+
target_session_attrs=target_session_attrs,
2356+
krbsrvname=krbsrvname,
23482357
)
23492358

23502359

‎asyncpg/protocol/coreproto.pxd

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,6 @@ cdef enum AuthenticationMessage:
5151
AUTH_SASL_FINAL = 12
5252

5353

54-
AUTH_METHOD_NAME = {
55-
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
56-
AUTH_REQUIRED_PASSWORD: 'password',
57-
AUTH_REQUIRED_PASSWORDMD5: 'md5',
58-
AUTH_REQUIRED_GSS: 'gss',
59-
AUTH_REQUIRED_SASL: 'scram-sha-256',
60-
AUTH_REQUIRED_SSPI: 'sspi',
61-
}
62-
63-
6454
cdef enum ResultType:
6555
RESULT_OK = 1
6656
RESULT_FAILED = 2
@@ -96,10 +86,13 @@ cdef class CoreProtocol:
9686

9787
object transport
9888

89+
object address
9990
# Instance of _ConnectionParameters
10091
object con_params
10192
# Instance of SCRAMAuthentication
10293
SCRAMAuthentication scram
94+
# Instance of gssapi.SecurityContext
95+
object gss_ctx
10396

10497
readonly int32_t backend_pid
10598
readonly int32_t backend_secret
@@ -145,6 +138,8 @@ cdef class CoreProtocol:
145138
cdef _auth_password_message_md5(self, bytes salt)
146139
cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods)
147140
cdef _auth_password_message_sasl_continue(self, bytes server_response)
141+
cdef _auth_gss_init(self)
142+
cdef _auth_gss_step(self, bytes server_response)
148143

149144
cdef _write(self, buf)
150145
cdef _writelines(self, list buffers)

‎asyncpg/protocol/coreproto.pyx

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,26 @@
66

77

88
import hashlib
9+
import socket
910

1011

1112
include "scram.pyx"
1213

1314

15+
cdef dict AUTH_METHOD_NAME = {
16+
AUTH_REQUIRED_KERBEROS: 'kerberosv5',
17+
AUTH_REQUIRED_PASSWORD: 'password',
18+
AUTH_REQUIRED_PASSWORDMD5: 'md5',
19+
AUTH_REQUIRED_GSS: 'gss',
20+
AUTH_REQUIRED_SASL: 'scram-sha-256',
21+
AUTH_REQUIRED_SSPI: 'sspi',
22+
}
23+
24+
1425
cdef class CoreProtocol:
1526

16-
def __init__(self, con_params):
27+
def __init__(self, addr, con_params):
28+
self.address = addr
1729
# type of `con_params` is `_ConnectionParameters`
1830
self.buffer = ReadBuffer()
1931
self.user = con_params.user
@@ -26,6 +38,8 @@ cdef class CoreProtocol:
2638
self.encoding = 'utf-8'
2739
# type of `scram` is `SCRAMAuthentcation`
2840
self.scram = None
41+
# type of `gss_ctx` is `gssapi.SecurityContext`
42+
self.gss_ctx = None
2943

3044
self._reset_result()
3145

@@ -619,9 +633,17 @@ cdef class CoreProtocol:
619633
'could not verify server signature for '
620634
'SCRAM authentciation: scram-sha-256',
621635
)
636+
self.scram = None
637+
638+
elif status == AUTH_REQUIRED_GSS:
639+
self._auth_gss_init()
640+
self.auth_msg = self._auth_gss_step(None)
641+
642+
elif status == AUTH_REQUIRED_GSS_CONTINUE:
643+
server_response = self.buffer.consume_message()
644+
self.auth_msg = self._auth_gss_step(server_response)
622645

623646
elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED,
624-
AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE,
625647
AUTH_REQUIRED_SSPI):
626648
self.result_type = RESULT_FAILED
627649
self.result = apg_exc.InterfaceError(
@@ -634,7 +656,8 @@ cdef class CoreProtocol:
634656
'unsupported authentication method requested by the '
635657
'server: {}'.format(status))
636658

637-
if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]:
659+
if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL,
660+
AUTH_REQUIRED_GSS_CONTINUE]:
638661
self.buffer.discard_message()
639662

640663
cdef _auth_password_message_cleartext(self):
@@ -691,6 +714,40 @@ cdef class CoreProtocol:
691714

692715
return msg
693716

717+
cdef _auth_gss_init(self):
718+
try:
719+
import gssapi
720+
except ModuleNotFoundError:
721+
raise RuntimeError(
722+
'gssapi module not found; please install asyncpg[gssapi] to '
723+
'use asyncpg with Kerberos or GSSAPI authentication'
724+
) from None
725+
726+
service_name = self.con_params.krbsrvname or 'postgres'
727+
# find the canonical name of the server host
728+
if isinstance(self.address, str):
729+
raise RuntimeError('GSSAPI authentication is only supported for '
730+
'TCP/IP connections')
731+
732+
host = self.address[0]
733+
host_cname = socket.gethostbyname_ex(host)[0]
734+
gss_name = gssapi.Name(f'{service_name}/{host_cname}')
735+
self.gss_ctx = gssapi.SecurityContext(name=gss_name, usage='initiate')
736+
737+
cdef _auth_gss_step(self, bytes server_response):
738+
cdef:
739+
WriteBuffer msg
740+
741+
token = self.gss_ctx.step(server_response)
742+
if not token:
743+
self.gss_ctx = None
744+
return None
745+
msg = WriteBuffer.new_message(b'p')
746+
msg.write_bytes(token)
747+
msg.end_message()
748+
749+
return msg
750+
694751
cdef _parse_msg_ready_for_query(self):
695752
cdef char status = self.buffer.read_byte()
696753

‎asyncpg/protocol/protocol.pxd

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ cdef class BaseProtocol(CoreProtocol):
3131

3232
cdef:
3333
object loop
34-
object address
3534
ConnectionSettings settings
3635
object cancel_sent_waiter
3736
object cancel_waiter

‎asyncpg/protocol/protocol.pyx

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,15 @@ NO_TIMEOUT = object()
7575
cdef class BaseProtocol(CoreProtocol):
7676
def __init__(self, addr, connected_fut, con_params, record_class: type, loop):
7777
# type of `con_params` is `_ConnectionParameters`
78-
CoreProtocol.__init__(self, con_params)
78+
CoreProtocol.__init__(self, addr, con_params)
7979

8080
self.loop = loop
8181
self.transport = None
8282
self.waiter = connected_fut
8383
self.cancel_waiter = None
8484
self.cancel_sent_waiter = None
8585

86-
self.address = addr
87-
self.settings = ConnectionSettings((self.address, con_params.database))
86+
self.settings = ConnectionSettings((addr, con_params.database))
8887
self.record_class = record_class
8988

9089
self.statement = None

‎pyproject.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,14 @@ dependencies = [
3535
github = "https://github.com/MagicStack/asyncpg"
3636

3737
[project.optional-dependencies]
38+
gssapi = [
39+
'gssapi',
40+
]
3841
test = [
3942
'flake8~=6.1',
4043
'uvloop>=0.15.3; platform_system != "Windows" and python_version < "3.12.0"',
44+
'gssapi; platform_system == "Linux"',
45+
'k5test; platform_system == "Linux"',
4146
]
4247
docs = [
4348
'Sphinx~=5.3.0',

‎tests/test_connect.py

Lines changed: 120 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -130,30 +130,22 @@ def test_server_version_02(self):
130130
CORRECT_PASSWORD = 'correct\u1680password'
131131

132132

133-
class TestAuthentication(tb.ConnectedTestCase):
133+
class BaseTestAuthentication(tb.ConnectedTestCase):
134+
USERS = []
135+
134136
def setUp(self):
135137
super().setUp()
136138

137139
if not self.cluster.is_managed():
138140
self.skipTest('unmanaged cluster')
139141

140-
methods = [
141-
('trust', None),
142-
('reject', None),
143-
('scram-sha-256', CORRECT_PASSWORD),
144-
('md5', CORRECT_PASSWORD),
145-
('password', CORRECT_PASSWORD),
146-
]
147-
148142
self.cluster.reset_hba()
149143

150144
create_script = []
151-
for method, password in methods:
145+
for username, method, password in self.USERS:
152146
if method == 'scram-sha-256' and self.server_version.major < 10:
153147
continue
154148

155-
username = method.replace('-', '_')
156-
157149
# if this is a SCRAM password, we need to set the encryption method
158150
# to "scram-sha-256" in order to properly hash the password
159151
if method == 'scram-sha-256':
@@ -162,7 +154,7 @@ def setUp(self):
162154
)
163155

164156
create_script.append(
165-
'CREATE ROLE {}_user WITH LOGIN{};'.format(
157+
'CREATE ROLE "{}" WITH LOGIN{};'.format(
166158
username,
167159
f' PASSWORD E{(password or "")!r}'
168160
)
@@ -175,20 +167,20 @@ def setUp(self):
175167
"SET password_encryption = 'md5';"
176168
)
177169

178-
if _system != 'Windows':
170+
if _system != 'Windows' and method != 'gss':
179171
self.cluster.add_hba_entry(
180172
type='local',
181-
database='postgres', user='{}_user'.format(username),
173+
database='postgres', user=username,
182174
auth_method=method)
183175

184176
self.cluster.add_hba_entry(
185177
type='host', address=ipaddress.ip_network('127.0.0.0/24'),
186-
database='postgres', user='{}_user'.format(username),
178+
database='postgres', user=username,
187179
auth_method=method)
188180

189181
self.cluster.add_hba_entry(
190182
type='host', address=ipaddress.ip_network('::1/128'),
191-
database='postgres', user='{}_user'.format(username),
183+
database='postgres', user=username,
192184
auth_method=method)
193185

194186
# Put hba changes into effect
@@ -201,28 +193,28 @@ def tearDown(self):
201193
# Reset cluster's pg_hba.conf since we've meddled with it
202194
self.cluster.trust_local_connections()
203195

204-
methods = [
205-
'trust',
206-
'reject',
207-
'scram-sha-256',
208-
'md5',
209-
'password',
210-
]
211-
212196
drop_script = []
213-
for method in methods:
197+
for username, method, _ in self.USERS:
214198
if method == 'scram-sha-256' and self.server_version.major < 10:
215199
continue
216200

217-
username = method.replace('-', '_')
218-
219-
drop_script.append('DROP ROLE {}_user;'.format(username))
201+
drop_script.append('DROP ROLE "{}";'.format(username))
220202

221203
drop_script = '\n'.join(drop_script)
222204
self.loop.run_until_complete(self.con.execute(drop_script))
223205

224206
super().tearDown()
225207

208+
209+
class TestAuthentication(BaseTestAuthentication):
210+
USERS = [
211+
('trust_user', 'trust', None),
212+
('reject_user', 'reject', None),
213+
('scram_sha_256_user', 'scram-sha-256', CORRECT_PASSWORD),
214+
('md5_user', 'md5', CORRECT_PASSWORD),
215+
('password_user', 'password', CORRECT_PASSWORD),
216+
]
217+
226218
async def _try_connect(self, **kwargs):
227219
# On Windows the server sometimes just closes
228220
# the connection sooner than we receive the
@@ -388,6 +380,62 @@ async def test_auth_md5_unsupported(self, _):
388380
await self.connect(user='md5_user', password=CORRECT_PASSWORD)
389381

390382

383+
class TestGssAuthentication(BaseTestAuthentication):
384+
@classmethod
385+
def setUpClass(cls):
386+
try:
387+
from k5test.realm import K5Realm
388+
except ModuleNotFoundError:
389+
raise unittest.SkipTest('k5test not installed')
390+
391+
cls.realm = K5Realm()
392+
cls.addClassCleanup(cls.realm.stop)
393+
# Setup environment before starting the cluster.
394+
patch = unittest.mock.patch.dict(os.environ, cls.realm.env)
395+
patch.start()
396+
cls.addClassCleanup(patch.stop)
397+
# Add credentials.
398+
cls.realm.addprinc('postgres/localhost')
399+
cls.realm.extract_keytab('postgres/localhost', cls.realm.keytab)
400+
401+
cls.USERS = [(cls.realm.user_princ, 'gss', None)]
402+
super().setUpClass()
403+
404+
cls.cluster.override_connection_spec(host='localhost')
405+
406+
@classmethod
407+
def get_server_settings(cls):
408+
settings = super().get_server_settings()
409+
settings['krb_server_keyfile'] = f'FILE:{cls.realm.keytab}'
410+
return settings
411+
412+
@classmethod
413+
def setup_cluster(cls):
414+
cls.cluster = cls.new_cluster(pg_cluster.TempCluster)
415+
cls.start_cluster(
416+
cls.cluster, server_settings=cls.get_server_settings())
417+
418+
async def test_auth_gssapi(self):
419+
conn = await self.connect(user=self.realm.user_princ)
420+
await conn.close()
421+
422+
# Service name mismatch.
423+
with self.assertRaisesRegex(
424+
exceptions.InternalClientError,
425+
'Server .* not found'
426+
):
427+
await self.connect(user=self.realm.user_princ, krbsrvname='wrong')
428+
429+
# Credentials mismatch.
430+
self.realm.addprinc('wrong_user', 'password')
431+
self.realm.kinit('wrong_user', 'password')
432+
with self.assertRaisesRegex(
433+
exceptions.InvalidAuthorizationSpecificationError,
434+
'GSSAPI authentication failed for user'
435+
):
436+
await self.connect(user=self.realm.user_princ)
437+
438+
391439
class TestConnectParams(tb.TestCase):
392440

393441
TESTS = [
@@ -600,6 +648,46 @@ class TestConnectParams(tb.TestCase):
600648
})
601649
},
602650

651+
{
652+
'name': 'krbsrvname',
653+
'dsn': 'postgresql://user@host/db?krbsrvname=srv_qs',
654+
'env': {
655+
'PGKRBSRVNAME': 'srv_env',
656+
},
657+
'result': ([('host', 5432)], {
658+
'database': 'db',
659+
'user': 'user',
660+
'target_session_attrs': 'any',
661+
'krbsrvname': 'srv_qs',
662+
})
663+
},
664+
665+
{
666+
'name': 'krbsrvname_2',
667+
'dsn': 'postgresql://user@host/db?krbsrvname=srv_qs',
668+
'krbsrvname': 'srv_kws',
669+
'result': ([('host', 5432)], {
670+
'database': 'db',
671+
'user': 'user',
672+
'target_session_attrs': 'any',
673+
'krbsrvname': 'srv_kws',
674+
})
675+
},
676+
677+
{
678+
'name': 'krbsrvname_3',
679+
'dsn': 'postgresql://user@host/db',
680+
'env': {
681+
'PGKRBSRVNAME': 'srv_env',
682+
},
683+
'result': ([('host', 5432)], {
684+
'database': 'db',
685+
'user': 'user',
686+
'target_session_attrs': 'any',
687+
'krbsrvname': 'srv_env',
688+
})
689+
},
690+
603691
{
604692
'name': 'dsn_ipv6_multi_host',
605693
'dsn': 'postgresql://user@[2001:db8::1234%25eth0],[::1]/db',
@@ -883,6 +971,7 @@ def run_testcase(self, testcase):
883971
sslmode = testcase.get('ssl')
884972
server_settings = testcase.get('server_settings')
885973
target_session_attrs = testcase.get('target_session_attrs')
974+
krbsrvname = testcase.get('krbsrvname')
886975

887976
expected = testcase.get('result')
888977
expected_error = testcase.get('error')
@@ -907,7 +996,8 @@ def run_testcase(self, testcase):
907996
passfile=passfile, database=database, ssl=sslmode,
908997
direct_tls=False,
909998
server_settings=server_settings,
910-
target_session_attrs=target_session_attrs)
999+
target_session_attrs=target_session_attrs,
1000+
krbsrvname=krbsrvname)
9111001

9121002
params = {
9131003
k: v for k, v in params._asdict().items()

0 commit comments

Comments
 (0)
Please sign in to comment.