From ac196a73d339433e7d66ad33f64c3c5b71dd596a Mon Sep 17 00:00:00 2001 From: Nik Date: Thu, 11 Aug 2016 21:14:15 +0300 Subject: [PATCH 1/3] Construct view in schema generator with respect of overrides --- rest_framework/schemas.py | 19 ++++++++++++++----- tests/test_schemas.py | 24 ++++++++++++++++++++++-- 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 688deec881..842a0fc649 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -86,9 +86,10 @@ def get_schema(self, request=None): endpoints = [] for key, link, callback in self.endpoints: method = link.action.upper() - view = callback.cls() + view = self.get_view(callback) view.request = clone_request(request, method) view.format_kwarg = None + try: view.check_permissions(view.request) except exceptions.APIException: @@ -135,6 +136,15 @@ def get_api_endpoints(self, patterns, prefix=''): return api_endpoints + def get_view(self, callback): + """ + Return constructed view with respect of overrided attributes by detail_route and list_route + """ + view = callback.cls() + for attr, val in getattr(callback, 'initkwargs', {}).iteritems(): + setattr(view, attr, val) + return view + def get_path(self, path_regex): """ Given a URL conf regex, return a URI template string. @@ -165,9 +175,10 @@ def get_allowed_methods(self, callback): if hasattr(callback, 'actions'): return [method.upper() for method in callback.actions.keys()] + view = self.get_view(callback) return [ method for method in - callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD') + view.allowed_methods if method not in ('OPTIONS', 'HEAD') ] def get_key(self, path, method, callback): @@ -194,9 +205,7 @@ def get_link(self, path, method, callback): """ Return a `coreapi.Link` instance for the given endpoint. """ - view = callback.cls() - for attr, val in getattr(callback, 'initkwargs', {}).items(): - setattr(view, attr, val) + view = self.get_view(callback) fields = self.get_path_fields(path, method, callback, view) fields += self.get_serializer_fields(path, method, callback, view) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 6c02c9d230..5deb5fed8b 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -33,15 +33,25 @@ class AnotherSerializer(serializers.Serializer): d = serializers.CharField(required=False) +class ForbidAll(permissions.BasePermission): + def has_permission(self, request, view): + return False + + class ExampleViewSet(ModelViewSet): pagination_class = ExamplePagination permission_classes = [permissions.IsAuthenticatedOrReadOnly] filter_backends = [filters.OrderingFilter] serializer_class = ExampleSerializer - @detail_route(methods=['post'], serializer_class=AnotherSerializer) + @detail_route(methods=['put', 'post'], + serializer_class=AnotherSerializer) def custom_action(self, request, pk): - return super(ExampleSerializer, self).retrieve(self, request) + return super(ExampleSerializer, self).update(self, request) + + @detail_route(permission_classes=[ForbidAll]) + def forbidden_action(self, request, pk): + return super(ExampleSerializer, self).update(self, request) class ExampleView(APIView): @@ -130,6 +140,16 @@ def test_authenticated_request(self): coreapi.Field('pk', required=True, location='path') ] ), + 'custom_action': coreapi.Link( + url='/example/{pk}/custom_action/', + action='put', + encoding='application/json', + fields=[ + coreapi.Field('pk', required=True, location='path'), + coreapi.Field('c', required=True, location='form'), + coreapi.Field('d', required=False, location='form'), + ] + ), 'custom_action': coreapi.Link( url='/example/{pk}/custom_action/', action='post', From 966934c0a2f5ec2e46f5dc5043709c862802f40d Mon Sep 17 00:00:00 2001 From: Nik Date: Thu, 11 Aug 2016 22:10:41 +0300 Subject: [PATCH 2/3] Add test for scheme allowed methods, fix categories for nested endpoints --- rest_framework/schemas.py | 5 ++-- tests/test_schemas.py | 51 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 51 insertions(+), 5 deletions(-) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index c635c1fd65..26d5e5dac3 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -64,7 +64,8 @@ def __init__(self, title=None, url=None, patterns=None, urlconf=None): def get_schema(self, request=None): if self.endpoints is None: - self.endpoints = self.get_api_endpoints(self.patterns) + endpoints = self.get_api_endpoints(self.patterns) + self.endpoints = self.add_categories(endpoints) links = [] for path, method, category, action, callback in self.endpoints: @@ -127,7 +128,7 @@ def get_api_endpoints(self, patterns, prefix=''): ) api_endpoints.extend(nested_endpoints) - return self.add_categories(api_endpoints) + return api_endpoints def add_categories(self, api_endpoints): """ diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 31ba97f2de..1b6ee30606 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -9,7 +9,7 @@ from rest_framework.response import Response from rest_framework.routers import DefaultRouter from rest_framework.schemas import SchemaGenerator -from rest_framework.test import APIClient +from rest_framework.test import APIClient, APIRequestFactory from rest_framework.views import APIView from rest_framework.viewsets import ModelViewSet @@ -62,6 +62,15 @@ def get_serializer(self, *args, **kwargs): return super(ExampleViewSet, self).get_serializer(*args, **kwargs) +class RestrictiveViewSet(ModelViewSet): + permission_classes = [ForbidAll] + serializer_class = ExampleSerializer + + @detail_route(methods=['put'], permission_classes=[permissions.AllowAny]) + def allowed_action(self, request): + return super(RestrictiveViewSet, self).update(self, request) + + class ExampleView(APIView): permission_classes = [permissions.IsAuthenticatedOrReadOnly] @@ -77,7 +86,14 @@ def post(self, request, *args, **kwargs): urlpatterns = [ url(r'^', include(router.urls)) ] -urlpatterns2 = [ + +router = DefaultRouter(schema_title='Restrictive API' if coreapi else None) +router.register('example', RestrictiveViewSet, base_name='example') +urlpatterns_restrict = [ + url(r'^', include(router.urls)) +] + +urlpatterns_view = [ url(r'^example-view/$', ExampleView.as_view(), name='example-view') ] @@ -209,10 +225,39 @@ def test_authenticated_request(self): self.assertEqual(response.data, expected) +@unittest.skipUnless(coreapi, 'coreapi is not installed') +class TestSchemaForRestrictedMethods(TestCase): + def test_resctricted_methods(self): + schema_generator = SchemaGenerator(title='Restrictive API', patterns=urlpatterns_restrict) + factory = APIRequestFactory() + from rest_framework.request import Request + mock_request = factory.get('/') + schema = schema_generator.get_schema(request=Request(mock_request)) + expected = coreapi.Document( + url='', + title='Restrictive API', + content={ + 'example': { + 'allowed_action': coreapi.Link( + url='/example/{pk}/allowed_action/', + action='put', + encoding='application/json', + fields=[ + coreapi.Field('pk', required=True, location='path'), + coreapi.Field('a', required=True, location='form', description='A field description'), + coreapi.Field('b', required=False, location='form') + ] + ), + } + } + ) + self.assertEqual(schema, expected) + + @unittest.skipUnless(coreapi, 'coreapi is not installed') class TestSchemaGenerator(TestCase): def test_view(self): - schema_generator = SchemaGenerator(title='Test View', patterns=urlpatterns2) + schema_generator = SchemaGenerator(title='Test View', patterns=urlpatterns_view) schema = schema_generator.get_schema() expected = coreapi.Document( url='', From fa6b80ef1880ba1e3ef1e7f2faea467866f914c2 Mon Sep 17 00:00:00 2001 From: Nik Date: Thu, 11 Aug 2016 22:33:39 +0300 Subject: [PATCH 3/3] Python 3 compatibility --- rest_framework/schemas.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 26d5e5dac3..d6b7cf8a8c 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -149,7 +149,7 @@ def get_view(self, callback): Return constructed view with respect of overrided attributes by detail_route and list_route """ view = callback.cls() - for attr, val in getattr(callback, 'initkwargs', {}).iteritems(): + for attr, val in getattr(callback, 'initkwargs', {}).items(): setattr(view, attr, val) return view