Skip to content

Commit 37f7b76

Browse files
committed
Merge pull request #3785 from sheppard/authtoken-import
don't import authtoken model until needed
2 parents dceb686 + 4f40714 commit 37f7b76

File tree

3 files changed

+50
-22
lines changed

3 files changed

+50
-22
lines changed

rest_framework/authentication.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from django.utils.translation import ugettext_lazy as _
1111

1212
from rest_framework import HTTP_HEADER_ENCODING, exceptions
13-
from rest_framework.authtoken.models import Token
1413

1514

1615
def get_authorization_header(request):
@@ -149,7 +148,14 @@ class TokenAuthentication(BaseAuthentication):
149148
Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a
150149
"""
151150

152-
model = Token
151+
model = None
152+
153+
def get_model(self):
154+
if self.model is not None:
155+
return self.model
156+
from rest_framework.authtoken.models import Token
157+
return Token
158+
153159
"""
154160
A custom token model may be used, but must have the following properties.
155161
@@ -179,9 +185,10 @@ def authenticate(self, request):
179185
return self.authenticate_credentials(token)
180186

181187
def authenticate_credentials(self, key):
188+
model = self.get_model()
182189
try:
183-
token = self.model.objects.select_related('user').get(key=key)
184-
except self.model.DoesNotExist:
190+
token = model.objects.select_related('user').get(key=key)
191+
except model.DoesNotExist:
185192
raise exceptions.AuthenticationFailed(_('Invalid token.'))
186193

187194
if not token.user.is_active:

rest_framework/authtoken/models.py

-8
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,6 @@ class Token(models.Model):
2121
user = models.OneToOneField(AUTH_USER_MODEL, related_name='auth_token')
2222
created = models.DateTimeField(auto_now_add=True)
2323

24-
class Meta:
25-
# Work around for a bug in Django:
26-
# https://code.djangoproject.com/ticket/19422
27-
#
28-
# Also see corresponding ticket:
29-
# https://github.com/tomchristie/django-rest-framework/issues/705
30-
abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS
31-
3224
def save(self, *args, **kwargs):
3325
if not self.key:
3426
self.key = self.generate_key()

tests/test_authentication.py

+39-10
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from django.conf.urls import include, url
88
from django.contrib.auth.models import User
9+
from django.db import models
910
from django.http import HttpResponse
1011
from django.test import TestCase
1112
from django.utils import six
@@ -25,6 +26,15 @@
2526
factory = APIRequestFactory()
2627

2728

29+
class CustomToken(models.Model):
30+
key = models.CharField(max_length=40, primary_key=True)
31+
user = models.OneToOneField(User)
32+
33+
34+
class CustomTokenAuthentication(TokenAuthentication):
35+
model = CustomToken
36+
37+
2838
class MockView(APIView):
2939
permission_classes = (permissions.IsAuthenticated,)
3040

@@ -42,6 +52,7 @@ def put(self, request):
4252
url(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
4353
url(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
4454
url(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
55+
url(r'^customtoken/$', MockView.as_view(authentication_classes=[CustomTokenAuthentication])),
4556
url(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
4657
url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
4758
]
@@ -142,9 +153,11 @@ def test_post_form_session_auth_failing(self):
142153
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
143154

144155

145-
class TokenAuthTests(TestCase):
156+
class BaseTokenAuthTests(object):
146157
"""Token authentication"""
147158
urls = 'tests.test_authentication'
159+
model = None
160+
path = None
148161

149162
def setUp(self):
150163
self.csrf_client = APIClient(enforce_csrf_checks=True)
@@ -154,54 +167,65 @@ def setUp(self):
154167
self.user = User.objects.create_user(self.username, self.email, self.password)
155168

156169
self.key = 'abcd1234'
157-
self.token = Token.objects.create(key=self.key, user=self.user)
170+
self.token = self.model.objects.create(key=self.key, user=self.user)
158171

159172
def test_post_form_passing_token_auth(self):
160173
"""Ensure POSTing json over token auth with correct credentials passes and does not require CSRF"""
161174
auth = 'Token ' + self.key
162-
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
175+
response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
163176
self.assertEqual(response.status_code, status.HTTP_200_OK)
164177

178+
def test_fail_post_form_passing_nonexistent_token_auth(self):
179+
# use a nonexistent token key
180+
auth = 'Token wxyz6789'
181+
response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
182+
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
183+
165184
def test_fail_post_form_passing_invalid_token_auth(self):
166185
# add an 'invalid' unicode character
167186
auth = 'Token ' + self.key + "¸"
168-
response = self.csrf_client.post('/token/', {'example': 'example'}, HTTP_AUTHORIZATION=auth)
187+
response = self.csrf_client.post(self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth)
169188
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
170189

171190
def test_post_json_passing_token_auth(self):
172191
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
173192
auth = "Token " + self.key
174-
response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
193+
response = self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
175194
self.assertEqual(response.status_code, status.HTTP_200_OK)
176195

177196
def test_post_json_makes_one_db_query(self):
178197
"""Ensure that authenticating a user using a token performs only one DB query"""
179198
auth = "Token " + self.key
180199

181200
def func_to_test():
182-
return self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
201+
return self.csrf_client.post(self.path, {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
183202

184203
self.assertNumQueries(1, func_to_test)
185204

186205
def test_post_form_failing_token_auth(self):
187206
"""Ensure POSTing form over token auth without correct credentials fails"""
188-
response = self.csrf_client.post('/token/', {'example': 'example'})
207+
response = self.csrf_client.post(self.path, {'example': 'example'})
189208
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
190209

191210
def test_post_json_failing_token_auth(self):
192211
"""Ensure POSTing json over token auth without correct credentials fails"""
193-
response = self.csrf_client.post('/token/', {'example': 'example'}, format='json')
212+
response = self.csrf_client.post(self.path, {'example': 'example'}, format='json')
194213
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
195214

215+
216+
class TokenAuthTests(BaseTokenAuthTests, TestCase):
217+
model = Token
218+
path = '/token/'
219+
196220
def test_token_has_auto_assigned_key_if_none_provided(self):
197221
"""Ensure creating a token with no key will auto-assign a key"""
198222
self.token.delete()
199-
token = Token.objects.create(user=self.user)
223+
token = self.model.objects.create(user=self.user)
200224
self.assertTrue(bool(token.key))
201225

202226
def test_generate_key_returns_string(self):
203227
"""Ensure generate_key returns a string"""
204-
token = Token()
228+
token = self.model()
205229
key = token.generate_key()
206230
self.assertTrue(isinstance(key, six.string_types))
207231

@@ -236,6 +260,11 @@ def test_token_login_form(self):
236260
self.assertEqual(response.data['token'], self.key)
237261

238262

263+
class CustomTokenAuthTests(BaseTokenAuthTests, TestCase):
264+
model = CustomToken
265+
path = '/customtoken/'
266+
267+
239268
class IncorrectCredentialsTests(TestCase):
240269
def test_incorrect_credentials(self):
241270
"""

0 commit comments

Comments
 (0)