From 7c8edbc31534cfe0d68349ec1f1df4989994e2fd Mon Sep 17 00:00:00 2001 From: Martin Desrumaux Date: Mon, 2 Mar 2020 18:36:36 +0100 Subject: [PATCH] Implement OpenAPI Components --- docs/api-guide/schemas.md | 63 +++++++++ rest_framework/schemas/openapi.py | 94 ++++++++---- tests/schemas/test_openapi.py | 228 ++++++++++++++++++++++-------- tests/schemas/views.py | 48 +++++++ 4 files changed, 347 insertions(+), 86 deletions(-) diff --git a/docs/api-guide/schemas.md b/docs/api-guide/schemas.md index 5766a6a61c..1d1e09b46a 100644 --- a/docs/api-guide/schemas.md +++ b/docs/api-guide/schemas.md @@ -316,6 +316,65 @@ class CustomSchema(AutoSchema): def get_operation_id(self, path, method): pass +class MyView(APIView): + schema = AutoSchema(component_name="Ulysses") +``` + +### Components + +Since DRF 3.12, Schema uses the [OpenAPI Components](openapi-components). This method defines components in the schema and [references them](openapi-reference) inside request and response objects. By default, the component's name is deduced from the Serializer's name. + +Using OpenAPI's components provides the following advantages: +* The schema is more readable and lightweight. +* If you use the schema to generate an SDK (using [openapi-generator](openapi-generator) or [swagger-codegen](swagger-codegen)). The generator can name your SDK's models. + +### Handling component's schema errors + +You may get the following error while generating the schema: +``` +"Serializer" is an invalid class name for schema generation. +Serializer's class name should be unique and explicit. e.g. "ItemSerializer". +``` + +This error occurs when the Serializer name is "Serializer". You should choose a component's name unique across your schema and different than "Serializer". + +You may also get the following warning: +``` +Schema component "ComponentName" has been overriden with a different value. +``` + +This warning occurs when different components have the same name in one schema. Your component name should be unique across your project. This is likely an error that may lead to an invalid schema. + +You have two ways to solve the previous issues: +* You can rename your serializer with a unique name and another name than "Serializer". +* You can set the `component_name` kwarg parameter of the AutoSchema constructor (see below). +* You can override the `get_component_name` method of the AutoSchema class (see below). + +#### Set a custom component's name for your view + +To override the component's name in your view, you can use the `component_name` parameter of the AutoSchema constructor: + +```python +from rest_framework.schemas.openapi import AutoSchema + +class MyView(APIView): + schema = AutoSchema(component_name="Ulysses") +``` + +#### Override the default implementation + +If you want to have more control and customization about how the schema's components are generated, you can override the `get_component_name` and `get_components` method from the AutoSchema class. + +```python +from rest_framework.schemas.openapi import AutoSchema + +class CustomSchema(AutoSchema): + def get_components(self, path, method): + # Implement your custom implementation + + def get_component_name(self, serializer): + # Implement your custom implementation + class CustomView(APIView): """APIView subclass with custom schema introspection.""" schema = CustomSchema() @@ -326,3 +385,7 @@ class CustomView(APIView): [openapi-operation]: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#operationObject [openapi-tags]: https://swagger.io/specification/#tagObject [openapi-operationid]: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#fixed-fields-17 +[openapi-components]: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#componentsObject +[openapi-reference]: https://github.com/OAI/OpenAPI-Specification/blob/master/versions/3.0.2.md#referenceObject +[openapi-generator]: https://github.com/OpenAPITools/openapi-generator +[swagger-codegen]: https://github.com/swagger-api/swagger-codegen diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index d3a373aaa8..6bed120922 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -1,3 +1,4 @@ +import re import warnings from collections import OrderedDict from decimal import Decimal @@ -65,9 +66,9 @@ def get_schema(self, request=None, public=False): Generate a OpenAPI schema. """ self._initialise_endpoints() + components_schemas = {} # Iterate endpoints generating per method path operations. - # TODO: …and reference components. paths = {} _, view_endpoints = self._get_paths_and_endpoints(None if public else request) for path, method, view in view_endpoints: @@ -75,6 +76,16 @@ def get_schema(self, request=None, public=False): continue operation = view.schema.get_operation(path, method) + components = view.schema.get_components(path, method) + for k in components.keys(): + if k not in components_schemas: + continue + if components_schemas[k] == components[k]: + continue + warnings.warn('Schema component "{}" has been overriden with a different value.'.format(k)) + + components_schemas.update(components) + # Normalise path for any provided mount url. if path.startswith('/'): path = path[1:] @@ -92,6 +103,11 @@ def get_schema(self, request=None, public=False): 'paths': paths, } + if len(components_schemas) > 0: + schema['components'] = { + 'schemas': components_schemas + } + return schema # View Inspectors @@ -99,14 +115,16 @@ def get_schema(self, request=None, public=False): class AutoSchema(ViewInspector): - def __init__(self, operation_id_base=None, tags=None): + def __init__(self, tags=None, operation_id_base=None, component_name=None): """ :param operation_id_base: user-defined name in operationId. If empty, it will be deducted from the Model/Serializer/View name. + :param component_name: user-defined component's name. If empty, it will be deducted from the Serializer's class name. """ if tags and not all(isinstance(tag, str) for tag in tags): raise ValueError('tags must be a list or tuple of string.') self._tags = tags self.operation_id_base = operation_id_base + self.component_name = component_name super().__init__() request_media_types = [] @@ -140,6 +158,43 @@ def get_operation(self, path, method): return operation + def get_component_name(self, serializer): + """ + Compute the component's name from the serializer. + Raise an exception if the serializer's class name is "Serializer" (case-insensitive). + """ + if self.component_name is not None: + return self.component_name + + # use the serializer's class name as the component name. + component_name = serializer.__class__.__name__ + # We remove the "serializer" string from the class name. + pattern = re.compile("serializer", re.IGNORECASE) + component_name = pattern.sub("", component_name) + + if component_name == "": + raise Exception( + '"{}" is an invalid class name for schema generation. ' + 'Serializer\'s class name should be unique and explicit. e.g. "ItemSerializer"' + .format(serializer.__class__.__name__) + ) + + return component_name + + def get_components(self, path, method): + """ + Return components with their properties from the serializer. + """ + serializer = self._get_serializer(path, method) + + if not isinstance(serializer, serializers.Serializer): + return {} + + component_name = self.get_component_name(serializer) + + content = self._map_serializer(serializer) + return {component_name: content} + def get_operation_id_base(self, path, method, action): """ Compute the base part for operation ID from the model, serializer or view name. @@ -434,10 +489,6 @@ def _map_min_max(self, field, content): def _map_serializer(self, serializer): # Assuming we have a valid serializer instance. - # TODO: - # - field is Nested or List serializer. - # - Handle read_only/write_only for request/response differences. - # - could do this with readOnly/writeOnly and then filter dict. required = [] properties = {} @@ -542,6 +593,9 @@ def _get_serializer(self, path, method): .format(view.__class__.__name__, method, path)) return None + def _get_reference(self, serializer): + return {'$ref': '#/components/schemas/{}'.format(self.get_component_name(serializer))} + def _get_request_body(self, path, method): if method not in ('PUT', 'PATCH', 'POST'): return {} @@ -551,20 +605,13 @@ def _get_request_body(self, path, method): serializer = self._get_serializer(path, method) if not isinstance(serializer, serializers.Serializer): - return {} - - content = self._map_serializer(serializer) - # No required fields for PATCH - if method == 'PATCH': - content.pop('required', None) - # No read_only fields for request. - for name, schema in content['properties'].copy().items(): - if 'readOnly' in schema: - del content['properties'][name] + item_schema = {} + else: + item_schema = self._get_reference(serializer) return { 'content': { - ct: {'schema': content} + ct: {'schema': item_schema} for ct in self.request_media_types } } @@ -580,17 +627,12 @@ def _get_responses(self, path, method): self.response_media_types = self.map_renderers(path, method) - item_schema = {} serializer = self._get_serializer(path, method) - if isinstance(serializer, serializers.Serializer): - item_schema = self._map_serializer(serializer) - # No write_only fields for response. - for name, schema in item_schema['properties'].copy().items(): - if 'writeOnly' in schema: - del item_schema['properties'][name] - if 'required' in item_schema: - item_schema['required'] = [f for f in item_schema['required'] if f != name] + if not isinstance(serializer, serializers.Serializer): + item_schema = {} + else: + item_schema = self._get_reference(serializer) if is_list_view(path, method, self.view): response_schema = { diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index b3f30b258b..95101403a3 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -85,12 +85,12 @@ def test_list_field_mapping(self): assert inspector._map_field(field) == mapping def test_lazy_string_field(self): - class Serializer(serializers.Serializer): + class ItemSerializer(serializers.Serializer): text = serializers.CharField(help_text=_('lazy string')) inspector = AutoSchema() - data = inspector._map_serializer(Serializer()) + data = inspector._map_serializer(ItemSerializer()) assert isinstance(data['properties']['text']['description'], str), "description must be str" def test_boolean_default_field(self): @@ -186,6 +186,33 @@ def test_request_body(self): path = '/' method = 'POST' + class ItemSerializer(serializers.Serializer): + text = serializers.CharField() + read_only = serializers.CharField(read_only=True) + + class View(generics.GenericAPIView): + serializer_class = ItemSerializer + + view = create_view( + View, + method, + create_request(path) + ) + inspector = AutoSchema() + inspector.view = view + + request_body = inspector._get_request_body(path, method) + print(request_body) + assert request_body['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' + + components = inspector.get_components(path, method) + assert components['Item']['required'] == ['text'] + assert sorted(list(components['Item']['properties'].keys())) == ['read_only', 'text'] + + def test_invalid_serializer_class_name(self): + path = '/' + method = 'POST' + class Serializer(serializers.Serializer): text = serializers.CharField() read_only = serializers.CharField(read_only=True) @@ -201,20 +228,22 @@ class View(generics.GenericAPIView): inspector = AutoSchema() inspector.view = view - request_body = inspector._get_request_body(path, method) - assert request_body['content']['application/json']['schema']['required'] == ['text'] - assert list(request_body['content']['application/json']['schema']['properties'].keys()) == ['text'] + serializer = inspector._get_serializer(path, method) + + with pytest.raises(Exception) as exc: + inspector.get_component_name(serializer) + assert "is an invalid class name for schema generation" in str(exc.value) def test_empty_required(self): path = '/' method = 'POST' - class Serializer(serializers.Serializer): + class ItemSerializer(serializers.Serializer): read_only = serializers.CharField(read_only=True) write_only = serializers.CharField(write_only=True, required=False) class View(generics.GenericAPIView): - serializer_class = Serializer + serializer_class = ItemSerializer view = create_view( View, @@ -224,23 +253,24 @@ class View(generics.GenericAPIView): inspector = AutoSchema() inspector.view = view - request_body = inspector._get_request_body(path, method) + components = inspector.get_components(path, method) + component = components['Item'] # there should be no empty 'required' property, see #6834 - assert 'required' not in request_body['content']['application/json']['schema'] + assert 'required' not in component for response in inspector._get_responses(path, method).values(): - assert 'required' not in response['content']['application/json']['schema'] + assert 'required' not in component def test_empty_required_with_patch_method(self): path = '/' method = 'PATCH' - class Serializer(serializers.Serializer): + class ItemSerializer(serializers.Serializer): read_only = serializers.CharField(read_only=True) write_only = serializers.CharField(write_only=True, required=False) class View(generics.GenericAPIView): - serializer_class = Serializer + serializer_class = ItemSerializer view = create_view( View, @@ -250,22 +280,23 @@ class View(generics.GenericAPIView): inspector = AutoSchema() inspector.view = view - request_body = inspector._get_request_body(path, method) + components = inspector.get_components(path, method) + component = components['Item'] # there should be no empty 'required' property, see #6834 - assert 'required' not in request_body['content']['application/json']['schema'] + assert 'required' not in component for response in inspector._get_responses(path, method).values(): - assert 'required' not in response['content']['application/json']['schema'] + assert 'required' not in component def test_response_body_generation(self): path = '/' method = 'POST' - class Serializer(serializers.Serializer): + class ItemSerializer(serializers.Serializer): text = serializers.CharField() write_only = serializers.CharField(write_only=True) class View(generics.GenericAPIView): - serializer_class = Serializer + serializer_class = ItemSerializer view = create_view( View, @@ -276,9 +307,11 @@ class View(generics.GenericAPIView): inspector.view = view responses = inspector._get_responses(path, method) - assert '201' in responses - assert responses['201']['content']['application/json']['schema']['required'] == ['text'] - assert list(responses['201']['content']['application/json']['schema']['properties'].keys()) == ['text'] + assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' + + components = inspector.get_components(path, method) + assert sorted(components['Item']['required']) == ['text', 'write_only'] + assert sorted(list(components['Item']['properties'].keys())) == ['text', 'write_only'] assert 'description' in responses['201'] def test_response_body_nested_serializer(self): @@ -288,12 +321,12 @@ def test_response_body_nested_serializer(self): class NestedSerializer(serializers.Serializer): number = serializers.IntegerField() - class Serializer(serializers.Serializer): + class ItemSerializer(serializers.Serializer): text = serializers.CharField() nested = NestedSerializer() class View(generics.GenericAPIView): - serializer_class = Serializer + serializer_class = ItemSerializer view = create_view( View, @@ -304,7 +337,11 @@ class View(generics.GenericAPIView): inspector.view = view responses = inspector._get_responses(path, method) - schema = responses['201']['content']['application/json']['schema'] + assert responses['201']['content']['application/json']['schema']['$ref'] == '#/components/schemas/Item' + components = inspector.get_components(path, method) + assert components['Item'] + + schema = components['Item'] assert sorted(schema['required']) == ['nested', 'text'] assert sorted(list(schema['properties'].keys())) == ['nested', 'text'] assert schema['properties']['nested']['type'] == 'object' @@ -339,19 +376,25 @@ class View(generics.GenericAPIView): 'schema': { 'type': 'array', 'items': { - 'type': 'object', - 'properties': { - 'text': { - 'type': 'string', - }, - }, - 'required': ['text'], + '$ref': '#/components/schemas/Item' }, }, }, }, }, } + components = inspector.get_components(path, method) + assert components == { + 'Item': { + 'type': 'object', + 'properties': { + 'text': { + 'type': 'string', + }, + }, + 'required': ['text'], + } + } def test_paginated_list_response_body_generation(self): """Test that pagination properties are added for a paginated list view.""" @@ -391,13 +434,7 @@ class View(generics.GenericAPIView): 'item': { 'type': 'array', 'items': { - 'type': 'object', - 'properties': { - 'text': { - 'type': 'string', - }, - }, - 'required': ['text'], + '$ref': '#/components/schemas/Item' }, }, }, @@ -405,6 +442,18 @@ class View(generics.GenericAPIView): }, }, } + components = inspector.get_components(path, method) + assert components == { + 'Item': { + 'type': 'object', + 'properties': { + 'text': { + 'type': 'string', + }, + }, + 'required': ['text'], + } + } def test_delete_response_body_generation(self): """Test that a view's delete method generates a proper response body schema.""" @@ -508,10 +557,10 @@ class View(generics.CreateAPIView): inspector = AutoSchema() inspector.view = view - request_body = inspector._get_request_body(path, method) - mp_media = request_body['content']['multipart/form-data'] - attachment = mp_media['schema']['properties']['attachment'] - assert attachment['format'] == 'binary' + components = inspector.get_components(path, method) + component = components['Item'] + properties = component['properties'] + assert properties['attachment']['format'] == 'binary' def test_retrieve_response_body_generation(self): """ @@ -551,19 +600,26 @@ class View(generics.GenericAPIView): 'content': { 'application/json': { 'schema': { - 'type': 'object', - 'properties': { - 'text': { - 'type': 'string', - }, - }, - 'required': ['text'], + '$ref': '#/components/schemas/Item' }, }, }, }, } + components = inspector.get_components(path, method) + assert components == { + 'Item': { + 'type': 'object', + 'properties': { + 'text': { + 'type': 'string', + }, + }, + 'required': ['text'], + } + } + def test_operation_id_generation(self): path = '/' method = 'GET' @@ -689,9 +745,9 @@ def test_serializer_datefield(self): inspector = AutoSchema() inspector.view = view - responses = inspector._get_responses(path, method) - response_schema = responses['200']['content']['application/json']['schema'] - properties = response_schema['items']['properties'] + components = inspector.get_components(path, method) + component = components['Example'] + properties = component['properties'] assert properties['date']['type'] == properties['datetime']['type'] == 'string' assert properties['date']['format'] == 'date' assert properties['datetime']['format'] == 'date-time' @@ -707,9 +763,9 @@ def test_serializer_hstorefield(self): inspector = AutoSchema() inspector.view = view - responses = inspector._get_responses(path, method) - response_schema = responses['200']['content']['application/json']['schema'] - properties = response_schema['items']['properties'] + components = inspector.get_components(path, method) + component = components['Example'] + properties = component['properties'] assert properties['hstore']['type'] == 'object' def test_serializer_callable_default(self): @@ -723,9 +779,9 @@ def test_serializer_callable_default(self): inspector = AutoSchema() inspector.view = view - responses = inspector._get_responses(path, method) - response_schema = responses['200']['content']['application/json']['schema'] - properties = response_schema['items']['properties'] + components = inspector.get_components(path, method) + component = components['Example'] + properties = component['properties'] assert 'default' not in properties['uuid_field'] def test_serializer_validators(self): @@ -739,9 +795,9 @@ def test_serializer_validators(self): inspector = AutoSchema() inspector.view = view - responses = inspector._get_responses(path, method) - response_schema = responses['200']['content']['application/json']['schema'] - properties = response_schema['items']['properties'] + components = inspector.get_components(path, method) + component = components['ExampleValidated'] + properties = component['properties'] assert properties['integer']['type'] == 'integer' assert properties['integer']['maximum'] == 99 @@ -819,6 +875,7 @@ class ExampleStringTagsViewSet(views.ExampleGenericViewSet): def test_auto_generated_apiview_tags(self): class RestaurantAPIView(views.ExampleGenericAPIView): + schema = AutoSchema(operation_id_base="restaurant") pass class BranchAPIView(views.ExampleGenericAPIView): @@ -932,3 +989,54 @@ def test_schema_information_empty(self): assert schema['info']['title'] == '' assert schema['info']['version'] == '' + + def test_serializer_model(self): + """Construction of the top level dictionary.""" + patterns = [ + url(r'^example/?$', views.ExampleGenericAPIViewModel.as_view()), + ] + generator = SchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + + print(schema) + assert 'components' in schema + assert 'schemas' in schema['components'] + assert 'ExampleModel' in schema['components']['schemas'] + + def test_component_name(self): + patterns = [ + url(r'^example/?$', views.ExampleAutoSchemaComponentName.as_view()), + ] + + generator = SchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + + print(schema) + assert 'components' in schema + assert 'schemas' in schema['components'] + assert 'Ulysses' in schema['components']['schemas'] + + def test_duplicate_component_name(self): + patterns = [ + url(r'^duplicate1/?$', views.ExampleAutoSchemaDuplicate1.as_view()), + url(r'^duplicate2/?$', views.ExampleAutoSchemaDuplicate2.as_view()), + ] + + generator = SchemaGenerator(patterns=patterns) + request = create_request('/') + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + schema = generator.get_schema(request=request) + + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert 'has been overriden with a different value.' in str(w[-1].message) + + assert 'components' in schema + assert 'schemas' in schema['components'] + assert 'Duplicate' in schema['components']['schemas'] diff --git a/tests/schemas/views.py b/tests/schemas/views.py index 5835a55727..1c8235b425 100644 --- a/tests/schemas/views.py +++ b/tests/schemas/views.py @@ -9,6 +9,7 @@ from rest_framework import generics, permissions, serializers from rest_framework.decorators import action from rest_framework.response import Response +from rest_framework.schemas.openapi import AutoSchema from rest_framework.views import APIView from rest_framework.viewsets import GenericViewSet @@ -167,3 +168,50 @@ class ExampleOperationIdDuplicate2(generics.GenericAPIView): def get(self, *args, **kwargs): pass + + +class ExampleGenericAPIViewModel(generics.GenericAPIView): + serializer_class = ExampleSerializerModel + + def get(self, *args, **kwargs): + from datetime import datetime + now = datetime.now() + + serializer = self.get_serializer(data=now.date(), datetime=now) + return Response(serializer.data) + + +class ExampleAutoSchemaComponentName(generics.GenericAPIView): + serializer_class = ExampleSerializerModel + schema = AutoSchema(component_name="Ulysses") + + def get(self, *args, **kwargs): + from datetime import datetime + now = datetime.now() + + serializer = self.get_serializer(data=now.date(), datetime=now) + return Response(serializer.data) + + +class ExampleAutoSchemaDuplicate1(generics.GenericAPIView): + serializer_class = ExampleValidatedSerializer + schema = AutoSchema(component_name="Duplicate") + + def get(self, *args, **kwargs): + from datetime import datetime + now = datetime.now() + + serializer = self.get_serializer(data=now.date(), datetime=now) + return Response(serializer.data) + + +class ExampleAutoSchemaDuplicate2(generics.GenericAPIView): + serializer_class = ExampleSerializerModel + schema = AutoSchema(component_name="Duplicate") + + def get(self, *args, **kwargs): + from datetime import datetime + now = datetime.now() + + serializer = self.get_serializer(data=now.date(), datetime=now) + return Response(serializer.data)