Skip to content

Hyperlinked PK optimization. #2242

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 10, 2014
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 27 additions & 30 deletions rest_framework/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,20 @@ def get_queryset(self):
queryset = queryset.all()
return queryset

def get_iterable(self, instance, source_attrs):
relationship = get_attribute(instance, source_attrs)
return relationship.all() if (hasattr(relationship, 'all')) else relationship
def use_pk_only_optimization(self):
return False

def get_attribute(self, instance):
if self.use_pk_only_optimization() and self.source_attrs:
# Optimized case, return a mock object only containing the pk attribute.
try:
instance = get_attribute(instance, self.source_attrs[:-1])
return PKOnlyObject(pk=instance.serializable_value(self.source_attrs[-1]))
except AttributeError:
pass

# Standard case, return the object instance.
return get_attribute(instance, self.source_attrs)

@property
def choices(self):
Expand Down Expand Up @@ -120,6 +131,9 @@ class PrimaryKeyRelatedField(RelatedField):
'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),
}

def use_pk_only_optimization(self):
return True

def to_internal_value(self, data):
try:
return self.get_queryset().get(pk=data)
Expand All @@ -128,32 +142,6 @@ def to_internal_value(self, data):
except (TypeError, ValueError):
self.fail('incorrect_type', data_type=type(data).__name__)

def get_attribute(self, instance):
# We customize `get_attribute` here for performance reasons.
# For relationships the instance will already have the pk of
# the related object. We return this directly instead of returning the
# object itself, which would require a database lookup.
try:
instance = get_attribute(instance, self.source_attrs[:-1])
return PKOnlyObject(pk=instance.serializable_value(self.source_attrs[-1]))
except AttributeError:
return get_attribute(instance, self.source_attrs)

def get_iterable(self, instance, source_attrs):
# For consistency with `get_attribute` we're using `serializable_value()`
# here. Typically there won't be any difference, but some custom field
# types might return a non-primitive value for the pk otherwise.
#
# We could try to get smart with `values_list('pk', flat=True)`, which
# would be better in some case, but would actually end up with *more*
# queries if the developer is using `prefetch_related` across the
# relationship.
relationship = super(PrimaryKeyRelatedField, self).get_iterable(instance, source_attrs)
return [
PKOnlyObject(pk=item.serializable_value('pk'))
for item in relationship
]

def to_representation(self, value):
return value.pk

Expand Down Expand Up @@ -184,6 +172,9 @@ def __init__(self, view_name=None, **kwargs):

super(HyperlinkedRelatedField, self).__init__(**kwargs)

def use_pk_only_optimization(self):
return self.lookup_field == 'pk'

def get_object(self, view_name, view_args, view_kwargs):
"""
Return the object corresponding to a matched URL.
Expand Down Expand Up @@ -285,6 +276,11 @@ def __init__(self, view_name=None, **kwargs):
kwargs['source'] = '*'
super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)

def use_pk_only_optimization(self):
# We have the complete object instance already. We don't need
# to run the 'only get the pk for this relationship' code.
return False


class SlugRelatedField(RelatedField):
"""
Expand Down Expand Up @@ -349,7 +345,8 @@ def to_internal_value(self, data):
]

def get_attribute(self, instance):
return self.child_relation.get_iterable(instance, self.source_attrs)
relationship = get_attribute(instance, self.source_attrs)
return relationship.all() if (hasattr(relationship, 'all')) else relationship

def to_representation(self, iterable):
return [
Expand Down
18 changes: 14 additions & 4 deletions tests/test_relations_hyperlink.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,14 @@ def test_many_to_many_retrieve(self):
{'url': 'http://testserver/manytomanysource/2/', 'name': 'source-2', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/']},
{'url': 'http://testserver/manytomanysource/3/', 'name': 'source-3', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
]
self.assertEqual(serializer.data, expected)
with self.assertNumQueries(4):
self.assertEqual(serializer.data, expected)

def test_many_to_many_retrieve_prefetch_related(self):
queryset = ManyToManySource.objects.all().prefetch_related('targets')
serializer = ManyToManySourceSerializer(queryset, many=True, context={'request': request})
with self.assertNumQueries(2):
serializer.data

def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all()
Expand All @@ -99,7 +106,8 @@ def test_reverse_many_to_many_retrieve(self):
{'url': 'http://testserver/manytomanytarget/2/', 'name': 'target-2', 'sources': ['http://testserver/manytomanysource/2/', 'http://testserver/manytomanysource/3/']},
{'url': 'http://testserver/manytomanytarget/3/', 'name': 'target-3', 'sources': ['http://testserver/manytomanysource/3/']}
]
self.assertEqual(serializer.data, expected)
with self.assertNumQueries(4):
self.assertEqual(serializer.data, expected)

def test_many_to_many_update(self):
data = {'url': 'http://testserver/manytomanysource/1/', 'name': 'source-1', 'targets': ['http://testserver/manytomanytarget/1/', 'http://testserver/manytomanytarget/2/', 'http://testserver/manytomanytarget/3/']}
Expand Down Expand Up @@ -197,7 +205,8 @@ def test_foreign_key_retrieve(self):
{'url': 'http://testserver/foreignkeysource/2/', 'name': 'source-2', 'target': 'http://testserver/foreignkeytarget/1/'},
{'url': 'http://testserver/foreignkeysource/3/', 'name': 'source-3', 'target': 'http://testserver/foreignkeytarget/1/'}
]
self.assertEqual(serializer.data, expected)
with self.assertNumQueries(1):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this fix, this is 4.

self.assertEqual(serializer.data, expected)

def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
Expand All @@ -206,7 +215,8 @@ def test_reverse_foreign_key_retrieve(self):
{'url': 'http://testserver/foreignkeytarget/1/', 'name': 'target-1', 'sources': ['http://testserver/foreignkeysource/1/', 'http://testserver/foreignkeysource/2/', 'http://testserver/foreignkeysource/3/']},
{'url': 'http://testserver/foreignkeytarget/2/', 'name': 'target-2', 'sources': []},
]
self.assertEqual(serializer.data, expected)
with self.assertNumQueries(3):
self.assertEqual(serializer.data, expected)

def test_foreign_key_update(self):
data = {'url': 'http://testserver/foreignkeysource/1/', 'name': 'source-1', 'target': 'http://testserver/foreignkeytarget/2/'}
Expand Down
12 changes: 12 additions & 0 deletions tests/test_relations_pk.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@ def test_many_to_many_retrieve(self):
with self.assertNumQueries(4):
self.assertEqual(serializer.data, expected)

def test_many_to_many_retrieve_prefetch_related(self):
queryset = ManyToManySource.objects.all().prefetch_related('targets')
serializer = ManyToManySourceSerializer(queryset, many=True)
with self.assertNumQueries(2):
serializer.data

def test_reverse_many_to_many_retrieve(self):
queryset = ManyToManyTarget.objects.all()
serializer = ManyToManyTargetSerializer(queryset, many=True)
Expand Down Expand Up @@ -188,6 +194,12 @@ def test_reverse_foreign_key_retrieve(self):
with self.assertNumQueries(3):
self.assertEqual(serializer.data, expected)

def test_reverse_foreign_key_retrieve_prefetch_related(self):
queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
serializer = ForeignKeyTargetSerializer(queryset, many=True)
with self.assertNumQueries(2):
serializer.data

def test_foreign_key_update(self):
data = {'id': 1, 'name': 'source-1', 'target': 2}
instance = ForeignKeySource.objects.get(pk=1)
Expand Down
15 changes: 14 additions & 1 deletion tests/test_relations_slug.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,14 @@ def test_foreign_key_retrieve(self):
{'id': 2, 'name': 'source-2', 'target': 'target-1'},
{'id': 3, 'name': 'source-3', 'target': 'target-1'}
]
self.assertEqual(serializer.data, expected)
with self.assertNumQueries(4):
self.assertEqual(serializer.data, expected)

def test_foreign_key_retrieve_select_related(self):
queryset = ForeignKeySource.objects.all().select_related('target')
serializer = ForeignKeySourceSerializer(queryset, many=True)
with self.assertNumQueries(1):
serializer.data

def test_reverse_foreign_key_retrieve(self):
queryset = ForeignKeyTarget.objects.all()
Expand All @@ -65,6 +72,12 @@ def test_reverse_foreign_key_retrieve(self):
]
self.assertEqual(serializer.data, expected)

def test_reverse_foreign_key_retrieve_prefetch_related(self):
queryset = ForeignKeyTarget.objects.all().prefetch_related('sources')
serializer = ForeignKeyTargetSerializer(queryset, many=True)
with self.assertNumQueries(2):
serializer.data

def test_foreign_key_update(self):
data = {'id': 1, 'name': 'source-1', 'target': 'target-2'}
instance = ForeignKeySource.objects.get(pk=1)
Expand Down