From c0dfaa1744abb126347b653adf496e6fe3826c8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Michael=20K=C3=A4ufl?= Date: Sat, 18 May 2019 19:23:20 +0200 Subject: [PATCH] Do not persist the context in validators Fixes encode/django-rest-framework#5760 --- docs/api-guide/fields.md | 2 +- docs/api-guide/validators.md | 15 ++-- rest_framework/fields.py | 15 +++- rest_framework/validators.py | 139 +++++++++++++++++------------------ tests/test_validators.py | 7 +- 5 files changed, 96 insertions(+), 82 deletions(-) diff --git a/docs/api-guide/fields.md b/docs/api-guide/fields.md index b2830d0c9d..4b53d613b0 100644 --- a/docs/api-guide/fields.md +++ b/docs/api-guide/fields.md @@ -47,7 +47,7 @@ If set, this gives the default value that will be used for the field if no input The `default` is not applied during partial update operations. In the partial update case only fields that are provided in the incoming data will have a validated value returned. -May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `set_context` method, that will be called each time before getting the value with the field instance as only argument. This works the same way as for [validators](validators.md#using-set_context). +May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `set_context` method, that will be called each time before getting the value with the field instance as only argument. When serializing the instance, default will be used if the the object attribute or dictionary key is not present in the instance. diff --git a/docs/api-guide/validators.md b/docs/api-guide/validators.md index 9b2fc82ed7..7d0b8c8196 100644 --- a/docs/api-guide/validators.md +++ b/docs/api-guide/validators.md @@ -290,13 +290,18 @@ To write a class-based validator, use the `__call__` method. Class-based validat message = 'This field must be a multiple of %d.' % self.base raise serializers.ValidationError(message) -#### Using `set_context()` +#### Accessing the context -In some advanced cases you might want a validator to be passed the serializer field it is being used with as additional context. You can do so by declaring a `set_context` method on a class-based validator. +In some advanced cases you might want a validator to be passed the serializer +field it is being used with as additional context. You can do so by using +`rest_framework.validators.ContextBasedValidator` as a base class for the +validator. The `__call__` method will then be called with the `serializer_field` +or `serializer` as an additional argument. - def set_context(self, serializer_field): + def __call__(self, value, serializer_field): # Determine if this is an update or a create operation. - # In `__call__` we can then use that information to modify the validation behavior. - self.is_update = serializer_field.parent.instance is not None + is_update = serializer_field.parent.instance is not None + + pass # implementation of the validator that uses `is_update` [cite]: https://docs.djangoproject.com/en/stable/ref/validators/ diff --git a/rest_framework/fields.py b/rest_framework/fields.py index aecfa33024..3c287faaf0 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -5,6 +5,7 @@ import inspect import re import uuid +import warnings from collections import OrderedDict from collections.abc import Mapping @@ -519,13 +520,25 @@ def run_validators(self, value): Test the given value against all the validators on the field, and either raise a `ValidationError` or simply return. """ + from rest_framework.validators import ContextBasedValidator + errors = [] for validator in self.validators: if hasattr(validator, 'set_context'): + warnings.warn( + "Method `set_context` on validators is deprecated and will " + "no longer be called starting with 3.11. Instead derive the " + "validator from `rest_framwork.validators.ContextBasedValidator` " + "and accept the context as an additional argument.", + DeprecationWarning, stacklevel=2 + ) validator.set_context(self) try: - validator(value) + if isinstance(validator, ContextBasedValidator): + validator(value, self) + else: + validator(value) except ValidationError as exc: # If the validation error contains a mapping of fields to # errors then simply raise it immediately rather than diff --git a/rest_framework/validators.py b/rest_framework/validators.py index 1cbe31b5ea..0ead180890 100644 --- a/rest_framework/validators.py +++ b/rest_framework/validators.py @@ -30,7 +30,17 @@ def qs_filter(queryset, **kwargs): return queryset.none() -class UniqueValidator: +class ContextBasedValidator: + """Base class for validators that need a context during evaluation. + + In extension to regular validators their `__call__` method must not only + accept a value, but also an instance of a serializer. + """ + def __call__(self, value, serializer): + raise NotImplementedError('`__call__()` must be implemented.') + + +class UniqueValidator(ContextBasedValidator): """ Validator that corresponds to `unique=True` on a model field. @@ -44,37 +54,32 @@ def __init__(self, queryset, message=None, lookup='exact'): self.message = message or self.message self.lookup = lookup - def set_context(self, serializer_field): - """ - This hook is called by the serializer instance, - prior to the validation call being made. - """ - # Determine the underlying model field name. This may not be the - # same as the serializer field name if `source=<>` is set. - self.field_name = serializer_field.source_attrs[-1] - # Determine the existing instance, if this is an update operation. - self.instance = getattr(serializer_field.parent, 'instance', None) - - def filter_queryset(self, value, queryset): + def filter_queryset(self, value, queryset, field_name): """ Filter the queryset to all instances matching the given attribute. """ - filter_kwargs = {'%s__%s' % (self.field_name, self.lookup): value} + filter_kwargs = {'%s__%s' % (field_name, self.lookup): value} return qs_filter(queryset, **filter_kwargs) - def exclude_current_instance(self, queryset): + def exclude_current_instance(self, queryset, instance): """ If an instance is being updated, then do not include that instance itself as a uniqueness conflict. """ - if self.instance is not None: - return queryset.exclude(pk=self.instance.pk) + if instance is not None: + return queryset.exclude(pk=instance.pk) return queryset - def __call__(self, value): + def __call__(self, value, serializer_field): + # Determine the underlying model field name. This may not be the + # same as the serializer field name if `source=<>` is set. + field_name = serializer_field.source_attrs[-1] + # Determine the existing instance, if this is an update operation. + instance = getattr(serializer_field.parent, 'instance', None) + queryset = self.queryset - queryset = self.filter_queryset(value, queryset) - queryset = self.exclude_current_instance(queryset) + queryset = self.filter_queryset(value, queryset, field_name) + queryset = self.exclude_current_instance(queryset, instance) if qs_exists(queryset): raise ValidationError(self.message, code='unique') @@ -85,7 +90,7 @@ def __repr__(self): ) -class UniqueTogetherValidator: +class UniqueTogetherValidator(ContextBasedValidator): """ Validator that corresponds to `unique_together = (...)` on a model class. @@ -100,20 +105,12 @@ def __init__(self, queryset, fields, message=None): self.serializer_field = None self.message = message or self.message - def set_context(self, serializer): - """ - This hook is called by the serializer instance, - prior to the validation call being made. - """ - # Determine the existing instance, if this is an update operation. - self.instance = getattr(serializer, 'instance', None) - - def enforce_required_fields(self, attrs): + def enforce_required_fields(self, attrs, instance): """ The `UniqueTogetherValidator` always forces an implied 'required' state on the fields it applies to. """ - if self.instance is not None: + if instance is not None: return missing_items = { @@ -124,16 +121,16 @@ def enforce_required_fields(self, attrs): if missing_items: raise ValidationError(missing_items, code='required') - def filter_queryset(self, attrs, queryset): + def filter_queryset(self, attrs, queryset, instance): """ Filter the queryset to all instances matching the given attributes. """ # If this is an update, then any unprovided field should # have it's value set based on the existing instance attribute. - if self.instance is not None: + if instance is not None: for field_name in self.fields: if field_name not in attrs: - attrs[field_name] = getattr(self.instance, field_name) + attrs[field_name] = getattr(instance, field_name) # Determine the filter keyword arguments and filter the queryset. filter_kwargs = { @@ -142,20 +139,23 @@ def filter_queryset(self, attrs, queryset): } return qs_filter(queryset, **filter_kwargs) - def exclude_current_instance(self, attrs, queryset): + def exclude_current_instance(self, attrs, queryset, instance): """ If an instance is being updated, then do not include that instance itself as a uniqueness conflict. """ - if self.instance is not None: - return queryset.exclude(pk=self.instance.pk) + if instance is not None: + return queryset.exclude(pk=instance.pk) return queryset - def __call__(self, attrs): - self.enforce_required_fields(attrs) + def __call__(self, attrs, serializer): + # Determine the existing instance, if this is an update operation. + instance = getattr(serializer, 'instance', None) + + self.enforce_required_fields(attrs, instance) queryset = self.queryset - queryset = self.filter_queryset(attrs, queryset) - queryset = self.exclude_current_instance(attrs, queryset) + queryset = self.filter_queryset(attrs, queryset, instance) + queryset = self.exclude_current_instance(attrs, queryset, instance) # Ignore validation if any field is None checked_values = [ @@ -174,7 +174,7 @@ def __repr__(self): ) -class BaseUniqueForValidator: +class BaseUniqueForValidator(ContextBasedValidator): message = None missing_message = _('This field is required.') @@ -184,18 +184,6 @@ def __init__(self, queryset, field, date_field, message=None): self.date_field = date_field self.message = message or self.message - def set_context(self, serializer): - """ - This hook is called by the serializer instance, - prior to the validation call being made. - """ - # Determine the underlying model field names. These may not be the - # same as the serializer field names if `source=<>` is set. - self.field_name = serializer.fields[self.field].source_attrs[-1] - self.date_field_name = serializer.fields[self.date_field].source_attrs[-1] - # Determine the existing instance, if this is an update operation. - self.instance = getattr(serializer, 'instance', None) - def enforce_required_fields(self, attrs): """ The `UniqueForValidator` classes always force an implied @@ -209,23 +197,30 @@ def enforce_required_fields(self, attrs): if missing_items: raise ValidationError(missing_items, code='required') - def filter_queryset(self, attrs, queryset): + def filter_queryset(self, attrs, queryset, field_name, date_field_name): raise NotImplementedError('`filter_queryset` must be implemented.') - def exclude_current_instance(self, attrs, queryset): + def exclude_current_instance(self, attrs, queryset, instance): """ If an instance is being updated, then do not include that instance itself as a uniqueness conflict. """ - if self.instance is not None: - return queryset.exclude(pk=self.instance.pk) + if instance is not None: + return queryset.exclude(pk=instance.pk) return queryset - def __call__(self, attrs): + def __call__(self, attrs, serializer): + # Determine the underlying model field names. These may not be the + # same as the serializer field names if `source=<>` is set. + field_name = serializer.fields[self.field].source_attrs[-1] + date_field_name = serializer.fields[self.date_field].source_attrs[-1] + # Determine the existing instance, if this is an update operation. + instance = getattr(serializer, 'instance', None) + self.enforce_required_fields(attrs) queryset = self.queryset - queryset = self.filter_queryset(attrs, queryset) - queryset = self.exclude_current_instance(attrs, queryset) + queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name) + queryset = self.exclude_current_instance(attrs, queryset, instance) if qs_exists(queryset): message = self.message.format(date_field=self.date_field) raise ValidationError({ @@ -244,39 +239,39 @@ def __repr__(self): class UniqueForDateValidator(BaseUniqueForValidator): message = _('This field must be unique for the "{date_field}" date.') - def filter_queryset(self, attrs, queryset): + def filter_queryset(self, attrs, queryset, field_name, date_field_name): value = attrs[self.field] date = attrs[self.date_field] filter_kwargs = {} - filter_kwargs[self.field_name] = value - filter_kwargs['%s__day' % self.date_field_name] = date.day - filter_kwargs['%s__month' % self.date_field_name] = date.month - filter_kwargs['%s__year' % self.date_field_name] = date.year + filter_kwargs[field_name] = value + filter_kwargs['%s__day' % date_field_name] = date.day + filter_kwargs['%s__month' % date_field_name] = date.month + filter_kwargs['%s__year' % date_field_name] = date.year return qs_filter(queryset, **filter_kwargs) class UniqueForMonthValidator(BaseUniqueForValidator): message = _('This field must be unique for the "{date_field}" month.') - def filter_queryset(self, attrs, queryset): + def filter_queryset(self, attrs, queryset, field_name, date_field_name): value = attrs[self.field] date = attrs[self.date_field] filter_kwargs = {} - filter_kwargs[self.field_name] = value - filter_kwargs['%s__month' % self.date_field_name] = date.month + filter_kwargs[field_name] = value + filter_kwargs['%s__month' % date_field_name] = date.month return qs_filter(queryset, **filter_kwargs) class UniqueForYearValidator(BaseUniqueForValidator): message = _('This field must be unique for the "{date_field}" year.') - def filter_queryset(self, attrs, queryset): + def filter_queryset(self, attrs, queryset, field_name, date_field_name): value = attrs[self.field] date = attrs[self.date_field] filter_kwargs = {} - filter_kwargs[self.field_name] = value - filter_kwargs['%s__year' % self.date_field_name] = date.year + filter_kwargs[field_name] = value + filter_kwargs['%s__year' % date_field_name] = date.year return qs_filter(queryset, **filter_kwargs) diff --git a/tests/test_validators.py b/tests/test_validators.py index fe31ba2357..bb29a4305b 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -361,8 +361,7 @@ def filter(self, **kwargs): queryset = MockQueryset() validator = UniqueTogetherValidator(queryset, fields=('race_name', 'position')) - validator.instance = self.instance - validator.filter_queryset(attrs=data, queryset=queryset) + validator.filter_queryset(attrs=data, queryset=queryset, instance=self.instance) assert queryset.called_with == {'race_name': 'bar', 'position': 1} @@ -586,4 +585,6 @@ def test_validator_raises_error_when_abstract_method_called(self): validator = BaseUniqueForValidator(queryset=object(), field='foo', date_field='bar') with pytest.raises(NotImplementedError): - validator.filter_queryset(attrs=None, queryset=None) + validator.filter_queryset( + attrs=None, queryset=None, field_name='', date_field_name='' + )