Skip to content

Commit 83c23f7

Browse files
committed
Support UniqueConstraint
1 parent 599e2b1 commit 83c23f7

File tree

3 files changed

+175
-47
lines changed

3 files changed

+175
-47
lines changed

rest_framework/serializers.py

Lines changed: 46 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,23 @@ def get_extra_kwargs(self):
13731373

13741374
return extra_kwargs
13751375

1376+
def get_unique_together_constraints(self, model):
1377+
"""
1378+
Returns iterator of (fields, queryset), each entry describe an unique together
1379+
constraint on `fields` in `queryset`.
1380+
"""
1381+
for parent_class in [model] + list(model._meta.parents):
1382+
for unique_together in parent_class._meta.unique_together:
1383+
yield unique_together, model._default_manager
1384+
for constraint in parent_class._meta.constraints:
1385+
if isinstance(constraint, models.UniqueConstraint) and len(constraint.fields) > 1:
1386+
yield (
1387+
constraint.fields,
1388+
model._default_manager
1389+
if constraint.condition is None
1390+
else model._default_manager.filter(constraint.condition)
1391+
)
1392+
13761393
def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs):
13771394
"""
13781395
Return any additional field options that need to be included as a
@@ -1401,12 +1418,11 @@ def get_uniqueness_extra_kwargs(self, field_names, declared_fields, extra_kwargs
14011418

14021419
unique_constraint_names -= {None}
14031420

1404-
# Include each of the `unique_together` field names,
1421+
# Include each of the `unique_together` and `UniqueConstraint` field names,
14051422
# so long as all the field names are included on the serializer.
1406-
for parent_class in [model] + list(model._meta.parents):
1407-
for unique_together_list in parent_class._meta.unique_together:
1408-
if set(field_names).issuperset(set(unique_together_list)):
1409-
unique_constraint_names |= set(unique_together_list)
1423+
for unique_together_list, queryset in self.get_unique_together_constraints(model):
1424+
if set(field_names).issuperset(set(unique_together_list)):
1425+
unique_constraint_names |= set(unique_together_list)
14101426

14111427
# Now we have all the field names that have uniqueness constraints
14121428
# applied, we can add the extra 'required=...' or 'default=...'
@@ -1503,11 +1519,6 @@ def get_unique_together_validators(self):
15031519
"""
15041520
Determine a default set of validators for any unique_together constraints.
15051521
"""
1506-
model_class_inheritance_tree = (
1507-
[self.Meta.model] +
1508-
list(self.Meta.model._meta.parents)
1509-
)
1510-
15111522
# The field names we're passing though here only include fields
15121523
# which may map onto a model field. Any dotted field name lookups
15131524
# cannot map to a field, and must be a traversal, so we're not
@@ -1533,34 +1544,33 @@ def get_unique_together_validators(self):
15331544
# Note that we make sure to check `unique_together` both on the
15341545
# base model class, but also on any parent classes.
15351546
validators = []
1536-
for parent_class in model_class_inheritance_tree:
1537-
for unique_together in parent_class._meta.unique_together:
1538-
# Skip if serializer does not map to all unique together sources
1539-
if not set(source_map).issuperset(set(unique_together)):
1540-
continue
1541-
1542-
for source in unique_together:
1543-
assert len(source_map[source]) == 1, (
1544-
"Unable to create `UniqueTogetherValidator` for "
1545-
"`{model}.{field}` as `{serializer}` has multiple "
1546-
"fields ({fields}) that map to this model field. "
1547-
"Either remove the extra fields, or override "
1548-
"`Meta.validators` with a `UniqueTogetherValidator` "
1549-
"using the desired field names."
1550-
.format(
1551-
model=self.Meta.model.__name__,
1552-
serializer=self.__class__.__name__,
1553-
field=source,
1554-
fields=', '.join(source_map[source]),
1555-
)
1556-
)
1547+
for unique_together, queryset in self.get_unique_together_constraints(self.Meta.model):
1548+
# Skip if serializer does not map to all unique together sources
1549+
if not set(source_map).issuperset(set(unique_together)):
1550+
continue
15571551

1558-
field_names = tuple(source_map[f][0] for f in unique_together)
1559-
validator = UniqueTogetherValidator(
1560-
queryset=parent_class._default_manager,
1561-
fields=field_names
1552+
for source in unique_together:
1553+
assert len(source_map[source]) == 1, (
1554+
"Unable to create `UniqueTogetherValidator` for "
1555+
"`{model}.{field}` as `{serializer}` has multiple "
1556+
"fields ({fields}) that map to this model field. "
1557+
"Either remove the extra fields, or override "
1558+
"`Meta.validators` with a `UniqueTogetherValidator` "
1559+
"using the desired field names."
1560+
.format(
1561+
model=self.Meta.model.__name__,
1562+
serializer=self.__class__.__name__,
1563+
field=source,
1564+
fields=', '.join(source_map[source]),
1565+
)
15621566
)
1563-
validators.append(validator)
1567+
1568+
field_names = tuple(source_map[f][0] for f in unique_together)
1569+
validator = UniqueTogetherValidator(
1570+
queryset=queryset,
1571+
fields=field_names
1572+
)
1573+
validators.append(validator)
15641574
return validators
15651575

15661576
def get_unique_for_date_validators(self):

rest_framework/utils/field_mapping.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,34 @@ def get_detail_view_name(model):
6363
}
6464

6565

66+
def get_unique_validators(field_name, model_field):
67+
"""
68+
Returns a list of UniqueValidators that should be applied to the field.
69+
"""
70+
field_set = set([field_name])
71+
conditions = {
72+
c.condition
73+
for c in model_field.model._meta.constraints
74+
if isinstance(c, models.UniqueConstraint) and set(c.fields) == field_set
75+
}
76+
if getattr(model_field, 'unique', False):
77+
conditions.add(None)
78+
if not conditions:
79+
return []
80+
unique_error_message = model_field.error_messages.get('unique', None)
81+
if unique_error_message:
82+
unique_error_message = unique_error_message % {
83+
'model_name': model_field.model._meta.verbose_name,
84+
'field_label': model_field.verbose_name
85+
}
86+
queryset = model_field.model._default_manager
87+
for condition in conditions:
88+
yield UniqueValidator(
89+
queryset=queryset if condition is None else queryset.filter(condition),
90+
message=unique_error_message
91+
)
92+
93+
6694
def get_field_kwargs(field_name, model_field):
6795
"""
6896
Creates a default instance of a basic non-relational field.
@@ -216,17 +244,7 @@ def get_field_kwargs(field_name, model_field):
216244
if not isinstance(validator, validators.MinLengthValidator)
217245
]
218246

219-
if getattr(model_field, 'unique', False):
220-
unique_error_message = model_field.error_messages.get('unique', None)
221-
if unique_error_message:
222-
unique_error_message = unique_error_message % {
223-
'model_name': model_field.model._meta.verbose_name,
224-
'field_label': model_field.verbose_name
225-
}
226-
validator = UniqueValidator(
227-
queryset=model_field.model._default_manager,
228-
message=unique_error_message)
229-
validator_kwarg.append(validator)
247+
validator_kwarg += get_unique_validators(field_name, model_field)
230248

231249
if validator_kwarg:
232250
kwargs['validators'] = validator_kwarg

tests/test_validators.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -452,6 +452,106 @@ def filter(self, **kwargs):
452452
assert queryset.called_with == {'race_name': 'bar', 'position': 1}
453453

454454

455+
class UniqueConstraintModel(models.Model):
456+
race_name = models.CharField(max_length=100)
457+
position = models.IntegerField()
458+
global_id = models.IntegerField()
459+
fancy_conditions = models.IntegerField(null=True)
460+
461+
class Meta:
462+
constraints = [
463+
models.UniqueConstraint(
464+
name="unique_constraint_model_global_id_uniq",
465+
fields=('global_id',),
466+
),
467+
models.UniqueConstraint(
468+
name="unique_constraint_model_fancy_1_uniq",
469+
fields=('fancy_conditions',),
470+
condition=models.Q(global_id__lte=1)
471+
),
472+
models.UniqueConstraint(
473+
name="unique_constraint_model_fancy_3_uniq",
474+
fields=('fancy_conditions',),
475+
condition=models.Q(global_id__gte=3)
476+
),
477+
models.UniqueConstraint(
478+
name="unique_constraint_model_together_uniq",
479+
fields=('race_name', 'position'),
480+
condition=models.Q(race_name='example'),
481+
)
482+
]
483+
484+
485+
class UniqueConstraintSerializer(serializers.ModelSerializer):
486+
class Meta:
487+
model = UniqueConstraintModel
488+
fields = '__all__'
489+
490+
491+
class TestUniqueConstraintValidation(TestCase):
492+
def setUp(self):
493+
self.instance = UniqueConstraintModel.objects.create(
494+
race_name='example',
495+
position=1,
496+
global_id=1
497+
)
498+
UniqueConstraintModel.objects.create(
499+
race_name='example',
500+
position=2,
501+
global_id=2
502+
)
503+
UniqueConstraintModel.objects.create(
504+
race_name='other',
505+
position=1,
506+
global_id=3
507+
)
508+
509+
def test_repr(self):
510+
serializer = UniqueConstraintSerializer()
511+
# the order of validators isn't deterministic so delete
512+
# fancy_conditions field that has two of them
513+
del serializer.fields['fancy_conditions']
514+
expected = dedent("""
515+
UniqueConstraintSerializer():
516+
id = IntegerField(label='ID', read_only=True)
517+
race_name = CharField(max_length=100, required=True)
518+
position = IntegerField(required=True)
519+
global_id = IntegerField(validators=[<UniqueValidator(queryset=UniqueConstraintModel.objects.all())>])
520+
class Meta:
521+
validators = [<UniqueTogetherValidator(queryset=<QuerySet [<UniqueConstraintModel: UniqueConstraintModel object (1)>, <UniqueConstraintModel: UniqueConstraintModel object (2)>]>, fields=('race_name', 'position'))>]
522+
""")
523+
assert repr(serializer) == expected
524+
525+
def test_unique_together_field(self):
526+
"""
527+
UniqueConstraint fields and condition attributes must be passed
528+
to UniqueTogetherValidator as fields and queryset
529+
"""
530+
serializer = UniqueConstraintSerializer()
531+
assert len(serializer.validators) == 1
532+
validator = serializer.validators[0]
533+
assert validator.fields == ('race_name', 'position')
534+
assert set(validator.queryset.values_list(flat=True)) == set(
535+
UniqueConstraintModel.objects.filter(race_name='example').values_list(flat=True)
536+
)
537+
538+
def test_single_field_uniq_validators(self):
539+
"""
540+
UniqueConstraint with single field must be transformed into
541+
field's UniqueValidator
542+
"""
543+
serializer = UniqueConstraintSerializer()
544+
assert len(serializer.validators) == 1
545+
validators = serializer.fields['global_id'].validators
546+
assert len(validators) == 1
547+
assert validators[0].queryset == UniqueConstraintModel.objects
548+
549+
validators = serializer.fields['fancy_conditions'].validators
550+
assert len(validators) == 2
551+
ids_in_qs = {frozenset(v.queryset.values_list(flat=True)) for v in validators}
552+
assert ids_in_qs == {frozenset([1]), frozenset([3])}
553+
554+
455555
# Tests for `UniqueForDateValidator`
456556
# ----------------------------------
457557

0 commit comments

Comments
 (0)