4
4
from time import time
5
5
6
6
import django .test .testcases
7
+ from django .db .backends .utils import CursorWrapper
7
8
from django .utils .encoding import force_str
8
9
9
10
from debug_toolbar import settings as dt_settings
@@ -64,7 +65,7 @@ def chunked_cursor(*args, **kwargs):
64
65
# solves https://github.com/jazzband/django-debug-toolbar/issues/1239
65
66
logger = connection ._djdt_logger
66
67
cursor = connection ._djdt_chunked_cursor (* args , ** kwargs )
67
- if logger is not None and not isinstance (cursor , BaseCursorWrapper ):
68
+ if logger is not None and not isinstance (cursor , DjDTCursorWrapper ):
68
69
if allow_sql .get ():
69
70
wrapper = NormalCursorWrapper
70
71
else :
@@ -76,35 +77,28 @@ def chunked_cursor(*args, **kwargs):
76
77
connection .chunked_cursor = chunked_cursor
77
78
78
79
79
- class BaseCursorWrapper :
80
- pass
80
+ class DjDTCursorWrapper (CursorWrapper ):
81
+ def __init__ (self , cursor , db , logger ):
82
+ super ().__init__ (cursor , db )
83
+ # logger must implement a ``record`` method
84
+ self .logger = logger
81
85
82
86
83
- class ExceptionCursorWrapper (BaseCursorWrapper ):
87
+ class ExceptionCursorWrapper (DjDTCursorWrapper ):
84
88
"""
85
89
Wraps a cursor and raises an exception on any operation.
86
90
Used in Templates panel.
87
91
"""
88
92
89
- def __init__ (self , cursor , db , logger ):
90
- pass
91
-
92
93
def __getattr__ (self , attr ):
93
94
raise SQLQueryTriggered ()
94
95
95
96
96
- class NormalCursorWrapper (BaseCursorWrapper ):
97
+ class NormalCursorWrapper (DjDTCursorWrapper ):
97
98
"""
98
99
Wraps a cursor and logs queries.
99
100
"""
100
101
101
- def __init__ (self , cursor , db , logger ):
102
- self .cursor = cursor
103
- # Instance of a BaseDatabaseWrapper subclass
104
- self .db = db
105
- # logger must implement a ``record`` method
106
- self .logger = logger
107
-
108
102
def _quote_expr (self , element ):
109
103
if isinstance (element , str ):
110
104
return "'%s'" % element .replace ("'" , "''" )
@@ -246,22 +240,10 @@ def _record(self, method, sql, params):
246
240
self .logger .record (** params )
247
241
248
242
def callproc (self , procname , params = None ):
249
- return self ._record (self . cursor .callproc , procname , params )
243
+ return self ._record (super () .callproc , procname , params )
250
244
251
245
def execute (self , sql , params = None ):
252
- return self ._record (self . cursor .execute , sql , params )
246
+ return self ._record (super () .execute , sql , params )
253
247
254
248
def executemany (self , sql , param_list ):
255
- return self ._record (self .cursor .executemany , sql , param_list )
256
-
257
- def __getattr__ (self , attr ):
258
- return getattr (self .cursor , attr )
259
-
260
- def __iter__ (self ):
261
- return iter (self .cursor )
262
-
263
- def __enter__ (self ):
264
- return self
265
-
266
- def __exit__ (self , type , value , traceback ):
267
- self .close ()
249
+ return self ._record (super ().executemany , sql , param_list )
0 commit comments