Skip to content

Commit 5a7d015

Browse files
tim-schillingliving180
authored andcommitted
Utilize the allow_sql context var to record SQL queries.
This also switches our CursorWrappers to inherit from the django class django.db.backends.utils.CursorWrapper. This reduces some of the code we need. This also explicitly disallows specific cursor methods from being used. This is because the psycopg3 backend's mogrify function creates a new cursor in order to determine the last used connection. This means that the ExceptionCursorWrapper technically needs to access both connection and cursor. Rather than defining an odd allow list, an explicit deny list made more sense. One area of concern is that this wouldn't cover the __iter__ function.
1 parent ab42456 commit 5a7d015

File tree

1 file changed

+102
-114
lines changed

1 file changed

+102
-114
lines changed

debug_toolbar/panels/sql/tracking.py

Lines changed: 102 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
from time import time
55

6+
from django.db.backends.utils import CursorWrapper
67
from django.utils.encoding import force_str
78

89
from debug_toolbar import settings as dt_settings
@@ -43,23 +44,16 @@ def cursor(*args, **kwargs):
4344
# See:
4445
# https://github.com/jazzband/django-debug-toolbar/pull/615
4546
# https://github.com/jazzband/django-debug-toolbar/pull/896
46-
cursor = connection._djdt_cursor(*args, **kwargs)
47-
# Do not wrap cursors that are created during post-processing in ._record()
48-
if connection._djdt_in_record:
49-
return cursor
5047
if allow_sql.get():
5148
wrapper = NormalCursorWrapper
5249
else:
5350
wrapper = ExceptionCursorWrapper
54-
return wrapper(cursor, connection, panel)
51+
return wrapper(connection._djdt_cursor(*args, **kwargs), connection, panel)
5552

5653
def chunked_cursor(*args, **kwargs):
5754
# prevent double wrapping
5855
# solves https://github.com/jazzband/django-debug-toolbar/issues/1239
5956
cursor = connection._djdt_chunked_cursor(*args, **kwargs)
60-
# Do not wrap cursors that are created during post-processing in ._record()
61-
if connection._djdt_in_record:
62-
return cursor
6357
if not isinstance(cursor, BaseCursorWrapper):
6458
if allow_sql.get():
6559
wrapper = NormalCursorWrapper
@@ -70,7 +64,6 @@ def chunked_cursor(*args, **kwargs):
7064

7165
connection.cursor = cursor
7266
connection.chunked_cursor = chunked_cursor
73-
connection._djdt_in_record = False
7467

7568

7669
def unwrap_cursor(connection):
@@ -93,8 +86,11 @@ def unwrap_cursor(connection):
9386
del connection._djdt_chunked_cursor
9487

9588

96-
class BaseCursorWrapper:
97-
pass
89+
class BaseCursorWrapper(CursorWrapper):
90+
def __init__(self, cursor, db, logger):
91+
super().__init__(cursor, db)
92+
# logger must implement a ``record`` method
93+
self.logger = logger
9894

9995

10096
class ExceptionCursorWrapper(BaseCursorWrapper):
@@ -103,25 +99,28 @@ class ExceptionCursorWrapper(BaseCursorWrapper):
10399
Used in Templates panel.
104100
"""
105101

106-
def __init__(self, cursor, db, logger):
107-
pass
108-
109102
def __getattr__(self, attr):
110-
raise SQLQueryTriggered()
103+
# This allows the cursor to access connection and close which
104+
# are needed in psycopg to determine the last_executed_query via
105+
# the mogrify function.
106+
if attr in (
107+
"callproc",
108+
"execute",
109+
"executemany",
110+
"fetchone",
111+
"fetchmany",
112+
"fetchall",
113+
"nextset",
114+
):
115+
raise SQLQueryTriggered(f"Attr: {attr} was accessed")
116+
return super().__getattr__(attr)
111117

112118

113119
class NormalCursorWrapper(BaseCursorWrapper):
114120
"""
115121
Wraps a cursor and logs queries.
116122
"""
117123

118-
def __init__(self, cursor, db, logger):
119-
self.cursor = cursor
120-
# Instance of a BaseDatabaseWrapper subclass
121-
self.db = db
122-
# logger must implement a ``record`` method
123-
self.logger = logger
124-
125124
def _quote_expr(self, element):
126125
if isinstance(element, str):
127126
return "'%s'" % element.replace("'", "''")
@@ -161,6 +160,17 @@ def _decode(self, param):
161160
except UnicodeDecodeError:
162161
return "(encoded string)"
163162

163+
def _get_last_executed_query(self, sql, params):
164+
"""Get the last executed query from the connection."""
165+
# The pyscopg3 backend uses a mogrify function which creates a new cursor.
166+
# We need to avoid hooking into that cursor.
167+
reset_token = allow_sql.set(False)
168+
sql_query = self.db.ops.last_executed_query(
169+
self.cursor, sql, self._quote_params(params)
170+
)
171+
allow_sql.reset(reset_token)
172+
return sql_query
173+
164174
def _record(self, method, sql, params):
165175
alias = self.db.alias
166176
vendor = self.db.vendor
@@ -174,106 +184,84 @@ def _record(self, method, sql, params):
174184
try:
175185
return method(sql, params)
176186
finally:
177-
# In certain cases the following code can cause Django to create additional
178-
# CursorWrapper instances (in particular, the
179-
# self.db.ops.last_executed_query() call with psycopg3). However, we do not
180-
# want to wrap such cursors, so set the following flag to avoid that.
181-
self.db._djdt_in_record = True
187+
stop_time = time()
188+
duration = (stop_time - start_time) * 1000
189+
_params = ""
182190
try:
183-
stop_time = time()
184-
duration = (stop_time - start_time) * 1000
185-
_params = ""
191+
_params = json.dumps(self._decode(params))
192+
except TypeError:
193+
pass # object not JSON serializable
194+
template_info = get_template_info()
195+
196+
# Sql might be an object (such as psycopg Composed).
197+
# For logging purposes, make sure it's str.
198+
if vendor == "postgresql" and not isinstance(sql, str):
199+
sql = sql.as_string(conn)
200+
else:
201+
sql = str(sql)
202+
203+
params = {
204+
"vendor": vendor,
205+
"alias": alias,
206+
"sql": self._get_last_executed_query(sql, params),
207+
"duration": duration,
208+
"raw_sql": sql,
209+
"params": _params,
210+
"raw_params": params,
211+
"stacktrace": get_stack_trace(skip=2),
212+
"start_time": start_time,
213+
"stop_time": stop_time,
214+
"is_slow": (
215+
duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"]
216+
),
217+
"is_select": sql.lower().strip().startswith("select"),
218+
"template_info": template_info,
219+
}
220+
221+
if vendor == "postgresql":
222+
# If an erroneous query was ran on the connection, it might
223+
# be in a state where checking isolation_level raises an
224+
# exception.
186225
try:
187-
_params = json.dumps(self._decode(params))
188-
except TypeError:
189-
pass # object not JSON serializable
190-
template_info = get_template_info()
191-
192-
# Sql might be an object (such as psycopg Composed).
193-
# For logging purposes, make sure it's str.
194-
if vendor == "postgresql" and not isinstance(sql, str):
195-
sql = sql.as_string(conn)
196-
else:
197-
sql = str(sql)
198-
199-
params = {
200-
"vendor": vendor,
201-
"alias": alias,
202-
"sql": self.db.ops.last_executed_query(
203-
self.cursor, sql, self._quote_params(params)
204-
),
205-
"duration": duration,
206-
"raw_sql": sql,
207-
"params": _params,
208-
"raw_params": params,
209-
"stacktrace": get_stack_trace(skip=2),
210-
"start_time": start_time,
211-
"stop_time": stop_time,
212-
"is_slow": (
213-
duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"]
214-
),
215-
"is_select": sql.lower().strip().startswith("select"),
216-
"template_info": template_info,
217-
}
218-
219-
if vendor == "postgresql":
220-
# If an erroneous query was ran on the connection, it might
221-
# be in a state where checking isolation_level raises an
222-
# exception.
223-
try:
224-
iso_level = conn.isolation_level
225-
except conn.InternalError:
226-
iso_level = "unknown"
227-
# PostgreSQL does not expose any sort of transaction ID, so it is
228-
# necessary to generate synthetic transaction IDs here. If the
229-
# connection was not in a transaction when the query started, and was
230-
# after the query finished, a new transaction definitely started, so get
231-
# a new transaction ID from logger.new_transaction_id(). If the query
232-
# was in a transaction both before and after executing, make the
233-
# assumption that it is the same transaction and get the current
234-
# transaction ID from logger.current_transaction_id(). There is an edge
235-
# case where Django can start a transaction before the first query
236-
# executes, so in that case logger.current_transaction_id() will
237-
# generate a new transaction ID since one does not already exist.
238-
final_conn_status = conn.info.transaction_status
239-
if final_conn_status == STATUS_IN_TRANSACTION:
240-
if initial_conn_status == STATUS_IN_TRANSACTION:
241-
trans_id = self.logger.current_transaction_id(alias)
242-
else:
243-
trans_id = self.logger.new_transaction_id(alias)
226+
iso_level = conn.isolation_level
227+
except conn.InternalError:
228+
iso_level = "unknown"
229+
# PostgreSQL does not expose any sort of transaction ID, so it is
230+
# necessary to generate synthetic transaction IDs here. If the
231+
# connection was not in a transaction when the query started, and was
232+
# after the query finished, a new transaction definitely started, so get
233+
# a new transaction ID from logger.new_transaction_id(). If the query
234+
# was in a transaction both before and after executing, make the
235+
# assumption that it is the same transaction and get the current
236+
# transaction ID from logger.current_transaction_id(). There is an edge
237+
# case where Django can start a transaction before the first query
238+
# executes, so in that case logger.current_transaction_id() will
239+
# generate a new transaction ID since one does not already exist.
240+
final_conn_status = conn.info.transaction_status
241+
if final_conn_status == STATUS_IN_TRANSACTION:
242+
if initial_conn_status == STATUS_IN_TRANSACTION:
243+
trans_id = self.logger.current_transaction_id(alias)
244244
else:
245-
trans_id = None
245+
trans_id = self.logger.new_transaction_id(alias)
246+
else:
247+
trans_id = None
246248

247-
params.update(
248-
{
249-
"trans_id": trans_id,
250-
"trans_status": conn.info.transaction_status,
251-
"iso_level": iso_level,
252-
}
253-
)
249+
params.update(
250+
{
251+
"trans_id": trans_id,
252+
"trans_status": conn.info.transaction_status,
253+
"iso_level": iso_level,
254+
}
255+
)
254256

255-
# We keep `sql` to maintain backwards compatibility
256-
self.logger.record(**params)
257-
finally:
258-
self.db._djdt_in_record = False
257+
# We keep `sql` to maintain backwards compatibility
258+
self.logger.record(**params)
259259

260260
def callproc(self, procname, params=None):
261-
return self._record(self.cursor.callproc, procname, params)
261+
return self._record(super().callproc, procname, params)
262262

263263
def execute(self, sql, params=None):
264-
return self._record(self.cursor.execute, sql, params)
264+
return self._record(super().execute, sql, params)
265265

266266
def executemany(self, sql, param_list):
267-
return self._record(self.cursor.executemany, sql, param_list)
268-
269-
def __getattr__(self, attr):
270-
return getattr(self.cursor, attr)
271-
272-
def __iter__(self):
273-
return iter(self.cursor)
274-
275-
def __enter__(self):
276-
return self
277-
278-
def __exit__(self, type, value, traceback):
279-
self.close()
267+
return self._record(super().executemany, sql, param_list)

0 commit comments

Comments
 (0)