diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index df0c48b86a..abaac0c223 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -18,13 +18,20 @@ def validate(self, attrs): if user: if not user.is_active: msg = _('User account is disabled.') - raise serializers.ValidationError(msg) + raise serializers.ValidationError( + msg, + code='authorization') else: msg = _('Unable to log in with provided credentials.') - raise serializers.ValidationError(msg) + raise serializers.ValidationError( + msg, + code='authorization') + else: msg = _('Must include "username" and "password".') - raise serializers.ValidationError(msg) + raise serializers.ValidationError( + msg, + code='authorization') attrs['user'] = user return attrs diff --git a/rest_framework/exceptions.py b/rest_framework/exceptions.py index 29afaffe00..6e30834e6a 100644 --- a/rest_framework/exceptions.py +++ b/rest_framework/exceptions.py @@ -58,6 +58,14 @@ def __str__(self): return self.detail +def build_error_from_django_validation_error(exc_info): + code = getattr(exc_info, 'code', None) or 'invalid' + return [ + ValidationErrorMessage(msg, code=code) + for msg in exc_info.messages + ] + + # The recommended style for using `ValidationError` is to keep it namespaced # under `serializers`, in order to minimize potential confusion with Django's # built in `ValidationError`. For example: @@ -65,10 +73,25 @@ def __str__(self): # from rest_framework import serializers # raise serializers.ValidationError('Value was invalid') +class ValidationErrorMessage(six.text_type): + code = None + + def __new__(cls, string, code=None, *args, **kwargs): + self = super(ValidationErrorMessage, cls).__new__( + cls, string, *args, **kwargs) + + self.code = code + return self + + class ValidationError(APIException): status_code = status.HTTP_400_BAD_REQUEST - def __init__(self, detail): + def __init__(self, detail, code=None): + # If code is there, this means we are dealing with a message. + if code and not isinstance(detail, ValidationErrorMessage): + detail = ValidationErrorMessage(detail, code=code) + # For validation errors the 'detail' key is always required. # The details should always be coerced to a list if not already. if not isinstance(detail, dict) and not isinstance(detail, list): diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f76e4e8011..3962595840 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -34,7 +34,9 @@ from rest_framework.compat import ( get_remote_field, unicode_repr, unicode_to_repr, value_from_object ) -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import ( + ValidationError, build_error_from_django_validation_error +) from rest_framework.settings import api_settings from rest_framework.utils import html, humanize_datetime, representation @@ -509,7 +511,9 @@ def run_validators(self, value): raise errors.extend(exc.detail) except DjangoValidationError as exc: - errors.extend(exc.messages) + errors.extend( + build_error_from_django_validation_error(exc) + ) if errors: raise ValidationError(errors) @@ -547,7 +551,7 @@ def fail(self, key, **kwargs): msg = MISSING_ERROR_MESSAGE.format(class_name=class_name, key=key) raise AssertionError(msg) message_string = msg.format(**kwargs) - raise ValidationError(message_string) + raise ValidationError(message_string, code=key) @cached_property def root(self): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4d1ed63aef..5b3ef37709 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -22,8 +22,10 @@ from django.utils.functional import cached_property from django.utils.translation import ugettext_lazy as _ +from rest_framework import exceptions from rest_framework.compat import JSONField as ModelJSONField from rest_framework.compat import postgres_fields, unicode_to_repr +from rest_framework.exceptions import ValidationErrorMessage from rest_framework.utils import model_meta from rest_framework.utils.field_mapping import ( ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs, @@ -220,7 +222,6 @@ def is_valid(self, raise_exception=False): if self._errors and raise_exception: raise ValidationError(self.errors) - return not bool(self._errors) @property @@ -301,7 +302,8 @@ def get_validation_error_detail(exc): # exception class as well for simpler compat. # Eg. Calling Model.clean() explicitly inside Serializer.validate() return { - api_settings.NON_FIELD_ERRORS_KEY: list(exc.messages) + api_settings.NON_FIELD_ERRORS_KEY: + exceptions.build_error_from_django_validation_error(exc) } elif isinstance(exc.detail, dict): # If errors may be a dict we use the standard {key: list of values}. @@ -423,8 +425,9 @@ def to_internal_value(self, data): message = self.error_messages['invalid'].format( datatype=type(data).__name__ ) + error = ValidationErrorMessage(message, code='invalid') raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [error] }) ret = OrderedDict() @@ -441,7 +444,9 @@ def to_internal_value(self, data): except ValidationError as exc: errors[field.field_name] = exc.detail except DjangoValidationError as exc: - errors[field.field_name] = list(exc.messages) + errors[field.field_name] = ( + exceptions.build_error_from_django_validation_error(exc) + ) except SkipField: pass else: @@ -580,12 +585,18 @@ def to_internal_value(self, data): message = self.error_messages['not_a_list'].format( input_type=type(data).__name__ ) + error = ValidationErrorMessage( + message, + code='not_a_list' + ) raise ValidationError({ - api_settings.NON_FIELD_ERRORS_KEY: [message] + api_settings.NON_FIELD_ERRORS_KEY: [error] }) if not self.allow_empty and len(data) == 0: - message = self.error_messages['empty'] + message = ValidationErrorMessage( + self.error_messages['empty'], + code='empty_not_allowed') raise ValidationError({ api_settings.NON_FIELD_ERRORS_KEY: [message] }) diff --git a/rest_framework/validators.py b/rest_framework/validators.py index ef23b9bd70..3b8678a70d 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -12,7 +12,7 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import unicode_to_repr -from rest_framework.exceptions import ValidationError +from rest_framework.exceptions import ValidationError, ValidationErrorMessage from rest_framework.utils.representation import smart_repr @@ -79,7 +79,7 @@ def __call__(self, value): queryset = self.filter_queryset(value, queryset) queryset = self.exclude_current_instance(queryset) if qs_exists(queryset): - raise ValidationError(self.message) + raise ValidationError(self.message, code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s)>' % ( @@ -120,7 +120,10 @@ def enforce_required_fields(self, attrs): return missing = { - field_name: self.missing_message + field_name: ValidationErrorMessage( + self.missing_message, + code='required') + for field_name in self.fields if field_name not in attrs } @@ -166,7 +169,9 @@ def __call__(self, attrs): ] if None not in checked_values and qs_exists(queryset): field_names = ', '.join(self.fields) - raise ValidationError(self.message.format(field_names=field_names)) + raise ValidationError( + self.message.format(field_names=field_names), + code='unique') def __repr__(self): return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % ( @@ -204,7 +209,9 @@ def enforce_required_fields(self, attrs): 'required' state on the fields they are applied to. """ missing = { - field_name: self.missing_message + field_name: ValidationErrorMessage( + self.missing_message, + code='required') for field_name in [self.field, self.date_field] if field_name not in attrs } @@ -230,7 +237,9 @@ def __call__(self, attrs): queryset = self.exclude_current_instance(attrs, queryset) if qs_exists(queryset): message = self.message.format(date_field=self.date_field) - raise ValidationError({self.field: message}) + raise ValidationError({ + self.field: ValidationErrorMessage(message, code='unique'), + }) def __repr__(self): return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % ( diff --git a/tests/test_validation_error.py b/tests/test_validation_error.py new file mode 100644 index 0000000000..a9d244176d --- /dev/null +++ b/tests/test_validation_error.py @@ -0,0 +1,74 @@ +from django.test import TestCase + +from rest_framework import serializers, status +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory +from rest_framework.views import APIView + +factory = APIRequestFactory() + + +class ExampleSerializer(serializers.Serializer): + char = serializers.CharField() + integer = serializers.IntegerField() + + +class ErrorView(APIView): + def get(self, request, *args, **kwargs): + ExampleSerializer(data={}).is_valid(raise_exception=True) + + +@api_view(['GET']) +def error_view(request): + ExampleSerializer(data={}).is_valid(raise_exception=True) + + +class TestValidationErrorWithCode(TestCase): + def setUp(self): + self.DEFAULT_HANDLER = api_settings.EXCEPTION_HANDLER + + def exception_handler(exc, request): + return_errors = {} + for field_name, errors in exc.detail.items(): + return_errors[field_name] = [] + for error in errors: + return_errors[field_name].append({ + 'code': error.code, + 'message': error + }) + + return Response(return_errors, status=status.HTTP_400_BAD_REQUEST) + + api_settings.EXCEPTION_HANDLER = exception_handler + + self.expected_response_data = { + 'char': [{ + 'message': 'This field is required.', + 'code': 'required', + }], + 'integer': [{ + 'message': 'This field is required.', + 'code': 'required' + }], + } + + def tearDown(self): + api_settings.EXCEPTION_HANDLER = self.DEFAULT_HANDLER + + def test_class_based_view_exception_handler(self): + view = ErrorView.as_view() + + request = factory.get('/', content_type='application/json') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, self.expected_response_data) + + def test_function_based_view_exception_handler(self): + view = error_view + + request = factory.get('/', content_type='application/json') + response = view(request) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertEqual(response.data, self.expected_response_data)