diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 6b6324033c..d6b7cf8a8c 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -64,19 +64,19 @@ def __init__(self, title=None, url=None, patterns=None, urlconf=None): def get_schema(self, request=None): if self.endpoints is None: - self.endpoints = self.get_api_endpoints(self.patterns) + endpoints = self.get_api_endpoints(self.patterns) + self.endpoints = self.add_categories(endpoints) links = [] for path, method, category, action, callback in self.endpoints: - view = callback.cls() - for attr, val in getattr(callback, 'initkwargs', {}).items(): - setattr(view, attr, val) + view = self.get_view(callback) view.args = () view.kwargs = {} view.format_kwarg = None if request is not None: view.request = clone_request(request, method) + try: view.check_permissions(view.request) except exceptions.APIException: @@ -128,7 +128,7 @@ def get_api_endpoints(self, patterns, prefix=''): ) api_endpoints.extend(nested_endpoints) - return self.add_categories(api_endpoints) + return api_endpoints def add_categories(self, api_endpoints): """ @@ -144,6 +144,15 @@ def add_categories(self, api_endpoints): for (path, method, action, callback) in api_endpoints ] + def get_view(self, callback): + """ + Return constructed view with respect of overrided attributes by detail_route and list_route + """ + view = callback.cls() + for attr, val in getattr(callback, 'initkwargs', {}).items(): + setattr(view, attr, val) + return view + def get_path(self, path_regex): """ Given a URL conf regex, return a URI template string. @@ -174,9 +183,10 @@ def get_allowed_methods(self, callback): if hasattr(callback, 'actions'): return [method.upper() for method in callback.actions.keys()] + view = self.get_view(callback) return [ method for method in - callback.cls().allowed_methods if method not in ('OPTIONS', 'HEAD') + view.allowed_methods if method not in ('OPTIONS', 'HEAD') ] def get_action(self, path, method, callback): diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 81b796c35a..1b6ee30606 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -9,7 +9,7 @@ from rest_framework.response import Response from rest_framework.routers import DefaultRouter from rest_framework.schemas import SchemaGenerator -from rest_framework.test import APIClient +from rest_framework.test import APIClient, APIRequestFactory from rest_framework.views import APIView from rest_framework.viewsets import ModelViewSet @@ -33,15 +33,25 @@ class AnotherSerializer(serializers.Serializer): d = serializers.CharField(required=False) +class ForbidAll(permissions.BasePermission): + def has_permission(self, request, view): + return False + + class ExampleViewSet(ModelViewSet): pagination_class = ExamplePagination permission_classes = [permissions.IsAuthenticatedOrReadOnly] filter_backends = [filters.OrderingFilter] serializer_class = ExampleSerializer - @detail_route(methods=['post'], serializer_class=AnotherSerializer) + @detail_route(methods=['put', 'post'], + serializer_class=AnotherSerializer) def custom_action(self, request, pk): - return super(ExampleSerializer, self).retrieve(self, request) + return super(ExampleSerializer, self).update(self, request) + + @detail_route(permission_classes=[ForbidAll]) + def forbidden_action(self, request, pk): + return super(ExampleSerializer, self).update(self, request) @list_route() def custom_list_action(self, request): @@ -52,6 +62,15 @@ def get_serializer(self, *args, **kwargs): return super(ExampleViewSet, self).get_serializer(*args, **kwargs) +class RestrictiveViewSet(ModelViewSet): + permission_classes = [ForbidAll] + serializer_class = ExampleSerializer + + @detail_route(methods=['put'], permission_classes=[permissions.AllowAny]) + def allowed_action(self, request): + return super(RestrictiveViewSet, self).update(self, request) + + class ExampleView(APIView): permission_classes = [permissions.IsAuthenticatedOrReadOnly] @@ -67,7 +86,14 @@ def post(self, request, *args, **kwargs): urlpatterns = [ url(r'^', include(router.urls)) ] -urlpatterns2 = [ + +router = DefaultRouter(schema_title='Restrictive API' if coreapi else None) +router.register('example', RestrictiveViewSet, base_name='example') +urlpatterns_restrict = [ + url(r'^', include(router.urls)) +] + +urlpatterns_view = [ url(r'^example-view/$', ExampleView.as_view(), name='example-view') ] @@ -142,6 +168,16 @@ def test_authenticated_request(self): coreapi.Field('pk', required=True, location='path') ] ), + 'custom_action': coreapi.Link( + url='/example/{pk}/custom_action/', + action='put', + encoding='application/json', + fields=[ + coreapi.Field('pk', required=True, location='path'), + coreapi.Field('c', required=True, location='form'), + coreapi.Field('d', required=False, location='form'), + ] + ), 'custom_action': coreapi.Link( url='/example/{pk}/custom_action/', action='post', @@ -189,10 +225,39 @@ def test_authenticated_request(self): self.assertEqual(response.data, expected) +@unittest.skipUnless(coreapi, 'coreapi is not installed') +class TestSchemaForRestrictedMethods(TestCase): + def test_resctricted_methods(self): + schema_generator = SchemaGenerator(title='Restrictive API', patterns=urlpatterns_restrict) + factory = APIRequestFactory() + from rest_framework.request import Request + mock_request = factory.get('/') + schema = schema_generator.get_schema(request=Request(mock_request)) + expected = coreapi.Document( + url='', + title='Restrictive API', + content={ + 'example': { + 'allowed_action': coreapi.Link( + url='/example/{pk}/allowed_action/', + action='put', + encoding='application/json', + fields=[ + coreapi.Field('pk', required=True, location='path'), + coreapi.Field('a', required=True, location='form', description='A field description'), + coreapi.Field('b', required=False, location='form') + ] + ), + } + } + ) + self.assertEqual(schema, expected) + + @unittest.skipUnless(coreapi, 'coreapi is not installed') class TestSchemaGenerator(TestCase): def test_view(self): - schema_generator = SchemaGenerator(title='Test View', patterns=urlpatterns2) + schema_generator = SchemaGenerator(title='Test View', patterns=urlpatterns_view) schema = schema_generator.get_schema() expected = coreapi.Document( url='',