Skip to content

Commit 83f0810

Browse files
committed
Rebase on #7127 and remove _get_item_schema refactoring
1 parent 67549c3 commit 83f0810

File tree

3 files changed

+102
-15
lines changed

3 files changed

+102
-15
lines changed

rest_framework/schemas/openapi.py

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from enum import Enum
23
from operator import attrgetter
34
from urllib.parse import urljoin
45

@@ -37,16 +38,21 @@ def get_schema(self, request=None, public=False):
3738
Generate a OpenAPI schema.
3839
"""
3940
self._initialise_endpoints()
41+
components_schemas = {}
4042

4143
# Iterate endpoints generating per method path operations.
42-
# TODO: …and reference components.
4344
paths = {}
4445
_, view_endpoints = self._get_paths_and_endpoints(None if public else request)
4546
for path, method, view in view_endpoints:
4647
if not self.has_view_permissions(path, method, view):
4748
continue
4849

4950
operation = view.schema.get_operation(path, method)
51+
component = view.schema.get_components(path, method)
52+
53+
if component is not None:
54+
components_schemas.update(component)
55+
5056
# Normalise path for any provided mount url.
5157
if path.startswith('/'):
5258
path = path[1:]
@@ -59,9 +65,14 @@ def get_schema(self, request=None, public=False):
5965
schema = {
6066
'openapi': '3.0.2',
6167
'info': self.get_info(),
62-
'paths': paths,
68+
'paths': paths
6369
}
6470

71+
if len(components_schemas) > 0:
72+
schema['components'] = {
73+
'schemas': components_schemas
74+
}
75+
6576
return schema
6677

6778
# View Inspectors
@@ -99,6 +110,21 @@ def get_operation(self, path, method):
99110

100111
return operation
101112

113+
def get_components(self, path, method):
114+
serializer = self._get_serializer(path, method)
115+
116+
if not isinstance(serializer, serializers.Serializer):
117+
return None
118+
119+
# If the model has no model, then the serializer will be inlined
120+
if not hasattr(serializer, 'Meta') or not hasattr(serializer.Meta, 'model'):
121+
return None
122+
123+
model_name = serializer.Meta.model.__name__
124+
content = self._map_serializer(serializer)
125+
126+
return {model_name: content}
127+
102128
def _get_operation_id(self, path, method):
103129
"""
104130
Compute an operation ID from the model, serializer or view name.
@@ -470,6 +496,10 @@ def _get_serializer(self, method, path):
470496
.format(view.__class__.__name__, method, path))
471497
return None
472498

499+
def _get_reference(self, serializer):
500+
model_name = serializer.Meta.model.__name__
501+
return {'$ref': '#/components/schemas/{}'.format(model_name)}
502+
473503
def _get_request_body(self, path, method):
474504
if method not in ('PUT', 'PATCH', 'POST'):
475505
return {}
@@ -479,20 +509,30 @@ def _get_request_body(self, path, method):
479509
serializer = self._get_serializer(path, method)
480510

481511
if not isinstance(serializer, serializers.Serializer):
482-
return {}
483-
484-
content = self._map_serializer(serializer)
485-
# No required fields for PATCH
486-
if method == 'PATCH':
487-
content.pop('required', None)
488-
# No read_only fields for request.
489-
for name, schema in content['properties'].copy().items():
490-
if 'readOnly' in schema:
491-
del content['properties'][name]
512+
item_schema = {}
513+
elif hasattr(serializer, 'Meta') and hasattr(serializer.Meta, 'model'):
514+
# If the serializer uses a model, we should use a reference
515+
item_schema = self._get_reference(serializer)
516+
else:
517+
# There is no model, we'll map the serializer's fields
518+
item_schema = self._map_serializer(serializer)
519+
# No required fields for PATCH
520+
if method == 'PATCH':
521+
item_schema.pop('required', None)
522+
# No read_only fields for request.
523+
# No write_only fields for response.
524+
for name, schema in item_schema['properties'].copy().items():
525+
if 'writeOnly' in schema:
526+
del item_schema['properties'][name]
527+
if 'required' in item_schema:
528+
item_schema['required'] = [f for f in item_schema['required'] if f != name]
529+
for name, schema in item_schema['properties'].copy().items():
530+
if 'readOnly' in schema:
531+
del item_schema['properties'][name]
492532

493533
return {
494534
'content': {
495-
ct: {'schema': content}
535+
ct: {'schema': item_schema}
496536
for ct in self.request_media_types
497537
}
498538
}
@@ -508,10 +548,15 @@ def _get_responses(self, path, method):
508548

509549
self.response_media_types = self.map_renderers(path, method)
510550

511-
item_schema = {}
512551
serializer = self._get_serializer(path, method)
513552

514-
if isinstance(serializer, serializers.Serializer):
553+
if not isinstance(serializer, serializers.Serializer):
554+
item_schema = {}
555+
elif hasattr(serializer, 'Meta') and hasattr(serializer.Meta, 'model'):
556+
# If the serializer uses a model, we should use a reference
557+
item_schema = self._get_reference(serializer)
558+
else:
559+
# There is no model, we'll map the serializer's fields
515560
item_schema = self._map_serializer(serializer)
516561
# No write_only fields for response.
517562
for name, schema in item_schema['properties'].copy().items():

tests/schemas/test_openapi.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -742,3 +742,18 @@ def test_schema_information_empty(self):
742742

743743
assert schema['info']['title'] == ''
744744
assert schema['info']['version'] == ''
745+
746+
def test_serializer_model(self):
747+
"""Construction of the top level dictionary."""
748+
patterns = [
749+
url(r'^example/?$', views.ExampleGenericAPIViewModel.as_view()),
750+
]
751+
generator = SchemaGenerator(patterns=patterns)
752+
753+
request = create_request('/')
754+
schema = generator.get_schema(request=request)
755+
756+
print(schema)
757+
assert 'components' in schema
758+
assert 'schemas' in schema['components']
759+
assert 'OpenAPIExample' in schema['components']['schemas']

tests/schemas/views.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
DecimalValidator, MaxLengthValidator, MaxValueValidator,
55
MinLengthValidator, MinValueValidator, RegexValidator
66
)
7+
from django.db import models
78

89
from rest_framework import generics, permissions, serializers
910
from rest_framework.decorators import action
@@ -137,3 +138,29 @@ def get(self, *args, **kwargs):
137138
url='http://localhost', uuid=uuid.uuid4(), ip4='127.0.0.1', ip6='::1',
138139
ip='192.168.1.1')
139140
return Response(serializer.data)
141+
142+
143+
# Serializer with model.
144+
class OpenAPIExample(models.Model):
145+
first_name = models.CharField(max_length=30)
146+
147+
148+
class ExampleSerializerModel(serializers.Serializer):
149+
date = serializers.DateField()
150+
datetime = serializers.DateTimeField()
151+
hstore = serializers.HStoreField()
152+
uuid_field = serializers.UUIDField(default=uuid.uuid4)
153+
154+
class Meta:
155+
model = OpenAPIExample
156+
157+
158+
class ExampleGenericAPIViewModel(generics.GenericAPIView):
159+
serializer_class = ExampleSerializerModel
160+
161+
def get(self, *args, **kwargs):
162+
from datetime import datetime
163+
now = datetime.now()
164+
165+
serializer = self.get_serializer(data=now.date(), datetime=now)
166+
return Response(serializer.data)

0 commit comments

Comments
 (0)