Skip to content

Commit 568452c

Browse files
gnuletikcarltongibson
authored andcommitted
Generate components for OpenAPI schemas.
1 parent 4137ef4 commit 568452c

File tree

3 files changed

+114
-14
lines changed

3 files changed

+114
-14
lines changed

rest_framework/schemas/openapi.py

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
import warnings
23
from collections import OrderedDict
34
from decimal import Decimal
@@ -39,16 +40,21 @@ def get_schema(self, request=None, public=False):
3940
Generate a OpenAPI schema.
4041
"""
4142
self._initialise_endpoints()
43+
components_schemas = {}
4244

4345
# Iterate endpoints generating per method path operations.
44-
# TODO: …and reference components.
4546
paths = {}
4647
_, view_endpoints = self._get_paths_and_endpoints(None if public else request)
4748
for path, method, view in view_endpoints:
4849
if not self.has_view_permissions(path, method, view):
4950
continue
5051

5152
operation = view.schema.get_operation(path, method)
53+
component = view.schema.get_components(path, method)
54+
55+
if component is not None:
56+
components_schemas.update(component)
57+
5258
# Normalise path for any provided mount url.
5359
if path.startswith('/'):
5460
path = path[1:]
@@ -64,6 +70,11 @@ def get_schema(self, request=None, public=False):
6470
'paths': paths,
6571
}
6672

73+
if len(components_schemas) > 0:
74+
schema['components'] = {
75+
'schemas': components_schemas
76+
}
77+
6778
return schema
6879

6980
# View Inspectors
@@ -101,6 +112,34 @@ def get_operation(self, path, method):
101112

102113
return operation
103114

115+
def _get_serializer_component_name(self, serializer):
116+
if not hasattr(serializer, 'Meta'):
117+
return None
118+
119+
if hasattr(serializer.Meta, 'schema_component_name'):
120+
return serializer.Meta.schema_component_name
121+
122+
# If the serializer has no Meta.schema_component_name, we use
123+
# the serializer's class name as the component name.
124+
component_name = serializer.__class__.__name__
125+
# We remove the "serializer" string from the class name.
126+
pattern = re.compile("serializer", re.IGNORECASE)
127+
return pattern.sub("", component_name)
128+
129+
def get_components(self, path, method):
130+
serializer = self._get_serializer(path, method)
131+
132+
if not isinstance(serializer, serializers.Serializer):
133+
return None
134+
135+
component_name = self._get_serializer_component_name(serializer)
136+
137+
if component_name is None:
138+
return None
139+
140+
content = self._map_serializer(serializer)
141+
return {component_name: content}
142+
104143
def _get_operation_id(self, path, method):
105144
"""
106145
Compute an operation ID from the model, serializer or view name.
@@ -491,6 +530,10 @@ def _get_serializer(self, path, method):
491530
.format(view.__class__.__name__, method, path))
492531
return None
493532

533+
def _get_reference(self, serializer):
534+
component_name = self._get_serializer_component_name(serializer)
535+
return {'$ref': '#/components/schemas/{}'.format(component_name)}
536+
494537
def _get_request_body(self, path, method):
495538
if method not in ('PUT', 'PATCH', 'POST'):
496539
return {}
@@ -500,20 +543,30 @@ def _get_request_body(self, path, method):
500543
serializer = self._get_serializer(path, method)
501544

502545
if not isinstance(serializer, serializers.Serializer):
503-
return {}
504-
505-
content = self._map_serializer(serializer)
506-
# No required fields for PATCH
507-
if method == 'PATCH':
508-
content.pop('required', None)
509-
# No read_only fields for request.
510-
for name, schema in content['properties'].copy().items():
511-
if 'readOnly' in schema:
512-
del content['properties'][name]
546+
item_schema = {}
547+
elif hasattr(serializer, 'Meta'):
548+
# If possible, the serializer should use a reference
549+
item_schema = self._get_reference(serializer)
550+
else:
551+
# There is no model, we'll map the serializer's fields
552+
item_schema = self._map_serializer(serializer)
553+
# No required fields for PATCH
554+
if method == 'PATCH':
555+
item_schema.pop('required', None)
556+
# No read_only fields for request.
557+
# No write_only fields for response.
558+
for name, schema in item_schema['properties'].copy().items():
559+
if 'writeOnly' in schema:
560+
del item_schema['properties'][name]
561+
if 'required' in item_schema:
562+
item_schema['required'] = [f for f in item_schema['required'] if f != name]
563+
for name, schema in item_schema['properties'].copy().items():
564+
if 'readOnly' in schema:
565+
del item_schema['properties'][name]
513566

514567
return {
515568
'content': {
516-
ct: {'schema': content}
569+
ct: {'schema': item_schema}
517570
for ct in self.request_media_types
518571
}
519572
}
@@ -529,10 +582,15 @@ def _get_responses(self, path, method):
529582

530583
self.response_media_types = self.map_renderers(path, method)
531584

532-
item_schema = {}
533585
serializer = self._get_serializer(path, method)
534586

535-
if isinstance(serializer, serializers.Serializer):
587+
if not isinstance(serializer, serializers.Serializer):
588+
item_schema = {}
589+
elif hasattr(serializer, 'Meta') and hasattr(serializer.Meta, 'model'):
590+
# If the serializer uses a model, we should use a reference
591+
item_schema = self._get_reference(serializer)
592+
else:
593+
# There is no model, we'll map the serializer's fields
536594
item_schema = self._map_serializer(serializer)
537595
# No write_only fields for response.
538596
for name, schema in item_schema['properties'].copy().items():

tests/schemas/test_openapi.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,3 +795,18 @@ def test_schema_information_empty(self):
795795

796796
assert schema['info']['title'] == ''
797797
assert schema['info']['version'] == ''
798+
799+
def test_serializer_model(self):
800+
"""Construction of the top level dictionary."""
801+
patterns = [
802+
url(r'^example/?$', views.ExampleGenericAPIViewModel.as_view()),
803+
]
804+
generator = SchemaGenerator(patterns=patterns)
805+
806+
request = create_request('/')
807+
schema = generator.get_schema(request=request)
808+
809+
print(schema)
810+
assert 'components' in schema
811+
assert 'schemas' in schema['components']
812+
assert 'ExampleModel' in schema['components']['schemas']

tests/schemas/views.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
DecimalValidator, MaxLengthValidator, MaxValueValidator,
55
MinLengthValidator, MinValueValidator, RegexValidator
66
)
7+
from django.db import models
78

89
from rest_framework import generics, permissions, serializers
910
from rest_framework.decorators import action
@@ -137,3 +138,29 @@ def get(self, *args, **kwargs):
137138
url='http://localhost', uuid=uuid.uuid4(), ip4='127.0.0.1', ip6='::1',
138139
ip='192.168.1.1')
139140
return Response(serializer.data)
141+
142+
143+
# Serializer with model.
144+
class OpenAPIExample(models.Model):
145+
first_name = models.CharField(max_length=30)
146+
147+
148+
class ExampleSerializerModel(serializers.Serializer):
149+
date = serializers.DateField()
150+
datetime = serializers.DateTimeField()
151+
hstore = serializers.HStoreField()
152+
uuid_field = serializers.UUIDField(default=uuid.uuid4)
153+
154+
class Meta:
155+
model = OpenAPIExample
156+
157+
158+
class ExampleGenericAPIViewModel(generics.GenericAPIView):
159+
serializer_class = ExampleSerializerModel
160+
161+
def get(self, *args, **kwargs):
162+
from datetime import datetime
163+
now = datetime.now()
164+
165+
serializer = self.get_serializer(data=now.date(), datetime=now)
166+
return Response(serializer.data)

0 commit comments

Comments
 (0)