-
-
Notifications
You must be signed in to change notification settings - Fork 6.9k
Do not persist the context in validators #6172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -290,13 +290,18 @@ To write a class-based validator, use the `__call__` method. Class-based validat | |
message = 'This field must be a multiple of %d.' % self.base | ||
raise serializers.ValidationError(message) | ||
|
||
#### Using `set_context()` | ||
#### Accessing the context | ||
|
||
In some advanced cases you might want a validator to be passed the serializer field it is being used with as additional context. You can do so by declaring a `set_context` method on a class-based validator. | ||
In some advanced cases you might want a validator to be passed the serializer | ||
field it is being used with as additional context. You can do so by using | ||
`rest_framework.validators.ContextBasedValidator` as a base class for the | ||
validator. The `__call__` method will then be called with the `serializer_field` | ||
or `serializer` as an additional argument. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
def set_context(self, serializer_field): | ||
def __call__(self, value, serializer_field): | ||
# Determine if this is an update or a create operation. | ||
# In `__call__` we can then use that information to modify the validation behavior. | ||
self.is_update = serializer_field.parent.instance is not None | ||
is_update = serializer_field.parent.instance is not None | ||
|
||
pass # implementation of the validator that uses `is_update` | ||
|
||
[cite]: https://docs.djangoproject.com/en/stable/ref/validators/ |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -30,7 +30,17 @@ def qs_filter(queryset, **kwargs): | |
return queryset.none() | ||
|
||
|
||
class UniqueValidator: | ||
class ContextBasedValidator: | ||
"""Base class for validators that need a context during evaluation. | ||
|
||
In extension to regular validators their `__call__` method must not only | ||
accept a value, but also an instance of a serializer. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should say "serializer field" here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure. Depends on the validator, I guess. I think it's mostly relevant for validators used on serializers since fields are deepcopied. |
||
""" | ||
def __call__(self, value, serializer): | ||
raise NotImplementedError('`__call__()` must be implemented.') | ||
|
||
|
||
class UniqueValidator(ContextBasedValidator): | ||
""" | ||
Validator that corresponds to `unique=True` on a model field. | ||
|
||
|
@@ -44,37 +54,32 @@ def __init__(self, queryset, message=None, lookup='exact'): | |
self.message = message or self.message | ||
self.lookup = lookup | ||
|
||
def set_context(self, serializer_field): | ||
""" | ||
This hook is called by the serializer instance, | ||
prior to the validation call being made. | ||
""" | ||
# Determine the underlying model field name. This may not be the | ||
# same as the serializer field name if `source=<>` is set. | ||
self.field_name = serializer_field.source_attrs[-1] | ||
# Determine the existing instance, if this is an update operation. | ||
self.instance = getattr(serializer_field.parent, 'instance', None) | ||
|
||
def filter_queryset(self, value, queryset): | ||
def filter_queryset(self, value, queryset, field_name): | ||
""" | ||
Filter the queryset to all instances matching the given attribute. | ||
""" | ||
filter_kwargs = {'%s__%s' % (self.field_name, self.lookup): value} | ||
filter_kwargs = {'%s__%s' % (field_name, self.lookup): value} | ||
return qs_filter(queryset, **filter_kwargs) | ||
|
||
def exclude_current_instance(self, queryset): | ||
def exclude_current_instance(self, queryset, instance): | ||
""" | ||
If an instance is being updated, then do not include | ||
that instance itself as a uniqueness conflict. | ||
""" | ||
if self.instance is not None: | ||
return queryset.exclude(pk=self.instance.pk) | ||
if instance is not None: | ||
return queryset.exclude(pk=instance.pk) | ||
return queryset | ||
|
||
def __call__(self, value): | ||
def __call__(self, value, serializer_field): | ||
# Determine the underlying model field name. This may not be the | ||
# same as the serializer field name if `source=<>` is set. | ||
field_name = serializer_field.source_attrs[-1] | ||
# Determine the existing instance, if this is an update operation. | ||
instance = getattr(serializer_field.parent, 'instance', None) | ||
|
||
queryset = self.queryset | ||
queryset = self.filter_queryset(value, queryset) | ||
queryset = self.exclude_current_instance(queryset) | ||
queryset = self.filter_queryset(value, queryset, field_name) | ||
queryset = self.exclude_current_instance(queryset, instance) | ||
if qs_exists(queryset): | ||
raise ValidationError(self.message, code='unique') | ||
|
||
|
@@ -85,7 +90,7 @@ def __repr__(self): | |
) | ||
|
||
|
||
class UniqueTogetherValidator: | ||
class UniqueTogetherValidator(ContextBasedValidator): | ||
""" | ||
Validator that corresponds to `unique_together = (...)` on a model class. | ||
|
||
|
@@ -100,20 +105,12 @@ def __init__(self, queryset, fields, message=None): | |
self.serializer_field = None | ||
self.message = message or self.message | ||
|
||
def set_context(self, serializer): | ||
""" | ||
This hook is called by the serializer instance, | ||
prior to the validation call being made. | ||
""" | ||
# Determine the existing instance, if this is an update operation. | ||
self.instance = getattr(serializer, 'instance', None) | ||
|
||
def enforce_required_fields(self, attrs): | ||
def enforce_required_fields(self, attrs, instance): | ||
""" | ||
The `UniqueTogetherValidator` always forces an implied 'required' | ||
state on the fields it applies to. | ||
""" | ||
if self.instance is not None: | ||
if instance is not None: | ||
return | ||
|
||
missing_items = { | ||
|
@@ -124,16 +121,16 @@ def enforce_required_fields(self, attrs): | |
if missing_items: | ||
raise ValidationError(missing_items, code='required') | ||
|
||
def filter_queryset(self, attrs, queryset): | ||
def filter_queryset(self, attrs, queryset, instance): | ||
""" | ||
Filter the queryset to all instances matching the given attributes. | ||
""" | ||
# If this is an update, then any unprovided field should | ||
# have it's value set based on the existing instance attribute. | ||
if self.instance is not None: | ||
if instance is not None: | ||
for field_name in self.fields: | ||
if field_name not in attrs: | ||
attrs[field_name] = getattr(self.instance, field_name) | ||
attrs[field_name] = getattr(instance, field_name) | ||
|
||
# Determine the filter keyword arguments and filter the queryset. | ||
filter_kwargs = { | ||
|
@@ -142,20 +139,23 @@ def filter_queryset(self, attrs, queryset): | |
} | ||
return qs_filter(queryset, **filter_kwargs) | ||
|
||
def exclude_current_instance(self, attrs, queryset): | ||
def exclude_current_instance(self, attrs, queryset, instance): | ||
""" | ||
If an instance is being updated, then do not include | ||
that instance itself as a uniqueness conflict. | ||
""" | ||
if self.instance is not None: | ||
return queryset.exclude(pk=self.instance.pk) | ||
if instance is not None: | ||
return queryset.exclude(pk=instance.pk) | ||
return queryset | ||
|
||
def __call__(self, attrs): | ||
self.enforce_required_fields(attrs) | ||
def __call__(self, attrs, serializer): | ||
# Determine the existing instance, if this is an update operation. | ||
instance = getattr(serializer, 'instance', None) | ||
|
||
self.enforce_required_fields(attrs, instance) | ||
queryset = self.queryset | ||
queryset = self.filter_queryset(attrs, queryset) | ||
queryset = self.exclude_current_instance(attrs, queryset) | ||
queryset = self.filter_queryset(attrs, queryset, instance) | ||
queryset = self.exclude_current_instance(attrs, queryset, instance) | ||
|
||
# Ignore validation if any field is None | ||
checked_values = [ | ||
|
@@ -174,7 +174,7 @@ def __repr__(self): | |
) | ||
|
||
|
||
class BaseUniqueForValidator: | ||
class BaseUniqueForValidator(ContextBasedValidator): | ||
message = None | ||
missing_message = _('This field is required.') | ||
|
||
|
@@ -184,18 +184,6 @@ def __init__(self, queryset, field, date_field, message=None): | |
self.date_field = date_field | ||
self.message = message or self.message | ||
|
||
def set_context(self, serializer): | ||
""" | ||
This hook is called by the serializer instance, | ||
prior to the validation call being made. | ||
""" | ||
# Determine the underlying model field names. These may not be the | ||
# same as the serializer field names if `source=<>` is set. | ||
self.field_name = serializer.fields[self.field].source_attrs[-1] | ||
self.date_field_name = serializer.fields[self.date_field].source_attrs[-1] | ||
# Determine the existing instance, if this is an update operation. | ||
self.instance = getattr(serializer, 'instance', None) | ||
|
||
def enforce_required_fields(self, attrs): | ||
""" | ||
The `UniqueFor<Range>Validator` classes always force an implied | ||
|
@@ -209,23 +197,30 @@ def enforce_required_fields(self, attrs): | |
if missing_items: | ||
raise ValidationError(missing_items, code='required') | ||
|
||
def filter_queryset(self, attrs, queryset): | ||
def filter_queryset(self, attrs, queryset, field_name, date_field_name): | ||
raise NotImplementedError('`filter_queryset` must be implemented.') | ||
|
||
def exclude_current_instance(self, attrs, queryset): | ||
def exclude_current_instance(self, attrs, queryset, instance): | ||
""" | ||
If an instance is being updated, then do not include | ||
that instance itself as a uniqueness conflict. | ||
""" | ||
if self.instance is not None: | ||
return queryset.exclude(pk=self.instance.pk) | ||
if instance is not None: | ||
return queryset.exclude(pk=instance.pk) | ||
return queryset | ||
|
||
def __call__(self, attrs): | ||
def __call__(self, attrs, serializer): | ||
# Determine the underlying model field names. These may not be the | ||
# same as the serializer field names if `source=<>` is set. | ||
field_name = serializer.fields[self.field].source_attrs[-1] | ||
date_field_name = serializer.fields[self.date_field].source_attrs[-1] | ||
# Determine the existing instance, if this is an update operation. | ||
instance = getattr(serializer, 'instance', None) | ||
|
||
self.enforce_required_fields(attrs) | ||
queryset = self.queryset | ||
queryset = self.filter_queryset(attrs, queryset) | ||
queryset = self.exclude_current_instance(attrs, queryset) | ||
queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name) | ||
queryset = self.exclude_current_instance(attrs, queryset, instance) | ||
if qs_exists(queryset): | ||
message = self.message.format(date_field=self.date_field) | ||
raise ValidationError({ | ||
|
@@ -244,39 +239,39 @@ def __repr__(self): | |
class UniqueForDateValidator(BaseUniqueForValidator): | ||
message = _('This field must be unique for the "{date_field}" date.') | ||
|
||
def filter_queryset(self, attrs, queryset): | ||
def filter_queryset(self, attrs, queryset, field_name, date_field_name): | ||
value = attrs[self.field] | ||
date = attrs[self.date_field] | ||
|
||
filter_kwargs = {} | ||
filter_kwargs[self.field_name] = value | ||
filter_kwargs['%s__day' % self.date_field_name] = date.day | ||
filter_kwargs['%s__month' % self.date_field_name] = date.month | ||
filter_kwargs['%s__year' % self.date_field_name] = date.year | ||
filter_kwargs[field_name] = value | ||
filter_kwargs['%s__day' % date_field_name] = date.day | ||
filter_kwargs['%s__month' % date_field_name] = date.month | ||
filter_kwargs['%s__year' % date_field_name] = date.year | ||
return qs_filter(queryset, **filter_kwargs) | ||
|
||
|
||
class UniqueForMonthValidator(BaseUniqueForValidator): | ||
message = _('This field must be unique for the "{date_field}" month.') | ||
|
||
def filter_queryset(self, attrs, queryset): | ||
def filter_queryset(self, attrs, queryset, field_name, date_field_name): | ||
value = attrs[self.field] | ||
date = attrs[self.date_field] | ||
|
||
filter_kwargs = {} | ||
filter_kwargs[self.field_name] = value | ||
filter_kwargs['%s__month' % self.date_field_name] = date.month | ||
filter_kwargs[field_name] = value | ||
filter_kwargs['%s__month' % date_field_name] = date.month | ||
return qs_filter(queryset, **filter_kwargs) | ||
|
||
|
||
class UniqueForYearValidator(BaseUniqueForValidator): | ||
message = _('This field must be unique for the "{date_field}" year.') | ||
|
||
def filter_queryset(self, attrs, queryset): | ||
def filter_queryset(self, attrs, queryset, field_name, date_field_name): | ||
value = attrs[self.field] | ||
date = attrs[self.date_field] | ||
|
||
filter_kwargs = {} | ||
filter_kwargs[self.field_name] = value | ||
filter_kwargs['%s__year' % self.date_field_name] = date.year | ||
filter_kwargs[field_name] = value | ||
filter_kwargs['%s__year' % date_field_name] = date.year | ||
return qs_filter(queryset, **filter_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like changing
default
to use the same mechanism would be a good idea as well.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be not relevant due to the deepcloning. But I'm not sure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think even if it's not strictly required for defaults, we should still adopt it for consistency.