Skip to content

Commit 0c97dd1

Browse files
committed
Merge remote-tracking branch 'reference/master'
2 parents 8fa3284 + 1957679 commit 0c97dd1

File tree

4 files changed

+134
-2
lines changed

4 files changed

+134
-2
lines changed

rest_framework/compat.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from __future__ import unicode_literals
88
from django.core.exceptions import ImproperlyConfigured
99
from django.conf import settings
10+
from django.db import connection, transaction
1011
from django.utils.encoding import force_text
1112
from django.utils.six.moves.urllib.parse import urlparse as _urlparse
1213
from django.utils import six
@@ -266,3 +267,19 @@ def apply_markdown(text):
266267
from django.utils.duration import duration_string
267268
else:
268269
DurationField = duration_string = parse_duration = None
270+
271+
272+
def set_rollback():
273+
if hasattr(transaction, 'set_rollback'):
274+
if connection.settings_dict.get('ATOMIC_REQUESTS', False):
275+
# If running in >=1.6 then mark a rollback as required,
276+
# and allow it to be handled by Django.
277+
transaction.set_rollback(True)
278+
elif transaction.is_managed():
279+
# Otherwise handle it explicitly if in managed mode.
280+
if transaction.is_dirty():
281+
transaction.rollback()
282+
transaction.leave_transaction_management()
283+
else:
284+
# transaction not managed
285+
pass

rest_framework/views.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from django.utils.translation import ugettext_lazy as _
1010
from django.views.decorators.csrf import csrf_exempt
1111
from rest_framework import status, exceptions
12-
from rest_framework.compat import HttpResponseBase, View
12+
from rest_framework.compat import HttpResponseBase, View, set_rollback
1313
from rest_framework.request import Request
1414
from rest_framework.response import Response
1515
from rest_framework.settings import api_settings
@@ -71,16 +71,21 @@ def exception_handler(exc, context):
7171
else:
7272
data = {'detail': exc.detail}
7373

74+
set_rollback()
7475
return Response(data, status=exc.status_code, headers=headers)
7576

7677
elif isinstance(exc, Http404):
7778
msg = _('Not found.')
7879
data = {'detail': six.text_type(msg)}
80+
81+
set_rollback()
7982
return Response(data, status=status.HTTP_404_NOT_FOUND)
8083

8184
elif isinstance(exc, PermissionDenied):
8285
msg = _('Permission denied.')
8386
data = {'detail': six.text_type(msg)}
87+
88+
set_rollback()
8489
return Response(data, status=status.HTTP_403_FORBIDDEN)
8590

8691
# Note: Unhandled exceptions will raise a 500 error.

tests/test_atomic_requests.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
from __future__ import unicode_literals
2+
3+
from django.db import connection, connections, transaction
4+
from django.test import TestCase
5+
from django.utils.unittest import skipUnless
6+
from rest_framework import status
7+
from rest_framework.exceptions import APIException
8+
from rest_framework.response import Response
9+
from rest_framework.test import APIRequestFactory
10+
from rest_framework.views import APIView
11+
from tests.models import BasicModel
12+
13+
14+
factory = APIRequestFactory()
15+
16+
17+
class BasicView(APIView):
18+
def post(self, request, *args, **kwargs):
19+
BasicModel.objects.create()
20+
return Response({'method': 'GET'})
21+
22+
23+
class ErrorView(APIView):
24+
def post(self, request, *args, **kwargs):
25+
BasicModel.objects.create()
26+
raise Exception
27+
28+
29+
class APIExceptionView(APIView):
30+
def post(self, request, *args, **kwargs):
31+
BasicModel.objects.create()
32+
raise APIException
33+
34+
35+
@skipUnless(connection.features.uses_savepoints,
36+
"'atomic' requires transactions and savepoints.")
37+
class DBTransactionTests(TestCase):
38+
def setUp(self):
39+
self.view = BasicView.as_view()
40+
connections.databases['default']['ATOMIC_REQUESTS'] = True
41+
42+
def tearDown(self):
43+
connections.databases['default']['ATOMIC_REQUESTS'] = False
44+
45+
def test_no_exception_conmmit_transaction(self):
46+
request = factory.post('/')
47+
48+
with self.assertNumQueries(1):
49+
response = self.view(request)
50+
self.assertFalse(transaction.get_rollback())
51+
self.assertEqual(response.status_code, status.HTTP_200_OK)
52+
assert BasicModel.objects.count() == 1
53+
54+
55+
@skipUnless(connection.features.uses_savepoints,
56+
"'atomic' requires transactions and savepoints.")
57+
class DBTransactionErrorTests(TestCase):
58+
def setUp(self):
59+
self.view = ErrorView.as_view()
60+
connections.databases['default']['ATOMIC_REQUESTS'] = True
61+
62+
def tearDown(self):
63+
connections.databases['default']['ATOMIC_REQUESTS'] = False
64+
65+
def test_generic_exception_delegate_transaction_management(self):
66+
"""
67+
Transaction is eventually managed by outer-most transaction atomic
68+
block. DRF do not try to interfere here.
69+
70+
We let django deal with the transaction when it will catch the Exception.
71+
"""
72+
request = factory.post('/')
73+
with self.assertNumQueries(3):
74+
# 1 - begin savepoint
75+
# 2 - insert
76+
# 3 - release savepoint
77+
with transaction.atomic():
78+
self.assertRaises(Exception, self.view, request)
79+
self.assertFalse(transaction.get_rollback())
80+
assert BasicModel.objects.count() == 1
81+
82+
83+
@skipUnless(connection.features.uses_savepoints,
84+
"'atomic' requires transactions and savepoints.")
85+
class DBTransactionAPIExceptionTests(TestCase):
86+
def setUp(self):
87+
self.view = APIExceptionView.as_view()
88+
connections.databases['default']['ATOMIC_REQUESTS'] = True
89+
90+
def tearDown(self):
91+
connections.databases['default']['ATOMIC_REQUESTS'] = False
92+
93+
def test_api_exception_rollback_transaction(self):
94+
"""
95+
Transaction is rollbacked by our transaction atomic block.
96+
"""
97+
request = factory.post('/')
98+
num_queries = (4 if getattr(connection.features,
99+
'can_release_savepoints', False) else 3)
100+
with self.assertNumQueries(num_queries):
101+
# 1 - begin savepoint
102+
# 2 - insert
103+
# 3 - rollback savepoint
104+
# 4 - release savepoint (django>=1.8 only)
105+
with transaction.atomic():
106+
response = self.view(request)
107+
self.assertTrue(transaction.get_rollback())
108+
self.assertEqual(response.status_code,
109+
status.HTTP_500_INTERNAL_SERVER_ERROR)
110+
assert BasicModel.objects.count() == 0

tox.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ envlist =
66
{py27,py32,py33,py34}-django{17,18,master}
77

88
[testenv]
9-
commands = ./runtests.py --fast
9+
commands = ./runtests.py --fast {posargs}
1010
setenv =
1111
PYTHONDONTWRITEBYTECODE=1
1212
deps =

0 commit comments

Comments
 (0)