3
3
import json
4
4
from time import time
5
5
6
+ from django .db .backends .utils import CursorWrapper
6
7
from django .utils .encoding import force_str
7
8
8
9
from debug_toolbar import settings as dt_settings
@@ -43,21 +44,16 @@ def cursor(*args, **kwargs):
43
44
# See:
44
45
# https://github.com/jazzband/django-debug-toolbar/pull/615
45
46
# https://github.com/jazzband/django-debug-toolbar/pull/896
46
- cursor = connection ._djdt_cursor (* args , ** kwargs )
47
- if connection ._djdt_in_record :
48
- return cursor
49
47
if allow_sql .get ():
50
48
wrapper = NormalCursorWrapper
51
49
else :
52
50
wrapper = ExceptionCursorWrapper
53
- return wrapper (cursor , connection , panel )
51
+ return wrapper (connection . _djdt_cursor ( * args , ** kwargs ) , connection , panel )
54
52
55
53
def chunked_cursor (* args , ** kwargs ):
56
54
# prevent double wrapping
57
55
# solves https://github.com/jazzband/django-debug-toolbar/issues/1239
58
56
cursor = connection ._djdt_chunked_cursor (* args , ** kwargs )
59
- if connection ._djdt_in_record :
60
- return cursor
61
57
if not isinstance (cursor , BaseCursorWrapper ):
62
58
if allow_sql .get ():
63
59
wrapper = NormalCursorWrapper
@@ -68,7 +64,6 @@ def chunked_cursor(*args, **kwargs):
68
64
69
65
connection .cursor = cursor
70
66
connection .chunked_cursor = chunked_cursor
71
- connection ._djdt_in_record = False
72
67
73
68
74
69
def unwrap_cursor (connection ):
@@ -91,8 +86,11 @@ def unwrap_cursor(connection):
91
86
del connection ._djdt_chunked_cursor
92
87
93
88
94
- class BaseCursorWrapper :
95
- pass
89
+ class BaseCursorWrapper (CursorWrapper ):
90
+ def __init__ (self , cursor , db , logger ):
91
+ super ().__init__ (cursor , db )
92
+ # logger must implement a ``record`` method
93
+ self .logger = logger
96
94
97
95
98
96
class ExceptionCursorWrapper (BaseCursorWrapper ):
@@ -101,25 +99,28 @@ class ExceptionCursorWrapper(BaseCursorWrapper):
101
99
Used in Templates panel.
102
100
"""
103
101
104
- def __init__ (self , cursor , db , logger ):
105
- pass
106
-
107
102
def __getattr__ (self , attr ):
108
- raise SQLQueryTriggered ()
103
+ # This allows the cursor to access connection and close which
104
+ # are needed in psycopg to determine the last_executed_query via
105
+ # the mogrify function.
106
+ if attr in (
107
+ "callproc" ,
108
+ "execute" ,
109
+ "executemany" ,
110
+ "fetchone" ,
111
+ "fetchmany" ,
112
+ "fetchall" ,
113
+ "nextset" ,
114
+ ):
115
+ raise SQLQueryTriggered (f"Attr: { attr } was accessed" )
116
+ return super ().__getattr__ (attr )
109
117
110
118
111
119
class NormalCursorWrapper (BaseCursorWrapper ):
112
120
"""
113
121
Wraps a cursor and logs queries.
114
122
"""
115
123
116
- def __init__ (self , cursor , db , logger ):
117
- self .cursor = cursor
118
- # Instance of a BaseDatabaseWrapper subclass
119
- self .db = db
120
- # logger must implement a ``record`` method
121
- self .logger = logger
122
-
123
124
def _quote_expr (self , element ):
124
125
if isinstance (element , str ):
125
126
return "'%s'" % element .replace ("'" , "''" )
@@ -159,115 +160,108 @@ def _decode(self, param):
159
160
except UnicodeDecodeError :
160
161
return "(encoded string)"
161
162
163
+ def _get_last_executed_query (self , sql , params ):
164
+ """Get the last executed query from the connection."""
165
+ # The pyscopg3 backend uses a mogrify function which creates a new cursor.
166
+ # We need to avoid hooking into that cursor.
167
+ reset_token = allow_sql .set (False )
168
+ sql_query = self .db .ops .last_executed_query (
169
+ self .cursor , sql , self ._quote_params (params )
170
+ )
171
+ allow_sql .reset (reset_token )
172
+ return sql_query
173
+
162
174
def _record (self , method , sql , params ):
163
- self .db ._djdt_in_record = True
164
- try :
165
- alias = self .db .alias
166
- vendor = self .db .vendor
175
+ alias = self .db .alias
176
+ vendor = self .db .vendor
167
177
168
- if vendor == "postgresql" :
169
- # The underlying DB connection (as opposed to Django's wrapper)
170
- conn = self .db .connection
171
- initial_conn_status = conn .info .transaction_status
178
+ if vendor == "postgresql" :
179
+ # The underlying DB connection (as opposed to Django's wrapper)
180
+ conn = self .db .connection
181
+ initial_conn_status = conn .info .transaction_status
172
182
173
- start_time = time ()
183
+ start_time = time ()
184
+ try :
185
+ return method (sql , params )
186
+ finally :
187
+ stop_time = time ()
188
+ duration = (stop_time - start_time ) * 1000
189
+ _params = ""
174
190
try :
175
- return method (sql , params )
176
- finally :
177
- stop_time = time ()
178
- duration = (stop_time - start_time ) * 1000
179
- _params = ""
191
+ _params = json .dumps (self ._decode (params ))
192
+ except TypeError :
193
+ pass # object not JSON serializable
194
+ template_info = get_template_info ()
195
+
196
+ # Sql might be an object (such as psycopg Composed).
197
+ # For logging purposes, make sure it's str.
198
+ if vendor == "postgresql" and not isinstance (sql , str ):
199
+ sql = sql .as_string (conn )
200
+ else :
201
+ sql = str (sql )
202
+
203
+ params = {
204
+ "vendor" : vendor ,
205
+ "alias" : alias ,
206
+ "sql" : self ._get_last_executed_query (sql , params ),
207
+ "duration" : duration ,
208
+ "raw_sql" : sql ,
209
+ "params" : _params ,
210
+ "raw_params" : params ,
211
+ "stacktrace" : get_stack_trace (skip = 2 ),
212
+ "start_time" : start_time ,
213
+ "stop_time" : stop_time ,
214
+ "is_slow" : (
215
+ duration > dt_settings .get_config ()["SQL_WARNING_THRESHOLD" ]
216
+ ),
217
+ "is_select" : sql .lower ().strip ().startswith ("select" ),
218
+ "template_info" : template_info ,
219
+ }
220
+
221
+ if vendor == "postgresql" :
222
+ # If an erroneous query was ran on the connection, it might
223
+ # be in a state where checking isolation_level raises an
224
+ # exception.
180
225
try :
181
- _params = json .dumps (self ._decode (params ))
182
- except TypeError :
183
- pass # object not JSON serializable
184
- template_info = get_template_info ()
185
-
186
- # Sql might be an object (such as psycopg Composed).
187
- # For logging purposes, make sure it's str.
188
- if vendor == "postgresql" and not isinstance (sql , str ):
189
- sql = sql .as_string (conn )
190
- else :
191
- sql = str (sql )
192
-
193
- params = {
194
- "vendor" : vendor ,
195
- "alias" : alias ,
196
- "sql" : self .db .ops .last_executed_query (
197
- self .cursor , sql , self ._quote_params (params )
198
- ),
199
- "duration" : duration ,
200
- "raw_sql" : sql ,
201
- "params" : _params ,
202
- "raw_params" : params ,
203
- "stacktrace" : get_stack_trace (skip = 2 ),
204
- "start_time" : start_time ,
205
- "stop_time" : stop_time ,
206
- "is_slow" : (
207
- duration > dt_settings .get_config ()["SQL_WARNING_THRESHOLD" ]
208
- ),
209
- "is_select" : sql .lower ().strip ().startswith ("select" ),
210
- "template_info" : template_info ,
211
- }
212
-
213
- if vendor == "postgresql" :
214
- # If an erroneous query was ran on the connection, it might
215
- # be in a state where checking isolation_level raises an
216
- # exception.
217
- try :
218
- iso_level = conn .isolation_level
219
- except conn .InternalError :
220
- iso_level = "unknown"
221
- # PostgreSQL does not expose any sort of transaction ID, so it is
222
- # necessary to generate synthetic transaction IDs here. If the
223
- # connection was not in a transaction when the query started, and was
224
- # after the query finished, a new transaction definitely started, so get
225
- # a new transaction ID from logger.new_transaction_id(). If the query
226
- # was in a transaction both before and after executing, make the
227
- # assumption that it is the same transaction and get the current
228
- # transaction ID from logger.current_transaction_id(). There is an edge
229
- # case where Django can start a transaction before the first query
230
- # executes, so in that case logger.current_transaction_id() will
231
- # generate a new transaction ID since one does not already exist.
232
- final_conn_status = conn .info .transaction_status
233
- if final_conn_status == STATUS_IN_TRANSACTION :
234
- if initial_conn_status == STATUS_IN_TRANSACTION :
235
- trans_id = self .logger .current_transaction_id (alias )
236
- else :
237
- trans_id = self .logger .new_transaction_id (alias )
226
+ iso_level = conn .isolation_level
227
+ except conn .InternalError :
228
+ iso_level = "unknown"
229
+ # PostgreSQL does not expose any sort of transaction ID, so it is
230
+ # necessary to generate synthetic transaction IDs here. If the
231
+ # connection was not in a transaction when the query started, and was
232
+ # after the query finished, a new transaction definitely started, so get
233
+ # a new transaction ID from logger.new_transaction_id(). If the query
234
+ # was in a transaction both before and after executing, make the
235
+ # assumption that it is the same transaction and get the current
236
+ # transaction ID from logger.current_transaction_id(). There is an edge
237
+ # case where Django can start a transaction before the first query
238
+ # executes, so in that case logger.current_transaction_id() will
239
+ # generate a new transaction ID since one does not already exist.
240
+ final_conn_status = conn .info .transaction_status
241
+ if final_conn_status == STATUS_IN_TRANSACTION :
242
+ if initial_conn_status == STATUS_IN_TRANSACTION :
243
+ trans_id = self .logger .current_transaction_id (alias )
238
244
else :
239
- trans_id = None
240
-
241
- params . update (
242
- {
243
- "trans_id" : trans_id ,
244
- "trans_status" : conn . info . transaction_status ,
245
- "iso_level " : iso_level ,
246
- }
247
- )
248
-
249
- # We keep `sql` to maintain backwards compatibility
250
- self . logger . record ( ** params )
251
- finally :
252
- self .db . _djdt_in_record = False
245
+ trans_id = self . logger . new_transaction_id ( alias )
246
+ else :
247
+ trans_id = None
248
+
249
+ params . update (
250
+ {
251
+ "trans_id " : trans_id ,
252
+ "trans_status" : conn . info . transaction_status ,
253
+ "iso_level" : iso_level ,
254
+ }
255
+ )
256
+
257
+ # We keep `sql` to maintain backwards compatibility
258
+ self .logger . record ( ** params )
253
259
254
260
def callproc (self , procname , params = None ):
255
- return self ._record (self . cursor .callproc , procname , params )
261
+ return self ._record (super () .callproc , procname , params )
256
262
257
263
def execute (self , sql , params = None ):
258
- return self ._record (self . cursor .execute , sql , params )
264
+ return self ._record (super () .execute , sql , params )
259
265
260
266
def executemany (self , sql , param_list ):
261
- return self ._record (self .cursor .executemany , sql , param_list )
262
-
263
- def __getattr__ (self , attr ):
264
- return getattr (self .cursor , attr )
265
-
266
- def __iter__ (self ):
267
- return iter (self .cursor )
268
-
269
- def __enter__ (self ):
270
- return self
271
-
272
- def __exit__ (self , type , value , traceback ):
273
- self .close ()
267
+ return self ._record (super ().executemany , sql , param_list )
0 commit comments