Skip to content

Do not persist the context in validators #6172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/api-guide/fields.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ If set, this gives the default value that will be used for the field if no input

The `default` is not applied during partial update operations. In the partial update case only fields that are provided in the incoming data will have a validated value returned.

May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `set_context` method, that will be called each time before getting the value with the field instance as only argument. This works the same way as for [validators](validators.md#using-set_context).
May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `set_context` method, that will be called each time before getting the value with the field instance as only argument.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like changing default to use the same mechanism would be a good idea as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be not relevant due to the deepcloning. But I'm not sure.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think even if it's not strictly required for defaults, we should still adopt it for consistency.


When serializing the instance, default will be used if the the object attribute or dictionary key is not present in the instance.

Expand Down
15 changes: 10 additions & 5 deletions docs/api-guide/validators.md
Original file line number Diff line number Diff line change
Expand Up @@ -290,13 +290,18 @@ To write a class-based validator, use the `__call__` method. Class-based validat
message = 'This field must be a multiple of %d.' % self.base
raise serializers.ValidationError(message)

#### Using `set_context()`
#### Accessing the context

In some advanced cases you might want a validator to be passed the serializer field it is being used with as additional context. You can do so by declaring a `set_context` method on a class-based validator.
In some advanced cases you might want a validator to be passed the serializer
field it is being used with as additional context. You can do so by using
`rest_framework.validators.ContextBasedValidator` as a base class for the
validator. The `__call__` method will then be called with the `serializer_field`
or `serializer` as an additional argument.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ContextBasedValidator is potentially confusing, as the serializer context is an entirely different concept. Possibly FieldContextValidator or FieldBasedValidator?


def set_context(self, serializer_field):
def __call__(self, value, serializer_field):
# Determine if this is an update or a create operation.
# In `__call__` we can then use that information to modify the validation behavior.
self.is_update = serializer_field.parent.instance is not None
is_update = serializer_field.parent.instance is not None

pass # implementation of the validator that uses `is_update`

[cite]: https://docs.djangoproject.com/en/stable/ref/validators/
15 changes: 14 additions & 1 deletion rest_framework/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import inspect
import re
import uuid
import warnings
from collections import OrderedDict
from collections.abc import Mapping

Expand Down Expand Up @@ -519,13 +520,25 @@ def run_validators(self, value):
Test the given value against all the validators on the field,
and either raise a `ValidationError` or simply return.
"""
from rest_framework.validators import ContextBasedValidator

errors = []
for validator in self.validators:
if hasattr(validator, 'set_context'):
warnings.warn(
"Method `set_context` on validators is deprecated and will "
"no longer be called starting with 3.11. Instead derive the "
"validator from `rest_framwork.validators.ContextBasedValidator` "
"and accept the context as an additional argument.",
DeprecationWarning, stacklevel=2
)
validator.set_context(self)

try:
validator(value)
if isinstance(validator, ContextBasedValidator):
validator(value, self)
else:
validator(value)
except ValidationError as exc:
# If the validation error contains a mapping of fields to
# errors then simply raise it immediately rather than
Expand Down
139 changes: 67 additions & 72 deletions rest_framework/validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,17 @@ def qs_filter(queryset, **kwargs):
return queryset.none()


class UniqueValidator:
class ContextBasedValidator:
"""Base class for validators that need a context during evaluation.

In extension to regular validators their `__call__` method must not only
accept a value, but also an instance of a serializer.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should say "serializer field" here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. Depends on the validator, I guess. I think it's mostly relevant for validators used on serializers since fields are deepcopied.

"""
def __call__(self, value, serializer):
raise NotImplementedError('`__call__()` must be implemented.')


class UniqueValidator(ContextBasedValidator):
"""
Validator that corresponds to `unique=True` on a model field.

Expand All @@ -44,37 +54,32 @@ def __init__(self, queryset, message=None, lookup='exact'):
self.message = message or self.message
self.lookup = lookup

def set_context(self, serializer_field):
"""
This hook is called by the serializer instance,
prior to the validation call being made.
"""
# Determine the underlying model field name. This may not be the
# same as the serializer field name if `source=<>` is set.
self.field_name = serializer_field.source_attrs[-1]
# Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer_field.parent, 'instance', None)

def filter_queryset(self, value, queryset):
def filter_queryset(self, value, queryset, field_name):
"""
Filter the queryset to all instances matching the given attribute.
"""
filter_kwargs = {'%s__%s' % (self.field_name, self.lookup): value}
filter_kwargs = {'%s__%s' % (field_name, self.lookup): value}
return qs_filter(queryset, **filter_kwargs)

def exclude_current_instance(self, queryset):
def exclude_current_instance(self, queryset, instance):
"""
If an instance is being updated, then do not include
that instance itself as a uniqueness conflict.
"""
if self.instance is not None:
return queryset.exclude(pk=self.instance.pk)
if instance is not None:
return queryset.exclude(pk=instance.pk)
return queryset

def __call__(self, value):
def __call__(self, value, serializer_field):
# Determine the underlying model field name. This may not be the
# same as the serializer field name if `source=<>` is set.
field_name = serializer_field.source_attrs[-1]
# Determine the existing instance, if this is an update operation.
instance = getattr(serializer_field.parent, 'instance', None)

queryset = self.queryset
queryset = self.filter_queryset(value, queryset)
queryset = self.exclude_current_instance(queryset)
queryset = self.filter_queryset(value, queryset, field_name)
queryset = self.exclude_current_instance(queryset, instance)
if qs_exists(queryset):
raise ValidationError(self.message, code='unique')

Expand All @@ -85,7 +90,7 @@ def __repr__(self):
)


class UniqueTogetherValidator:
class UniqueTogetherValidator(ContextBasedValidator):
"""
Validator that corresponds to `unique_together = (...)` on a model class.

Expand All @@ -100,20 +105,12 @@ def __init__(self, queryset, fields, message=None):
self.serializer_field = None
self.message = message or self.message

def set_context(self, serializer):
"""
This hook is called by the serializer instance,
prior to the validation call being made.
"""
# Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer, 'instance', None)

def enforce_required_fields(self, attrs):
def enforce_required_fields(self, attrs, instance):
"""
The `UniqueTogetherValidator` always forces an implied 'required'
state on the fields it applies to.
"""
if self.instance is not None:
if instance is not None:
return

missing_items = {
Expand All @@ -124,16 +121,16 @@ def enforce_required_fields(self, attrs):
if missing_items:
raise ValidationError(missing_items, code='required')

def filter_queryset(self, attrs, queryset):
def filter_queryset(self, attrs, queryset, instance):
"""
Filter the queryset to all instances matching the given attributes.
"""
# If this is an update, then any unprovided field should
# have it's value set based on the existing instance attribute.
if self.instance is not None:
if instance is not None:
for field_name in self.fields:
if field_name not in attrs:
attrs[field_name] = getattr(self.instance, field_name)
attrs[field_name] = getattr(instance, field_name)

# Determine the filter keyword arguments and filter the queryset.
filter_kwargs = {
Expand All @@ -142,20 +139,23 @@ def filter_queryset(self, attrs, queryset):
}
return qs_filter(queryset, **filter_kwargs)

def exclude_current_instance(self, attrs, queryset):
def exclude_current_instance(self, attrs, queryset, instance):
"""
If an instance is being updated, then do not include
that instance itself as a uniqueness conflict.
"""
if self.instance is not None:
return queryset.exclude(pk=self.instance.pk)
if instance is not None:
return queryset.exclude(pk=instance.pk)
return queryset

def __call__(self, attrs):
self.enforce_required_fields(attrs)
def __call__(self, attrs, serializer):
# Determine the existing instance, if this is an update operation.
instance = getattr(serializer, 'instance', None)

self.enforce_required_fields(attrs, instance)
queryset = self.queryset
queryset = self.filter_queryset(attrs, queryset)
queryset = self.exclude_current_instance(attrs, queryset)
queryset = self.filter_queryset(attrs, queryset, instance)
queryset = self.exclude_current_instance(attrs, queryset, instance)

# Ignore validation if any field is None
checked_values = [
Expand All @@ -174,7 +174,7 @@ def __repr__(self):
)


class BaseUniqueForValidator:
class BaseUniqueForValidator(ContextBasedValidator):
message = None
missing_message = _('This field is required.')

Expand All @@ -184,18 +184,6 @@ def __init__(self, queryset, field, date_field, message=None):
self.date_field = date_field
self.message = message or self.message

def set_context(self, serializer):
"""
This hook is called by the serializer instance,
prior to the validation call being made.
"""
# Determine the underlying model field names. These may not be the
# same as the serializer field names if `source=<>` is set.
self.field_name = serializer.fields[self.field].source_attrs[-1]
self.date_field_name = serializer.fields[self.date_field].source_attrs[-1]
# Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer, 'instance', None)

def enforce_required_fields(self, attrs):
"""
The `UniqueFor<Range>Validator` classes always force an implied
Expand All @@ -209,23 +197,30 @@ def enforce_required_fields(self, attrs):
if missing_items:
raise ValidationError(missing_items, code='required')

def filter_queryset(self, attrs, queryset):
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
raise NotImplementedError('`filter_queryset` must be implemented.')

def exclude_current_instance(self, attrs, queryset):
def exclude_current_instance(self, attrs, queryset, instance):
"""
If an instance is being updated, then do not include
that instance itself as a uniqueness conflict.
"""
if self.instance is not None:
return queryset.exclude(pk=self.instance.pk)
if instance is not None:
return queryset.exclude(pk=instance.pk)
return queryset

def __call__(self, attrs):
def __call__(self, attrs, serializer):
# Determine the underlying model field names. These may not be the
# same as the serializer field names if `source=<>` is set.
field_name = serializer.fields[self.field].source_attrs[-1]
date_field_name = serializer.fields[self.date_field].source_attrs[-1]
# Determine the existing instance, if this is an update operation.
instance = getattr(serializer, 'instance', None)

self.enforce_required_fields(attrs)
queryset = self.queryset
queryset = self.filter_queryset(attrs, queryset)
queryset = self.exclude_current_instance(attrs, queryset)
queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name)
queryset = self.exclude_current_instance(attrs, queryset, instance)
if qs_exists(queryset):
message = self.message.format(date_field=self.date_field)
raise ValidationError({
Expand All @@ -244,39 +239,39 @@ def __repr__(self):
class UniqueForDateValidator(BaseUniqueForValidator):
message = _('This field must be unique for the "{date_field}" date.')

def filter_queryset(self, attrs, queryset):
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
value = attrs[self.field]
date = attrs[self.date_field]

filter_kwargs = {}
filter_kwargs[self.field_name] = value
filter_kwargs['%s__day' % self.date_field_name] = date.day
filter_kwargs['%s__month' % self.date_field_name] = date.month
filter_kwargs['%s__year' % self.date_field_name] = date.year
filter_kwargs[field_name] = value
filter_kwargs['%s__day' % date_field_name] = date.day
filter_kwargs['%s__month' % date_field_name] = date.month
filter_kwargs['%s__year' % date_field_name] = date.year
return qs_filter(queryset, **filter_kwargs)


class UniqueForMonthValidator(BaseUniqueForValidator):
message = _('This field must be unique for the "{date_field}" month.')

def filter_queryset(self, attrs, queryset):
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
value = attrs[self.field]
date = attrs[self.date_field]

filter_kwargs = {}
filter_kwargs[self.field_name] = value
filter_kwargs['%s__month' % self.date_field_name] = date.month
filter_kwargs[field_name] = value
filter_kwargs['%s__month' % date_field_name] = date.month
return qs_filter(queryset, **filter_kwargs)


class UniqueForYearValidator(BaseUniqueForValidator):
message = _('This field must be unique for the "{date_field}" year.')

def filter_queryset(self, attrs, queryset):
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
value = attrs[self.field]
date = attrs[self.date_field]

filter_kwargs = {}
filter_kwargs[self.field_name] = value
filter_kwargs['%s__year' % self.date_field_name] = date.year
filter_kwargs[field_name] = value
filter_kwargs['%s__year' % date_field_name] = date.year
return qs_filter(queryset, **filter_kwargs)
7 changes: 4 additions & 3 deletions tests/test_validators.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,7 @@ def filter(self, **kwargs):
queryset = MockQueryset()
validator = UniqueTogetherValidator(queryset, fields=('race_name',
'position'))
validator.instance = self.instance
validator.filter_queryset(attrs=data, queryset=queryset)
validator.filter_queryset(attrs=data, queryset=queryset, instance=self.instance)
assert queryset.called_with == {'race_name': 'bar', 'position': 1}


Expand Down Expand Up @@ -586,4 +585,6 @@ def test_validator_raises_error_when_abstract_method_called(self):
validator = BaseUniqueForValidator(queryset=object(), field='foo',
date_field='bar')
with pytest.raises(NotImplementedError):
validator.filter_queryset(attrs=None, queryset=None)
validator.filter_queryset(
attrs=None, queryset=None, field_name='', date_field_name=''
)