Skip to content

Commit c0dfaa1

Browse files
committed
Do not persist the context in validators
Fixes encode#5760
1 parent ac0f0a1 commit c0dfaa1

File tree

5 files changed

+96
-82
lines changed

5 files changed

+96
-82
lines changed

docs/api-guide/fields.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ If set, this gives the default value that will be used for the field if no input
4747

4848
The `default` is not applied during partial update operations. In the partial update case only fields that are provided in the incoming data will have a validated value returned.
4949

50-
May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `set_context` method, that will be called each time before getting the value with the field instance as only argument. This works the same way as for [validators](validators.md#using-set_context).
50+
May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `set_context` method, that will be called each time before getting the value with the field instance as only argument.
5151

5252
When serializing the instance, default will be used if the the object attribute or dictionary key is not present in the instance.
5353

docs/api-guide/validators.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,18 @@ To write a class-based validator, use the `__call__` method. Class-based validat
290290
message = 'This field must be a multiple of %d.' % self.base
291291
raise serializers.ValidationError(message)
292292

293-
#### Using `set_context()`
293+
#### Accessing the context
294294

295-
In some advanced cases you might want a validator to be passed the serializer field it is being used with as additional context. You can do so by declaring a `set_context` method on a class-based validator.
295+
In some advanced cases you might want a validator to be passed the serializer
296+
field it is being used with as additional context. You can do so by using
297+
`rest_framework.validators.ContextBasedValidator` as a base class for the
298+
validator. The `__call__` method will then be called with the `serializer_field`
299+
or `serializer` as an additional argument.
296300

297-
def set_context(self, serializer_field):
301+
def __call__(self, value, serializer_field):
298302
# Determine if this is an update or a create operation.
299-
# In `__call__` we can then use that information to modify the validation behavior.
300-
self.is_update = serializer_field.parent.instance is not None
303+
is_update = serializer_field.parent.instance is not None
304+
305+
pass # implementation of the validator that uses `is_update`
301306

302307
[cite]: https://docs.djangoproject.com/en/stable/ref/validators/

rest_framework/fields.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import inspect
66
import re
77
import uuid
8+
import warnings
89
from collections import OrderedDict
910
from collections.abc import Mapping
1011

@@ -519,13 +520,25 @@ def run_validators(self, value):
519520
Test the given value against all the validators on the field,
520521
and either raise a `ValidationError` or simply return.
521522
"""
523+
from rest_framework.validators import ContextBasedValidator
524+
522525
errors = []
523526
for validator in self.validators:
524527
if hasattr(validator, 'set_context'):
528+
warnings.warn(
529+
"Method `set_context` on validators is deprecated and will "
530+
"no longer be called starting with 3.11. Instead derive the "
531+
"validator from `rest_framwork.validators.ContextBasedValidator` "
532+
"and accept the context as an additional argument.",
533+
DeprecationWarning, stacklevel=2
534+
)
525535
validator.set_context(self)
526536

527537
try:
528-
validator(value)
538+
if isinstance(validator, ContextBasedValidator):
539+
validator(value, self)
540+
else:
541+
validator(value)
529542
except ValidationError as exc:
530543
# If the validation error contains a mapping of fields to
531544
# errors then simply raise it immediately rather than

rest_framework/validators.py

Lines changed: 67 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,17 @@ def qs_filter(queryset, **kwargs):
3030
return queryset.none()
3131

3232

33-
class UniqueValidator:
33+
class ContextBasedValidator:
34+
"""Base class for validators that need a context during evaluation.
35+
36+
In extension to regular validators their `__call__` method must not only
37+
accept a value, but also an instance of a serializer.
38+
"""
39+
def __call__(self, value, serializer):
40+
raise NotImplementedError('`__call__()` must be implemented.')
41+
42+
43+
class UniqueValidator(ContextBasedValidator):
3444
"""
3545
Validator that corresponds to `unique=True` on a model field.
3646
@@ -44,37 +54,32 @@ def __init__(self, queryset, message=None, lookup='exact'):
4454
self.message = message or self.message
4555
self.lookup = lookup
4656

47-
def set_context(self, serializer_field):
48-
"""
49-
This hook is called by the serializer instance,
50-
prior to the validation call being made.
51-
"""
52-
# Determine the underlying model field name. This may not be the
53-
# same as the serializer field name if `source=<>` is set.
54-
self.field_name = serializer_field.source_attrs[-1]
55-
# Determine the existing instance, if this is an update operation.
56-
self.instance = getattr(serializer_field.parent, 'instance', None)
57-
58-
def filter_queryset(self, value, queryset):
57+
def filter_queryset(self, value, queryset, field_name):
5958
"""
6059
Filter the queryset to all instances matching the given attribute.
6160
"""
62-
filter_kwargs = {'%s__%s' % (self.field_name, self.lookup): value}
61+
filter_kwargs = {'%s__%s' % (field_name, self.lookup): value}
6362
return qs_filter(queryset, **filter_kwargs)
6463

65-
def exclude_current_instance(self, queryset):
64+
def exclude_current_instance(self, queryset, instance):
6665
"""
6766
If an instance is being updated, then do not include
6867
that instance itself as a uniqueness conflict.
6968
"""
70-
if self.instance is not None:
71-
return queryset.exclude(pk=self.instance.pk)
69+
if instance is not None:
70+
return queryset.exclude(pk=instance.pk)
7271
return queryset
7372

74-
def __call__(self, value):
73+
def __call__(self, value, serializer_field):
74+
# Determine the underlying model field name. This may not be the
75+
# same as the serializer field name if `source=<>` is set.
76+
field_name = serializer_field.source_attrs[-1]
77+
# Determine the existing instance, if this is an update operation.
78+
instance = getattr(serializer_field.parent, 'instance', None)
79+
7580
queryset = self.queryset
76-
queryset = self.filter_queryset(value, queryset)
77-
queryset = self.exclude_current_instance(queryset)
81+
queryset = self.filter_queryset(value, queryset, field_name)
82+
queryset = self.exclude_current_instance(queryset, instance)
7883
if qs_exists(queryset):
7984
raise ValidationError(self.message, code='unique')
8085

@@ -85,7 +90,7 @@ def __repr__(self):
8590
)
8691

8792

88-
class UniqueTogetherValidator:
93+
class UniqueTogetherValidator(ContextBasedValidator):
8994
"""
9095
Validator that corresponds to `unique_together = (...)` on a model class.
9196
@@ -100,20 +105,12 @@ def __init__(self, queryset, fields, message=None):
100105
self.serializer_field = None
101106
self.message = message or self.message
102107

103-
def set_context(self, serializer):
104-
"""
105-
This hook is called by the serializer instance,
106-
prior to the validation call being made.
107-
"""
108-
# Determine the existing instance, if this is an update operation.
109-
self.instance = getattr(serializer, 'instance', None)
110-
111-
def enforce_required_fields(self, attrs):
108+
def enforce_required_fields(self, attrs, instance):
112109
"""
113110
The `UniqueTogetherValidator` always forces an implied 'required'
114111
state on the fields it applies to.
115112
"""
116-
if self.instance is not None:
113+
if instance is not None:
117114
return
118115

119116
missing_items = {
@@ -124,16 +121,16 @@ def enforce_required_fields(self, attrs):
124121
if missing_items:
125122
raise ValidationError(missing_items, code='required')
126123

127-
def filter_queryset(self, attrs, queryset):
124+
def filter_queryset(self, attrs, queryset, instance):
128125
"""
129126
Filter the queryset to all instances matching the given attributes.
130127
"""
131128
# If this is an update, then any unprovided field should
132129
# have it's value set based on the existing instance attribute.
133-
if self.instance is not None:
130+
if instance is not None:
134131
for field_name in self.fields:
135132
if field_name not in attrs:
136-
attrs[field_name] = getattr(self.instance, field_name)
133+
attrs[field_name] = getattr(instance, field_name)
137134

138135
# Determine the filter keyword arguments and filter the queryset.
139136
filter_kwargs = {
@@ -142,20 +139,23 @@ def filter_queryset(self, attrs, queryset):
142139
}
143140
return qs_filter(queryset, **filter_kwargs)
144141

145-
def exclude_current_instance(self, attrs, queryset):
142+
def exclude_current_instance(self, attrs, queryset, instance):
146143
"""
147144
If an instance is being updated, then do not include
148145
that instance itself as a uniqueness conflict.
149146
"""
150-
if self.instance is not None:
151-
return queryset.exclude(pk=self.instance.pk)
147+
if instance is not None:
148+
return queryset.exclude(pk=instance.pk)
152149
return queryset
153150

154-
def __call__(self, attrs):
155-
self.enforce_required_fields(attrs)
151+
def __call__(self, attrs, serializer):
152+
# Determine the existing instance, if this is an update operation.
153+
instance = getattr(serializer, 'instance', None)
154+
155+
self.enforce_required_fields(attrs, instance)
156156
queryset = self.queryset
157-
queryset = self.filter_queryset(attrs, queryset)
158-
queryset = self.exclude_current_instance(attrs, queryset)
157+
queryset = self.filter_queryset(attrs, queryset, instance)
158+
queryset = self.exclude_current_instance(attrs, queryset, instance)
159159

160160
# Ignore validation if any field is None
161161
checked_values = [
@@ -174,7 +174,7 @@ def __repr__(self):
174174
)
175175

176176

177-
class BaseUniqueForValidator:
177+
class BaseUniqueForValidator(ContextBasedValidator):
178178
message = None
179179
missing_message = _('This field is required.')
180180

@@ -184,18 +184,6 @@ def __init__(self, queryset, field, date_field, message=None):
184184
self.date_field = date_field
185185
self.message = message or self.message
186186

187-
def set_context(self, serializer):
188-
"""
189-
This hook is called by the serializer instance,
190-
prior to the validation call being made.
191-
"""
192-
# Determine the underlying model field names. These may not be the
193-
# same as the serializer field names if `source=<>` is set.
194-
self.field_name = serializer.fields[self.field].source_attrs[-1]
195-
self.date_field_name = serializer.fields[self.date_field].source_attrs[-1]
196-
# Determine the existing instance, if this is an update operation.
197-
self.instance = getattr(serializer, 'instance', None)
198-
199187
def enforce_required_fields(self, attrs):
200188
"""
201189
The `UniqueFor<Range>Validator` classes always force an implied
@@ -209,23 +197,30 @@ def enforce_required_fields(self, attrs):
209197
if missing_items:
210198
raise ValidationError(missing_items, code='required')
211199

212-
def filter_queryset(self, attrs, queryset):
200+
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
213201
raise NotImplementedError('`filter_queryset` must be implemented.')
214202

215-
def exclude_current_instance(self, attrs, queryset):
203+
def exclude_current_instance(self, attrs, queryset, instance):
216204
"""
217205
If an instance is being updated, then do not include
218206
that instance itself as a uniqueness conflict.
219207
"""
220-
if self.instance is not None:
221-
return queryset.exclude(pk=self.instance.pk)
208+
if instance is not None:
209+
return queryset.exclude(pk=instance.pk)
222210
return queryset
223211

224-
def __call__(self, attrs):
212+
def __call__(self, attrs, serializer):
213+
# Determine the underlying model field names. These may not be the
214+
# same as the serializer field names if `source=<>` is set.
215+
field_name = serializer.fields[self.field].source_attrs[-1]
216+
date_field_name = serializer.fields[self.date_field].source_attrs[-1]
217+
# Determine the existing instance, if this is an update operation.
218+
instance = getattr(serializer, 'instance', None)
219+
225220
self.enforce_required_fields(attrs)
226221
queryset = self.queryset
227-
queryset = self.filter_queryset(attrs, queryset)
228-
queryset = self.exclude_current_instance(attrs, queryset)
222+
queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name)
223+
queryset = self.exclude_current_instance(attrs, queryset, instance)
229224
if qs_exists(queryset):
230225
message = self.message.format(date_field=self.date_field)
231226
raise ValidationError({
@@ -244,39 +239,39 @@ def __repr__(self):
244239
class UniqueForDateValidator(BaseUniqueForValidator):
245240
message = _('This field must be unique for the "{date_field}" date.')
246241

247-
def filter_queryset(self, attrs, queryset):
242+
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
248243
value = attrs[self.field]
249244
date = attrs[self.date_field]
250245

251246
filter_kwargs = {}
252-
filter_kwargs[self.field_name] = value
253-
filter_kwargs['%s__day' % self.date_field_name] = date.day
254-
filter_kwargs['%s__month' % self.date_field_name] = date.month
255-
filter_kwargs['%s__year' % self.date_field_name] = date.year
247+
filter_kwargs[field_name] = value
248+
filter_kwargs['%s__day' % date_field_name] = date.day
249+
filter_kwargs['%s__month' % date_field_name] = date.month
250+
filter_kwargs['%s__year' % date_field_name] = date.year
256251
return qs_filter(queryset, **filter_kwargs)
257252

258253

259254
class UniqueForMonthValidator(BaseUniqueForValidator):
260255
message = _('This field must be unique for the "{date_field}" month.')
261256

262-
def filter_queryset(self, attrs, queryset):
257+
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
263258
value = attrs[self.field]
264259
date = attrs[self.date_field]
265260

266261
filter_kwargs = {}
267-
filter_kwargs[self.field_name] = value
268-
filter_kwargs['%s__month' % self.date_field_name] = date.month
262+
filter_kwargs[field_name] = value
263+
filter_kwargs['%s__month' % date_field_name] = date.month
269264
return qs_filter(queryset, **filter_kwargs)
270265

271266

272267
class UniqueForYearValidator(BaseUniqueForValidator):
273268
message = _('This field must be unique for the "{date_field}" year.')
274269

275-
def filter_queryset(self, attrs, queryset):
270+
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
276271
value = attrs[self.field]
277272
date = attrs[self.date_field]
278273

279274
filter_kwargs = {}
280-
filter_kwargs[self.field_name] = value
281-
filter_kwargs['%s__year' % self.date_field_name] = date.year
275+
filter_kwargs[field_name] = value
276+
filter_kwargs['%s__year' % date_field_name] = date.year
282277
return qs_filter(queryset, **filter_kwargs)

tests/test_validators.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,7 @@ def filter(self, **kwargs):
361361
queryset = MockQueryset()
362362
validator = UniqueTogetherValidator(queryset, fields=('race_name',
363363
'position'))
364-
validator.instance = self.instance
365-
validator.filter_queryset(attrs=data, queryset=queryset)
364+
validator.filter_queryset(attrs=data, queryset=queryset, instance=self.instance)
366365
assert queryset.called_with == {'race_name': 'bar', 'position': 1}
367366

368367

@@ -586,4 +585,6 @@ def test_validator_raises_error_when_abstract_method_called(self):
586585
validator = BaseUniqueForValidator(queryset=object(), field='foo',
587586
date_field='bar')
588587
with pytest.raises(NotImplementedError):
589-
validator.filter_queryset(attrs=None, queryset=None)
588+
validator.filter_queryset(
589+
attrs=None, queryset=None, field_name='', date_field_name=''
590+
)

0 commit comments

Comments
 (0)