Skip to content

Commit fd12150

Browse files
committed
add OAS securitySchemes and security objects
1 parent 3875d32 commit fd12150

File tree

4 files changed

+219
-1
lines changed

4 files changed

+219
-1
lines changed

docs/api-guide/schemas.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,6 +389,26 @@ differentiate between request and response objects.
389389
By default returns `get_serializer()` but can be overridden to
390390
differentiate between request and response objects.
391391

392+
#### `get_security_schemes()`
393+
394+
Generates the OpenAPI `securitySchemes` components based on:
395+
- Your default `authentication_classes` (`settings.DEFAULT_AUTHENTICATION_CLASSES`)
396+
- Per-view non-default `authentication_classes`
397+
398+
These are generated using the authentication classes' `openapi_security_scheme()` class method. If you
399+
extend `BaseAuthentication` with your own authentication class, you can add this class method to return
400+
the appropriate security scheme object.
401+
402+
#### `get_security_requirements()`
403+
404+
Root-level security requirements (the top-level `security` object) are generated based on the
405+
default authentication classes. Operation-level security requirements are generated only if the given view's
406+
`authentication_classes` differ from the defaults.
407+
408+
These are generated using the authentication classes' `openapi_security_requirement()` class
409+
method. If you extended `BaseAuthentication` with your own authentication class, you can add this
410+
class method to return the appropriate list of security requirements objects.
411+
392412
### `AutoSchema.__init__()` kwargs
393413

394414
`AutoSchema` provides a number of `__init__()` kwargs that can be used for

rest_framework/authentication.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,32 @@ def authenticate_header(self, request):
4949
"""
5050
pass
5151

52+
#: Name of openapi security scheme. Override if you want to customize it.
53+
openapi_security_scheme_name = None
54+
55+
@classmethod
56+
def openapi_security_scheme(cls):
57+
"""
58+
Override this to return an Open API Specification `securityScheme object
59+
<http://spec.openapis.org/oas/v3.0.3#security-scheme-object>`_
60+
"""
61+
return {}
62+
63+
@classmethod
64+
def openapi_security_requirement(cls, view, method):
65+
"""
66+
Override this to return an Open API Specification `security requirement object
67+
<http://spec.openapis.org/oas/v3.0.3#security-requirement-object>`_
68+
69+
:param view: used to find view attributes used by a permission class or None for root-level
70+
:param method: used to distinguish among method-specific permissions or None for root-level
71+
:return:list: [security requirement objects]
72+
"""
73+
# At this point, none of the built-in DRF authentication classes fill in the
74+
# requirement list: OAuth2/OIDC are the only security types that currently uses the list
75+
# (for scopes). See http://spec.openapis.org/oas/v3.0.3#patterned-fields-2.
76+
return [{}]
77+
5278

5379
class BasicAuthentication(BaseAuthentication):
5480
"""
@@ -108,6 +134,22 @@ def authenticate_credentials(self, userid, password, request=None):
108134
def authenticate_header(self, request):
109135
return 'Basic realm="%s"' % self.www_authenticate_realm
110136

137+
openapi_security_scheme_name = 'basicAuth'
138+
139+
@classmethod
140+
def openapi_security_scheme(cls):
141+
return {
142+
cls.openapi_security_scheme_name: {
143+
'type': 'http',
144+
'scheme': 'basic',
145+
'description': 'Basic Authentication'
146+
}
147+
}
148+
149+
@classmethod
150+
def openapi_security_requirement(cls, view, method):
151+
return [{cls.openapi_security_scheme_name: []}]
152+
111153

112154
class SessionAuthentication(BaseAuthentication):
113155
"""
@@ -147,6 +189,23 @@ def dummy_get_response(request): # pragma: no cover
147189
# CSRF failed, bail with explicit error message
148190
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
149191

192+
openapi_security_scheme_name = 'sessionAuth'
193+
194+
@classmethod
195+
def openapi_security_scheme(cls):
196+
return {
197+
cls.openapi_security_scheme_name: {
198+
'type': 'apiKey',
199+
'in': 'cookie',
200+
'name': 'JSESSIONID',
201+
'description': 'Session authentication'
202+
}
203+
}
204+
205+
@classmethod
206+
def openapi_security_requirement(cls, view, method):
207+
return [{cls.openapi_security_scheme_name: []}]
208+
150209

151210
class TokenAuthentication(BaseAuthentication):
152211
"""
@@ -210,6 +269,23 @@ def authenticate_credentials(self, key):
210269
def authenticate_header(self, request):
211270
return self.keyword
212271

272+
openapi_security_scheme_name = 'tokenAuth'
273+
274+
@classmethod
275+
def openapi_security_scheme(cls):
276+
return {
277+
cls.openapi_security_scheme_name: {
278+
'type': 'http',
279+
'in': 'header',
280+
'name': 'Authorization', # Authorization: token ...
281+
'description': 'Token authentication'
282+
}
283+
}
284+
285+
@classmethod
286+
def openapi_security_requirement(cls, view, method):
287+
return [{cls.openapi_security_scheme_name: []}]
288+
213289

214290
class RemoteUserAuthentication(BaseAuthentication):
215291
"""
@@ -230,3 +306,20 @@ def authenticate(self, request):
230306
user = authenticate(remote_user=request.META.get(self.header))
231307
if user and user.is_active:
232308
return (user, None)
309+
310+
openapi_security_scheme_name = 'remoteUserAuth'
311+
312+
@classmethod
313+
def openapi_security_scheme(cls):
314+
return {
315+
cls.openapi_security_scheme_name: {
316+
'type': 'http',
317+
'in': 'header',
318+
'name': 'REMOTE_USER',
319+
'description': 'Remote User authentication'
320+
}
321+
}
322+
323+
@classmethod
324+
def openapi_security_requirement(cls, view, method):
325+
return [{cls.openapi_security_scheme_name: []}]

rest_framework/schemas/openapi.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ def get_schema(self, request=None, public=False):
7070
"""
7171
self._initialise_endpoints()
7272
components_schemas = {}
73+
security_schemes_schemas = {}
74+
root_security_requirements = []
75+
76+
if api_settings.DEFAULT_AUTHENTICATION_CLASSES:
77+
for auth_class in api_settings.DEFAULT_AUTHENTICATION_CLASSES:
78+
req = auth_class.openapi_security_requirement(None, None)
79+
if req:
80+
root_security_requirements += req
7381

7482
# Iterate endpoints generating per method path operations.
7583
paths = {}
@@ -80,6 +88,7 @@ def get_schema(self, request=None, public=False):
8088

8189
operation = view.schema.get_operation(path, method)
8290
components = view.schema.get_components(path, method)
91+
8392
for k in components.keys():
8493
if k not in components_schemas:
8594
continue
@@ -89,6 +98,16 @@ def get_schema(self, request=None, public=False):
8998

9099
components_schemas.update(components)
91100

101+
security_schemes = view.schema.get_security_schemes(path, method)
102+
for k in security_schemes.keys():
103+
if k not in security_schemes_schemas:
104+
continue
105+
if security_schemes_schemas[k] == security_schemes[k]:
106+
continue
107+
warnings.warn('Security scheme component "{}" has been overriden with a different '
108+
'value.'.format(k))
109+
security_schemes_schemas.update(security_schemes)
110+
92111
# Normalise path for any provided mount url.
93112
if path.startswith('/'):
94113
path = path[1:]
@@ -111,6 +130,14 @@ def get_schema(self, request=None, public=False):
111130
'schemas': components_schemas
112131
}
113132

133+
if len(security_schemes_schemas) > 0:
134+
if 'components' not in schema:
135+
schema['components'] = {}
136+
schema['components']['securitySchemes'] = security_schemes_schemas
137+
138+
if len(root_security_requirements) > 0:
139+
schema['security'] = root_security_requirements
140+
114141
return schema
115142

116143
# View Inspectors
@@ -146,6 +173,9 @@ def get_operation(self, path, method):
146173

147174
operation['operationId'] = self.get_operation_id(path, method)
148175
operation['description'] = self.get_description(path, method)
176+
security = self.get_security_requirements(path, method)
177+
if security is not None:
178+
operation['security'] = security
149179

150180
parameters = []
151181
parameters += self.get_path_parameters(path, method)
@@ -713,6 +743,34 @@ def get_tags(self, path, method):
713743

714744
return [path.split('/')[0].replace('_', '-')]
715745

746+
def get_security_schemes(self, path, method):
747+
"""
748+
Get components.schemas.securitySchemes required by this path.
749+
returns dict of securitySchemes.
750+
"""
751+
schemes = {}
752+
for auth_class in self.view.authentication_classes:
753+
if hasattr(auth_class, 'openapi_security_scheme'):
754+
schemes.update(auth_class.openapi_security_scheme())
755+
return schemes
756+
757+
def get_security_requirements(self, path, method):
758+
"""
759+
Get Security Requirement Object list for this operation.
760+
Returns a list of security requirement objects based on the view's authentication classes
761+
unless this view's authentication classes are the same as the root-level defaults.
762+
"""
763+
# references the securityScheme names described above in get_security_schemes()
764+
security = []
765+
if self.view.authentication_classes == api_settings.DEFAULT_AUTHENTICATION_CLASSES:
766+
return None
767+
for auth_class in self.view.authentication_classes:
768+
if hasattr(auth_class, 'openapi_security_requirement'):
769+
req = auth_class.openapi_security_requirement(self.view, method)
770+
if req:
771+
security += req
772+
return security
773+
716774
def _get_path_parameters(self, path, method):
717775
warnings.warn(
718776
"Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. "

tests/schemas/test_openapi.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from django.utils.translation import gettext_lazy as _
88

99
from rest_framework import filters, generics, pagination, routers, serializers
10+
from rest_framework.authentication import TokenAuthentication
1011
from rest_framework.authtoken.views import obtain_auth_token
1112
from rest_framework.compat import uritemplate
1213
from rest_framework.parsers import JSONParser, MultiPartParser
@@ -1216,5 +1217,51 @@ class ExampleView(generics.DestroyAPIView):
12161217
]
12171218
generator = SchemaGenerator(patterns=url_patterns)
12181219
schema = generator.get_schema(request=create_request('/'))
1219-
assert 'components' not in schema
1220+
assert 'schemas' not in schema['components']
12201221
assert 'content' not in schema['paths']['/example/']['delete']['responses']['204']
1222+
1223+
def test_default_root_security_schemes(self):
1224+
patterns = [
1225+
path('^example/?$', views.ExampleAutoSchemaComponentName.as_view()),
1226+
]
1227+
1228+
generator = SchemaGenerator(patterns=patterns)
1229+
1230+
request = create_request('/')
1231+
schema = generator.get_schema(request=request)
1232+
assert 'security' in schema
1233+
assert {'sessionAuth': []} in schema['security']
1234+
assert {'basicAuth': []} in schema['security']
1235+
assert 'security' not in schema['paths']['/example/']['get']
1236+
1237+
@override_settings(REST_FRAMEWORK={'DEFAULT_AUTHENTICATION_CLASSES': None})
1238+
def test_no_default_root_security_schemes(self):
1239+
patterns = [
1240+
path('^example/?$', views.ExampleAutoSchemaComponentName.as_view()),
1241+
]
1242+
1243+
generator = SchemaGenerator(patterns=patterns)
1244+
1245+
request = create_request('/')
1246+
schema = generator.get_schema(request=request)
1247+
assert 'security' not in schema
1248+
1249+
def test_operation_security_schemes(self):
1250+
class MyExample(views.ExampleAutoSchemaComponentName):
1251+
authentication_classes = [TokenAuthentication]
1252+
1253+
patterns = [
1254+
path('^example/?$', MyExample.as_view()),
1255+
]
1256+
1257+
generator = SchemaGenerator(patterns=patterns)
1258+
1259+
request = create_request('/')
1260+
schema = generator.get_schema(request=request)
1261+
assert 'security' in schema
1262+
assert {'sessionAuth': []} in schema['security']
1263+
assert {'basicAuth': []} in schema['security']
1264+
get_operation = schema['paths']['/example/']['get']
1265+
assert 'security' in get_operation
1266+
assert {'tokenAuth': []} in get_operation['security']
1267+
assert len(get_operation['security']) == 1

0 commit comments

Comments
 (0)