Skip to content

Commit ab42456

Browse files
committed
Fix psycopg3 tests
Several tests (such as SQLPanelTestCase.test_cursor_wrapper_singleton) are written to ensure that only a single cursor wrapper is instantiated during a test. However, this fails when using psycopg3, since the .last_executed_query() call in NormalCursorWrapper._record() ends up creating an additional cursor (via [1]). To avoid this, use a ._djdt_in_record attribute on the database wrapper. Make the NormalCursorWrapper._record() method set ._djdt_in_record to True on entry and reset it to False on exit. Then in the overridden database wrapper .cursor() and .chunked_cursor() methods, check the ._djdt_in_record attribute and return the original cursor without wrapping if the attribute is True. [1] https://github.com/django/django/blob/4.2.1/django/db/backends/postgresql/psycopg_any.py#L21
1 parent 5ff7ca2 commit ab42456

File tree

1 file changed

+88
-70
lines changed

1 file changed

+88
-70
lines changed

debug_toolbar/panels/sql/tracking.py

Lines changed: 88 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -43,16 +43,23 @@ def cursor(*args, **kwargs):
4343
# See:
4444
# https://github.com/jazzband/django-debug-toolbar/pull/615
4545
# https://github.com/jazzband/django-debug-toolbar/pull/896
46+
cursor = connection._djdt_cursor(*args, **kwargs)
47+
# Do not wrap cursors that are created during post-processing in ._record()
48+
if connection._djdt_in_record:
49+
return cursor
4650
if allow_sql.get():
4751
wrapper = NormalCursorWrapper
4852
else:
4953
wrapper = ExceptionCursorWrapper
50-
return wrapper(connection._djdt_cursor(*args, **kwargs), connection, panel)
54+
return wrapper(cursor, connection, panel)
5155

5256
def chunked_cursor(*args, **kwargs):
5357
# prevent double wrapping
5458
# solves https://github.com/jazzband/django-debug-toolbar/issues/1239
5559
cursor = connection._djdt_chunked_cursor(*args, **kwargs)
60+
# Do not wrap cursors that are created during post-processing in ._record()
61+
if connection._djdt_in_record:
62+
return cursor
5663
if not isinstance(cursor, BaseCursorWrapper):
5764
if allow_sql.get():
5865
wrapper = NormalCursorWrapper
@@ -63,6 +70,7 @@ def chunked_cursor(*args, **kwargs):
6370

6471
connection.cursor = cursor
6572
connection.chunked_cursor = chunked_cursor
73+
connection._djdt_in_record = False
6674

6775

6876
def unwrap_cursor(connection):
@@ -166,78 +174,88 @@ def _record(self, method, sql, params):
166174
try:
167175
return method(sql, params)
168176
finally:
169-
stop_time = time()
170-
duration = (stop_time - start_time) * 1000
171-
_params = ""
177+
# In certain cases the following code can cause Django to create additional
178+
# CursorWrapper instances (in particular, the
179+
# self.db.ops.last_executed_query() call with psycopg3). However, we do not
180+
# want to wrap such cursors, so set the following flag to avoid that.
181+
self.db._djdt_in_record = True
172182
try:
173-
_params = json.dumps(self._decode(params))
174-
except TypeError:
175-
pass # object not JSON serializable
176-
template_info = get_template_info()
177-
178-
# Sql might be an object (such as psycopg Composed).
179-
# For logging purposes, make sure it's str.
180-
if vendor == "postgresql" and not isinstance(sql, str):
181-
sql = sql.as_string(conn)
182-
else:
183-
sql = str(sql)
184-
185-
params = {
186-
"vendor": vendor,
187-
"alias": alias,
188-
"sql": self.db.ops.last_executed_query(
189-
self.cursor, sql, self._quote_params(params)
190-
),
191-
"duration": duration,
192-
"raw_sql": sql,
193-
"params": _params,
194-
"raw_params": params,
195-
"stacktrace": get_stack_trace(skip=2),
196-
"start_time": start_time,
197-
"stop_time": stop_time,
198-
"is_slow": duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"],
199-
"is_select": sql.lower().strip().startswith("select"),
200-
"template_info": template_info,
201-
}
202-
203-
if vendor == "postgresql":
204-
# If an erroneous query was ran on the connection, it might
205-
# be in a state where checking isolation_level raises an
206-
# exception.
183+
stop_time = time()
184+
duration = (stop_time - start_time) * 1000
185+
_params = ""
207186
try:
208-
iso_level = conn.isolation_level
209-
except conn.InternalError:
210-
iso_level = "unknown"
211-
# PostgreSQL does not expose any sort of transaction ID, so it is
212-
# necessary to generate synthetic transaction IDs here. If the
213-
# connection was not in a transaction when the query started, and was
214-
# after the query finished, a new transaction definitely started, so get
215-
# a new transaction ID from logger.new_transaction_id(). If the query
216-
# was in a transaction both before and after executing, make the
217-
# assumption that it is the same transaction and get the current
218-
# transaction ID from logger.current_transaction_id(). There is an edge
219-
# case where Django can start a transaction before the first query
220-
# executes, so in that case logger.current_transaction_id() will
221-
# generate a new transaction ID since one does not already exist.
222-
final_conn_status = conn.info.transaction_status
223-
if final_conn_status == STATUS_IN_TRANSACTION:
224-
if initial_conn_status == STATUS_IN_TRANSACTION:
225-
trans_id = self.logger.current_transaction_id(alias)
226-
else:
227-
trans_id = self.logger.new_transaction_id(alias)
187+
_params = json.dumps(self._decode(params))
188+
except TypeError:
189+
pass # object not JSON serializable
190+
template_info = get_template_info()
191+
192+
# Sql might be an object (such as psycopg Composed).
193+
# For logging purposes, make sure it's str.
194+
if vendor == "postgresql" and not isinstance(sql, str):
195+
sql = sql.as_string(conn)
228196
else:
229-
trans_id = None
230-
231-
params.update(
232-
{
233-
"trans_id": trans_id,
234-
"trans_status": conn.info.transaction_status,
235-
"iso_level": iso_level,
236-
}
237-
)
238-
239-
# We keep `sql` to maintain backwards compatibility
240-
self.logger.record(**params)
197+
sql = str(sql)
198+
199+
params = {
200+
"vendor": vendor,
201+
"alias": alias,
202+
"sql": self.db.ops.last_executed_query(
203+
self.cursor, sql, self._quote_params(params)
204+
),
205+
"duration": duration,
206+
"raw_sql": sql,
207+
"params": _params,
208+
"raw_params": params,
209+
"stacktrace": get_stack_trace(skip=2),
210+
"start_time": start_time,
211+
"stop_time": stop_time,
212+
"is_slow": (
213+
duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"]
214+
),
215+
"is_select": sql.lower().strip().startswith("select"),
216+
"template_info": template_info,
217+
}
218+
219+
if vendor == "postgresql":
220+
# If an erroneous query was ran on the connection, it might
221+
# be in a state where checking isolation_level raises an
222+
# exception.
223+
try:
224+
iso_level = conn.isolation_level
225+
except conn.InternalError:
226+
iso_level = "unknown"
227+
# PostgreSQL does not expose any sort of transaction ID, so it is
228+
# necessary to generate synthetic transaction IDs here. If the
229+
# connection was not in a transaction when the query started, and was
230+
# after the query finished, a new transaction definitely started, so get
231+
# a new transaction ID from logger.new_transaction_id(). If the query
232+
# was in a transaction both before and after executing, make the
233+
# assumption that it is the same transaction and get the current
234+
# transaction ID from logger.current_transaction_id(). There is an edge
235+
# case where Django can start a transaction before the first query
236+
# executes, so in that case logger.current_transaction_id() will
237+
# generate a new transaction ID since one does not already exist.
238+
final_conn_status = conn.info.transaction_status
239+
if final_conn_status == STATUS_IN_TRANSACTION:
240+
if initial_conn_status == STATUS_IN_TRANSACTION:
241+
trans_id = self.logger.current_transaction_id(alias)
242+
else:
243+
trans_id = self.logger.new_transaction_id(alias)
244+
else:
245+
trans_id = None
246+
247+
params.update(
248+
{
249+
"trans_id": trans_id,
250+
"trans_status": conn.info.transaction_status,
251+
"iso_level": iso_level,
252+
}
253+
)
254+
255+
# We keep `sql` to maintain backwards compatibility
256+
self.logger.record(**params)
257+
finally:
258+
self.db._djdt_in_record = False
241259

242260
def callproc(self, procname, params=None):
243261
return self._record(self.cursor.callproc, procname, params)

0 commit comments

Comments
 (0)