Skip to content

Commit 69c6355

Browse files
committed
Add schema references for OpenAPI generation, see #6984
1 parent 373e521 commit 69c6355

File tree

3 files changed

+118
-33
lines changed

3 files changed

+118
-33
lines changed

rest_framework/schemas/openapi.py

Lines changed: 71 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from enum import Enum
23
from operator import attrgetter
34
from urllib.parse import urljoin
45

@@ -39,12 +40,18 @@ def get_paths(self, request=None):
3940

4041
# Only generate the path prefix for paths that will be included
4142
if not paths:
42-
return None
43+
return None, None
44+
45+
components_schemas = {}
4346

4447
for path, method, view in view_endpoints:
4548
if not self.has_view_permissions(path, method, view):
4649
continue
47-
operation = view.schema.get_operation(path, method)
50+
operation, operation_schema = view.schema.get_operation(path, method)
51+
52+
if operation_schema is not None:
53+
components_schemas.update(operation_schema)
54+
4855
# Normalise path for any provided mount url.
4956
if path.startswith('/'):
5057
path = path[1:]
@@ -53,24 +60,29 @@ def get_paths(self, request=None):
5360
result.setdefault(path, {})
5461
result[path][method.lower()] = operation
5562

56-
return result
63+
return result, components_schemas
5764

5865
def get_schema(self, request=None, public=False):
5966
"""
6067
Generate a OpenAPI schema.
6168
"""
6269
self._initialise_endpoints()
6370

64-
paths = self.get_paths(None if public else request)
71+
paths, components_schemas = self.get_paths(None if public else request)
6572
if not paths:
6673
return None
6774

6875
schema = {
6976
'openapi': '3.0.2',
7077
'info': self.get_info(),
71-
'paths': paths,
78+
'paths': paths
7279
}
7380

81+
if len(components_schemas) > 0:
82+
schema['components'] = {
83+
'schemas': components_schemas
84+
}
85+
7486
return schema
7587

7688
# View Inspectors
@@ -106,7 +118,9 @@ def get_operation(self, path, method):
106118
operation['requestBody'] = request_body
107119
operation['responses'] = self._get_responses(path, method)
108120

109-
return operation
121+
component_schema = self._get_component_schema(path, method)
122+
123+
return operation, component_schema
110124

111125
def _get_operation_id(self, path, method):
112126
"""
@@ -479,29 +493,67 @@ def _get_serializer(self, method, path):
479493
.format(view.__class__.__name__, method, path))
480494
return None
481495

482-
def _get_request_body(self, path, method):
483-
if method not in ('PUT', 'PATCH', 'POST'):
496+
class SchemaMode(Enum):
497+
RESPONSE = 1
498+
BODY = 2
499+
500+
def _get_item_schema(self, serializer, schema_mode, method):
501+
if not isinstance(serializer, serializers.Serializer):
484502
return {}
485503

486-
self.request_media_types = self.map_parsers(path, method)
504+
# If the serializer uses a model, we should use a reference
505+
if hasattr(serializer, 'Meta') and hasattr(serializer.Meta, 'model'):
506+
model_name = serializer.Meta.model.__name__
507+
return {'$ref': '#/components/schemas/{}'.format(model_name)}
508+
509+
# There is no model, we'll map the serializer's fields
510+
item_schema = self._map_serializer(serializer)
487511

512+
if schema_mode == self.SchemaMode.RESPONSE:
513+
# No write_only fields for response.
514+
for name, schema in item_schema['properties'].copy().items():
515+
if 'writeOnly' in schema:
516+
del item_schema['properties'][name]
517+
if 'required' in item_schema:
518+
item_schema['required'] = [f for f in item_schema['required'] if f != name]
519+
520+
elif schema_mode == self.SchemaMode.BODY:
521+
# No required fields for PATCH
522+
if method == 'PATCH':
523+
item_schema.pop('required', None)
524+
# No read_only fields for request.
525+
for name, schema in item_schema['properties'].copy().items():
526+
if 'readOnly' in schema:
527+
del item_schema['properties'][name]
528+
529+
return item_schema
530+
531+
def _get_component_schema(self, path, method):
488532
serializer = self._get_serializer(path, method)
489533

490534
if not isinstance(serializer, serializers.Serializer):
491-
return {}
535+
return None
536+
537+
# If the model has no model, then the serializer will be inlined
538+
if not hasattr(serializer, 'Meta') or not hasattr(serializer.Meta, 'model'):
539+
return None
492540

541+
model_name = serializer.Meta.model.__name__
493542
content = self._map_serializer(serializer)
494-
# No required fields for PATCH
495-
if method == 'PATCH':
496-
content.pop('required', None)
497-
# No read_only fields for request.
498-
for name, schema in content['properties'].copy().items():
499-
if 'readOnly' in schema:
500-
del content['properties'][name]
543+
544+
return {model_name: content}
545+
546+
def _get_request_body(self, path, method):
547+
if method not in ('PUT', 'PATCH', 'POST'):
548+
return {}
549+
550+
self.request_media_types = self.map_parsers(path, method)
551+
552+
serializer = self._get_serializer(path, method)
501553

502554
return {
503555
'content': {
504-
ct: {'schema': content}
556+
ct: {'schema': self._get_item_schema(serializer, self.SchemaMode.BODY, method)}
505557
for ct in self.request_media_types
506558
}
507559
}
@@ -517,17 +569,8 @@ def _get_responses(self, path, method):
517569

518570
self.response_media_types = self.map_renderers(path, method)
519571

520-
item_schema = {}
521572
serializer = self._get_serializer(path, method)
522-
523-
if isinstance(serializer, serializers.Serializer):
524-
item_schema = self._map_serializer(serializer)
525-
# No write_only fields for response.
526-
for name, schema in item_schema['properties'].copy().items():
527-
if 'writeOnly' in schema:
528-
del item_schema['properties'][name]
529-
if 'required' in item_schema:
530-
item_schema['required'] = [f for f in item_schema['required'] if f != name]
573+
item_schema = self._get_item_schema(serializer, self.SchemaMode.RESPONSE, method)
531574

532575
if is_list_view(path, method, self.view):
533576
response_schema = {

tests/schemas/test_openapi.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_path_without_parameters(self):
8484
inspector = AutoSchema()
8585
inspector.view = view
8686

87-
operation = inspector.get_operation(path, method)
87+
operation, _ = inspector.get_operation(path, method)
8888
assert operation == {
8989
'operationId': 'listDocStringExamples',
9090
'description': 'A description of my GET operation.',
@@ -116,7 +116,7 @@ def test_path_with_id_parameter(self):
116116
inspector = AutoSchema()
117117
inspector.view = view
118118

119-
operation = inspector.get_operation(path, method)
119+
operation, _ = inspector.get_operation(path, method)
120120
assert operation == {
121121
'operationId': 'RetrieveDocStringExampleDetail',
122122
'description': 'A description of my GET operation.',
@@ -659,7 +659,7 @@ def test_paths_construction(self):
659659
generator = SchemaGenerator(patterns=patterns)
660660
generator._initialise_endpoints()
661661

662-
paths = generator.get_paths()
662+
paths, _ = generator.get_paths()
663663

664664
assert '/example/' in paths
665665
example_operations = paths['/example/']
@@ -676,7 +676,7 @@ def test_prefixed_paths_construction(self):
676676
generator = SchemaGenerator(patterns=patterns)
677677
generator._initialise_endpoints()
678678

679-
paths = generator.get_paths()
679+
paths, _ = generator.get_paths()
680680

681681
assert '/v1/example/' in paths
682682
assert '/v1/example/{id}/' in paths
@@ -689,7 +689,7 @@ def test_mount_url_prefixed_to_paths(self):
689689
generator = SchemaGenerator(patterns=patterns, url='/api')
690690
generator._initialise_endpoints()
691691

692-
paths = generator.get_paths()
692+
paths, _ = generator.get_paths()
693693

694694
assert '/api/example/' in paths
695695
assert '/api/example/{id}/' in paths
@@ -733,3 +733,18 @@ def test_schema_information_empty(self):
733733

734734
assert schema['info']['title'] == ''
735735
assert schema['info']['version'] == ''
736+
737+
def test_serializer_model(self):
738+
"""Construction of the top level dictionary."""
739+
patterns = [
740+
url(r'^example/?$', views.ExampleGenericAPIViewModel.as_view()),
741+
]
742+
generator = SchemaGenerator(patterns=patterns)
743+
744+
request = create_request('/')
745+
schema = generator.get_schema(request=request)
746+
747+
print(schema)
748+
assert 'components' in schema
749+
assert 'schemas' in schema['components']
750+
assert 'OpenAPIExample' in schema['components']['schemas']

tests/schemas/views.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
DecimalValidator, MaxLengthValidator, MaxValueValidator,
55
MinLengthValidator, MinValueValidator, RegexValidator
66
)
7+
from django.db import models
78

89
from rest_framework import generics, permissions, serializers
910
from rest_framework.decorators import action
@@ -137,3 +138,29 @@ def get(self, *args, **kwargs):
137138
url='http://localhost', uuid=uuid.uuid4(), ip4='127.0.0.1', ip6='::1',
138139
ip='192.168.1.1')
139140
return Response(serializer.data)
141+
142+
143+
# Serializer with model.
144+
class OpenAPIExample(models.Model):
145+
first_name = models.CharField(max_length=30)
146+
147+
148+
class ExampleSerializerModel(serializers.Serializer):
149+
date = serializers.DateField()
150+
datetime = serializers.DateTimeField()
151+
hstore = serializers.HStoreField()
152+
uuid_field = serializers.UUIDField(default=uuid.uuid4)
153+
154+
class Meta:
155+
model = OpenAPIExample
156+
157+
158+
class ExampleGenericAPIViewModel(generics.GenericAPIView):
159+
serializer_class = ExampleSerializerModel
160+
161+
def get(self, *args, **kwargs):
162+
from datetime import datetime
163+
now = datetime.now()
164+
165+
serializer = self.get_serializer(data=now.date(), datetime=now)
166+
return Response(serializer.data)

0 commit comments

Comments
 (0)