From 7d55da21a4f08f14cf7cc474c5c24c5879e6c5a9 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus <elvis@edgedb.com> Date: Sun, 7 Nov 2021 13:14:39 -0800 Subject: [PATCH] Make it possible to specify a statement name in Connection.prepare() This adds the new `name` keyword argument to `Connection.prepare()` and `PreparedStatement.get_name()` method returning the name of a statement. Some users of asyncpg might find it useful to be able to control how prepared statements are named, especially when a custom prepared statement caching scheme is in use. Specifically, This should help with pgbouncer support in SQLAlchemy asyncpg dialect. Fixes: #837. --- asyncpg/connection.py | 27 ++++++++++++++++++++++----- asyncpg/prepared_stmt.py | 8 ++++++++ tests/test_prepare.py | 11 +++++++++++ 3 files changed, 41 insertions(+), 5 deletions(-) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index a7ec7719..e54807fd 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -359,8 +359,8 @@ async def _get_statement( query, timeout, *, - named: bool=False, - use_cache: bool=True, + named=False, + use_cache=True, ignore_custom_codec=False, record_class=None ): @@ -385,7 +385,9 @@ async def _get_statement( len(query) > self._config.max_cacheable_statement_size): use_cache = False - if use_cache or named: + if isinstance(named, str): + stmt_name = named + elif use_cache or named: stmt_name = self._get_unique_id('stmt') else: stmt_name = '' @@ -526,11 +528,21 @@ def cursor( record_class, ) - async def prepare(self, query, *, timeout=None, record_class=None): + async def prepare( + self, + query, + *, + name=None, + timeout=None, + record_class=None, + ): """Create a *prepared statement* for the specified query. :param str query: Text of the query to create a prepared statement for. + :param str name: + Optional name of the returned prepared statement. If not + specified, the name is auto-generated. :param float timeout: Optional timeout value in seconds. :param type record_class: @@ -544,9 +556,13 @@ async def prepare(self, query, *, timeout=None, record_class=None): .. versionchanged:: 0.22.0 Added the *record_class* parameter. + + .. versionchanged:: 0.25.0 + Added the *name* parameter. """ return await self._prepare( query, + name=name, timeout=timeout, use_cache=False, record_class=record_class, @@ -556,6 +572,7 @@ async def _prepare( self, query, *, + name=None, timeout=None, use_cache: bool=False, record_class=None @@ -564,7 +581,7 @@ async def _prepare( stmt = await self._get_statement( query, timeout, - named=True, + named=True if name is None else name, use_cache=use_cache, record_class=record_class, ) diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py index eeb45367..8e241d67 100644 --- a/asyncpg/prepared_stmt.py +++ b/asyncpg/prepared_stmt.py @@ -24,6 +24,14 @@ def __init__(self, connection, query, state): state.attach() self._last_status = None + @connresource.guarded + def get_name(self) -> str: + """Return the name of this prepared statement. + + .. versionadded:: 0.25.0 + """ + return self._state.name + @connresource.guarded def get_query(self) -> str: """Return the text of the query for this prepared statement. diff --git a/tests/test_prepare.py b/tests/test_prepare.py index c441b45a..5911ccf2 100644 --- a/tests/test_prepare.py +++ b/tests/test_prepare.py @@ -600,3 +600,14 @@ async def test_prepare_does_not_use_cache(self): # prepare with disabled cache await self.con.prepare('select 1') self.assertEqual(len(cache), 0) + + async def test_prepare_explicitly_named(self): + ps = await self.con.prepare('select 1', name='foobar') + self.assertEqual(ps.get_name(), 'foobar') + self.assertEqual(await self.con.fetchval('EXECUTE foobar'), 1) + + with self.assertRaisesRegex( + exceptions.DuplicatePreparedStatementError, + 'prepared statement "foobar" already exists', + ): + await self.con.prepare('select 1', name='foobar')