1
+ import asyncio
1
2
import datetime
2
3
import os
3
4
import unittest
4
5
from unittest .mock import patch
5
6
6
7
import django
8
+ from asgiref .sync import sync_to_async
7
9
from django .contrib .auth .models import User
8
10
from django .db import connection
9
11
from django .db .models import Count
16
18
17
19
from ..base import BaseTestCase
18
20
from ..models import PostgresJSON
21
+ from ..sync import database_sync_to_async
22
+
23
+
24
+ def sql_call (use_iterator = False ):
25
+ qs = User .objects .all ()
26
+ if use_iterator :
27
+ qs = qs .iterator ()
28
+ return list (qs )
19
29
20
30
21
31
class SQLPanelTestCase (BaseTestCase ):
@@ -30,7 +40,7 @@ def test_disabled(self):
30
40
def test_recording (self ):
31
41
self .assertEqual (len (self .panel ._queries ), 0 )
32
42
33
- list ( User . objects . all () )
43
+ sql_call ( )
34
44
35
45
# ensure query was logged
36
46
self .assertEqual (len (self .panel ._queries ), 1 )
@@ -49,29 +59,64 @@ def test_recording(self):
49
59
def test_recording_chunked_cursor (self ):
50
60
self .assertEqual (len (self .panel ._queries ), 0 )
51
61
52
- list ( User . objects . all (). iterator () )
62
+ sql_call ( use_iterator = True )
53
63
54
64
# ensure query was logged
55
65
self .assertEqual (len (self .panel ._queries ), 1 )
56
66
57
- @patch ("debug_toolbar.panels.sql.tracking.state" , wraps = sql_tracking .state )
58
- def test_cursor_wrapper_singleton (self , mock_state ):
59
- list (User .objects .all ())
67
+ @patch (
68
+ "debug_toolbar.panels.sql.tracking.NormalCursorWrapper" ,
69
+ wraps = sql_tracking .NormalCursorWrapper ,
70
+ )
71
+ def test_cursor_wrapper_singleton (self , mock_wrapper ):
72
+ sql_call ()
60
73
61
74
# ensure that cursor wrapping is applied only once
62
- self .assertEqual (mock_state . Wrapper .call_count , 1 )
75
+ self .assertEqual (mock_wrapper .call_count , 1 )
63
76
64
- @patch ("debug_toolbar.panels.sql.tracking.state" , wraps = sql_tracking .state )
65
- def test_chunked_cursor_wrapper_singleton (self , mock_state ):
66
- list (User .objects .all ().iterator ())
77
+ @patch (
78
+ "debug_toolbar.panels.sql.tracking.NormalCursorWrapper" ,
79
+ wraps = sql_tracking .NormalCursorWrapper ,
80
+ )
81
+ def test_chunked_cursor_wrapper_singleton (self , mock_wrapper ):
82
+ sql_call (use_iterator = True )
67
83
68
84
# ensure that cursor wrapping is applied only once
69
- self .assertEqual (mock_state .Wrapper .call_count , 1 )
85
+ self .assertEqual (mock_wrapper .call_count , 1 )
86
+
87
+ @patch (
88
+ "debug_toolbar.panels.sql.tracking.NormalCursorWrapper" ,
89
+ wraps = sql_tracking .NormalCursorWrapper ,
90
+ )
91
+ async def test_cursor_wrapper_async (self , mock_wrapper ):
92
+ await sync_to_async (sql_call )()
93
+
94
+ self .assertEqual (mock_wrapper .call_count , 1 )
95
+
96
+ @patch (
97
+ "debug_toolbar.panels.sql.tracking.NormalCursorWrapper" ,
98
+ wraps = sql_tracking .NormalCursorWrapper ,
99
+ )
100
+ async def test_cursor_wrapper_asyncio_ctx (self , mock_wrapper ):
101
+ self .assertTrue (sql_tracking .recording .get ())
102
+ await sync_to_async (sql_call )()
103
+
104
+ async def task ():
105
+ sql_tracking .recording .set (False )
106
+ # Calling this in another context requires the db connections
107
+ # to be closed properly.
108
+ await database_sync_to_async (sql_call )()
109
+
110
+ # Ensure this is called in another context
111
+ await asyncio .create_task (task ())
112
+ # Because it was called in another context, it should not have affected ours
113
+ self .assertTrue (sql_tracking .recording .get ())
114
+ self .assertEqual (mock_wrapper .call_count , 1 )
70
115
71
116
def test_generate_server_timing (self ):
72
117
self .assertEqual (len (self .panel ._queries ), 0 )
73
118
74
- list ( User . objects . all () )
119
+ sql_call ( )
75
120
76
121
response = self .panel .process_request (self .request )
77
122
self .panel .generate_stats (self .request , response )
@@ -337,7 +382,7 @@ def test_disable_stacktraces(self):
337
382
self .assertEqual (len (self .panel ._queries ), 0 )
338
383
339
384
with self .settings (DEBUG_TOOLBAR_CONFIG = {"ENABLE_STACKTRACES" : False }):
340
- list ( User . objects . all () )
385
+ sql_call ( )
341
386
342
387
# ensure query was logged
343
388
self .assertEqual (len (self .panel ._queries ), 1 )
0 commit comments