diff --git a/debug_toolbar/panels/sql/panel.py b/debug_toolbar/panels/sql/panel.py index 00737a42d..23f453567 100644 --- a/debug_toolbar/panels/sql/panel.py +++ b/debug_toolbar/panels/sql/panel.py @@ -61,35 +61,29 @@ def __init__(self, *args, **kwargs): self._num_queries = 0 self._queries = [] self._databases = {} - self._transaction_status = {} + # synthetic transaction IDs, keyed by DB alias self._transaction_ids = {} - def get_transaction_id(self, alias): - if alias not in connections: - return - conn = connections[alias].connection - if not conn: - return - - if conn.vendor == "postgresql": - cur_status = conn.get_transaction_status() - else: - raise ValueError(conn.vendor) - - last_status = self._transaction_status.get(alias) - self._transaction_status[alias] = cur_status - - if not cur_status: - # No available state - return None - - if cur_status != last_status: - if cur_status: - self._transaction_ids[alias] = uuid.uuid4().hex - else: - self._transaction_ids[alias] = None - - return self._transaction_ids[alias] + def new_transaction_id(self, alias): + """ + Generate and return a new synthetic transaction ID for the specified DB alias. + """ + trans_id = uuid.uuid4().hex + self._transaction_ids[alias] = trans_id + return trans_id + + def current_transaction_id(self, alias): + """ + Return the current synthetic transaction ID for the specified DB alias. + """ + trans_id = self._transaction_ids.get(alias) + # Sometimes it is not possible to detect the beginning of the first transaction, + # so current_transaction_id() will be called before new_transaction_id(). In + # that case there won't yet be a transaction ID. so it is necessary to generate + # one using new_transaction_id(). + if trans_id is None: + trans_id = self.new_transaction_id(alias) + return trans_id def record(self, alias, **kwargs): self._queries.append((alias, kwargs)) @@ -184,23 +178,25 @@ def duplicate_key(query): rgb[nn] = nc db["rgb_color"] = rgb - trans_ids = {} - trans_id = None - i = 0 + # the last query recorded for each DB alias + last_by_alias = {} for alias, query in self._queries: query_similar[alias][similar_key(query)] += 1 query_duplicates[alias][duplicate_key(query)] += 1 trans_id = query.get("trans_id") - last_trans_id = trans_ids.get(alias) - - if trans_id != last_trans_id: - if last_trans_id: - self._queries[(i - 1)][1]["ends_trans"] = True - trans_ids[alias] = trans_id - if trans_id: + prev_query = last_by_alias.get(alias, {}) + prev_trans_id = prev_query.get("trans_id") + + # If two consecutive queries for a given DB alias have different + # transaction ID values, a transaction started, finished, or both, so + # annotate the queries as appropriate. + if trans_id != prev_trans_id: + if prev_trans_id is not None: + prev_query["ends_trans"] = True + if trans_id is not None: query["starts_trans"] = True - if trans_id: + if trans_id is not None: query["in_trans"] = True query["alias"] = alias @@ -228,12 +224,16 @@ def duplicate_key(query): query["end_offset"] = query["width_ratio"] + query["start_offset"] width_ratio_tally += query["width_ratio"] query["stacktrace"] = render_stacktrace(query["stacktrace"]) - i += 1 query["trace_color"] = trace_colors[query["stacktrace"]] - if trans_id: - self._queries[(i - 1)][1]["ends_trans"] = True + last_by_alias[alias] = query + + # Close out any transactions that were in progress, since there is no + # explicit way to know when a transaction finishes. + for final_query in last_by_alias.values(): + if final_query.get("trans_id") is not None: + final_query["ends_trans"] = True # Queries are similar / duplicates only if there's as least 2 of them. # Also, to hide queries, we need to give all the duplicate groups an id diff --git a/debug_toolbar/panels/sql/tracking.py b/debug_toolbar/panels/sql/tracking.py index 93304b21f..c479a8b5d 100644 --- a/debug_toolbar/panels/sql/tracking.py +++ b/debug_toolbar/panels/sql/tracking.py @@ -10,8 +10,10 @@ try: from psycopg2._json import Json as PostgresJson + from psycopg2.extensions import STATUS_IN_TRANSACTION except ImportError: PostgresJson = None + STATUS_IN_TRANSACTION = None # Prevents SQL queries from being sent to the DB. It's used # by the TemplatePanel to prevent the toolbar from issuing @@ -60,9 +62,22 @@ def chunked_cursor(*args, **kwargs): def unwrap_cursor(connection): if hasattr(connection, "_djdt_cursor"): + # Sometimes the cursor()/chunked_cursor() methods of the DatabaseWrapper + # instance are already monkey patched before wrap_cursor() is called. (In + # particular, Django's SimpleTestCase monkey patches those methods for any + # disallowed databases to raise an exception if they are accessed.) Thus only + # delete our monkey patch if the method we saved is the same as the class + # method. Otherwise, restore the prior monkey patch from our saved method. + if connection._djdt_cursor == connection.__class__.cursor: + del connection.cursor + else: + connection.cursor = connection._djdt_cursor del connection._djdt_cursor - del connection.cursor - del connection.chunked_cursor + if connection._djdt_chunked_cursor == connection.__class__.chunked_cursor: + del connection.chunked_cursor + else: + connection.chunked_cursor = connection._djdt_chunked_cursor + del connection._djdt_chunked_cursor class BaseCursorWrapper: @@ -126,6 +141,14 @@ def _decode(self, param): return "(encoded string)" def _record(self, method, sql, params): + alias = self.db.alias + vendor = self.db.vendor + + if vendor == "postgresql": + # The underlying DB connection (as opposed to Django's wrapper) + conn = self.db.connection + initial_conn_status = conn.status + start_time = time() try: return method(sql, params) @@ -143,10 +166,6 @@ def _record(self, method, sql, params): pass # object not JSON serializable template_info = get_template_info() - alias = getattr(self.db, "alias", "default") - conn = self.db.connection - vendor = getattr(conn, "vendor", "unknown") - # Sql might be an object (such as psycopg Composed). # For logging purposes, make sure it's str. sql = str(sql) @@ -177,12 +196,31 @@ def _record(self, method, sql, params): iso_level = conn.isolation_level except conn.InternalError: iso_level = "unknown" + # PostgreSQL does not expose any sort of transaction ID, so it is + # necessary to generate synthetic transaction IDs here. If the + # connection was not in a transaction when the query started, and was + # after the query finished, a new transaction definitely started, so get + # a new transaction ID from logger.new_transaction_id(). If the query + # was in a transaction both before and after executing, make the + # assumption that it is the same transaction and get the current + # transaction ID from logger.current_transaction_id(). There is an edge + # case where Django can start a transaction before the first query + # executes, so in that case logger.current_transaction_id() will + # generate a new transaction ID since one does not already exist. + final_conn_status = conn.status + if final_conn_status == STATUS_IN_TRANSACTION: + if initial_conn_status == STATUS_IN_TRANSACTION: + trans_id = self.logger.current_transaction_id(alias) + else: + trans_id = self.logger.new_transaction_id(alias) + else: + trans_id = None + params.update( { - "trans_id": self.logger.get_transaction_id(alias), + "trans_id": trans_id, "trans_status": conn.get_transaction_status(), "iso_level": iso_level, - "encoding": conn.encoding, } ) diff --git a/tests/base.py b/tests/base.py index ccd9f053c..5cc432add 100644 --- a/tests/base.py +++ b/tests/base.py @@ -1,7 +1,7 @@ import html5lib from asgiref.local import Local from django.http import HttpResponse -from django.test import Client, RequestFactory, TestCase +from django.test import Client, RequestFactory, TestCase, TransactionTestCase from debug_toolbar.toolbar import DebugToolbar @@ -29,7 +29,7 @@ def handle_toolbar_created(sender, toolbar=None, **kwargs): rf = RequestFactory() -class BaseTestCase(TestCase): +class BaseMixin: client_class = ToolbarTestClient panel_id = None @@ -67,6 +67,14 @@ def assertValidHTML(self, content): raise self.failureException("\n".join(msg_parts)) +class BaseTestCase(BaseMixin, TestCase): + pass + + +class BaseMultiDBTestCase(BaseMixin, TransactionTestCase): + databases = {"default", "replica"} + + class IntegrationTestCase(TestCase): """Base TestCase for tests involving clients making requests.""" diff --git a/tests/panels/test_sql.py b/tests/panels/test_sql.py index 9824a1bec..40ec83dbb 100644 --- a/tests/panels/test_sql.py +++ b/tests/panels/test_sql.py @@ -7,7 +7,7 @@ import django from asgiref.sync import sync_to_async from django.contrib.auth.models import User -from django.db import connection +from django.db import connection, transaction from django.db.models import Count from django.db.utils import DatabaseError from django.shortcuts import render @@ -16,7 +16,7 @@ import debug_toolbar.panels.sql.tracking as sql_tracking from debug_toolbar import settings as dt_settings -from ..base import BaseTestCase +from ..base import BaseMultiDBTestCase, BaseTestCase from ..models import PostgresJSON @@ -506,3 +506,114 @@ def test_nested_template_information(self): self.assertEqual(template_name, "included.html") self.assertEqual(template_info["context"][0]["content"].strip(), "{{ users }}") self.assertEqual(template_info["context"][0]["highlight"], True) + + +class SQLPanelMultiDBTestCase(BaseMultiDBTestCase): + panel_id = "SQLPanel" + + def test_aliases(self): + self.assertFalse(self.panel._queries) + + list(User.objects.all()) + list(User.objects.using("replica").all()) + + response = self.panel.process_request(self.request) + self.panel.generate_stats(self.request, response) + + self.assertTrue(self.panel._queries) + + query = self.panel._queries[0] + self.assertEqual(query[0], "default") + + query = self.panel._queries[-1] + self.assertEqual(query[0], "replica") + + def test_transaction_status(self): + """ + Test case for tracking the transaction status is properly associated with + queries on PostgreSQL, and that transactions aren't broken on other database + engines. + """ + self.assertEqual(len(self.panel._queries), 0) + + with transaction.atomic(): + list(User.objects.all()) + list(User.objects.using("replica").all()) + + with transaction.atomic(using="replica"): + list(User.objects.all()) + list(User.objects.using("replica").all()) + + with transaction.atomic(): + list(User.objects.all()) + + list(User.objects.using("replica").all()) + + response = self.panel.process_request(self.request) + self.panel.generate_stats(self.request, response) + + if connection.vendor == "postgresql": + # Connection tracking is currently only implemented for PostgreSQL. + self.assertEqual(len(self.panel._queries), 6) + + query = self.panel._queries[0] + self.assertEqual(query[0], "default") + self.assertIsNotNone(query[1]["trans_id"]) + self.assertTrue(query[1]["starts_trans"]) + self.assertTrue(query[1]["in_trans"]) + self.assertFalse("end_trans" in query[1]) + + query = self.panel._queries[-1] + self.assertEqual(query[0], "replica") + self.assertIsNone(query[1]["trans_id"]) + self.assertFalse("starts_trans" in query[1]) + self.assertFalse("in_trans" in query[1]) + self.assertFalse("end_trans" in query[1]) + + query = self.panel._queries[2] + self.assertEqual(query[0], "default") + self.assertIsNotNone(query[1]["trans_id"]) + self.assertEqual( + query[1]["trans_id"], self.panel._queries[0][1]["trans_id"] + ) + self.assertFalse("starts_trans" in query[1]) + self.assertTrue(query[1]["in_trans"]) + self.assertTrue(query[1]["ends_trans"]) + + query = self.panel._queries[3] + self.assertEqual(query[0], "replica") + self.assertIsNotNone(query[1]["trans_id"]) + self.assertNotEqual( + query[1]["trans_id"], self.panel._queries[0][1]["trans_id"] + ) + self.assertTrue(query[1]["starts_trans"]) + self.assertTrue(query[1]["in_trans"]) + self.assertTrue(query[1]["ends_trans"]) + + query = self.panel._queries[4] + self.assertEqual(query[0], "default") + self.assertIsNotNone(query[1]["trans_id"]) + self.assertNotEqual( + query[1]["trans_id"], self.panel._queries[0][1]["trans_id"] + ) + self.assertNotEqual( + query[1]["trans_id"], self.panel._queries[3][1]["trans_id"] + ) + self.assertTrue(query[1]["starts_trans"]) + self.assertTrue(query[1]["in_trans"]) + self.assertTrue(query[1]["ends_trans"]) + + query = self.panel._queries[5] + self.assertEqual(query[0], "replica") + self.assertIsNone(query[1]["trans_id"]) + self.assertFalse("starts_trans" in query[1]) + self.assertFalse("in_trans" in query[1]) + self.assertFalse("end_trans" in query[1]) + else: + # Ensure that nothing was recorded for other database engines. + self.assertTrue(self.panel._queries) + for query in self.panel._queries: + self.assertFalse("trans_id" in query[1]) + self.assertFalse("starts_trans" in query[1]) + self.assertFalse("in_trans" in query[1]) + self.assertFalse("end_trans" in query[1]) diff --git a/tests/settings.py b/tests/settings.py index da5067fbf..b3c281242 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -104,6 +104,20 @@ "USER": "default_test", }, }, + "replica": { + "ENGINE": "django.{}db.backends.{}".format( + "contrib.gis." if USE_GIS else "", os.getenv("DB_BACKEND", "sqlite3") + ), + "NAME": os.getenv("DB_NAME", ":memory:"), + "USER": os.getenv("DB_USER"), + "PASSWORD": os.getenv("DB_PASSWORD"), + "HOST": os.getenv("DB_HOST", ""), + "PORT": os.getenv("DB_PORT", ""), + "TEST": { + "USER": "default_test", + "MIRROR": "default", + }, + }, } DEFAULT_AUTO_FIELD = "django.db.models.AutoField" diff --git a/tox.ini b/tox.ini index f2f62a3a1..212e31f39 100644 --- a/tox.ini +++ b/tox.ini @@ -42,7 +42,7 @@ whitelist_externals = make pip_pre = True commands = python -b -W always -m coverage run -m django test -v2 {posargs:tests} -[testenv:py{37,38,39,310}-dj{40,41,main}-postgresql] +[testenv:py{37,38,39,310}-dj{32,40,41,main}-postgresql] setenv = {[testenv]setenv} DB_BACKEND = postgresql