From 7d55da21a4f08f14cf7cc474c5c24c5879e6c5a9 Mon Sep 17 00:00:00 2001
From: Elvis Pranskevichus <elvis@edgedb.com>
Date: Sun, 7 Nov 2021 13:14:39 -0800
Subject: [PATCH] Make it possible to specify a statement name in
 Connection.prepare()

This adds the new `name` keyword argument to `Connection.prepare()` and
`PreparedStatement.get_name()` method returning the name of a statement.

Some users of asyncpg might find it useful to be able to control how
prepared statements are named, especially when a custom prepared
statement caching scheme is in use.  Specifically, This should help with
pgbouncer support in SQLAlchemy asyncpg dialect.

Fixes: #837.
---
 asyncpg/connection.py    | 27 ++++++++++++++++++++++-----
 asyncpg/prepared_stmt.py |  8 ++++++++
 tests/test_prepare.py    | 11 +++++++++++
 3 files changed, 41 insertions(+), 5 deletions(-)

diff --git a/asyncpg/connection.py b/asyncpg/connection.py
index a7ec7719..e54807fd 100644
--- a/asyncpg/connection.py
+++ b/asyncpg/connection.py
@@ -359,8 +359,8 @@ async def _get_statement(
         query,
         timeout,
         *,
-        named: bool=False,
-        use_cache: bool=True,
+        named=False,
+        use_cache=True,
         ignore_custom_codec=False,
         record_class=None
     ):
@@ -385,7 +385,9 @@ async def _get_statement(
                     len(query) > self._config.max_cacheable_statement_size):
                 use_cache = False
 
-        if use_cache or named:
+        if isinstance(named, str):
+            stmt_name = named
+        elif use_cache or named:
             stmt_name = self._get_unique_id('stmt')
         else:
             stmt_name = ''
@@ -526,11 +528,21 @@ def cursor(
             record_class,
         )
 
-    async def prepare(self, query, *, timeout=None, record_class=None):
+    async def prepare(
+        self,
+        query,
+        *,
+        name=None,
+        timeout=None,
+        record_class=None,
+    ):
         """Create a *prepared statement* for the specified query.
 
         :param str query:
             Text of the query to create a prepared statement for.
+        :param str name:
+            Optional name of the returned prepared statement.  If not
+            specified, the name is auto-generated.
         :param float timeout:
             Optional timeout value in seconds.
         :param type record_class:
@@ -544,9 +556,13 @@ async def prepare(self, query, *, timeout=None, record_class=None):
 
         .. versionchanged:: 0.22.0
             Added the *record_class* parameter.
+
+        .. versionchanged:: 0.25.0
+            Added the *name* parameter.
         """
         return await self._prepare(
             query,
+            name=name,
             timeout=timeout,
             use_cache=False,
             record_class=record_class,
@@ -556,6 +572,7 @@ async def _prepare(
         self,
         query,
         *,
+        name=None,
         timeout=None,
         use_cache: bool=False,
         record_class=None
@@ -564,7 +581,7 @@ async def _prepare(
         stmt = await self._get_statement(
             query,
             timeout,
-            named=True,
+            named=True if name is None else name,
             use_cache=use_cache,
             record_class=record_class,
         )
diff --git a/asyncpg/prepared_stmt.py b/asyncpg/prepared_stmt.py
index eeb45367..8e241d67 100644
--- a/asyncpg/prepared_stmt.py
+++ b/asyncpg/prepared_stmt.py
@@ -24,6 +24,14 @@ def __init__(self, connection, query, state):
         state.attach()
         self._last_status = None
 
+    @connresource.guarded
+    def get_name(self) -> str:
+        """Return the name of this prepared statement.
+
+        .. versionadded:: 0.25.0
+        """
+        return self._state.name
+
     @connresource.guarded
     def get_query(self) -> str:
         """Return the text of the query for this prepared statement.
diff --git a/tests/test_prepare.py b/tests/test_prepare.py
index c441b45a..5911ccf2 100644
--- a/tests/test_prepare.py
+++ b/tests/test_prepare.py
@@ -600,3 +600,14 @@ async def test_prepare_does_not_use_cache(self):
         # prepare with disabled cache
         await self.con.prepare('select 1')
         self.assertEqual(len(cache), 0)
+
+    async def test_prepare_explicitly_named(self):
+        ps = await self.con.prepare('select 1', name='foobar')
+        self.assertEqual(ps.get_name(), 'foobar')
+        self.assertEqual(await self.con.fetchval('EXECUTE foobar'), 1)
+
+        with self.assertRaisesRegex(
+            exceptions.DuplicatePreparedStatementError,
+            'prepared statement "foobar" already exists',
+        ):
+            await self.con.prepare('select 1', name='foobar')