Skip to content

Commit f68f012

Browse files
authored
Merge pull request #1585 from bellini666/main
Fix sql recording for async views
2 parents 54e63f0 + 468660f commit f68f012

File tree

4 files changed

+94
-39
lines changed

4 files changed

+94
-39
lines changed

debug_toolbar/panels/sql/tracking.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
import contextvars
12
import datetime
23
import json
3-
from threading import local
44
from time import time
55

66
from django.utils.encoding import force_str
@@ -13,30 +13,12 @@
1313
except ImportError:
1414
PostgresJson = None
1515

16+
recording = contextvars.ContextVar("debug-toolbar-recording", default=True)
17+
1618

1719
class SQLQueryTriggered(Exception):
1820
"""Thrown when template panel triggers a query"""
1921

20-
pass
21-
22-
23-
class ThreadLocalState(local):
24-
def __init__(self):
25-
self.enabled = True
26-
27-
@property
28-
def Wrapper(self):
29-
if self.enabled:
30-
return NormalCursorWrapper
31-
return ExceptionCursorWrapper
32-
33-
def recording(self, v):
34-
self.enabled = v
35-
36-
37-
state = ThreadLocalState()
38-
recording = state.recording # export function
39-
4022

4123
def wrap_cursor(connection, panel):
4224
if not hasattr(connection, "_djdt_cursor"):
@@ -50,16 +32,22 @@ def cursor(*args, **kwargs):
5032
# See:
5133
# https://github.com/jazzband/django-debug-toolbar/pull/615
5234
# https://github.com/jazzband/django-debug-toolbar/pull/896
53-
return state.Wrapper(
54-
connection._djdt_cursor(*args, **kwargs), connection, panel
55-
)
35+
if recording.get():
36+
wrapper = NormalCursorWrapper
37+
else:
38+
wrapper = ExceptionCursorWrapper
39+
return wrapper(connection._djdt_cursor(*args, **kwargs), connection, panel)
5640

5741
def chunked_cursor(*args, **kwargs):
5842
# prevent double wrapping
5943
# solves https://github.com/jazzband/django-debug-toolbar/issues/1239
6044
cursor = connection._djdt_chunked_cursor(*args, **kwargs)
6145
if not isinstance(cursor, BaseCursorWrapper):
62-
return state.Wrapper(cursor, connection, panel)
46+
if recording.get():
47+
wrapper = NormalCursorWrapper
48+
else:
49+
wrapper = ExceptionCursorWrapper
50+
return wrapper(cursor, connection, panel)
6351
return cursor
6452

6553
connection.cursor = cursor

debug_toolbar/panels/templates/panel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def _store_template_info(self, sender, **kwargs):
118118
value.model._meta.label,
119119
)
120120
else:
121-
recording(False)
121+
token = recording.set(False)
122122
try:
123123
saferepr(value) # this MAY trigger a db query
124124
except SQLQueryTriggered:
@@ -130,7 +130,7 @@ def _store_template_info(self, sender, **kwargs):
130130
else:
131131
temp_layer[key] = value
132132
finally:
133-
recording(True)
133+
recording.reset(token)
134134
pformatted = pformat(temp_layer)
135135
self.pformat_layers.append((context_layer, pformatted))
136136
context_list.append(pformatted)

tests/panels/test_sql.py

Lines changed: 57 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import asyncio
12
import datetime
23
import os
34
import unittest
45
from unittest.mock import patch
56

67
import django
8+
from asgiref.sync import sync_to_async
79
from django.contrib.auth.models import User
810
from django.db import connection
911
from django.db.models import Count
@@ -16,6 +18,14 @@
1618

1719
from ..base import BaseTestCase
1820
from ..models import PostgresJSON
21+
from ..sync import database_sync_to_async
22+
23+
24+
def sql_call(use_iterator=False):
25+
qs = User.objects.all()
26+
if use_iterator:
27+
qs = qs.iterator()
28+
return list(qs)
1929

2030

2131
class SQLPanelTestCase(BaseTestCase):
@@ -30,7 +40,7 @@ def test_disabled(self):
3040
def test_recording(self):
3141
self.assertEqual(len(self.panel._queries), 0)
3242

33-
list(User.objects.all())
43+
sql_call()
3444

3545
# ensure query was logged
3646
self.assertEqual(len(self.panel._queries), 1)
@@ -49,29 +59,64 @@ def test_recording(self):
4959
def test_recording_chunked_cursor(self):
5060
self.assertEqual(len(self.panel._queries), 0)
5161

52-
list(User.objects.all().iterator())
62+
sql_call(use_iterator=True)
5363

5464
# ensure query was logged
5565
self.assertEqual(len(self.panel._queries), 1)
5666

57-
@patch("debug_toolbar.panels.sql.tracking.state", wraps=sql_tracking.state)
58-
def test_cursor_wrapper_singleton(self, mock_state):
59-
list(User.objects.all())
67+
@patch(
68+
"debug_toolbar.panels.sql.tracking.NormalCursorWrapper",
69+
wraps=sql_tracking.NormalCursorWrapper,
70+
)
71+
def test_cursor_wrapper_singleton(self, mock_wrapper):
72+
sql_call()
6073

6174
# ensure that cursor wrapping is applied only once
62-
self.assertEqual(mock_state.Wrapper.call_count, 1)
75+
self.assertEqual(mock_wrapper.call_count, 1)
6376

64-
@patch("debug_toolbar.panels.sql.tracking.state", wraps=sql_tracking.state)
65-
def test_chunked_cursor_wrapper_singleton(self, mock_state):
66-
list(User.objects.all().iterator())
77+
@patch(
78+
"debug_toolbar.panels.sql.tracking.NormalCursorWrapper",
79+
wraps=sql_tracking.NormalCursorWrapper,
80+
)
81+
def test_chunked_cursor_wrapper_singleton(self, mock_wrapper):
82+
sql_call(use_iterator=True)
6783

6884
# ensure that cursor wrapping is applied only once
69-
self.assertEqual(mock_state.Wrapper.call_count, 1)
85+
self.assertEqual(mock_wrapper.call_count, 1)
86+
87+
@patch(
88+
"debug_toolbar.panels.sql.tracking.NormalCursorWrapper",
89+
wraps=sql_tracking.NormalCursorWrapper,
90+
)
91+
async def test_cursor_wrapper_async(self, mock_wrapper):
92+
await sync_to_async(sql_call)()
93+
94+
self.assertEqual(mock_wrapper.call_count, 1)
95+
96+
@patch(
97+
"debug_toolbar.panels.sql.tracking.NormalCursorWrapper",
98+
wraps=sql_tracking.NormalCursorWrapper,
99+
)
100+
async def test_cursor_wrapper_asyncio_ctx(self, mock_wrapper):
101+
self.assertTrue(sql_tracking.recording.get())
102+
await sync_to_async(sql_call)()
103+
104+
async def task():
105+
sql_tracking.recording.set(False)
106+
# Calling this in another context requires the db connections
107+
# to be closed properly.
108+
await database_sync_to_async(sql_call)()
109+
110+
# Ensure this is called in another context
111+
await asyncio.create_task(task())
112+
# Because it was called in another context, it should not have affected ours
113+
self.assertTrue(sql_tracking.recording.get())
114+
self.assertEqual(mock_wrapper.call_count, 1)
70115

71116
def test_generate_server_timing(self):
72117
self.assertEqual(len(self.panel._queries), 0)
73118

74-
list(User.objects.all())
119+
sql_call()
75120

76121
response = self.panel.process_request(self.request)
77122
self.panel.generate_stats(self.request, response)
@@ -337,7 +382,7 @@ def test_disable_stacktraces(self):
337382
self.assertEqual(len(self.panel._queries), 0)
338383

339384
with self.settings(DEBUG_TOOLBAR_CONFIG={"ENABLE_STACKTRACES": False}):
340-
list(User.objects.all())
385+
sql_call()
341386

342387
# ensure query was logged
343388
self.assertEqual(len(self.panel._queries), 1)

tests/sync.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
"""
2+
Taken from channels.db
3+
"""
4+
from asgiref.sync import SyncToAsync
5+
from django.db import close_old_connections
6+
7+
8+
class DatabaseSyncToAsync(SyncToAsync):
9+
"""
10+
SyncToAsync version that cleans up old database connections when it exits.
11+
"""
12+
13+
def thread_handler(self, loop, *args, **kwargs):
14+
close_old_connections()
15+
try:
16+
return super().thread_handler(loop, *args, **kwargs)
17+
finally:
18+
close_old_connections()
19+
20+
21+
# The class is TitleCased, but we want to encourage use as a callable/decorator
22+
database_sync_to_async = DatabaseSyncToAsync

0 commit comments

Comments
 (0)