Skip to content

Commit e070344

Browse files
scardinepaulos
authored andcommitted
Add schema to ObtainAuthToken
Add encoding parameter to ManualSchema
1 parent d12005c commit e070344

File tree

2 files changed

+29
-3
lines changed

2 files changed

+29
-3
lines changed

rest_framework/authtoken/views.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
from rest_framework.authtoken.serializers import AuthTokenSerializer
44
from rest_framework.response import Response
55
from rest_framework.views import APIView
6+
from rest_framework.schemas import ManualSchema
7+
import coreapi
8+
import coreschema
69

710

811
class ObtainAuthToken(APIView):
@@ -11,6 +14,29 @@ class ObtainAuthToken(APIView):
1114
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
1215
renderer_classes = (renderers.JSONRenderer,)
1316
serializer_class = AuthTokenSerializer
17+
schema = ManualSchema(
18+
fields=[
19+
coreapi.Field(
20+
name="username",
21+
required=True,
22+
location='form',
23+
schema=coreschema.String(
24+
title="Username",
25+
description="Valid username for authentication",
26+
),
27+
),
28+
coreapi.Field(
29+
name="password",
30+
required=True,
31+
location='form',
32+
schema=coreschema.String(
33+
title="Password",
34+
description="Valid password for authentication",
35+
),
36+
),
37+
],
38+
encoding="application/json",
39+
)
1440

1541
def post(self, request, *args, **kwargs):
1642
serializer = self.serializer_class(data=request.data,
@@ -20,5 +46,4 @@ def post(self, request, *args, **kwargs):
2046
token, created = Token.objects.get_or_create(user=user)
2147
return Response({'token': token.key})
2248

23-
2449
obtain_auth_token = ObtainAuthToken.as_view()

rest_framework/schemas/inspectors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -432,7 +432,7 @@ class ManualSchema(ViewInspector):
432432
Allows providing a list of coreapi.Fields,
433433
plus an optional description.
434434
"""
435-
def __init__(self, fields, description=''):
435+
def __init__(self, fields, description='', encoding=None):
436436
"""
437437
Parameters:
438438
@@ -442,6 +442,7 @@ def __init__(self, fields, description=''):
442442
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
443443
self._fields = fields
444444
self._description = description
445+
self._encoding = encoding
445446

446447
def get_link(self, path, method, base_url):
447448

@@ -451,7 +452,7 @@ def get_link(self, path, method, base_url):
451452
return coreapi.Link(
452453
url=urlparse.urljoin(base_url, path),
453454
action=method.lower(),
454-
encoding=None,
455+
encoding=self._encoding,
455456
fields=self._fields,
456457
description=self._description
457458
)

0 commit comments

Comments
 (0)