Skip to content

Commit d71474a

Browse files
committed
Utilize the allow_sql context var to record SQL queries.
This also switches our CursorWrappers to inherit from the django class django.db.backends.utils.CursorWrapper. This reduces some of the code we need. This also explicitly disallows specific cursor methods from being used. This is because the psycopg3 backend's mogrify function creates a new cursor in order to determine the last used connection. This means that the ExceptionCursorWrapper technically needs to access both connection and cursor. Rather than defining an odd allow list, an explicit deny list made more sense. One area of concern is that this wouldn't cover the __iter__ function.
1 parent d935165 commit d71474a

File tree

1 file changed

+114
-120
lines changed

1 file changed

+114
-120
lines changed

debug_toolbar/panels/sql/tracking.py

Lines changed: 114 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import json
44
from time import time
55

6+
from django.db.backends.utils import CursorWrapper
67
from django.utils.encoding import force_str
78

89
from debug_toolbar import settings as dt_settings
@@ -43,21 +44,16 @@ def cursor(*args, **kwargs):
4344
# See:
4445
# https://github.com/jazzband/django-debug-toolbar/pull/615
4546
# https://github.com/jazzband/django-debug-toolbar/pull/896
46-
cursor = connection._djdt_cursor(*args, **kwargs)
47-
if connection._djdt_in_record:
48-
return cursor
4947
if allow_sql.get():
5048
wrapper = NormalCursorWrapper
5149
else:
5250
wrapper = ExceptionCursorWrapper
53-
return wrapper(cursor, connection, panel)
51+
return wrapper(connection._djdt_cursor(*args, **kwargs), connection, panel)
5452

5553
def chunked_cursor(*args, **kwargs):
5654
# prevent double wrapping
5755
# solves https://github.com/jazzband/django-debug-toolbar/issues/1239
5856
cursor = connection._djdt_chunked_cursor(*args, **kwargs)
59-
if connection._djdt_in_record:
60-
return cursor
6157
if not isinstance(cursor, BaseCursorWrapper):
6258
if allow_sql.get():
6359
wrapper = NormalCursorWrapper
@@ -68,7 +64,6 @@ def chunked_cursor(*args, **kwargs):
6864

6965
connection.cursor = cursor
7066
connection.chunked_cursor = chunked_cursor
71-
connection._djdt_in_record = False
7267

7368

7469
def unwrap_cursor(connection):
@@ -91,8 +86,11 @@ def unwrap_cursor(connection):
9186
del connection._djdt_chunked_cursor
9287

9388

94-
class BaseCursorWrapper:
95-
pass
89+
class BaseCursorWrapper(CursorWrapper):
90+
def __init__(self, cursor, db, logger):
91+
super().__init__(cursor, db)
92+
# logger must implement a ``record`` method
93+
self.logger = logger
9694

9795

9896
class ExceptionCursorWrapper(BaseCursorWrapper):
@@ -101,25 +99,28 @@ class ExceptionCursorWrapper(BaseCursorWrapper):
10199
Used in Templates panel.
102100
"""
103101

104-
def __init__(self, cursor, db, logger):
105-
pass
106-
107102
def __getattr__(self, attr):
108-
raise SQLQueryTriggered()
103+
# This allows the cursor to access connection and close which
104+
# are needed in psycopg to determine the last_executed_query via
105+
# the mogrify function.
106+
if attr in (
107+
"callproc",
108+
"execute",
109+
"executemany",
110+
"fetchone",
111+
"fetchmany",
112+
"fetchall",
113+
"nextset",
114+
):
115+
raise SQLQueryTriggered(f"Attr: {attr} was accessed")
116+
return super().__getattr__(attr)
109117

110118

111119
class NormalCursorWrapper(BaseCursorWrapper):
112120
"""
113121
Wraps a cursor and logs queries.
114122
"""
115123

116-
def __init__(self, cursor, db, logger):
117-
self.cursor = cursor
118-
# Instance of a BaseDatabaseWrapper subclass
119-
self.db = db
120-
# logger must implement a ``record`` method
121-
self.logger = logger
122-
123124
def _quote_expr(self, element):
124125
if isinstance(element, str):
125126
return "'%s'" % element.replace("'", "''")
@@ -159,115 +160,108 @@ def _decode(self, param):
159160
except UnicodeDecodeError:
160161
return "(encoded string)"
161162

163+
def _get_last_executed_query(self, sql, params):
164+
"""Get the last executed query from the connection."""
165+
# The pyscopg3 backend uses a mogrify function which creates a new cursor.
166+
# We need to avoid hooking into that cursor.
167+
reset_token = allow_sql.set(False)
168+
sql_query = self.db.ops.last_executed_query(
169+
self.cursor, sql, self._quote_params(params)
170+
)
171+
allow_sql.reset(reset_token)
172+
return sql_query
173+
162174
def _record(self, method, sql, params):
163-
self.db._djdt_in_record = True
164-
try:
165-
alias = self.db.alias
166-
vendor = self.db.vendor
175+
alias = self.db.alias
176+
vendor = self.db.vendor
167177

168-
if vendor == "postgresql":
169-
# The underlying DB connection (as opposed to Django's wrapper)
170-
conn = self.db.connection
171-
initial_conn_status = conn.info.transaction_status
178+
if vendor == "postgresql":
179+
# The underlying DB connection (as opposed to Django's wrapper)
180+
conn = self.db.connection
181+
initial_conn_status = conn.info.transaction_status
172182

173-
start_time = time()
183+
start_time = time()
184+
try:
185+
return method(sql, params)
186+
finally:
187+
stop_time = time()
188+
duration = (stop_time - start_time) * 1000
189+
_params = ""
174190
try:
175-
return method(sql, params)
176-
finally:
177-
stop_time = time()
178-
duration = (stop_time - start_time) * 1000
179-
_params = ""
191+
_params = json.dumps(self._decode(params))
192+
except TypeError:
193+
pass # object not JSON serializable
194+
template_info = get_template_info()
195+
196+
# Sql might be an object (such as psycopg Composed).
197+
# For logging purposes, make sure it's str.
198+
if vendor == "postgresql" and not isinstance(sql, str):
199+
sql = sql.as_string(conn)
200+
else:
201+
sql = str(sql)
202+
203+
params = {
204+
"vendor": vendor,
205+
"alias": alias,
206+
"sql": self._get_last_executed_query(sql, params),
207+
"duration": duration,
208+
"raw_sql": sql,
209+
"params": _params,
210+
"raw_params": params,
211+
"stacktrace": get_stack_trace(skip=2),
212+
"start_time": start_time,
213+
"stop_time": stop_time,
214+
"is_slow": (
215+
duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"]
216+
),
217+
"is_select": sql.lower().strip().startswith("select"),
218+
"template_info": template_info,
219+
}
220+
221+
if vendor == "postgresql":
222+
# If an erroneous query was ran on the connection, it might
223+
# be in a state where checking isolation_level raises an
224+
# exception.
180225
try:
181-
_params = json.dumps(self._decode(params))
182-
except TypeError:
183-
pass # object not JSON serializable
184-
template_info = get_template_info()
185-
186-
# Sql might be an object (such as psycopg Composed).
187-
# For logging purposes, make sure it's str.
188-
if vendor == "postgresql" and not isinstance(sql, str):
189-
sql = sql.as_string(conn)
190-
else:
191-
sql = str(sql)
192-
193-
params = {
194-
"vendor": vendor,
195-
"alias": alias,
196-
"sql": self.db.ops.last_executed_query(
197-
self.cursor, sql, self._quote_params(params)
198-
),
199-
"duration": duration,
200-
"raw_sql": sql,
201-
"params": _params,
202-
"raw_params": params,
203-
"stacktrace": get_stack_trace(skip=2),
204-
"start_time": start_time,
205-
"stop_time": stop_time,
206-
"is_slow": (
207-
duration > dt_settings.get_config()["SQL_WARNING_THRESHOLD"]
208-
),
209-
"is_select": sql.lower().strip().startswith("select"),
210-
"template_info": template_info,
211-
}
212-
213-
if vendor == "postgresql":
214-
# If an erroneous query was ran on the connection, it might
215-
# be in a state where checking isolation_level raises an
216-
# exception.
217-
try:
218-
iso_level = conn.isolation_level
219-
except conn.InternalError:
220-
iso_level = "unknown"
221-
# PostgreSQL does not expose any sort of transaction ID, so it is
222-
# necessary to generate synthetic transaction IDs here. If the
223-
# connection was not in a transaction when the query started, and was
224-
# after the query finished, a new transaction definitely started, so get
225-
# a new transaction ID from logger.new_transaction_id(). If the query
226-
# was in a transaction both before and after executing, make the
227-
# assumption that it is the same transaction and get the current
228-
# transaction ID from logger.current_transaction_id(). There is an edge
229-
# case where Django can start a transaction before the first query
230-
# executes, so in that case logger.current_transaction_id() will
231-
# generate a new transaction ID since one does not already exist.
232-
final_conn_status = conn.info.transaction_status
233-
if final_conn_status == STATUS_IN_TRANSACTION:
234-
if initial_conn_status == STATUS_IN_TRANSACTION:
235-
trans_id = self.logger.current_transaction_id(alias)
236-
else:
237-
trans_id = self.logger.new_transaction_id(alias)
226+
iso_level = conn.isolation_level
227+
except conn.InternalError:
228+
iso_level = "unknown"
229+
# PostgreSQL does not expose any sort of transaction ID, so it is
230+
# necessary to generate synthetic transaction IDs here. If the
231+
# connection was not in a transaction when the query started, and was
232+
# after the query finished, a new transaction definitely started, so get
233+
# a new transaction ID from logger.new_transaction_id(). If the query
234+
# was in a transaction both before and after executing, make the
235+
# assumption that it is the same transaction and get the current
236+
# transaction ID from logger.current_transaction_id(). There is an edge
237+
# case where Django can start a transaction before the first query
238+
# executes, so in that case logger.current_transaction_id() will
239+
# generate a new transaction ID since one does not already exist.
240+
final_conn_status = conn.info.transaction_status
241+
if final_conn_status == STATUS_IN_TRANSACTION:
242+
if initial_conn_status == STATUS_IN_TRANSACTION:
243+
trans_id = self.logger.current_transaction_id(alias)
238244
else:
239-
trans_id = None
240-
241-
params.update(
242-
{
243-
"trans_id": trans_id,
244-
"trans_status": conn.info.transaction_status,
245-
"iso_level": iso_level,
246-
}
247-
)
248-
249-
# We keep `sql` to maintain backwards compatibility
250-
self.logger.record(**params)
251-
finally:
252-
self.db._djdt_in_record = False
245+
trans_id = self.logger.new_transaction_id(alias)
246+
else:
247+
trans_id = None
248+
249+
params.update(
250+
{
251+
"trans_id": trans_id,
252+
"trans_status": conn.info.transaction_status,
253+
"iso_level": iso_level,
254+
}
255+
)
256+
257+
# We keep `sql` to maintain backwards compatibility
258+
self.logger.record(**params)
253259

254260
def callproc(self, procname, params=None):
255-
return self._record(self.cursor.callproc, procname, params)
261+
return self._record(super().callproc, procname, params)
256262

257263
def execute(self, sql, params=None):
258-
return self._record(self.cursor.execute, sql, params)
264+
return self._record(super().execute, sql, params)
259265

260266
def executemany(self, sql, param_list):
261-
return self._record(self.cursor.executemany, sql, param_list)
262-
263-
def __getattr__(self, attr):
264-
return getattr(self.cursor, attr)
265-
266-
def __iter__(self):
267-
return iter(self.cursor)
268-
269-
def __enter__(self):
270-
return self
271-
272-
def __exit__(self, type, value, traceback):
273-
self.close()
267+
return self._record(super().executemany, sql, param_list)

0 commit comments

Comments
 (0)