Skip to content

When prepared statements are disabled, avoid relying on them harder #1065

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Aug 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 31 additions & 15 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class Connection(metaclass=ConnectionMeta):
__slots__ = ('_protocol', '_transport', '_loop',
'_top_xact', '_aborted',
'_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
'_stmt_cache_enabled',
'_listeners', '_server_version', '_server_caps',
'_intro_query', '_reset_query', '_proxy',
'_stmt_exclusive_section', '_config', '_params', '_addr',
Expand Down Expand Up @@ -79,6 +80,7 @@ def __init__(self, protocol, transport, loop,
max_lifetime=config.max_cached_statement_lifetime)

self._stmts_to_close = set()
self._stmt_cache_enabled = config.statement_cache_size > 0

self._listeners = {}
self._log_listeners = set()
Expand Down Expand Up @@ -381,11 +383,13 @@ async def _get_statement(
# Only use the cache when:
# * `statement_cache_size` is greater than 0;
# * query size is less than `max_cacheable_statement_size`.
use_cache = self._stmt_cache.get_max_size() > 0
if (use_cache and
self._config.max_cacheable_statement_size and
len(query) > self._config.max_cacheable_statement_size):
use_cache = False
use_cache = (
self._stmt_cache_enabled
and (
not self._config.max_cacheable_statement_size
or len(query) <= self._config.max_cacheable_statement_size
)
)

if isinstance(named, str):
stmt_name = named
Expand Down Expand Up @@ -434,14 +438,16 @@ async def _get_statement(
# for the statement.
statement._init_codecs()

if need_reprepare:
await self._protocol.prepare(
stmt_name,
query,
timeout,
state=statement,
record_class=record_class,
)
if (
need_reprepare
or (not statement.name and not self._stmt_cache_enabled)
):
# Mark this anonymous prepared statement as "unprepared",
# causing it to get re-Parsed in next bind_execute.
# We always do this when stmt_cache_size is set to 0 assuming
# people are running PgBouncer which is mishandling implicit
# transactions.
statement.mark_unprepared()

if use_cache:
self._stmt_cache.put(
Expand Down Expand Up @@ -1679,7 +1685,13 @@ async def __execute(
record_class=None
):
executor = lambda stmt, timeout: self._protocol.bind_execute(
stmt, args, '', limit, return_status, timeout)
state=stmt,
args=args,
portal_name='',
limit=limit,
return_extra=return_status,
timeout=timeout,
)
timeout = self._protocol._get_timeout(timeout)
return await self._do_execute(
query,
Expand All @@ -1691,7 +1703,11 @@ async def __execute(

async def _executemany(self, query, args, timeout):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
stmt, args, '', timeout)
state=stmt,
args=args,
portal_name='',
timeout=timeout,
)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
result, _ = await self._do_execute(query, executor, timeout)
Expand Down
3 changes: 2 additions & 1 deletion asyncpg/protocol/coreproto.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ cdef class CoreProtocol:


cdef _connect(self)
cdef _prepare(self, str stmt_name, str query)
cdef _prepare_and_describe(self, str stmt_name, str query)
cdef _send_parse_message(self, str stmt_name, str query)
cdef _send_bind_message(self, str portal_name, str stmt_name,
WriteBuffer bind_data, int32_t limit)
cdef _bind_execute(self, str portal_name, str stmt_name,
Expand Down
18 changes: 17 additions & 1 deletion asyncpg/protocol/coreproto.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,10 @@ cdef class CoreProtocol:
# ErrorResponse
self._parse_msg_error_response(True)

elif mtype == b'1':
# ParseComplete, in case `_bind_execute()` is reparsing
self.buffer.discard_message()

elif mtype == b'2':
# BindComplete
self.buffer.discard_message()
Expand Down Expand Up @@ -269,6 +273,10 @@ cdef class CoreProtocol:
# ErrorResponse
self._parse_msg_error_response(True)

elif mtype == b'1':
# ParseComplete, in case `_bind_execute_many()` is reparsing
self.buffer.discard_message()

elif mtype == b'2':
# BindComplete
self.buffer.discard_message()
Expand Down Expand Up @@ -874,7 +882,15 @@ cdef class CoreProtocol:
outbuf.write_buffer(buf)
self._write(outbuf)

cdef _prepare(self, str stmt_name, str query):
cdef _send_parse_message(self, str stmt_name, str query):
cdef:
WriteBuffer msg

self._ensure_connected()
msg = self._build_parse_message(stmt_name, query)
self._write(msg)

cdef _prepare_and_describe(self, str stmt_name, str query):
cdef:
WriteBuffer packet
WriteBuffer buf
Expand Down
1 change: 1 addition & 0 deletions asyncpg/protocol/prepared_stmt.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ cdef class PreparedStatementState:
readonly str name
readonly str query
readonly bint closed
readonly bint prepared
readonly int refs
readonly type record_class
readonly bint ignore_custom_codec
Expand Down
7 changes: 7 additions & 0 deletions asyncpg/protocol/prepared_stmt.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ cdef class PreparedStatementState:
self.args_num = self.cols_num = 0
self.cols_desc = None
self.closed = False
self.prepared = True
self.refs = 0
self.record_class = record_class
self.ignore_custom_codec = ignore_custom_codec
Expand Down Expand Up @@ -101,6 +102,12 @@ cdef class PreparedStatementState:
def mark_closed(self):
self.closed = True

def mark_unprepared(self):
if self.name:
raise exceptions.InternalClientError(
"named prepared statements cannot be marked unprepared")
self.prepared = False

cdef _encode_bind_msg(self, args, int seqno = -1):
cdef:
int idx
Expand Down
31 changes: 23 additions & 8 deletions asyncpg/protocol/protocol.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ cdef class BaseProtocol(CoreProtocol):

waiter = self._new_waiter(timeout)
try:
self._prepare(stmt_name, query) # network op
self._prepare_and_describe(stmt_name, query) # network op
self.last_query = query
if state is None:
state = PreparedStatementState(
Expand All @@ -168,10 +168,15 @@ cdef class BaseProtocol(CoreProtocol):
return await waiter

@cython.iterable_coroutine
async def bind_execute(self, PreparedStatementState state, args,
str portal_name, int limit, return_extra,
timeout):

async def bind_execute(
self,
state: PreparedStatementState,
args,
portal_name: str,
limit: int,
return_extra: bool,
timeout,
):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
Expand All @@ -184,6 +189,9 @@ cdef class BaseProtocol(CoreProtocol):

waiter = self._new_waiter(timeout)
try:
if not state.prepared:
self._send_parse_message(state.name, state.query)

self._bind_execute(
portal_name,
state.name,
Expand All @@ -201,9 +209,13 @@ cdef class BaseProtocol(CoreProtocol):
return await waiter

@cython.iterable_coroutine
async def bind_execute_many(self, PreparedStatementState state, args,
str portal_name, timeout):

async def bind_execute_many(
self,
state: PreparedStatementState,
args,
portal_name: str,
timeout,
):
if self.cancel_waiter is not None:
await self.cancel_waiter
if self.cancel_sent_waiter is not None:
Expand All @@ -222,6 +234,9 @@ cdef class BaseProtocol(CoreProtocol):

waiter = self._new_waiter(timeout)
try:
if not state.prepared:
self._send_parse_message(state.name, state.query)

more = self._bind_execute_many(
portal_name,
state.name,
Expand Down