Skip to content

Commit 83be9c8

Browse files
committed
add OAS securitySchemes and security objects
1 parent 327cbef commit 83be9c8

File tree

4 files changed

+219
-1
lines changed

4 files changed

+219
-1
lines changed

docs/api-guide/schemas.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -375,6 +375,26 @@ operationIds.
375375
In order to work around this, you can override `get_operation_id_base()` to
376376
provide a different base for name part of the ID.
377377

378+
#### `get_security_schemes()`
379+
380+
Generates the OpenAPI `securitySchemes` components based on:
381+
- Your default `authentication_classes` (`settings.DEFAULT_AUTHENTICATION_CLASSES`)
382+
- Per-view non-default `authentication_classes`
383+
384+
These are generated using the authentication classes' `openapi_security_scheme()` class method. If you
385+
extend `BaseAuthentication` with your own authentication class, you can add this class method to return
386+
the appropriate security scheme object.
387+
388+
#### `get_security_requirements()`
389+
390+
Root-level security requirements (the top-level `security` object) are generated based on the
391+
default authentication classes. Operation-level security requirements are generated only if the given view's
392+
`authentication_classes` differ from the defaults.
393+
394+
These are generated using the authentication classes' `openapi_security_requirement()` class
395+
method. If you extended `BaseAuthentication` with your own authentication class, you can add this
396+
class method to return the appropriate list of security requirements objects.
397+
378398
### `AutoSchema.__init__()` kwargs
379399

380400
`AutoSchema` provides a number of `__init__()` kwargs that can be used for

rest_framework/authentication.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,32 @@ def authenticate_header(self, request):
4949
"""
5050
pass
5151

52+
#: Name of openapi security scheme. Override if you want to customize it.
53+
openapi_security_scheme_name = None
54+
55+
@classmethod
56+
def openapi_security_scheme(cls):
57+
"""
58+
Override this to return an Open API Specification `securityScheme object
59+
<http://spec.openapis.org/oas/v3.0.3#security-scheme-object>`_
60+
"""
61+
return {}
62+
63+
@classmethod
64+
def openapi_security_requirement(cls, view, method):
65+
"""
66+
Override this to return an Open API Specification `security requirement object
67+
<http://spec.openapis.org/oas/v3.0.3#security-requirement-object>`_
68+
69+
:param view: used to find view attributes used by a permission class or None for root-level
70+
:param method: used to distinguish among method-specific permissions or None for root-level
71+
:return:list: [security requirement objects]
72+
"""
73+
# At this point, none of the built-in DRF authentication classes fill in the
74+
# requirement list: OAuth2/OIDC are the only security types that currently uses the list
75+
# (for scopes). See http://spec.openapis.org/oas/v3.0.3#patterned-fields-2.
76+
return [{}]
77+
5278

5379
class BasicAuthentication(BaseAuthentication):
5480
"""
@@ -108,6 +134,22 @@ def authenticate_credentials(self, userid, password, request=None):
108134
def authenticate_header(self, request):
109135
return 'Basic realm="%s"' % self.www_authenticate_realm
110136

137+
openapi_security_scheme_name = 'basicAuth'
138+
139+
@classmethod
140+
def openapi_security_scheme(cls):
141+
return {
142+
cls.openapi_security_scheme_name: {
143+
'type': 'http',
144+
'scheme': 'basic',
145+
'description': 'Basic Authentication'
146+
}
147+
}
148+
149+
@classmethod
150+
def openapi_security_requirement(cls, view, method):
151+
return [{cls.openapi_security_scheme_name: []}]
152+
111153

112154
class SessionAuthentication(BaseAuthentication):
113155
"""
@@ -144,6 +186,23 @@ def enforce_csrf(self, request):
144186
# CSRF failed, bail with explicit error message
145187
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
146188

189+
openapi_security_scheme_name = 'sessionAuth'
190+
191+
@classmethod
192+
def openapi_security_scheme(cls):
193+
return {
194+
cls.openapi_security_scheme_name: {
195+
'type': 'apiKey',
196+
'in': 'cookie',
197+
'name': 'JSESSIONID',
198+
'description': 'Session authentication'
199+
}
200+
}
201+
202+
@classmethod
203+
def openapi_security_requirement(cls, view, method):
204+
return [{cls.openapi_security_scheme_name: []}]
205+
147206

148207
class TokenAuthentication(BaseAuthentication):
149208
"""
@@ -207,6 +266,23 @@ def authenticate_credentials(self, key):
207266
def authenticate_header(self, request):
208267
return self.keyword
209268

269+
openapi_security_scheme_name = 'tokenAuth'
270+
271+
@classmethod
272+
def openapi_security_scheme(cls):
273+
return {
274+
cls.openapi_security_scheme_name: {
275+
'type': 'http',
276+
'in': 'header',
277+
'name': 'Authorization', # Authorization: token ...
278+
'description': 'Token authentication'
279+
}
280+
}
281+
282+
@classmethod
283+
def openapi_security_requirement(cls, view, method):
284+
return [{cls.openapi_security_scheme_name: []}]
285+
210286

211287
class RemoteUserAuthentication(BaseAuthentication):
212288
"""
@@ -227,3 +303,20 @@ def authenticate(self, request):
227303
user = authenticate(remote_user=request.META.get(self.header))
228304
if user and user.is_active:
229305
return (user, None)
306+
307+
openapi_security_scheme_name = 'remoteUserAuth'
308+
309+
@classmethod
310+
def openapi_security_scheme(cls):
311+
return {
312+
cls.openapi_security_scheme_name: {
313+
'type': 'http',
314+
'in': 'header',
315+
'name': 'REMOTE_USER',
316+
'description': 'Remote User authentication'
317+
}
318+
}
319+
320+
@classmethod
321+
def openapi_security_requirement(cls, view, method):
322+
return [{cls.openapi_security_scheme_name: []}]

rest_framework/schemas/openapi.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,14 @@ def get_schema(self, request=None, public=False):
7070
"""
7171
self._initialise_endpoints()
7272
components_schemas = {}
73+
security_schemes_schemas = {}
74+
root_security_requirements = []
75+
76+
if api_settings.DEFAULT_AUTHENTICATION_CLASSES:
77+
for auth_class in api_settings.DEFAULT_AUTHENTICATION_CLASSES:
78+
req = auth_class.openapi_security_requirement(None, None)
79+
if req:
80+
root_security_requirements += req
7381

7482
# Iterate endpoints generating per method path operations.
7583
paths = {}
@@ -80,6 +88,7 @@ def get_schema(self, request=None, public=False):
8088

8189
operation = view.schema.get_operation(path, method)
8290
components = view.schema.get_components(path, method)
91+
8392
for k in components.keys():
8493
if k not in components_schemas:
8594
continue
@@ -89,6 +98,16 @@ def get_schema(self, request=None, public=False):
8998

9099
components_schemas.update(components)
91100

101+
security_schemes = view.schema.get_security_schemes(path, method)
102+
for k in security_schemes.keys():
103+
if k not in security_schemes_schemas:
104+
continue
105+
if security_schemes_schemas[k] == security_schemes[k]:
106+
continue
107+
warnings.warn('Security scheme component "{}" has been overriden with a different '
108+
'value.'.format(k))
109+
security_schemes_schemas.update(security_schemes)
110+
92111
# Normalise path for any provided mount url.
93112
if path.startswith('/'):
94113
path = path[1:]
@@ -111,6 +130,14 @@ def get_schema(self, request=None, public=False):
111130
'schemas': components_schemas
112131
}
113132

133+
if len(security_schemes_schemas) > 0:
134+
if 'components' not in schema:
135+
schema['components'] = {}
136+
schema['components']['securitySchemes'] = security_schemes_schemas
137+
138+
if len(root_security_requirements) > 0:
139+
schema['security'] = root_security_requirements
140+
114141
return schema
115142

116143
# View Inspectors
@@ -146,6 +173,9 @@ def get_operation(self, path, method):
146173

147174
operation['operationId'] = self.get_operation_id(path, method)
148175
operation['description'] = self.get_description(path, method)
176+
security = self.get_security_requirements(path, method)
177+
if security is not None:
178+
operation['security'] = security
149179

150180
parameters = []
151181
parameters += self.get_path_parameters(path, method)
@@ -692,6 +722,34 @@ def get_tags(self, path, method):
692722

693723
return [path.split('/')[0].replace('_', '-')]
694724

725+
def get_security_schemes(self, path, method):
726+
"""
727+
Get components.schemas.securitySchemes required by this path.
728+
returns dict of securitySchemes.
729+
"""
730+
schemes = {}
731+
for auth_class in self.view.authentication_classes:
732+
if hasattr(auth_class, 'openapi_security_scheme'):
733+
schemes.update(auth_class.openapi_security_scheme())
734+
return schemes
735+
736+
def get_security_requirements(self, path, method):
737+
"""
738+
Get Security Requirement Object list for this operation.
739+
Returns a list of security requirement objects based on the view's authentication classes
740+
unless this view's authentication classes are the same as the root-level defaults.
741+
"""
742+
# references the securityScheme names described above in get_security_schemes()
743+
security = []
744+
if self.view.authentication_classes == api_settings.DEFAULT_AUTHENTICATION_CLASSES:
745+
return None
746+
for auth_class in self.view.authentication_classes:
747+
if hasattr(auth_class, 'openapi_security_requirement'):
748+
req = auth_class.openapi_security_requirement(self.view, method)
749+
if req:
750+
security += req
751+
return security
752+
695753
def _get_path_parameters(self, path, method):
696754
warnings.warn(
697755
"Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. "

tests/schemas/test_openapi.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from django.utils.translation import gettext_lazy as _
88

99
from rest_framework import filters, generics, pagination, routers, serializers
10+
from rest_framework.authentication import TokenAuthentication
1011
from rest_framework.authtoken.views import obtain_auth_token
1112
from rest_framework.compat import uritemplate
1213
from rest_framework.parsers import JSONParser, MultiPartParser
@@ -1108,5 +1109,51 @@ class ExampleView(generics.DestroyAPIView):
11081109
]
11091110
generator = SchemaGenerator(patterns=url_patterns)
11101111
schema = generator.get_schema(request=create_request('/'))
1111-
assert 'components' not in schema
1112+
assert 'schemas' not in schema['components']
11121113
assert 'content' not in schema['paths']['/example/']['delete']['responses']['204']
1114+
1115+
def test_default_root_security_schemes(self):
1116+
patterns = [
1117+
path('^example/?$', views.ExampleAutoSchemaComponentName.as_view()),
1118+
]
1119+
1120+
generator = SchemaGenerator(patterns=patterns)
1121+
1122+
request = create_request('/')
1123+
schema = generator.get_schema(request=request)
1124+
assert 'security' in schema
1125+
assert {'sessionAuth': []} in schema['security']
1126+
assert {'basicAuth': []} in schema['security']
1127+
assert 'security' not in schema['paths']['/example/']['get']
1128+
1129+
@override_settings(REST_FRAMEWORK={'DEFAULT_AUTHENTICATION_CLASSES': None})
1130+
def test_no_default_root_security_schemes(self):
1131+
patterns = [
1132+
path('^example/?$', views.ExampleAutoSchemaComponentName.as_view()),
1133+
]
1134+
1135+
generator = SchemaGenerator(patterns=patterns)
1136+
1137+
request = create_request('/')
1138+
schema = generator.get_schema(request=request)
1139+
assert 'security' not in schema
1140+
1141+
def test_operation_security_schemes(self):
1142+
class MyExample(views.ExampleAutoSchemaComponentName):
1143+
authentication_classes = [TokenAuthentication]
1144+
1145+
patterns = [
1146+
path('^example/?$', MyExample.as_view()),
1147+
]
1148+
1149+
generator = SchemaGenerator(patterns=patterns)
1150+
1151+
request = create_request('/')
1152+
schema = generator.get_schema(request=request)
1153+
assert 'security' in schema
1154+
assert {'sessionAuth': []} in schema['security']
1155+
assert {'basicAuth': []} in schema['security']
1156+
get_operation = schema['paths']['/example/']['get']
1157+
assert 'security' in get_operation
1158+
assert {'tokenAuth': []} in get_operation['security']
1159+
assert len(get_operation['security']) == 1

0 commit comments

Comments
 (0)