Skip to content
Merged
82 changes: 41 additions & 41 deletions debug_toolbar/panels/sql/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,35 +61,29 @@ def __init__(self, *args, **kwargs):
self._num_queries = 0
self._queries = []
self._databases = {}
self._transaction_status = {}
# synthetic transaction IDs, keyed by DB alias
self._transaction_ids = {}

def get_transaction_id(self, alias):
if alias not in connections:
return
conn = connections[alias].connection
if not conn:
return

if conn.vendor == "postgresql":
cur_status = conn.get_transaction_status()
else:
raise ValueError(conn.vendor)

last_status = self._transaction_status.get(alias)
self._transaction_status[alias] = cur_status

if not cur_status:
# No available state
return None

if cur_status != last_status:
if cur_status:
self._transaction_ids[alias] = uuid.uuid4().hex
else:
self._transaction_ids[alias] = None

return self._transaction_ids[alias]
def new_transaction_id(self, alias):
"""
Generate and return a new synthetic transaction ID for the specified DB alias.
"""
trans_id = uuid.uuid4().hex
self._transaction_ids[alias] = trans_id
return trans_id

def current_transaction_id(self, alias):
"""
Return the current synthetic transaction ID for the specified DB alias.
"""
trans_id = self._transaction_ids.get(alias)
# Sometimes it is not possible to detect the beginning of the first transaction,
# so current_transaction_id() will be called before new_transaction_id(). In
# that case there won't yet be a transaction ID. so it is necessary to generate
# one using new_transaction_id().
if trans_id is None:
trans_id = self.new_transaction_id(alias)
return trans_id

def record(self, alias, **kwargs):
self._queries.append((alias, kwargs))
Expand Down Expand Up @@ -184,23 +178,25 @@ def duplicate_key(query):
rgb[nn] = nc
db["rgb_color"] = rgb

trans_ids = {}
trans_id = None
i = 0
# the last query recorded for each DB alias
last_by_alias = {}
for alias, query in self._queries:
query_similar[alias][similar_key(query)] += 1
query_duplicates[alias][duplicate_key(query)] += 1

trans_id = query.get("trans_id")
last_trans_id = trans_ids.get(alias)

if trans_id != last_trans_id:
if last_trans_id:
self._queries[(i - 1)][1]["ends_trans"] = True
trans_ids[alias] = trans_id
if trans_id:
prev_query = last_by_alias.get(alias, {})
prev_trans_id = prev_query.get("trans_id")

# If two consecutive queries for a given DB alias have different
# transaction ID values, a transaction started, finished, or both, so
# annotate the queries as appropriate.
if trans_id != prev_trans_id:
if prev_trans_id is not None:
prev_query["ends_trans"] = True
if trans_id is not None:
query["starts_trans"] = True
if trans_id:
if trans_id is not None:
query["in_trans"] = True

query["alias"] = alias
Expand Down Expand Up @@ -228,12 +224,16 @@ def duplicate_key(query):
query["end_offset"] = query["width_ratio"] + query["start_offset"]
width_ratio_tally += query["width_ratio"]
query["stacktrace"] = render_stacktrace(query["stacktrace"])
i += 1

query["trace_color"] = trace_colors[query["stacktrace"]]

if trans_id:
self._queries[(i - 1)][1]["ends_trans"] = True
last_by_alias[alias] = query

# Close out any transactions that were in progress, since there is no
# explicit way to know when a transaction finishes.
for final_query in last_by_alias.values():
if final_query.get("trans_id") is not None:
final_query["ends_trans"] = True

# Queries are similar / duplicates only if there's as least 2 of them.
# Also, to hide queries, we need to give all the duplicate groups an id
Expand Down
54 changes: 46 additions & 8 deletions debug_toolbar/panels/sql/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

try:
from psycopg2._json import Json as PostgresJson
from psycopg2.extensions import STATUS_IN_TRANSACTION
except ImportError:
PostgresJson = None
STATUS_IN_TRANSACTION = None

# Prevents SQL queries from being sent to the DB. It's used
# by the TemplatePanel to prevent the toolbar from issuing
Expand Down Expand Up @@ -60,9 +62,22 @@ def chunked_cursor(*args, **kwargs):

def unwrap_cursor(connection):
if hasattr(connection, "_djdt_cursor"):
# Sometimes the cursor()/chunked_cursor() methods of the DatabaseWrapper
# instance are already monkey patched before wrap_cursor() is called. (In
# particular, Django's SimpleTestCase monkey patches those methods for any
# disallowed databases to raise an exception if they are accessed.) Thus only
# delete our monkey patch if the method we saved is the same as the class
# method. Otherwise, restore the prior monkey patch from our saved method.
if connection._djdt_cursor == connection.__class__.cursor:
del connection.cursor
else:
connection.cursor = connection._djdt_cursor
del connection._djdt_cursor
del connection.cursor
del connection.chunked_cursor
if connection._djdt_chunked_cursor == connection.__class__.chunked_cursor:
del connection.chunked_cursor
else:
connection.chunked_cursor = connection._djdt_chunked_cursor
del connection._djdt_chunked_cursor


class BaseCursorWrapper:
Expand Down Expand Up @@ -126,6 +141,14 @@ def _decode(self, param):
return "(encoded string)"

def _record(self, method, sql, params):
alias = self.db.alias
vendor = self.db.vendor

if vendor == "postgresql":
# The underlying DB connection (as opposed to Django's wrapper)
conn = self.db.connection
initial_conn_status = conn.status

start_time = time()
try:
return method(sql, params)
Expand All @@ -143,10 +166,6 @@ def _record(self, method, sql, params):
pass # object not JSON serializable
template_info = get_template_info()

alias = getattr(self.db, "alias", "default")
conn = self.db.connection
vendor = getattr(conn, "vendor", "unknown")

# Sql might be an object (such as psycopg Composed).
# For logging purposes, make sure it's str.
sql = str(sql)
Expand Down Expand Up @@ -177,12 +196,31 @@ def _record(self, method, sql, params):
iso_level = conn.isolation_level
except conn.InternalError:
iso_level = "unknown"
# PostgreSQL does not expose any sort of transaction ID, so it is
# necessary to generate synthetic transaction IDs here. If the
# connection was not in a transaction when the query started, and was
# after the query finished, a new transaction definitely started, so get
# a new transaction ID from logger.new_transaction_id(). If the query
# was in a transaction both before and after executing, make the
# assumption that it is the same transaction and get the current
# transaction ID from logger.current_transaction_id(). There is an edge
# case where Django can start a transaction before the first query
# executes, so in that case logger.current_transaction_id() will
# generate a new transaction ID since one does not already exist.
final_conn_status = conn.status
if final_conn_status == STATUS_IN_TRANSACTION:
if initial_conn_status == STATUS_IN_TRANSACTION:
trans_id = self.logger.current_transaction_id(alias)
else:
trans_id = self.logger.new_transaction_id(alias)
else:
trans_id = None

params.update(
{
"trans_id": self.logger.get_transaction_id(alias),
"trans_id": trans_id,
"trans_status": conn.get_transaction_status(),
"iso_level": iso_level,
"encoding": conn.encoding,
}
)

Expand Down
12 changes: 10 additions & 2 deletions tests/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import html5lib
from asgiref.local import Local
from django.http import HttpResponse
from django.test import Client, RequestFactory, TestCase
from django.test import Client, RequestFactory, TestCase, TransactionTestCase

from debug_toolbar.toolbar import DebugToolbar

Expand Down Expand Up @@ -29,7 +29,7 @@ def handle_toolbar_created(sender, toolbar=None, **kwargs):
rf = RequestFactory()


class BaseTestCase(TestCase):
class BaseMixin:
client_class = ToolbarTestClient

panel_id = None
Expand Down Expand Up @@ -67,6 +67,14 @@ def assertValidHTML(self, content):
raise self.failureException("\n".join(msg_parts))


class BaseTestCase(BaseMixin, TestCase):
pass


class BaseMultiDBTestCase(BaseMixin, TransactionTestCase):
databases = {"default", "replica"}


class IntegrationTestCase(TestCase):
"""Base TestCase for tests involving clients making requests."""

Expand Down
115 changes: 113 additions & 2 deletions tests/panels/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import django
from asgiref.sync import sync_to_async
from django.contrib.auth.models import User
from django.db import connection
from django.db import connection, transaction
from django.db.models import Count
from django.db.utils import DatabaseError
from django.shortcuts import render
Expand All @@ -16,7 +16,7 @@
import debug_toolbar.panels.sql.tracking as sql_tracking
from debug_toolbar import settings as dt_settings

from ..base import BaseTestCase
from ..base import BaseMultiDBTestCase, BaseTestCase
from ..models import PostgresJSON


Expand Down Expand Up @@ -506,3 +506,114 @@ def test_nested_template_information(self):
self.assertEqual(template_name, "included.html")
self.assertEqual(template_info["context"][0]["content"].strip(), "{{ users }}")
self.assertEqual(template_info["context"][0]["highlight"], True)


class SQLPanelMultiDBTestCase(BaseMultiDBTestCase):
panel_id = "SQLPanel"

def test_aliases(self):
self.assertFalse(self.panel._queries)

list(User.objects.all())
list(User.objects.using("replica").all())

response = self.panel.process_request(self.request)
self.panel.generate_stats(self.request, response)

self.assertTrue(self.panel._queries)

query = self.panel._queries[0]
self.assertEqual(query[0], "default")

query = self.panel._queries[-1]
self.assertEqual(query[0], "replica")

def test_transaction_status(self):
"""
Test case for tracking the transaction status is properly associated with
queries on PostgreSQL, and that transactions aren't broken on other database
engines.
"""
self.assertEqual(len(self.panel._queries), 0)

with transaction.atomic():
list(User.objects.all())
list(User.objects.using("replica").all())

with transaction.atomic(using="replica"):
list(User.objects.all())
list(User.objects.using("replica").all())

with transaction.atomic():
list(User.objects.all())

list(User.objects.using("replica").all())

response = self.panel.process_request(self.request)
self.panel.generate_stats(self.request, response)

if connection.vendor == "postgresql":
# Connection tracking is currently only implemented for PostgreSQL.
self.assertEqual(len(self.panel._queries), 6)

query = self.panel._queries[0]
self.assertEqual(query[0], "default")
self.assertIsNotNone(query[1]["trans_id"])
self.assertTrue(query[1]["starts_trans"])
self.assertTrue(query[1]["in_trans"])
self.assertFalse("end_trans" in query[1])

query = self.panel._queries[-1]
self.assertEqual(query[0], "replica")
self.assertIsNone(query[1]["trans_id"])
self.assertFalse("starts_trans" in query[1])
self.assertFalse("in_trans" in query[1])
self.assertFalse("end_trans" in query[1])

query = self.panel._queries[2]
self.assertEqual(query[0], "default")
self.assertIsNotNone(query[1]["trans_id"])
self.assertEqual(
query[1]["trans_id"], self.panel._queries[0][1]["trans_id"]
)
self.assertFalse("starts_trans" in query[1])
self.assertTrue(query[1]["in_trans"])
self.assertTrue(query[1]["ends_trans"])

query = self.panel._queries[3]
self.assertEqual(query[0], "replica")
self.assertIsNotNone(query[1]["trans_id"])
self.assertNotEqual(
query[1]["trans_id"], self.panel._queries[0][1]["trans_id"]
)
self.assertTrue(query[1]["starts_trans"])
self.assertTrue(query[1]["in_trans"])
self.assertTrue(query[1]["ends_trans"])

query = self.panel._queries[4]
self.assertEqual(query[0], "default")
self.assertIsNotNone(query[1]["trans_id"])
self.assertNotEqual(
query[1]["trans_id"], self.panel._queries[0][1]["trans_id"]
)
self.assertNotEqual(
query[1]["trans_id"], self.panel._queries[3][1]["trans_id"]
)
self.assertTrue(query[1]["starts_trans"])
self.assertTrue(query[1]["in_trans"])
self.assertTrue(query[1]["ends_trans"])

query = self.panel._queries[5]
self.assertEqual(query[0], "replica")
self.assertIsNone(query[1]["trans_id"])
self.assertFalse("starts_trans" in query[1])
self.assertFalse("in_trans" in query[1])
self.assertFalse("end_trans" in query[1])
else:
# Ensure that nothing was recorded for other database engines.
self.assertTrue(self.panel._queries)
for query in self.panel._queries:
self.assertFalse("trans_id" in query[1])
self.assertFalse("starts_trans" in query[1])
self.assertFalse("in_trans" in query[1])
self.assertFalse("end_trans" in query[1])
14 changes: 14 additions & 0 deletions tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,20 @@
"USER": "default_test",
},
},
"replica": {
"ENGINE": "django.{}db.backends.{}".format(
"contrib.gis." if USE_GIS else "", os.getenv("DB_BACKEND", "sqlite3")
),
"NAME": os.getenv("DB_NAME", ":memory:"),
"USER": os.getenv("DB_USER"),
"PASSWORD": os.getenv("DB_PASSWORD"),
"HOST": os.getenv("DB_HOST", ""),
"PORT": os.getenv("DB_PORT", ""),
"TEST": {
"USER": "default_test",
"MIRROR": "default",
},
},
}

DEFAULT_AUTO_FIELD = "django.db.models.AutoField"
Expand Down
Loading