Skip to content

add OpenAPI schema initialization #6670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions docs/api-guide/schemas.md
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,33 @@ May be used to pass a canonical URL for the schema.
url='https://www.example.org/api/'
)

#### `openapi_schema`

May be used to pass a static initial OpenAPI schema document, typically
containing top-level OpenAPI fields. The schema document will
be added to by the AutoSchema generator.

schema_view = get_schema_view(
openapi_schema = {
'info': {
'title': 'my title',
'version': '1.0',
'contact': {
'name': 'API Support',
'url': 'http://www.example.com/support',
'email': '[email protected]'
},
'license': {
'name': 'Apache 2.0',
'url': 'https://www.apache.org/licenses/LICENSE-2.0.html'
}.
'servers': [
{'url': 'https://api.example.com'}
]
}
)


#### `urlconf`

A string representing the import path to the URL conf that you want
Expand Down
17 changes: 12 additions & 5 deletions rest_framework/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@


def get_schema_view(
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
title=None, url=None, description=None, openapi_schema=None, urlconf=None, renderer_classes=None,
public=False, patterns=None, generator_class=None,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
Expand All @@ -41,10 +41,17 @@ def get_schema_view(
else:
generator_class = openapi.SchemaGenerator

generator = generator_class(
title=title, url=url, description=description,
urlconf=urlconf, patterns=patterns,
)
if isinstance(generator_class, openapi.SchemaGenerator):
generator = generator_class(
title=title, url=url, description=description,
urlconf=urlconf, patterns=patterns,
openapi_schema=openapi_schema,
)
else:
generator = generator_class(
title=title, url=url, description=description,
urlconf=urlconf, patterns=patterns,
)

# Avoid import cycle on APIView
from .views import SchemaView
Expand Down
2 changes: 1 addition & 1 deletion rest_framework/schemas/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class BaseSchemaGenerator(object):
# Set by 'SCHEMA_COERCE_PATH_PK'.
coerce_path_pk = None

def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None, **kwargs):
if url and not url.endswith('/'):
url += '/'

Expand Down
15 changes: 14 additions & 1 deletion rest_framework/schemas/openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@


class SchemaGenerator(BaseSchemaGenerator):
def __init__(self, **kwargs):
super().__init__(**kwargs)
#: the openapi schema document:
self.openapi_schema = {}

def get_info(self):
info = {
Expand Down Expand Up @@ -43,6 +47,9 @@ def get_paths(self, request=None):
subpath = '/' + path[len(prefix):]
result.setdefault(subpath, {})
result[subpath][method.lower()] = operation
if hasattr(view.schema, 'openapi_schema'):
# TODO: shallow or deep merge?
self.openapi_schema = {**self.openapi_schema, **view.schema.openapi_schema}

return result

Expand All @@ -61,13 +68,19 @@ def get_schema(self, request=None, public=False):
'info': self.get_info(),
'paths': paths,
}
# TODO: shallow or deep merge?
self.openapi_schema = {**schema, **self.openapi_schema}

return schema
return self.openapi_schema

# View Inspectors


class AutoSchema(ViewInspector):
def __init__(self, openapi_schema={}):
super().__init__()
# TODO: call this manual_fields ala coreapi?
self.openapi_schema = openapi_schema

content_types = ['application/json']
method_mapping = {
Expand Down
5 changes: 5 additions & 0 deletions tests/schemas/test_get_schema_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ def test_openapi(self):
assert isinstance(schema_view.initkwargs['schema_generator'], openapi.SchemaGenerator)
assert renderers.OpenAPIRenderer in schema_view.cls().renderer_classes

def test_openapi_initialized(self):
schema_view = get_schema_view(openapi_schema={'info': {'title': 'With OpenAPI'}})
assert isinstance(schema_view.initkwargs['schema_generator'], openapi.SchemaGenerator)
assert renderers.OpenAPIRenderer in schema_view.cls().renderer_classes

@pytest.mark.skipif(not coreapi.coreapi, reason='coreapi is not installed')
def test_coreapi(self):
with override_settings(REST_FRAMEWORK={'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema'}):
Expand Down
23 changes: 23 additions & 0 deletions tests/schemas/test_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,29 @@ def test_schema_construction(self):

assert 'openapi' in schema
assert 'paths' in schema
assert 'info' in schema
assert 'title' in schema['info']
assert 'version' in schema['info']
assert schema['info']['title'] is None
assert schema['info']['version'] == 'TODO'

def test_schema_initializer(self):
"""Construction of top-level dictionary with an initializer."""
class MyListView(views.ExampleListView):
schema = AutoSchema(openapi_schema={'info': {'title': 'mytitle', 'version': 'myversion'}})

patterns = [
url(r'^example/?$', MyListView.as_view()),
]
generator = SchemaGenerator(patterns=patterns)

request = create_request('/')
schema = generator.get_schema(request=request)

assert 'info' in schema
assert 'title' in schema['info']
assert 'version' in schema['info']
assert schema['info']['title'] == 'mytitle' and schema['info']['version'] == 'myversion'

def test_serializer_datefield(self):
patterns = [
Expand Down