Skip to content

Get PostgreSQL transaction tracking working #1619

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
May 24, 2022
Merged
82 changes: 41 additions & 41 deletions debug_toolbar/panels/sql/panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,35 +61,29 @@ def __init__(self, *args, **kwargs):
self._num_queries = 0
self._queries = []
self._databases = {}
self._transaction_status = {}
# synthetic transaction IDs, keyed by DB alias
self._transaction_ids = {}

def get_transaction_id(self, alias):
if alias not in connections:
return
conn = connections[alias].connection
if not conn:
return

if conn.vendor == "postgresql":
cur_status = conn.get_transaction_status()
else:
raise ValueError(conn.vendor)

last_status = self._transaction_status.get(alias)
self._transaction_status[alias] = cur_status

if not cur_status:
# No available state
return None

if cur_status != last_status:
if cur_status:
self._transaction_ids[alias] = uuid.uuid4().hex
else:
self._transaction_ids[alias] = None

return self._transaction_ids[alias]
def new_transaction_id(self, alias):
"""
Generate and return a new synthetic transaction ID for the specified DB alias.
"""
trans_id = uuid.uuid4().hex
self._transaction_ids[alias] = trans_id
return trans_id

def current_transaction_id(self, alias):
"""
Return the current synthetic transaction ID for the specified DB alias.
"""
trans_id = self._transaction_ids.get(alias)
# Sometimes it is not possible to detect the beginning of the first transaction,
# so current_transaction_id() will be called before new_transaction_id(). In
# that case there won't yet be a transaction ID. so it is necessary to generate
# one using new_transaction_id().
if trans_id is None:
trans_id = self.new_transaction_id(alias)
return trans_id

def record(self, alias, **kwargs):
self._queries.append((alias, kwargs))
Expand Down Expand Up @@ -184,23 +178,25 @@ def duplicate_key(query):
rgb[nn] = nc
db["rgb_color"] = rgb

trans_ids = {}
trans_id = None
i = 0
# the last query recorded for each DB alias
last_by_alias = {}
for alias, query in self._queries:
query_similar[alias][similar_key(query)] += 1
query_duplicates[alias][duplicate_key(query)] += 1

trans_id = query.get("trans_id")
last_trans_id = trans_ids.get(alias)

if trans_id != last_trans_id:
if last_trans_id:
self._queries[(i - 1)][1]["ends_trans"] = True
trans_ids[alias] = trans_id
if trans_id:
prev_query = last_by_alias.get(alias, {})
prev_trans_id = prev_query.get("trans_id")

# If two consecutive queries for a given DB alias have different
# transaction ID values, a transaction started, finished, or both, so
# annotate the queries as appropriate.
if trans_id != prev_trans_id:
if prev_trans_id is not None:
prev_query["ends_trans"] = True
if trans_id is not None:
query["starts_trans"] = True
if trans_id:
if trans_id is not None:
query["in_trans"] = True

query["alias"] = alias
Expand Down Expand Up @@ -228,12 +224,16 @@ def duplicate_key(query):
query["end_offset"] = query["width_ratio"] + query["start_offset"]
width_ratio_tally += query["width_ratio"]
query["stacktrace"] = render_stacktrace(query["stacktrace"])
i += 1

query["trace_color"] = trace_colors[query["stacktrace"]]

if trans_id:
self._queries[(i - 1)][1]["ends_trans"] = True
last_by_alias[alias] = query

# Close out any transactions that were in progress, since there is no
# explicit way to know when a transaction finishes.
for final_query in last_by_alias.values():
if final_query.get("trans_id") is not None:
final_query["ends_trans"] = True

# Queries are similar / duplicates only if there's as least 2 of them.
# Also, to hide queries, we need to give all the duplicate groups an id
Expand Down
54 changes: 46 additions & 8 deletions debug_toolbar/panels/sql/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

try:
from psycopg2._json import Json as PostgresJson
from psycopg2.extensions import STATUS_IN_TRANSACTION
except ImportError:
PostgresJson = None
STATUS_IN_TRANSACTION = None

# Prevents SQL queries from being sent to the DB. It's used
# by the TemplatePanel to prevent the toolbar from issuing
Expand Down Expand Up @@ -60,9 +62,22 @@ def chunked_cursor(*args, **kwargs):

def unwrap_cursor(connection):
if hasattr(connection, "_djdt_cursor"):
# Sometimes the cursor()/chunked_cursor() methods of the DatabaseWrapper
# instance are already monkey patched before wrap_cursor() is called. (In
# particular, Django's SimpleTestCase monkey patches those methods for any
# disallowed databases to raise an exception if they are accessed.) Thus only
# delete our monkey patch if the method we saved is the same as the class
# method. Otherwise, restore the prior monkey patch from our saved method.
if connection._djdt_cursor == connection.__class__.cursor:
del connection.cursor
else:
connection.cursor = connection._djdt_cursor
del connection._djdt_cursor
del connection.cursor
del connection.chunked_cursor
if connection._djdt_chunked_cursor == connection.__class__.chunked_cursor:
del connection.chunked_cursor
else:
connection.chunked_cursor = connection._djdt_chunked_cursor
del connection._djdt_chunked_cursor


class BaseCursorWrapper:
Expand Down Expand Up @@ -126,6 +141,14 @@ def _decode(self, param):
return "(encoded string)"

def _record(self, method, sql, params):
alias = self.db.alias
vendor = self.db.vendor

if vendor == "postgresql":
# The underlying DB connection (as opposed to Django's wrapper)
conn = self.db.connection
initial_conn_status = conn.status

start_time = time()
try:
return method(sql, params)
Expand All @@ -143,10 +166,6 @@ def _record(self, method, sql, params):
pass # object not JSON serializable
template_info = get_template_info()

alias = getattr(self.db, "alias", "default")
conn = self.db.connection
vendor = getattr(conn, "vendor", "unknown")

# Sql might be an object (such as psycopg Composed).
# For logging purposes, make sure it's str.
sql = str(sql)
Expand Down Expand Up @@ -177,12 +196,31 @@ def _record(self, method, sql, params):
iso_level = conn.isolation_level
except conn.InternalError:
iso_level = "unknown"
# PostgreSQL does not expose any sort of transaction ID, so it is
# necessary to generate synthetic transaction IDs here. If the
# connection was not in a transaction when the query started, and was
# after the query finished, a new transaction definitely started, so get
# a new transaction ID from logger.new_transaction_id(). If the query
# was in a transaction both before and after executing, make the
# assumption that it is the same transaction and get the current
# transaction ID from logger.current_transaction_id(). There is an edge
# case where Django can start a transaction before the first query
# executes, so in that case logger.current_transaction_id() will
# generate a new transaction ID since one does not already exist.
final_conn_status = conn.status
if final_conn_status == STATUS_IN_TRANSACTION:
if initial_conn_status == STATUS_IN_TRANSACTION:
trans_id = self.logger.current_transaction_id(alias)
else:
trans_id = self.logger.new_transaction_id(alias)
else:
trans_id = None

params.update(
{
"trans_id": self.logger.get_transaction_id(alias),
"trans_id": trans_id,
"trans_status": conn.get_transaction_status(),
"iso_level": iso_level,
"encoding": conn.encoding,
}
)

Expand Down
12 changes: 10 additions & 2 deletions tests/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import html5lib
from asgiref.local import Local
from django.http import HttpResponse
from django.test import Client, RequestFactory, TestCase
from django.test import Client, RequestFactory, TestCase, TransactionTestCase

from debug_toolbar.toolbar import DebugToolbar

Expand Down Expand Up @@ -29,7 +29,7 @@ def handle_toolbar_created(sender, toolbar=None, **kwargs):
rf = RequestFactory()


class BaseTestCase(TestCase):
class BaseMixin:
client_class = ToolbarTestClient

panel_id = None
Expand Down Expand Up @@ -67,6 +67,14 @@ def assertValidHTML(self, content):
raise self.failureException("\n".join(msg_parts))


class BaseTestCase(BaseMixin, TestCase):
pass


class BaseMultiDBTestCase(BaseMixin, TransactionTestCase):
databases = {"default", "replica"}


class IntegrationTestCase(TestCase):
"""Base TestCase for tests involving clients making requests."""

Expand Down
115 changes: 113 additions & 2 deletions tests/panels/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import django
from asgiref.sync import sync_to_async
from django.contrib.auth.models import User
from django.db import connection
from django.db import connection, transaction
from django.db.models import Count
from django.db.utils import DatabaseError
from django.shortcuts import render
Expand All @@ -16,7 +16,7 @@
import debug_toolbar.panels.sql.tracking as sql_tracking
from debug_toolbar import settings as dt_settings

from ..base import BaseTestCase
from ..base import BaseMultiDBTestCase, BaseTestCase
from ..models import PostgresJSON


Expand Down Expand Up @@ -506,3 +506,114 @@ def test_nested_template_information(self):
self.assertEqual(template_name, "included.html")
self.assertEqual(template_info["context"][0]["content"].strip(), "{{ users }}")
self.assertEqual(template_info["context"][0]["highlight"], True)


class SQLPanelMultiDBTestCase(BaseMultiDBTestCase):
panel_id = "SQLPanel"

def test_aliases(self):
self.assertFalse(self.panel._queries)

list(User.objects.all())
list(User.objects.using("replica").all())

response = self.panel.process_request(self.request)
self.panel.generate_stats(self.request, response)

self.assertTrue(self.panel._queries)

query = self.panel._queries[0]
self.assertEqual(query[0], "default")

query = self.panel._queries[-1]
self.assertEqual(query[0], "replica")

def test_transaction_status(self):
"""
Test case for tracking the transaction status is properly associated with
queries on PostgreSQL, and that transactions aren't broken on other database
engines.
"""
self.assertEqual(len(self.panel._queries), 0)

with transaction.atomic():
list(User.objects.all())
list(User.objects.using("replica").all())

with transaction.atomic(using="replica"):
list(User.objects.all())
list(User.objects.using("replica").all())

with transaction.atomic():
list(User.objects.all())

list(User.objects.using("replica").all())

response = self.panel.process_request(self.request)
self.panel.generate_stats(self.request, response)

if connection.vendor == "postgresql":
# Connection tracking is currently only implemented for PostgreSQL.
self.assertEqual(len(self.panel._queries), 6)

query = self.panel._queries[0]
self.assertEqual(query[0], "default")
self.assertIsNotNone(query[1]["trans_id"])
self.assertTrue(query[1]["starts_trans"])
self.assertTrue(query[1]["in_trans"])
self.assertFalse("end_trans" in query[1])

query = self.panel._queries[-1]
self.assertEqual(query[0], "replica")
self.assertIsNone(query[1]["trans_id"])
self.assertFalse("starts_trans" in query[1])
self.assertFalse("in_trans" in query[1])
self.assertFalse("end_trans" in query[1])

query = self.panel._queries[2]
self.assertEqual(query[0], "default")
self.assertIsNotNone(query[1]["trans_id"])
self.assertEqual(
query[1]["trans_id"], self.panel._queries[0][1]["trans_id"]
)
self.assertFalse("starts_trans" in query[1])
self.assertTrue(query[1]["in_trans"])
self.assertTrue(query[1]["ends_trans"])

query = self.panel._queries[3]
self.assertEqual(query[0], "replica")
self.assertIsNotNone(query[1]["trans_id"])
self.assertNotEqual(
query[1]["trans_id"], self.panel._queries[0][1]["trans_id"]
)
self.assertTrue(query[1]["starts_trans"])
self.assertTrue(query[1]["in_trans"])
self.assertTrue(query[1]["ends_trans"])

query = self.panel._queries[4]
self.assertEqual(query[0], "default")
self.assertIsNotNone(query[1]["trans_id"])
self.assertNotEqual(
query[1]["trans_id"], self.panel._queries[0][1]["trans_id"]
)
self.assertNotEqual(
query[1]["trans_id"], self.panel._queries[3][1]["trans_id"]
)
self.assertTrue(query[1]["starts_trans"])
self.assertTrue(query[1]["in_trans"])
self.assertTrue(query[1]["ends_trans"])

query = self.panel._queries[5]
self.assertEqual(query[0], "replica")
self.assertIsNone(query[1]["trans_id"])
self.assertFalse("starts_trans" in query[1])
self.assertFalse("in_trans" in query[1])
self.assertFalse("end_trans" in query[1])
else:
# Ensure that nothing was recorded for other database engines.
self.assertTrue(self.panel._queries)
for query in self.panel._queries:
self.assertFalse("trans_id" in query[1])
self.assertFalse("starts_trans" in query[1])
self.assertFalse("in_trans" in query[1])
self.assertFalse("end_trans" in query[1])
14 changes: 14 additions & 0 deletions tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,20 @@
"USER": "default_test",
},
},
"replica": {
"ENGINE": "django.{}db.backends.{}".format(
"contrib.gis." if USE_GIS else "", os.getenv("DB_BACKEND", "sqlite3")
),
"NAME": os.getenv("DB_NAME", ":memory:"),
"USER": os.getenv("DB_USER"),
"PASSWORD": os.getenv("DB_PASSWORD"),
"HOST": os.getenv("DB_HOST", ""),
"PORT": os.getenv("DB_PORT", ""),
"TEST": {
"USER": "default_test",
"MIRROR": "default",
},
},
}

DEFAULT_AUTO_FIELD = "django.db.models.AutoField"
Expand Down
Loading