Skip to content

Commit d346f30

Browse files
committed
Merge pull request #224 from Apkawa/master
Add escaping string literal in sql for copy&pasting in dbshell
2 parents 5c55e3c + d6a3b17 commit d346f30

File tree

1 file changed

+33
-9
lines changed
  • debug_toolbar/utils/tracking

1 file changed

+33
-9
lines changed

debug_toolbar/utils/tracking/db.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,20 @@
99
from django.utils.encoding import force_unicode, smart_str
1010
from django.utils.hashcompat import sha_constructor
1111

12-
from debug_toolbar.utils import ms_from_timedelta, tidy_stacktrace, get_template_info, \
13-
get_stack
12+
from debug_toolbar.utils import ms_from_timedelta, tidy_stacktrace, \
13+
get_template_info, get_stack
1414
from debug_toolbar.utils.compat.db import connections
1515
# TODO:This should be set in the toolbar loader as a default and panels should
1616
# get a copy of the toolbar object with access to its config dictionary
1717
SQL_WARNING_THRESHOLD = getattr(settings, 'DEBUG_TOOLBAR_CONFIG', {}) \
1818
.get('SQL_WARNING_THRESHOLD', 500)
1919

20+
2021
class SQLQueryTriggered(Exception):
2122
"""Thrown when template panel triggers a query"""
2223
pass
2324

25+
2426
class ThreadLocalState(local):
2527
def __init__(self):
2628
self.enabled = True
@@ -34,12 +36,15 @@ def Wrapper(self):
3436
def recording(self, v):
3537
self.enabled = v
3638

39+
3740
state = ThreadLocalState()
38-
recording = state.recording # export function
41+
recording = state.recording # export function
42+
3943

4044
def CursorWrapper(*args, **kwds): # behave like a class
4145
return state.Wrapper(*args, **kwds)
4246

47+
4348
class ExceptionCursorWrapper(object):
4449
"""
4550
Wraps a cursor and raises an exception on any operation.
@@ -51,6 +56,7 @@ def __init__(self, cursor, db, logger):
5156
def __getattr__(self, attr):
5257
raise SQLQueryTriggered()
5358

59+
5460
class NormalCursorWrapper(object):
5561
"""
5662
Wraps a cursor and logs queries.
@@ -63,6 +69,19 @@ def __init__(self, cursor, db, logger):
6369
# logger must implement a ``record`` method
6470
self.logger = logger
6571

72+
def _quote_expr(self, element):
73+
if isinstance(element, basestring):
74+
element = element.replace("'", "''")
75+
return "'%s'" % element
76+
else:
77+
return repr(element)
78+
79+
def _quote_params(self, params):
80+
if isinstance(params, dict):
81+
return dict((key, self._quote_expr(value))
82+
for key, value in params.iteritems())
83+
return map(self._quote_expr, params)
84+
6685
def execute(self, sql, params=()):
6786
__traceback_hide__ = True
6887
start = datetime.now()
@@ -71,17 +90,20 @@ def execute(self, sql, params=()):
7190
finally:
7291
stop = datetime.now()
7392
duration = ms_from_timedelta(stop - start)
74-
enable_stacktraces = getattr(settings, 'DEBUG_TOOLBAR_CONFIG', {}) \
93+
enable_stacktraces = getattr(settings,
94+
'DEBUG_TOOLBAR_CONFIG', {}) \
7595
.get('ENABLE_STACKTRACES', True)
7696
if enable_stacktraces:
7797
stacktrace = tidy_stacktrace(reversed(get_stack()))
7898
else:
7999
stacktrace = []
80100
_params = ''
81101
try:
82-
_params = simplejson.dumps([force_unicode(x, strings_only=True) for x in params])
102+
_params = simplejson.dumps(
103+
[force_unicode(x, strings_only=True) for x in params]
104+
)
83105
except TypeError:
84-
pass # object not JSON serializable
106+
pass # object not JSON serializable
85107

86108
template_info = None
87109
cur_frame = sys._getframe().f_back
@@ -108,11 +130,14 @@ def execute(self, sql, params=()):
108130
params = {
109131
'engine': engine,
110132
'alias': alias,
111-
'sql': self.db.ops.last_executed_query(self.cursor, sql, params),
133+
'sql': self.db.ops.last_executed_query(self.cursor, sql,
134+
self._quote_params(params)),
112135
'duration': duration,
113136
'raw_sql': sql,
114137
'params': _params,
115-
'hash': sha_constructor(settings.SECRET_KEY + smart_str(sql) + _params).hexdigest(),
138+
'hash': sha_constructor(settings.SECRET_KEY \
139+
+ smart_str(sql) \
140+
+ _params).hexdigest(),
116141
'stacktrace': stacktrace,
117142
'start_time': start,
118143
'stop_time': stop,
@@ -129,7 +154,6 @@ def execute(self, sql, params=()):
129154
'encoding': conn.encoding,
130155
})
131156

132-
133157
# We keep `sql` to maintain backwards compatibility
134158
self.logger.record(**params)
135159

0 commit comments

Comments
 (0)