From 08c7853655252cb9de0ade24eddfc9762a01cee7 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 1 Aug 2016 14:15:35 +0100 Subject: [PATCH 01/11] Start test case --- rest_framework/test.py | 5 +++++ tests/test_requests_client.py | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 tests/test_requests_client.py diff --git a/rest_framework/test.py b/rest_framework/test.py index 3ba4059a9f..fb08c4a736 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -12,6 +12,7 @@ from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode +from requests import Session from rest_framework.settings import api_settings @@ -221,6 +222,10 @@ class APITransactionTestCase(testcases.TransactionTestCase): class APITestCase(testcases.TestCase): client_class = APIClient + def _pre_setup(self): + super(APITestCase, self)._pre_setup() + self.requests = Session() + class APISimpleTestCase(testcases.SimpleTestCase): client_class = APIClient diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py new file mode 100644 index 0000000000..a36349a3f0 --- /dev/null +++ b/tests/test_requests_client.py @@ -0,0 +1,24 @@ +from __future__ import unicode_literals + +from django.conf.urls import url +from django.test import override_settings + +from rest_framework.response import Response +from rest_framework.test import APITestCase +from rest_framework.views import APIView + + +class Root(APIView): + def get(self, request): + return Response({'hello': 'world'}) + + +urlpatterns = [ + url(r'^$', Root.as_view()), +] + + +@override_settings(ROOT_URLCONF='tests.test_requests_client') +class RequestsClientTests(APITestCase): + def test_get_root(self): + print self.requests.get('http://example.com') From 3d1fff3f26835612be17b6624d766a56520880ec Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 15 Aug 2016 15:42:35 +0100 Subject: [PATCH 02/11] Added 'requests' test client --- rest_framework/request.py | 2 +- rest_framework/test.py | 86 +++++++++++++++++++++++- tests/test_requests_client.py | 119 +++++++++++++++++++++++++++++++++- 3 files changed, 202 insertions(+), 5 deletions(-) diff --git a/rest_framework/request.py b/rest_framework/request.py index aafafcb325..f5738bfd50 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -373,7 +373,7 @@ def POST(self): if not _hasattr(self, '_data'): self._load_data_and_files() if is_form_media_type(self.content_type): - return self.data + return self._data return QueryDict('', encoding=self._request._encoding) @property diff --git a/rest_framework/test.py b/rest_framework/test.py index fb08c4a736..1fd530a0c3 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -4,7 +4,10 @@ # to make it harder for the user to import the wrong thing without realizing. from __future__ import unicode_literals +import io + from django.conf import settings +from django.core.handlers.wsgi import WSGIHandler from django.test import testcases from django.test.client import Client as DjangoClient from django.test.client import RequestFactory as DjangoRequestFactory @@ -13,6 +16,10 @@ from django.utils.encoding import force_bytes from django.utils.http import urlencode from requests import Session +from requests.adapters import BaseAdapter +from requests.models import Response +from requests.structures import CaseInsensitiveDict +from requests.utils import get_encoding_from_headers from rest_framework.settings import api_settings @@ -22,6 +29,83 @@ def force_authenticate(request, user=None, token=None): request._force_auth_token = token +class DjangoTestAdapter(BaseAdapter): + """ + A transport adaptor for `requests`, that makes requests via the + Django WSGI app, rather than making actual HTTP requests ovet the network. + """ + def __init__(self): + self.app = WSGIHandler() + self.factory = DjangoRequestFactory() + + def get_environ(self, request): + """ + Given a `requests.PreparedRequest` instance, return a WSGI environ dict. + """ + method = request.method + url = request.url + kwargs = {} + + # Set request content, if any exists. + if request.body is not None: + kwargs['data'] = request.body + if 'content-type' in request.headers: + kwargs['content_type'] = request.headers['content-type'] + + # Set request headers. + for key, value in request.headers.items(): + key = key.upper() + if key in ('CONNECTION', 'CONTENT_LENGTH', 'CONTENT-TYPE'): + continue + kwargs['HTTP_%s' % key] = value + + return self.factory.generic(method, url, **kwargs).environ + + def send(self, request, *args, **kwargs): + """ + Make an outgoing request to the Django WSGI application. + """ + response = Response() + + def start_response(status, headers): + status_code, _, reason_phrase = status.partition(' ') + response.status_code = int(status_code) + response.reason = reason_phrase + response.headers = CaseInsensitiveDict(headers) + response.encoding = get_encoding_from_headers(response.headers) + + environ = self.get_environ(request) + raw_bytes = self.app(environ, start_response) + + response.request = request + response.url = request.url + response.raw = io.BytesIO(b''.join(raw_bytes)) + + return response + + def close(self): + pass + + +class DjangoTestSession(Session): + def __init__(self, *args, **kwargs): + super(DjangoTestSession, self).__init__(*args, **kwargs) + + adapter = DjangoTestAdapter() + hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] + + for hostname in hostnames: + if hostname == '*': + hostname = '' + self.mount('http://%s' % hostname, adapter) + self.mount('https://%s' % hostname, adapter) + + def request(self, method, url, *args, **kwargs): + if ':' not in url: + url = 'http://testserver/' + url.lstrip('/') + return super(DjangoTestSession, self).request(method, url, *args, **kwargs) + + class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT @@ -224,7 +308,7 @@ class APITestCase(testcases.TestCase): def _pre_setup(self): super(APITestCase, self)._pre_setup() - self.requests = Session() + self.requests = DjangoTestSession() class APISimpleTestCase(testcases.SimpleTestCase): diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index a36349a3f0..0687dd92e1 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -10,15 +10,128 @@ class Root(APIView): def get(self, request): - return Response({'hello': 'world'}) + return Response({ + 'method': request.method, + 'query_params': request.query_params, + }) + + def post(self, request): + files = { + key: (value.name, value.read()) + for key, value in request.FILES.items() + } + post = request.POST + json = None + if request.META.get('CONTENT_TYPE') == 'application/json': + json = request.data + + return Response({ + 'method': request.method, + 'query_params': request.query_params, + 'POST': post, + 'FILES': files, + 'JSON': json + }) + + +class Headers(APIView): + def get(self, request): + headers = { + key[5:]: value + for key, value in request.META.items() + if key.startswith('HTTP_') + } + return Response({ + 'method': request.method, + 'headers': headers + }) urlpatterns = [ url(r'^$', Root.as_view()), + url(r'^headers/$', Headers.as_view()), ] @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): - def test_get_root(self): - print self.requests.get('http://example.com') + def test_get_request(self): + response = self.requests.get('/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {} + } + assert response.json() == expected + + def test_get_request_query_params_in_url(self): + response = self.requests.get('/?key=value') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {'key': 'value'} + } + assert response.json() == expected + + def test_get_request_query_params_by_kwarg(self): + response = self.requests.get('/', params={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {'key': 'value'} + } + assert response.json() == expected + + def test_get_with_headers(self): + response = self.requests.get('/headers/', headers={'User-Agent': 'example'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + headers = response.json()['headers'] + assert headers['USER-AGENT'] == 'example' + + def test_post_form_request(self): + response = self.requests.post('/', data={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'POST': {'key': 'value'}, + 'FILES': {}, + 'JSON': None + } + assert response.json() == expected + + def test_post_json_request(self): + response = self.requests.post('/', json={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'POST': {}, + 'FILES': {}, + 'JSON': {'key': 'value'} + } + assert response.json() == expected + + def test_post_multipart_request(self): + files = { + 'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n') + } + response = self.requests.post('/', files=files) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']}, + 'POST': {}, + 'JSON': None + } + assert response.json() == expected + + # cookies/session auth From e76ca6eb8838148ccb1c25a2a8a735a42f644d99 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 15 Aug 2016 16:06:04 +0100 Subject: [PATCH 03/11] Address typos --- rest_framework/test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index 1fd530a0c3..e1d8eff82e 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -31,8 +31,8 @@ def force_authenticate(request, user=None, token=None): class DjangoTestAdapter(BaseAdapter): """ - A transport adaptor for `requests`, that makes requests via the - Django WSGI app, rather than making actual HTTP requests ovet the network. + A transport adapter for `requests`, that makes requests via the + Django WSGI app, rather than making actual HTTP requests over the network. """ def __init__(self): self.app = WSGIHandler() @@ -55,9 +55,9 @@ def get_environ(self, request): # Set request headers. for key, value in request.headers.items(): key = key.upper() - if key in ('CONNECTION', 'CONTENT_LENGTH', 'CONTENT-TYPE'): + if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): continue - kwargs['HTTP_%s' % key] = value + kwargs['HTTP_%s' % key.replace('-', '_')] = value return self.factory.generic(method, url, **kwargs).environ From 6ede654315e415362f4c7c8e38a3f641039bfc46 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 12:11:01 +0100 Subject: [PATCH 04/11] Graceful fallback if requests is not installed. --- rest_framework/compat.py | 7 ++ rest_framework/test.py | 161 +++++++++++++++++----------------- tests/test_requests_client.py | 6 +- 3 files changed, 92 insertions(+), 82 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index cee430a84b..bda346fa82 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -178,6 +178,13 @@ def value_from_object(field, obj): uritemplate = None +# requests is optional +try: + import requests +except ImportError: + requests = None + + # Django-guardian is optional. Import only if guardian is in INSTALLED_APPS # Fixes (#1712). We keep the try/except for the test suite. guardian = None diff --git a/rest_framework/test.py b/rest_framework/test.py index e1d8eff82e..eba4b96cfa 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -15,12 +15,8 @@ from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode -from requests import Session -from requests.adapters import BaseAdapter -from requests.models import Response -from requests.structures import CaseInsensitiveDict -from requests.utils import get_encoding_from_headers +from rest_framework.compat import requests from rest_framework.settings import api_settings @@ -29,81 +25,81 @@ def force_authenticate(request, user=None, token=None): request._force_auth_token = token -class DjangoTestAdapter(BaseAdapter): - """ - A transport adapter for `requests`, that makes requests via the - Django WSGI app, rather than making actual HTTP requests over the network. - """ - def __init__(self): - self.app = WSGIHandler() - self.factory = DjangoRequestFactory() - - def get_environ(self, request): - """ - Given a `requests.PreparedRequest` instance, return a WSGI environ dict. - """ - method = request.method - url = request.url - kwargs = {} - - # Set request content, if any exists. - if request.body is not None: - kwargs['data'] = request.body - if 'content-type' in request.headers: - kwargs['content_type'] = request.headers['content-type'] - - # Set request headers. - for key, value in request.headers.items(): - key = key.upper() - if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): - continue - kwargs['HTTP_%s' % key.replace('-', '_')] = value - - return self.factory.generic(method, url, **kwargs).environ - - def send(self, request, *args, **kwargs): +if requests is not None: + class DjangoTestAdapter(requests.adapters.BaseAdapter): """ - Make an outgoing request to the Django WSGI application. + A transport adapter for `requests`, that makes requests via the + Django WSGI app, rather than making actual HTTP requests over the network. """ - response = Response() - - def start_response(status, headers): - status_code, _, reason_phrase = status.partition(' ') - response.status_code = int(status_code) - response.reason = reason_phrase - response.headers = CaseInsensitiveDict(headers) - response.encoding = get_encoding_from_headers(response.headers) - - environ = self.get_environ(request) - raw_bytes = self.app(environ, start_response) - - response.request = request - response.url = request.url - response.raw = io.BytesIO(b''.join(raw_bytes)) - - return response - - def close(self): - pass - - -class DjangoTestSession(Session): - def __init__(self, *args, **kwargs): - super(DjangoTestSession, self).__init__(*args, **kwargs) - - adapter = DjangoTestAdapter() - hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] - - for hostname in hostnames: - if hostname == '*': - hostname = '' - self.mount('http://%s' % hostname, adapter) - self.mount('https://%s' % hostname, adapter) - - def request(self, method, url, *args, **kwargs): - if ':' not in url: - url = 'http://testserver/' + url.lstrip('/') - return super(DjangoTestSession, self).request(method, url, *args, **kwargs) + def __init__(self): + self.app = WSGIHandler() + self.factory = DjangoRequestFactory() + + def get_environ(self, request): + """ + Given a `requests.PreparedRequest` instance, return a WSGI environ dict. + """ + method = request.method + url = request.url + kwargs = {} + + # Set request content, if any exists. + if request.body is not None: + kwargs['data'] = request.body + if 'content-type' in request.headers: + kwargs['content_type'] = request.headers['content-type'] + + # Set request headers. + for key, value in request.headers.items(): + key = key.upper() + if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): + continue + kwargs['HTTP_%s' % key.replace('-', '_')] = value + + return self.factory.generic(method, url, **kwargs).environ + + def send(self, request, *args, **kwargs): + """ + Make an outgoing request to the Django WSGI application. + """ + response = requests.models.Response() + + def start_response(status, headers): + status_code, _, reason_phrase = status.partition(' ') + response.status_code = int(status_code) + response.reason = reason_phrase + response.headers = requests.structures.CaseInsensitiveDict(headers) + response.encoding = requests.utils.get_encoding_from_headers(response.headers) + + environ = self.get_environ(request) + raw_bytes = self.app(environ, start_response) + + response.request = request + response.url = request.url + response.raw = io.BytesIO(b''.join(raw_bytes)) + + return response + + def close(self): + pass + + class DjangoTestSession(requests.Session): + def __init__(self, *args, **kwargs): + super(DjangoTestSession, self).__init__(*args, **kwargs) + + adapter = DjangoTestAdapter() + hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] + + for hostname in hostnames: + if hostname == '*': + hostname = '' + self.mount('http://%s' % hostname, adapter) + self.mount('https://%s' % hostname, adapter) + + def request(self, method, url, *args, **kwargs): + if ':' not in url: + url = 'http://testserver/' + url.lstrip('/') + return super(DjangoTestSession, self).request(method, url, *args, **kwargs) class APIRequestFactory(DjangoRequestFactory): @@ -306,9 +302,12 @@ class APITransactionTestCase(testcases.TransactionTestCase): class APITestCase(testcases.TestCase): client_class = APIClient - def _pre_setup(self): - super(APITestCase, self)._pre_setup() - self.requests = DjangoTestSession() + @property + def requests(self): + if not hasattr(self, '_requests'): + assert requests is not None, 'requests must be installed' + self._requests = DjangoTestSession() + return self._requests class APISimpleTestCase(testcases.SimpleTestCase): diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 0687dd92e1..24e29d3b86 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -1,8 +1,11 @@ from __future__ import unicode_literals +import unittest + from django.conf.urls import url from django.test import override_settings +from rest_framework.compat import requests from rest_framework.response import Response from rest_framework.test import APITestCase from rest_framework.views import APIView @@ -37,7 +40,7 @@ def post(self, request): class Headers(APIView): def get(self, request): headers = { - key[5:]: value + key[5:].replace('_', '-'): value for key, value in request.META.items() if key.startswith('HTTP_') } @@ -53,6 +56,7 @@ def get(self, request): ] +@unittest.skipUnless(requests, 'requests not installed') @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): def test_get_request(self): From 049a39e060ab8bbd028ffaa0fb2f5104a71f400d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 15:43:12 +0100 Subject: [PATCH 05/11] Add cookie support --- rest_framework/test.py | 41 +++++++++++++------ tests/test_requests_client.py | 76 ++++++++++++++++++++++++++++++++++- 2 files changed, 102 insertions(+), 15 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index eba4b96cfa..bc8ecc5db6 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -26,7 +26,7 @@ def force_authenticate(request, user=None, token=None): if requests is not None: - class DjangoTestAdapter(requests.adapters.BaseAdapter): + class DjangoTestAdapter(requests.adapters.HTTPAdapter): """ A transport adapter for `requests`, that makes requests via the Django WSGI app, rather than making actual HTTP requests over the network. @@ -62,23 +62,38 @@ def send(self, request, *args, **kwargs): """ Make an outgoing request to the Django WSGI application. """ - response = requests.models.Response() + raw_kwargs = {} - def start_response(status, headers): - status_code, _, reason_phrase = status.partition(' ') - response.status_code = int(status_code) - response.reason = reason_phrase - response.headers = requests.structures.CaseInsensitiveDict(headers) - response.encoding = requests.utils.get_encoding_from_headers(response.headers) + def start_response(wsgi_status, wsgi_headers): + class MockOriginalResponse(object): + def __init__(self, headers): + self.msg = requests.packages.urllib3._collections.HTTPHeaderDict(headers) + self.closed = False + def isclosed(self): + return self.closed + + def close(self): + self.closed = True + + status, _, reason = wsgi_status.partition(' ') + raw_kwargs['status'] = int(status) + raw_kwargs['reason'] = reason + raw_kwargs['headers'] = wsgi_headers + raw_kwargs['version'] = 11 + raw_kwargs['preload_content'] = False + raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers) + + # Make the outgoing request via WSGI. environ = self.get_environ(request) - raw_bytes = self.app(environ, start_response) + wsgi_response = self.app(environ, start_response) - response.request = request - response.url = request.url - response.raw = io.BytesIO(b''.join(raw_bytes)) + # Build the underlying urllib3.HTTPResponse + raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response)) + raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) - return response + # Build the requests.Response + return self.build_response(request, raw) def close(self): pass diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 24e29d3b86..10158efa7d 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -37,7 +37,7 @@ def post(self, request): }) -class Headers(APIView): +class HeadersView(APIView): def get(self, request): headers = { key[5:].replace('_', '-'): value @@ -50,9 +50,32 @@ def get(self, request): }) +class SessionView(APIView): + def get(self, request): + return Response({ + key: value for key, value in request.session.items() + }) + + def post(self, request): + for key, value in request.data.items(): + request.session[key] = value + return Response({ + key: value for key, value in request.session.items() + }) + + +class CookiesView(APIView): + def get(self, request): + return Response({ + key: value for key, value in request.COOKIES.items() + }) + + urlpatterns = [ url(r'^$', Root.as_view()), - url(r'^headers/$', Headers.as_view()), + url(r'^headers/$', HeadersView.as_view()), + url(r'^session/$', SessionView.as_view()), + url(r'^cookies/$', CookiesView.as_view()), ] @@ -138,4 +161,53 @@ def test_post_multipart_request(self): } assert response.json() == expected + def test_session(self): + response = self.requests.get('/session/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {} + assert response.json() == expected + + response = self.requests.post('/session/', json={'example': 'abc'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {'example': 'abc'} + assert response.json() == expected + + response = self.requests.get('/session/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {'example': 'abc'} + assert response.json() == expected + + def test_cookies(self): + """ + Test for explicitly setting a cookie. + """ + my_cookie = { + "version": 0, + "name": 'COOKIE_NAME', + "value": 'COOKIE_VALUE', + "port": None, + # "port_specified":False, + "domain": 'testserver.local', + # "domain_specified":False, + # "domain_initial_dot":False, + "path": '/', + # "path_specified":True, + "secure": False, + "expires": None, + "discard": True, + "comment": None, + "comment_url": None, + "rest": {}, + "rfc2109": False + } + self.requests.cookies.set(**my_cookie) + response = self.requests.get('/cookies/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {'COOKIE_NAME': 'COOKIE_VALUE'} + assert response.json() == expected + # cookies/session auth From 64e19c738fce463df6fafd8f377ecc702f068813 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 17:54:03 +0100 Subject: [PATCH 06/11] Tests for auth and CSRF --- tests/test_requests_client.py | 85 ++++++++++++++++++++++------------- 1 file changed, 55 insertions(+), 30 deletions(-) diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 10158efa7d..aa99a71da5 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -3,9 +3,14 @@ import unittest from django.conf.urls import url +from django.contrib.auth import authenticate, login +from django.contrib.auth.models import User +from django.shortcuts import redirect from django.test import override_settings +from django.utils.decorators import method_decorator +from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie -from rest_framework.compat import requests +from rest_framework.compat import is_authenticated, requests from rest_framework.response import Response from rest_framework.test import APITestCase from rest_framework.views import APIView @@ -64,18 +69,33 @@ def post(self, request): }) -class CookiesView(APIView): +class AuthView(APIView): + @method_decorator(ensure_csrf_cookie) def get(self, request): + if is_authenticated(request.user): + username = request.user.username + else: + username = None return Response({ - key: value for key, value in request.COOKIES.items() + 'username': username }) + @method_decorator(csrf_protect) + def post(self, request): + username = request.data['username'] + password = request.data['password'] + user = authenticate(username=username, password=password) + if user is None: + return Response({'error': 'incorrect credentials'}) + login(request, user) + return redirect('/auth/') + urlpatterns = [ url(r'^$', Root.as_view()), url(r'^headers/$', HeadersView.as_view()), url(r'^session/$', SessionView.as_view()), - url(r'^cookies/$', CookiesView.as_view()), + url(r'^auth/$', AuthView.as_view()), ] @@ -180,34 +200,39 @@ def test_session(self): expected = {'example': 'abc'} assert response.json() == expected - def test_cookies(self): - """ - Test for explicitly setting a cookie. - """ - my_cookie = { - "version": 0, - "name": 'COOKIE_NAME', - "value": 'COOKIE_VALUE', - "port": None, - # "port_specified":False, - "domain": 'testserver.local', - # "domain_specified":False, - # "domain_initial_dot":False, - "path": '/', - # "path_specified":True, - "secure": False, - "expires": None, - "discard": True, - "comment": None, - "comment_url": None, - "rest": {}, - "rfc2109": False + def test_auth(self): + # Confirm session is not authenticated + response = self.requests.get('/auth/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'username': None } - self.requests.cookies.set(**my_cookie) - response = self.requests.get('/cookies/') + assert response.json() == expected + assert 'csrftoken' in response.cookies + csrftoken = response.cookies['csrftoken'] + + user = User.objects.create(username='tom') + user.set_password('password') + user.save() + + # Perform a login + response = self.requests.post('/auth/', json={ + 'username': 'tom', + 'password': 'password' + }, headers={'X-CSRFToken': csrftoken}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' - expected = {'COOKIE_NAME': 'COOKIE_VALUE'} + expected = { + 'username': 'tom' + } assert response.json() == expected - # cookies/session auth + # Confirm session is authenticated + response = self.requests.get('/auth/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'username': 'tom' + } + assert response.json() == expected From da47c345c09eaf4fffca9cfcd18bba7acd844e8b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 18:09:19 +0100 Subject: [PATCH 07/11] Py3 compat --- rest_framework/test.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index bc8ecc5db6..a95d185376 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -26,6 +26,21 @@ def force_authenticate(request, user=None, token=None): if requests is not None: + class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): + def get_all(self, key): + return self.getheaders(self, key) + + class MockOriginalResponse(object): + def __init__(self, headers): + self.msg = HeaderDict(headers) + self.closed = False + + def isclosed(self): + return self.closed + + def close(self): + self.closed = True + class DjangoTestAdapter(requests.adapters.HTTPAdapter): """ A transport adapter for `requests`, that makes requests via the @@ -65,17 +80,6 @@ def send(self, request, *args, **kwargs): raw_kwargs = {} def start_response(wsgi_status, wsgi_headers): - class MockOriginalResponse(object): - def __init__(self, headers): - self.msg = requests.packages.urllib3._collections.HTTPHeaderDict(headers) - self.closed = False - - def isclosed(self): - return self.closed - - def close(self): - self.closed = True - status, _, reason = wsgi_status.partition(' ') raw_kwargs['status'] = int(status) raw_kwargs['reason'] = reason From 53117698e042691a7045688d41edae9cc118643b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 18:47:01 +0100 Subject: [PATCH 08/11] py3 compat --- rest_framework/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index a95d185376..bf22ff08d5 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -27,7 +27,7 @@ def force_authenticate(request, user=None, token=None): if requests is not None: class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): - def get_all(self, key): + def get_all(self, key, default): return self.getheaders(self, key) class MockOriginalResponse(object): From 0b3db028a2a7a0a91c3111fc6febbdbcd9cbd6b5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 18:50:02 +0100 Subject: [PATCH 09/11] py3 compat --- rest_framework/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index bf22ff08d5..ded9d5fe93 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -28,7 +28,7 @@ def force_authenticate(request, user=None, token=None): if requests is not None: class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): def get_all(self, key, default): - return self.getheaders(self, key) + return self.getheaders(key) class MockOriginalResponse(object): def __init__(self, headers): From 0cc3f5008fcd41d1597776c43dc57502ed2a7542 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 18 Aug 2016 15:34:19 +0100 Subject: [PATCH 10/11] Add get_requests_client --- rest_framework/test.py | 12 +++++------- tests/test_requests_client.py | 37 ++++++++++++++++++++++------------- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index ded9d5fe93..e17c19a43f 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -121,6 +121,11 @@ def request(self, method, url, *args, **kwargs): return super(DjangoTestSession, self).request(method, url, *args, **kwargs) +def get_requests_client(): + assert requests is not None, 'requests must be installed' + return DjangoTestSession() + + class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT @@ -321,13 +326,6 @@ class APITransactionTestCase(testcases.TransactionTestCase): class APITestCase(testcases.TestCase): client_class = APIClient - @property - def requests(self): - if not hasattr(self, '_requests'): - assert requests is not None, 'requests must be installed' - self._requests = DjangoTestSession() - return self._requests - class APISimpleTestCase(testcases.SimpleTestCase): client_class = APIClient diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index aa99a71da5..37bde10922 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -12,7 +12,7 @@ from rest_framework.compat import is_authenticated, requests from rest_framework.response import Response -from rest_framework.test import APITestCase +from rest_framework.test import APITestCase, get_requests_client from rest_framework.views import APIView @@ -103,7 +103,8 @@ def post(self, request): @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): def test_get_request(self): - response = self.requests.get('/') + client = get_requests_client() + response = client.get('/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -113,7 +114,8 @@ def test_get_request(self): assert response.json() == expected def test_get_request_query_params_in_url(self): - response = self.requests.get('/?key=value') + client = get_requests_client() + response = client.get('/?key=value') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -123,7 +125,8 @@ def test_get_request_query_params_in_url(self): assert response.json() == expected def test_get_request_query_params_by_kwarg(self): - response = self.requests.get('/', params={'key': 'value'}) + client = get_requests_client() + response = client.get('/', params={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -133,14 +136,16 @@ def test_get_request_query_params_by_kwarg(self): assert response.json() == expected def test_get_with_headers(self): - response = self.requests.get('/headers/', headers={'User-Agent': 'example'}) + client = get_requests_client() + response = client.get('/headers/', headers={'User-Agent': 'example'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' headers = response.json()['headers'] assert headers['USER-AGENT'] == 'example' def test_post_form_request(self): - response = self.requests.post('/', data={'key': 'value'}) + client = get_requests_client() + response = client.post('/', data={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -153,7 +158,8 @@ def test_post_form_request(self): assert response.json() == expected def test_post_json_request(self): - response = self.requests.post('/', json={'key': 'value'}) + client = get_requests_client() + response = client.post('/', json={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -166,10 +172,11 @@ def test_post_json_request(self): assert response.json() == expected def test_post_multipart_request(self): + client = get_requests_client() files = { 'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n') } - response = self.requests.post('/', files=files) + response = client.post('/', files=files) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -182,19 +189,20 @@ def test_post_multipart_request(self): assert response.json() == expected def test_session(self): - response = self.requests.get('/session/') + client = get_requests_client() + response = client.get('/session/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {} assert response.json() == expected - response = self.requests.post('/session/', json={'example': 'abc'}) + response = client.post('/session/', json={'example': 'abc'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {'example': 'abc'} assert response.json() == expected - response = self.requests.get('/session/') + response = client.get('/session/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {'example': 'abc'} @@ -202,7 +210,8 @@ def test_session(self): def test_auth(self): # Confirm session is not authenticated - response = self.requests.get('/auth/') + client = get_requests_client() + response = client.get('/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -217,7 +226,7 @@ def test_auth(self): user.save() # Perform a login - response = self.requests.post('/auth/', json={ + response = client.post('/auth/', json={ 'username': 'tom', 'password': 'password' }, headers={'X-CSRFToken': csrftoken}) @@ -229,7 +238,7 @@ def test_auth(self): assert response.json() == expected # Confirm session is authenticated - response = self.requests.get('/auth/') + response = client.get('/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { From 37b3475e5d048f7f7e751dee6bdb23695fe484bd Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 29 Sep 2016 09:46:28 +0100 Subject: [PATCH 11/11] API client (#4424) --- README.md | 2 +- docs/api-guide/permissions.md | 7 +- docs/topics/release-notes.md | 25 + requirements/requirements-optionals.txt | 2 +- rest_framework/__init__.py | 2 +- rest_framework/fields.py | 2 +- rest_framework/relations.py | 6 +- rest_framework/renderers.py | 9 +- rest_framework/request.py | 5 + rest_framework/schemas.py | 7 + rest_framework/serializers.py | 8 +- .../static/rest_framework/js/csrf.js | 2 +- .../templates/rest_framework/admin.html | 1 + .../templates/rest_framework/base.html | 1 + .../vertical/checkbox_multiple.html | 4 +- rest_framework/test.py | 15 +- rest_framework/views.py | 20 +- tests/test_api_client.py | 452 ++++++++++++++++++ tests/test_request.py | 11 + tests/test_schemas.py | 1 + 20 files changed, 557 insertions(+), 25 deletions(-) create mode 100644 tests/test_api_client.py diff --git a/README.md b/README.md index 179f2891a8..e1e2526091 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ You may also want to [follow the author on Twitter][twitter]. # Security -If you believe you’ve found something in Django REST framework which has security implications, please **do not raise the issue in a public forum**. +If you believe you've found something in Django REST framework which has security implications, please **do not raise the issue in a public forum**. Send a description of the issue via email to [rest-framework-security@googlegroups.com][security-mail]. The project maintainers will then work with you to resolve any issues where required, prior to any public disclosure. diff --git a/docs/api-guide/permissions.md b/docs/api-guide/permissions.md index e0838e94a9..7cdb595313 100644 --- a/docs/api-guide/permissions.md +++ b/docs/api-guide/permissions.md @@ -92,7 +92,7 @@ Or, if you're using the `@api_view` decorator with function based views. from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response - @api_view('GET') + @api_view(['GET']) @permission_classes((IsAuthenticated, )) def example_view(request, format=None): content = { @@ -261,6 +261,10 @@ The [REST Condition][rest-condition] package is another extension for building c The [DRY Rest Permissions][dry-rest-permissions] package provides the ability to define different permissions for individual default and custom actions. This package is made for apps with permissions that are derived from relationships defined in the app's data model. It also supports permission checks being returned to a client app through the API's serializer. Additionally it supports adding permissions to the default and custom list actions to restrict the data they retrive per user. +## Django Rest Framework Roles + +The [Django Rest Framework Roles][django-rest-framework-roles] package makes it easier to parameterize your API over multiple types of users. + [cite]: https://developer.apple.com/library/mac/#documentation/security/Conceptual/AuthenticationAndAuthorizationGuide/Authorization/Authorization.html [authentication]: authentication.md [throttling]: throttling.md @@ -275,3 +279,4 @@ The [DRY Rest Permissions][dry-rest-permissions] package provides the ability to [composed-permissions]: https://github.com/niwibe/djangorestframework-composed-permissions [rest-condition]: https://github.com/caxap/rest_condition [dry-rest-permissions]: https://github.com/Helioscene/dry-rest-permissions +[django-rest-framework-roles]: https://github.com/computer-lab/django-rest-framework-roles diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md index 78a5a8ba90..24728a252e 100644 --- a/docs/topics/release-notes.md +++ b/docs/topics/release-notes.md @@ -40,6 +40,18 @@ You can determine your currently installed version using `pip freeze`: ## 3.4.x series +### 3.4.5 + +**Date**: [19th August 2016][3.4.5-milestone] + +* Improve debug error handling. ([#4416][gh4416], [#4409][gh4409]) +* Allow custom CSRF_HEADER_NAME setting. ([#4415][gh4415], [#4410][gh4410]) +* Include .action attribute on viewsets when generating schemas. ([#4408][gh4408], [#4398][gh4398]) +* Do not include request.FILES items in request.POST. ([#4407][gh4407]) +* Fix rendering of checkbox multiple. ([#4403][gh4403]) +* Fix docstring of Field.get_default. ([#4404][gh4404]) +* Replace utf8 character with its ascii counterpart in README. ([#4412][gh4412]) + ### 3.4.4 **Date**: [12th August 2016][3.4.4-milestone] @@ -560,6 +572,7 @@ For older release notes, [please see the version 2.x documentation][old-release- [3.4.2-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.2+Release%22 [3.4.3-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.3+Release%22 [3.4.4-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.4+Release%22 +[3.4.5-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.5+Release%22 [gh2013]: https://github.com/tomchristie/django-rest-framework/issues/2013 @@ -1065,3 +1078,15 @@ For older release notes, [please see the version 2.x documentation][old-release- [gh4392]: https://github.com/tomchristie/django-rest-framework/issues/4392 [gh4393]: https://github.com/tomchristie/django-rest-framework/issues/4393 [gh4394]: https://github.com/tomchristie/django-rest-framework/issues/4394 + + +[gh4416]: https://github.com/tomchristie/django-rest-framework/issues/4416 +[gh4409]: https://github.com/tomchristie/django-rest-framework/issues/4409 +[gh4415]: https://github.com/tomchristie/django-rest-framework/issues/4415 +[gh4410]: https://github.com/tomchristie/django-rest-framework/issues/4410 +[gh4408]: https://github.com/tomchristie/django-rest-framework/issues/4408 +[gh4398]: https://github.com/tomchristie/django-rest-framework/issues/4398 +[gh4407]: https://github.com/tomchristie/django-rest-framework/issues/4407 +[gh4403]: https://github.com/tomchristie/django-rest-framework/issues/4403 +[gh4404]: https://github.com/tomchristie/django-rest-framework/issues/4404 +[gh4412]: https://github.com/tomchristie/django-rest-framework/issues/4412 diff --git a/requirements/requirements-optionals.txt b/requirements/requirements-optionals.txt index 20436e6b4c..afade0aa0c 100644 --- a/requirements/requirements-optionals.txt +++ b/requirements/requirements-optionals.txt @@ -2,4 +2,4 @@ markdown==2.6.4 django-guardian==1.4.3 django-filter==0.13.0 -coreapi==1.32.0 +coreapi==2.0.0 diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 999c5de315..3f8736c258 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -8,7 +8,7 @@ """ __title__ = 'Django REST framework' -__version__ = '3.4.4' +__version__ = '3.4.5' __author__ = 'Tom Christie' __license__ = 'BSD 2-Clause' __copyright__ = 'Copyright 2011-2016 Tom Christie' diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 8f12b2df48..f76e4e8011 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -432,7 +432,7 @@ def get_default(self): is provided for this field. If a default has not been set for this field then this will simply - return `empty`, indicating that no value should be set in the + raise `SkipField`, indicating that no value should be set in the validated data for this field. """ if self.default is empty or getattr(self.root, 'partial', False): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 4b6b3bea45..65c4c03187 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -10,7 +10,7 @@ from django.db.models import Manager from django.db.models.query import QuerySet from django.utils import six -from django.utils.encoding import smart_text +from django.utils.encoding import python_2_unicode_compatible, smart_text from django.utils.six.moves.urllib import parse as urlparse from django.utils.translation import ugettext_lazy as _ @@ -47,6 +47,7 @@ def __getnewargs__(self): is_hyperlink = True +@python_2_unicode_compatible class PKOnlyObject(object): """ This is a mock object, used for when we only need the pk of the object @@ -56,6 +57,9 @@ class PKOnlyObject(object): def __init__(self, pk): self.pk = pk + def __str__(self): + return "%s" % self.pk + # We assume that 'validators' are intended for the child serializer, # rather than the parent serializer. diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 371cd6ec77..11e9fb9607 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -645,6 +645,12 @@ def get_context(self, data, accepted_media_type, renderer_context): else: paginator = None + csrf_cookie_name = settings.CSRF_COOKIE_NAME + csrf_header_name = getattr(settings, 'CSRF_HEADER_NAME', 'HTTP_X_CSRFToken') # Fallback for Django 1.8 + if csrf_header_name.startswith('HTTP_'): + csrf_header_name = csrf_header_name[5:] + csrf_header_name = csrf_header_name.replace('_', '-') + context = { 'content': self.get_content(renderer, data, accepted_media_type, renderer_context), 'view': view, @@ -675,7 +681,8 @@ def get_context(self, data, accepted_media_type, renderer_context): 'display_edit_forms': bool(response.status_code != 403), 'api_settings': api_settings, - 'csrf_cookie_name': settings.CSRF_COOKIE_NAME, + 'csrf_cookie_name': csrf_cookie_name, + 'csrf_header_name': csrf_header_name } return context diff --git a/rest_framework/request.py b/rest_framework/request.py index f5738bfd50..355cccad77 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -391,3 +391,8 @@ def QUERY_PARAMS(self): '`request.QUERY_PARAMS` has been deprecated in favor of `request.query_params` ' 'since version 3.0, and has been fully removed as of version 3.2.' ) + + def force_plaintext_errors(self, value): + # Hack to allow our exception handler to force choice of + # plaintext or html error responses. + self._request.is_ajax = lambda: value diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 0618e94fd2..c9834c64d0 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -79,6 +79,13 @@ def get_schema(self, request=None): view.kwargs = {} view.format_kwarg = None + actions = getattr(callback, 'actions', None) + if actions is not None: + if method == 'OPTIONS': + view.action = 'metadata' + else: + view.action = actions.get(method.lower()) + if request is not None: view.request = clone_request(request, method) try: diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 41412af8ad..4d1ed63aef 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -12,6 +12,7 @@ """ from __future__ import unicode_literals +import traceback import warnings from django.db import models @@ -870,19 +871,20 @@ def create(self, validated_data): try: instance = ModelClass.objects.create(**validated_data) - except TypeError as exc: + except TypeError: + tb = traceback.format_exc() msg = ( 'Got a `TypeError` when calling `%s.objects.create()`. ' 'This may be because you have a writable field on the ' 'serializer class that is not a valid argument to ' '`%s.objects.create()`. You may need to make the field ' 'read-only, or override the %s.create() method to handle ' - 'this correctly.\nOriginal exception text was: %s.' % + 'this correctly.\nOriginal exception was:\n %s' % ( ModelClass.__name__, ModelClass.__name__, self.__class__.__name__, - exc + tb ) ) raise TypeError(msg) diff --git a/rest_framework/static/rest_framework/js/csrf.js b/rest_framework/static/rest_framework/js/csrf.js index f8ab4428cb..97c8d01242 100644 --- a/rest_framework/static/rest_framework/js/csrf.js +++ b/rest_framework/static/rest_framework/js/csrf.js @@ -46,7 +46,7 @@ $.ajaxSetup({ // Send the token to same-origin, relative URLs only. // Send the token only if the method warrants CSRF protection // Using the CSRFToken value acquired earlier - xhr.setRequestHeader("X-CSRFToken", csrftoken); + xhr.setRequestHeader(window.drf.csrfHeaderName, csrftoken); } } }); diff --git a/rest_framework/templates/rest_framework/admin.html b/rest_framework/templates/rest_framework/admin.html index 89af81ef74..eb2b8f1c7e 100644 --- a/rest_framework/templates/rest_framework/admin.html +++ b/rest_framework/templates/rest_framework/admin.html @@ -232,6 +232,7 @@ {% block script %} diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 4c1136087c..989a086ea7 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -263,6 +263,7 @@

{{ name }}

{% block script %} diff --git a/rest_framework/templates/rest_framework/vertical/checkbox_multiple.html b/rest_framework/templates/rest_framework/vertical/checkbox_multiple.html index b933f4ff51..7a43b3f58b 100644 --- a/rest_framework/templates/rest_framework/vertical/checkbox_multiple.html +++ b/rest_framework/templates/rest_framework/vertical/checkbox_multiple.html @@ -9,7 +9,7 @@
{% for key, text in field.choices.items %} {% endfor %} @@ -18,7 +18,7 @@ {% for key, text in field.choices.items %}
diff --git a/rest_framework/test.py b/rest_framework/test.py index e17c19a43f..b8e486b216 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -16,7 +16,7 @@ from django.utils.encoding import force_bytes from django.utils.http import urlencode -from rest_framework.compat import requests +from rest_framework.compat import coreapi, requests from rest_framework.settings import api_settings @@ -60,7 +60,10 @@ def get_environ(self, request): # Set request content, if any exists. if request.body is not None: - kwargs['data'] = request.body + if hasattr(request.body, 'read'): + kwargs['data'] = request.body.read() + else: + kwargs['data'] = request.body if 'content-type' in request.headers: kwargs['content_type'] = request.headers['content-type'] @@ -126,6 +129,14 @@ def get_requests_client(): return DjangoTestSession() +def get_api_client(): + assert coreapi is not None, 'coreapi must be installed' + session = get_requests_client() + return coreapi.Client(transports=[ + coreapi.transports.HTTPTransport(session=session) + ]) + + class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT diff --git a/rest_framework/views.py b/rest_framework/views.py index b86bb7eaa6..15d8c6cde2 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,17 +3,14 @@ """ from __future__ import unicode_literals -import sys - from django.conf import settings from django.core.exceptions import PermissionDenied from django.db import models from django.http import Http404 -from django.http.response import HttpResponse, HttpResponseBase +from django.http.response import HttpResponseBase from django.utils import six from django.utils.encoding import smart_text from django.utils.translation import ugettext_lazy as _ -from django.views import debug from django.views.decorators.csrf import csrf_exempt from django.views.generic import View @@ -95,11 +92,6 @@ def exception_handler(exc, context): set_rollback() return Response(data, status=status.HTTP_403_FORBIDDEN) - # throw django's error page if debug is True - if settings.DEBUG: - exception_reporter = debug.ExceptionReporter(context.get('request'), *sys.exc_info()) - return HttpResponse(exception_reporter.get_traceback_html(), status=500) - return None @@ -439,11 +431,19 @@ def handle_exception(self, exc): response = exception_handler(exc, context) if response is None: - raise + self.raise_uncaught_exception(exc) response.exception = True return response + def raise_uncaught_exception(self, exc): + if settings.DEBUG: + request = self.request + renderer_format = getattr(request.accepted_renderer, 'format') + use_plaintext_traceback = renderer_format not in ('html', 'api', 'admin') + request.force_plaintext_errors(use_plaintext_traceback) + raise + # Note: Views are made CSRF exempt from within `as_view` as to prevent # accidental removal of this exemption in cases where `dispatch` needs to # be overridden. diff --git a/tests/test_api_client.py b/tests/test_api_client.py new file mode 100644 index 0000000000..9daf3f3fe4 --- /dev/null +++ b/tests/test_api_client.py @@ -0,0 +1,452 @@ +from __future__ import unicode_literals + +import os +import tempfile +import unittest + +from django.conf.urls import url +from django.http import HttpResponse +from django.test import override_settings + +from rest_framework.compat import coreapi +from rest_framework.parsers import FileUploadParser +from rest_framework.renderers import CoreJSONRenderer +from rest_framework.response import Response +from rest_framework.test import APITestCase, get_api_client +from rest_framework.views import APIView + + +def get_schema(): + return coreapi.Document( + url='https://api.example.com/', + title='Example API', + content={ + 'simple_link': coreapi.Link('/example/', description='example link'), + 'location': { + 'query': coreapi.Link('/example/', fields=[ + coreapi.Field(name='example', description='example field') + ]), + 'form': coreapi.Link('/example/', action='post', fields=[ + coreapi.Field(name='example'), + ]), + 'body': coreapi.Link('/example/', action='post', fields=[ + coreapi.Field(name='example', location='body') + ]), + 'path': coreapi.Link('/example/{id}', fields=[ + coreapi.Field(name='id', location='path') + ]) + }, + 'encoding': { + 'multipart': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[ + coreapi.Field(name='example') + ]), + 'multipart-body': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[ + coreapi.Field(name='example', location='body') + ]), + 'urlencoded': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[ + coreapi.Field(name='example') + ]), + 'urlencoded-body': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[ + coreapi.Field(name='example', location='body') + ]), + 'raw_upload': coreapi.Link('/upload/', action='post', encoding='application/octet-stream', fields=[ + coreapi.Field(name='example', location='body') + ]), + }, + 'response': { + 'download': coreapi.Link('/download/'), + 'text': coreapi.Link('/text/') + } + } + ) + + +def _iterlists(querydict): + if hasattr(querydict, 'iterlists'): + return querydict.iterlists() + return querydict.lists() + + +def _get_query_params(request): + # Return query params in a plain dict, using a list value if more + # than one item is present for a given key. + return { + key: (value[0] if len(value) == 1 else value) + for key, value in + _iterlists(request.query_params) + } + + +def _get_data(request): + if not isinstance(request.data, dict): + return request.data + # Coerce multidict into regular dict, and remove files to + # make assertions simpler. + if hasattr(request.data, 'iterlists') or hasattr(request.data, 'lists'): + # Use a list value if a QueryDict contains multiple items for a key. + return { + key: value[0] if len(value) == 1 else value + for key, value in _iterlists(request.data) + if key not in request.FILES + } + return { + key: value + for key, value in request.data.items() + if key not in request.FILES + } + + +def _get_files(request): + if not request.FILES: + return {} + return { + key: {'name': value.name, 'content': value.read()} + for key, value in request.FILES.items() + } + + +class SchemaView(APIView): + renderer_classes = [CoreJSONRenderer] + + def get(self, request): + schema = get_schema() + return Response(schema) + + +class ListView(APIView): + def get(self, request): + return Response({ + 'method': request.method, + 'query_params': _get_query_params(request) + }) + + def post(self, request): + if request.content_type: + content_type = request.content_type.split(';')[0] + else: + content_type = None + + return Response({ + 'method': request.method, + 'query_params': _get_query_params(request), + 'data': _get_data(request), + 'files': _get_files(request), + 'content_type': content_type + }) + + +class DetailView(APIView): + def get(self, request, id): + return Response({ + 'id': id, + 'method': request.method, + 'query_params': _get_query_params(request) + }) + + +class UploadView(APIView): + parser_classes = [FileUploadParser] + + def post(self, request): + return Response({ + 'method': request.method, + 'files': _get_files(request), + 'content_type': request.content_type + }) + + +class DownloadView(APIView): + def get(self, request): + return HttpResponse('some file content', content_type='image/png') + + +class TextView(APIView): + def get(self, request): + return HttpResponse('123', content_type='text/plain') + + +urlpatterns = [ + url(r'^$', SchemaView.as_view()), + url(r'^example/$', ListView.as_view()), + url(r'^example/(?P[0-9]+)/$', DetailView.as_view()), + url(r'^upload/$', UploadView.as_view()), + url(r'^download/$', DownloadView.as_view()), + url(r'^text/$', TextView.as_view()), +] + + +@unittest.skipUnless(coreapi, 'coreapi not installed') +@override_settings(ROOT_URLCONF='tests.test_api_client') +class APIClientTests(APITestCase): + def test_api_client(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + assert schema.title == 'Example API' + assert schema.url == 'https://api.example.com/' + assert schema['simple_link'].description == 'example link' + assert schema['location']['query'].fields[0].description == 'example field' + data = client.action(schema, ['simple_link']) + expected = { + 'method': 'GET', + 'query_params': {} + } + assert data == expected + + def test_query_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'query'], params={'example': 123}) + expected = { + 'method': 'GET', + 'query_params': {'example': '123'} + } + assert data == expected + + def test_query_params_with_multiple_values(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'query'], params={'example': [1, 2, 3]}) + expected = { + 'method': 'GET', + 'query_params': {'example': ['1', '2', '3']} + } + assert data == expected + + def test_form_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'form'], params={'example': 123}) + expected = { + 'method': 'POST', + 'content_type': 'application/json', + 'query_params': {}, + 'data': {'example': 123}, + 'files': {} + } + assert data == expected + + def test_body_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'body'], params={'example': 123}) + expected = { + 'method': 'POST', + 'content_type': 'application/json', + 'query_params': {}, + 'data': 123, + 'files': {} + } + assert data == expected + + def test_path_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'path'], params={'id': 123}) + expected = { + 'method': 'GET', + 'query_params': {}, + 'id': '123' + } + assert data == expected + + def test_multipart_encoding(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + temp = tempfile.NamedTemporaryFile() + temp.write(b'example file content') + temp.flush() + + with open(temp.name, 'rb') as upload: + name = os.path.basename(upload.name) + data = client.action(schema, ['encoding', 'multipart'], params={'example': upload}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {}, + 'files': {'example': {'name': name, 'content': 'example file content'}} + } + assert data == expected + + def test_multipart_encoding_no_file(self): + # When no file is included, multipart encoding should still be used. + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['encoding', 'multipart'], params={'example': 123}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {'example': '123'}, + 'files': {} + } + assert data == expected + + def test_multipart_encoding_multiple_values(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['encoding', 'multipart'], params={'example': [1, 2, 3]}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {'example': ['1', '2', '3']}, + 'files': {} + } + assert data == expected + + def test_multipart_encoding_string_file_content(self): + # Test for `coreapi.utils.File` support. + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = File(name='example.txt', content='123') + data = client.action(schema, ['encoding', 'multipart'], params={'example': example}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {}, + 'files': {'example': {'name': 'example.txt', 'content': '123'}} + } + assert data == expected + + def test_multipart_encoding_in_body(self): + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = {'foo': File(name='example.txt', content='123'), 'bar': 'abc'} + data = client.action(schema, ['encoding', 'multipart-body'], params={'example': example}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {'bar': 'abc'}, + 'files': {'foo': {'name': 'example.txt', 'content': '123'}} + } + assert data == expected + + # URLencoded + + def test_urlencoded_encoding(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['encoding', 'urlencoded'], params={'example': 123}) + expected = { + 'method': 'POST', + 'content_type': 'application/x-www-form-urlencoded', + 'query_params': {}, + 'data': {'example': '123'}, + 'files': {} + } + assert data == expected + + def test_urlencoded_encoding_multiple_values(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['encoding', 'urlencoded'], params={'example': [1, 2, 3]}) + expected = { + 'method': 'POST', + 'content_type': 'application/x-www-form-urlencoded', + 'query_params': {}, + 'data': {'example': ['1', '2', '3']}, + 'files': {} + } + assert data == expected + + def test_urlencoded_encoding_in_body(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['encoding', 'urlencoded-body'], params={'example': {'foo': 123, 'bar': True}}) + expected = { + 'method': 'POST', + 'content_type': 'application/x-www-form-urlencoded', + 'query_params': {}, + 'data': {'foo': '123', 'bar': 'true'}, + 'files': {} + } + assert data == expected + + # Raw uploads + + def test_raw_upload(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + temp = tempfile.NamedTemporaryFile() + temp.write(b'example file content') + temp.flush() + + with open(temp.name, 'rb') as upload: + name = os.path.basename(upload.name) + data = client.action(schema, ['encoding', 'raw_upload'], params={'example': upload}) + + expected = { + 'method': 'POST', + 'files': {'file': {'name': name, 'content': 'example file content'}}, + 'content_type': 'application/octet-stream' + } + assert data == expected + + def test_raw_upload_string_file_content(self): + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = File('example.txt', '123') + data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example}) + + expected = { + 'method': 'POST', + 'files': {'file': {'name': 'example.txt', 'content': '123'}}, + 'content_type': 'text/plain' + } + assert data == expected + + def test_raw_upload_explicit_content_type(self): + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = File('example.txt', '123', 'text/html') + data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example}) + + expected = { + 'method': 'POST', + 'files': {'file': {'name': 'example.txt', 'content': '123'}}, + 'content_type': 'text/html' + } + assert data == expected + + # Responses + + def test_text_response(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['response', 'text']) + + expected = '123' + assert data == expected + + def test_download_response(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['response', 'download']) + assert data.basename == 'download.png' + assert data.read() == b'some file content' diff --git a/tests/test_request.py b/tests/test_request.py index dee636d766..dbfa695fd7 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -7,6 +7,7 @@ from django.contrib.auth import authenticate, login, logout from django.contrib.auth.models import User from django.contrib.sessions.middleware import SessionMiddleware +from django.core.files.uploadedfile import SimpleUploadedFile from django.test import TestCase, override_settings from django.utils import six @@ -78,6 +79,16 @@ def test_request_POST_with_form_content(self): request.parsers = (FormParser(), MultiPartParser()) self.assertEqual(list(request.POST.items()), list(data.items())) + def test_request_POST_with_files(self): + """ + Ensure request.POST returns no content for POST request with file content. + """ + upload = SimpleUploadedFile("file.txt", b"file_content") + request = Request(factory.post('/', {'upload': upload})) + request.parsers = (FormParser(), MultiPartParser()) + self.assertEqual(list(request.POST.keys()), []) + self.assertEqual(list(request.FILES.keys()), ['upload']) + def test_standard_behaviour_determines_form_content_PUT(self): """ Ensure request.data returns content for PUT request with form content. diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 81b796c35a..c866e09bed 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -49,6 +49,7 @@ def custom_list_action(self, request): def get_serializer(self, *args, **kwargs): assert self.request + assert self.action return super(ExampleViewSet, self).get_serializer(*args, **kwargs)