Skip to content

When prepared statements are disabled, avoid relying on them harder #1065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
@@ -47,6 +47,7 @@ class Connection(metaclass=ConnectionMeta):
__slots__ = ('_protocol', '_transport', '_loop',
'_top_xact', '_aborted',
'_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
'_stmt_cache_enabled',
'_listeners', '_server_version', '_server_caps',
'_intro_query', '_reset_query', '_proxy',
'_stmt_exclusive_section', '_config', '_params', '_addr',
@@ -79,6 +80,7 @@ def __init__(self, protocol, transport, loop,
max_lifetime=config.max_cached_statement_lifetime)

self._stmts_to_close = set()
self._stmt_cache_enabled = config.statement_cache_size > 0

self._listeners = {}
self._log_listeners = set()
@@ -381,11 +383,13 @@ async def _get_statement(
# Only use the cache when:
# * `statement_cache_size` is greater than 0;
# * query size is less than `max_cacheable_statement_size`.
use_cache = self._stmt_cache.get_max_size() > 0
if (use_cache and
self._config.max_cacheable_statement_size and
len(query) > self._config.max_cacheable_statement_size):
use_cache = False
use_cache = (
self._stmt_cache_enabled
and (
not self._config.max_cacheable_statement_size
or len(query) <= self._config.max_cacheable_statement_size
)
)

if isinstance(named, str):
stmt_name = named
@@ -434,14 +438,16 @@ async def _get_statement(
# for the statement.
statement._init_codecs()

if need_reprepare:
await self._protocol.prepare(
stmt_name,
query,
timeout,
state=statement,
record_class=record_class,
)
if (
need_reprepare
or (not statement.name and not self._stmt_cache_enabled)
):
# Mark this anonymous prepared statement as "unprepared",
# causing it to get re-Parsed in next bind_execute.
# We always do this when stmt_cache_size is set to 0 assuming
# people are running PgBouncer which is mishandling implicit
# transactions.
statement.mark_unprepared()

if use_cache:
self._stmt_cache.put(
@@ -1679,7 +1685,13 @@ async def __execute(
record_class=None
):
executor = lambda stmt, timeout: self._protocol.bind_execute(
stmt, args, '', limit, return_status, timeout)
state=stmt,
args=args,
portal_name='',
limit=limit,
return_extra=return_status,
timeout=timeout,
)
timeout = self._protocol._get_timeout(timeout)
return await self._do_execute(
query,
@@ -1691,7 +1703,11 @@ async def __execute(

async def _executemany(self, query, args, timeout):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
stmt, args, '', timeout)
state=stmt,
args=args,
portal_name='',
timeout=timeout,
)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
result, _ = await self._do_execute(query, executor, timeout)
3 changes: 2 additions & 1 deletion asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
@@ -167,7 +167,8 @@ cdef class CoreProtocol:


cdef _connect(self)
cdef _prepare(self, str stmt_name, str query)
cdef _prepare_and_describe(self, str stmt_name, str query)
cdef _send_parse_message(self, str stmt_name, str query)
cdef _send_bind_message(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef _bind_execute(self, str portal_name, str stmt_name,
18 changes: 17 additions & 1 deletion asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
@@ -237,6 +237,10 @@ cdef class CoreProtocol:
# ErrorResponse
self._parse_msg_error_response(True)

elif mtype == b'1':
# ParseComplete, in case `_bind_execute()` is reparsing
self.buffer.discard_message()

elif mtype == b'2':
# BindComplete
self.buffer.discard_message()
@@ -269,6 +273,10 @@ cdef class CoreProtocol:
# ErrorResponse
self._parse_msg_error_response(True)

elif mtype == b'1':
# ParseComplete, in case `_bind_execute_many()` is reparsing
self.buffer.discard_message()

elif mtype == b'2':
# BindComplete
self.buffer.discard_message()
@@ -874,7 +882,15 @@ cdef class CoreProtocol:
outbuf.write_buffer(buf)
self._write(outbuf)

cdef _prepare(self, str stmt_name, str query):
cdef _send_parse_message(self, str stmt_name, str query):
cdef:
WriteBuffer msg

self._ensure_connected()
msg = self._build_parse_message(stmt_name, query)
self._write(msg)

cdef _prepare_and_describe(self, str stmt_name, str query):
cdef:
WriteBuffer packet
WriteBuffer buf
1 change: 1 addition & 0 deletions asyncpg/protocol/prepared_stmt.pxd
Original file line number Diff line number Diff line change
@@ -10,6 +10,7 @@ cdef class PreparedStatementState:
readonly str name
readonly str query
readonly bint closed
readonly bint prepared
readonly int refs
readonly type record_class
readonly bint ignore_custom_codec
7 changes: 7 additions & 0 deletions asyncpg/protocol/prepared_stmt.pyx
Original file line number Diff line number Diff line change
@@ -27,6 +27,7 @@ cdef class PreparedStatementState:
self.args_num = self.cols_num = 0
self.cols_desc = None
self.closed = False
self.prepared = True
self.refs = 0
self.record_class = record_class
self.ignore_custom_codec = ignore_custom_codec
@@ -101,6 +102,12 @@ cdef class PreparedStatementState:
def mark_closed(self):
self.closed = True

def mark_unprepared(self):
if self.name:
raise exceptions.InternalClientError(
"named prepared statements cannot be marked unprepared")
self.prepared = False

cdef _encode_bind_msg(self, args, int seqno = -1):
cdef:
int idx
31 changes: 23 additions & 8 deletions asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
@@ -155,7 +155,7 @@ cdef class BaseProtocol(CoreProtocol):

waiter = self._new_waiter(timeout)
try:
self._prepare(stmt_name, query) # network op
self._prepare_and_describe(stmt_name, query) # network op
self.last_query = query
if state is None:
state = PreparedStatementState(
@@ -168,10 +168,15 @@ cdef class BaseProtocol(CoreProtocol):
return await waiter

@cython.iterable_coroutine
async def bind_execute(self, PreparedStatementState state, args,
str portal_name, int limit, return_extra,
timeout):

async def bind_execute(
self,
state: PreparedStatementState,
args,
portal_name: str,
limit: int,
return_extra: bool,
timeout,
):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
@@ -184,6 +189,9 @@ cdef class BaseProtocol(CoreProtocol):

waiter = self._new_waiter(timeout)
try:
if not state.prepared:
self._send_parse_message(state.name, state.query)

self._bind_execute(
portal_name,
state.name,
@@ -201,9 +209,13 @@ cdef class BaseProtocol(CoreProtocol):
return await waiter

@cython.iterable_coroutine
async def bind_execute_many(self, PreparedStatementState state, args,
str portal_name, timeout):

async def bind_execute_many(
self,
state: PreparedStatementState,
args,
portal_name: str,
timeout,
):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
@@ -222,6 +234,9 @@ cdef class BaseProtocol(CoreProtocol):

waiter = self._new_waiter(timeout)
try:
if not state.prepared:
self._send_parse_message(state.name, state.query)

more = self._bind_execute_many(
portal_name,
state.name,