diff --git a/requirements.txt b/requirements.txt index 43c8f5d..3a71b61 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,6 @@ Django==2.2.19 djangorestframework==3.12.2 +django-readers>=0.0.7 django-zen-queries==2.0.1 coverage==4.2 flake8==3.7.5 diff --git a/serialization_spec/serialization.py b/serialization_spec/serialization.py index a4328ca..7065a41 100644 --- a/serialization_spec/serialization.py +++ b/serialization_spec/serialization.py @@ -1,38 +1,5 @@ -from django.core.exceptions import ImproperlyConfigured from django.db.models import Prefetch -from rest_framework.utils import model_meta -from rest_framework.fields import Field, ReadOnlyField -from rest_framework.serializers import ModelSerializer -from zen_queries.rest_framework import QueriesDisabledViewMixin - -from typing import List, Dict, Union -from collections import OrderedDict -import copy - -""" -Parse a serialization spec such as: - -class ProductVersionDetail(SerializationSpecMixin, generics.RetrieveAPIView): - - queryset = ProductVersion.objects.all() - serialization_spec = [ - 'id', - {'product': [ - 'id', - 'name' - ]}, - {'report_templates': [ - 'id', - 'name' - ]} - ] - -1. fetch the data required to populate this -2. output it - -mixin should implement get_queryset() and get_serializer() - -""" +from django_readers import specs, pairs, qs, rest_framework class SerializationSpecPlugin: @@ -78,140 +45,6 @@ def __init__(self, field_name, serialization_spec=None): self.serialization_spec = serialization_spec -class SerializationSpecPluginField(Field): - def __init__(self, plugin): - self.plugin = plugin - super().__init__(source='*', read_only=True) - - def to_representation(self, value): - return self.plugin.get_value(value) - - -class AliasedField(ReadOnlyField): - def __init__(self, field_name): - super().__init__(source=field_name, read_only=True) - - -def get_fields(serialization_spec): - return sum( - [list(each.keys()) if isinstance(each, dict) else [each] for each in serialization_spec], - [] - ) - - -def get_only_fields(model, serialization_spec): - field_info = model_meta.get_field_info(model) - fields = set(field_info.fields_and_pk.keys()) | set(field_info.forward_relations.keys()) - return [ - field for field in get_fields(serialization_spec) - if field in fields - ] - - -def get_childspecs(serialization_spec): - return [each for each in serialization_spec if isinstance(each, dict)] - - -def handle_filtered(item): - key, values = item - if isinstance(values, Filtered): - return key, values.field_name or key, values.serialization_spec - return key, key, values - - -def make_serializer_class(model, serialization_spec): - relations = model_meta.get_field_info(model).relations - - return type( - 'MySerializer', - (ModelSerializer,), - { - 'Meta': type( - 'Meta', - (object,), - {'model': model, 'fields': get_fields(serialization_spec)} - ), - **{ - key: ( - SerializationSpecPluginField(values) if isinstance(values, SerializationSpecPlugin) - else AliasedField(field_name) if values is None - else make_serializer_class( - relations[field_name].related_model, - values - )(many=relations[field_name].to_many) - ) - for key, field_name, values - in [handle_filtered(item) for each in get_childspecs(serialization_spec) for item in each.items()] - }, - } - ) - - -def has_plugin(spec): - return isinstance(spec, list) and any( - isinstance(childspec, SerializationSpecPlugin) or has_plugin(childspec) - for each in spec if isinstance(each, dict) - for key, childspec in each.items() - ) - - -def prefetch_related(request_user, queryset, model, prefixes, serialization_spec, use_select_related): - relations = model_meta.get_field_info(model).relations - - for each in serialization_spec: - if isinstance(each, dict): - for key, childspec in each.items(): - if isinstance(childspec, SerializationSpecPlugin): - childspec.key = key - childspec.request_user = request_user - queryset = childspec.modify_queryset(queryset) - - else: - filters, to_attr = None, None - if isinstance(childspec, Filtered): - if not childspec.serialization_spec: - continue - - filters = childspec.filters - if childspec.field_name: - to_attr = key - key = childspec.field_name - childspec = childspec.serialization_spec - - relation = relations[key] - related_model = relation.related_model - - key_path = '__'.join(prefixes + [key]) - - if (relation.model_field and relation.model_field.one_to_one) or (use_select_related and not relation.to_many) and not has_plugin(childspec): - # no way to .only() on a select_related field - queryset = queryset.select_related(key_path) - queryset = prefetch_related(request_user, queryset, related_model, prefixes + [key], childspec, use_select_related) - else: - only_fields = get_only_fields(related_model, childspec) - if relation.reverse and not relation.has_through_model: - # need to include the reverse FK to allow prefetch to stitch results together - # Unfortunately that info is in the model._meta but is not in the RelationInfo tuple - reverse_fk = next( - rel.field.name - for rel in model._meta.related_objects - if rel.get_accessor_name() == key - ) - has_reverse_fk = any(field.name == reverse_fk for field in relation.related_model._meta.fields) - if has_reverse_fk: - only_fields += ['%s_id' % reverse_fk] - inner_queryset = prefetch_related(request_user, related_model.objects.only(*only_fields), related_model, [], childspec, use_select_related) - if filters: - inner_queryset = inner_queryset.filter(filters).distinct() - queryset = queryset.prefetch_related(Prefetch( - key_path, - queryset=inner_queryset, - **({'to_attr': to_attr} if to_attr else {}) - )) - - return queryset - - def get_serialization_spec(view_or_plugin, request_user=None): if hasattr(view_or_plugin, 'get_serialization_spec'): view_or_plugin.request_user = request_user @@ -219,130 +52,60 @@ def get_serialization_spec(view_or_plugin, request_user=None): return getattr(view_or_plugin, 'serialization_spec', None) -def expand_nested_specs(serialization_spec, request_user): - expanded_serialization_spec = [] - - for each in serialization_spec: - if not isinstance(each, dict): - expanded_serialization_spec.append(each) - else: - expanded_dict = {} - for key, childspec in each.items(): - if isinstance(childspec, SerializationSpecPlugin): - serialization_spec = get_serialization_spec(childspec, request_user) - if serialization_spec is not None: - plugin_copy = copy.deepcopy(childspec) - plugin_copy.serialization_spec = expand_nested_specs(plugin_copy.serialization_spec, request_user) - expanded_serialization_spec += plugin_copy.serialization_spec - expanded_dict[key] = plugin_copy - else: - expanded_dict[key] = childspec - elif isinstance(childspec, Filtered): - if childspec.serialization_spec: - childspec.serialization_spec = expand_nested_specs(childspec.serialization_spec, request_user) - expanded_dict[key] = childspec +def adapt_plugin_spec(plugin_spec, request_user=None): + assert len(plugin_spec) == 1 + key, plugin = next(iter(plugin_spec.items())) + plugin.key = key + plugin.request_user = request_user + + plugin_spec = get_serialization_spec(plugin) + if plugin_spec: + prepare, _ = specs.process(preprocess_spec(plugin_spec, request_user=request_user)) + else: + prepare = plugin.modify_queryset + + return prepare, plugin.get_value + + +def preprocess_item(item, request_user=None): + if isinstance(item, dict): + processed_item = [] + for key, value in item.items(): + if isinstance(value, list): + processed_item.append({key: preprocess_spec(value, request_user=request_user)}) + elif isinstance(value, SerializationSpecPlugin): + processed_item.append({key: adapt_plugin_spec({key: value}, request_user=request_user)}) + elif isinstance(value, Filtered): + if value.serialization_spec is None: + spec = {key: value.field_name} else: - expanded_dict[key] = expand_nested_specs(childspec, request_user) - expanded_serialization_spec.append(expanded_dict) - - return expanded_serialization_spec - - -class NormalisedSpec: - def __init__(self): - self.spec = None - self.fields = OrderedDict() - self.relations = OrderedDict() - - -def normalise_spec(serialization_spec): - def normalise(spec, normalised_spec): - if isinstance(spec, SerializationSpecPlugin) or isinstance(spec, Filtered): - normalised_spec.spec = spec - return - - for each in spec: - if isinstance(each, dict): - for key, childspec in each.items(): - if key not in normalised_spec.relations: - normalised_spec.relations[key] = NormalisedSpec() - normalise(childspec, normalised_spec.relations[key]) + relationship_spec = preprocess_spec(value.serialization_spec, request_user=request_user) + if value.filters: + relationship_spec.append( + pairs.prepare_only( + qs.pipe( + qs.filter(value.filters), + qs.distinct() + ) + ) + ) + to_attr = key if value.field_name and value.field_name != key else None + spec = specs.relationship(value.field_name or key, relationship_spec, to_attr=to_attr) + processed_item.append(spec) else: - normalised_spec.fields[each] = True - - def combine(normalised_spec): - return normalised_spec.spec or ( - list(normalised_spec.fields.keys()) + ([{ - key: combine(value) - for key, value in normalised_spec.relations.items() - }] if normalised_spec.relations else []) - ) - - normalised_spec = NormalisedSpec() - normalise(serialization_spec, normalised_spec) - return combine(normalised_spec) - - -def expand_many2many_id_fields(model, serialization_spec): - # Convert raw M2M fields to ManyToManyIDsPlugin - many_related_models = { - field_name: relation.related_model - for field_name, relation in model_meta.get_field_info(model).relations.items() - if relation.to_many - } - - for idx, each in enumerate(serialization_spec): - if not isinstance(each, dict): - if each in many_related_models: - serialization_spec[idx] = {each: ManyToManyIDsPlugin(many_related_models[each], each)} - else: - for key, childspec in each.items(): - if key in many_related_models: - expand_many2many_id_fields(many_related_models[key], each[key]) - - -def prefetch_queryset(queryset, serialization_spec, user=None, use_select_related=False): - expand_many2many_id_fields(queryset.model, serialization_spec) - serialization_spec = expand_nested_specs(serialization_spec, user) - serialization_spec = normalise_spec(serialization_spec) - queryset = queryset.only(*get_only_fields(queryset.model, serialization_spec)) - return prefetch_related(user, queryset, queryset.model, [], serialization_spec, use_select_related) - - -class SerializationSpecMixin(QueriesDisabledViewMixin): - - serialization_spec = None # type: SerializationSpec - - def get_object(self): - self.use_select_related = True - return super().get_object() - - def get_queryset(self): - self.serialization_spec = get_serialization_spec(self) - if self.serialization_spec is None: - raise ImproperlyConfigured('SerializationSpecMixin requires serialization_spec or get_serialization_spec') - - return prefetch_queryset(self.queryset, self.serialization_spec, self.request.user, getattr(self, 'use_select_related', False)) - - def get_serializer_class(self): - return make_serializer_class(self.queryset.model, self.serialization_spec) + processed_item.append({key: value}) + return processed_item + return [item] -""" -serialization_spec type should be +def preprocess_spec(spec, request_user=None): + processed_spec = [] + for item in spec: + processed_spec += preprocess_item(item, request_user=request_user) + return processed_spec - SerializationSpec = List[Union[str, Dict[str, Union[SerializationSpecPlugin, 'SerializationSpec']]]] -But recursive types are not yet implemented :( -So we specify to an (arbitrary) depth of 5 -""" -SerializationSpec = List[Union[str, Dict[str, Union[Filtered, SerializationSpecPlugin, - List[Union[str, Dict[str, Union[Filtered, SerializationSpecPlugin, - List[Union[str, Dict[str, Union[Filtered, SerializationSpecPlugin, - List[Union[str, Dict[str, Union[Filtered, SerializationSpecPlugin, - List[Union[str, Dict[str, Union[Filtered, SerializationSpecPlugin, - List]]]] - ]]]] - ]]]] - ]]]] -]]]] +class SerializationSpecMixin(rest_framework.SpecMixin): + def get_spec(self): + spec = get_serialization_spec(self) or super().get_spec() + return preprocess_spec(spec, request_user=self.request.user) diff --git a/setup.py b/setup.py index 72c9c61..d6d5f22 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,8 @@ install_requires = [ 'Django>=1.11', 'djangorestframework>=3.5.3', - 'django-zen-queries>=1.0.0' + 'django-zen-queries>=1.0.0', + 'django-readers>=0.0.7', ] def get_version(package): diff --git a/tests/test_api.py b/tests/test_api.py index cf483c5..978971c 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -3,6 +3,7 @@ from django.core.exceptions import ImproperlyConfigured from django.test import TestCase from rest_framework.test import APIClient +from rest_framework.utils.encoders import JSONEncoder from django.urls import reverse from django.db import connection from django.test.utils import CaptureQueriesContext @@ -40,8 +41,8 @@ def assert_status(self, response, status_code): def assertJsonEqual(self, expected, actual): self.assertEqual( - json.dumps(expected, indent=4, sort_keys=True), - json.dumps(actual, indent=4, sort_keys=True) + json.dumps(expected, indent=4, sort_keys=True, cls=JSONEncoder), + json.dumps(actual, indent=4, sort_keys=True, cls=JSONEncoder) ) @@ -91,7 +92,8 @@ def test_single_fk_and_reverse_fk(self): sorted(query['sql'] for query in django_version_compat(capture.captured_queries)), [ """SELECT "tests_class"."id", "tests_class"."name", "tests_class"."teacher_id" FROM "tests_class" WHERE "tests_class"."teacher_id" IN ('00000000000000000000000000000002') ORDER BY "tests_class"."id" ASC""", - """SELECT "tests_teacher"."id", "tests_teacher"."name", "tests_teacher"."school_id", "tests_school"."id", "tests_school"."created", "tests_school"."modified", "tests_school"."name", "tests_school"."lea_id" FROM "tests_teacher" INNER JOIN "tests_school" ON ("tests_teacher"."school_id" = "tests_school"."id") WHERE "tests_teacher"."id" = '00000000000000000000000000000002'""", + """SELECT "tests_school"."id", "tests_school"."name" FROM "tests_school" WHERE "tests_school"."id" IN ('00000000000000000000000000000001') ORDER BY "tests_school"."id" ASC""", + """SELECT "tests_teacher"."id", "tests_teacher"."name", "tests_teacher"."school_id" FROM "tests_teacher" WHERE "tests_teacher"."id" = '00000000000000000000000000000002'""", ] ) @@ -150,8 +152,10 @@ def test_single_fk_on_fk_and_reverse_m2m(self): self.assertJsonEqual( sorted(query['sql'] for query in django_version_compat(capture.captured_queries)), [ - """SELECT "tests_class"."id", "tests_class"."name", "tests_class"."teacher_id", "tests_teacher"."id", "tests_teacher"."created", "tests_teacher"."modified", "tests_teacher"."name", "tests_teacher"."school_id", "tests_school"."id", "tests_school"."created", "tests_school"."modified", "tests_school"."name", "tests_school"."lea_id" FROM "tests_class" INNER JOIN "tests_teacher" ON ("tests_class"."teacher_id" = "tests_teacher"."id") INNER JOIN "tests_school" ON ("tests_teacher"."school_id" = "tests_school"."id") WHERE "tests_class"."id" = '00000000000000000000000000000006'""", - """SELECT ("tests_student_classes"."class_id") AS "_prefetch_related_val_class_id", "tests_student"."id", "tests_student"."name" FROM "tests_student" INNER JOIN "tests_student_classes" ON ("tests_student"."id" = "tests_student_classes"."student_id") WHERE "tests_student_classes"."class_id" IN ('00000000000000000000000000000006') ORDER BY "tests_student"."id" ASC""" + """SELECT "tests_class"."id", "tests_class"."name", "tests_class"."teacher_id" FROM "tests_class" WHERE "tests_class"."id" = '00000000000000000000000000000006'""", + """SELECT "tests_school"."id", "tests_school"."name" FROM "tests_school" WHERE "tests_school"."id" IN ('00000000000000000000000000000001') ORDER BY "tests_school"."id" ASC""", + """SELECT "tests_teacher"."id", "tests_teacher"."name", "tests_teacher"."school_id" FROM "tests_teacher" WHERE "tests_teacher"."id" IN ('00000000000000000000000000000002') ORDER BY "tests_teacher"."id" ASC""", + """SELECT ("tests_student_classes"."class_id") AS "_prefetch_related_val_class_id", "tests_student"."id", "tests_student"."name" FROM "tests_student" INNER JOIN "tests_student_classes" ON ("tests_student"."id" = "tests_student_classes"."student_id") WHERE "tests_student_classes"."class_id" IN ('00000000000000000000000000000006') ORDER BY "tests_student"."id" ASC""", ] ) @@ -179,8 +183,9 @@ def test_single_fk_on_many_to_many(self): self.assertJsonEqual( sorted(query['sql'] for query in django_version_compat(capture.captured_queries)), [ - """SELECT "tests_class"."id", "tests_class"."subject_id", "tests_class"."name", "tests_class"."teacher_id", "tests_teacher"."id", "tests_teacher"."created", "tests_teacher"."modified", "tests_teacher"."name", "tests_teacher"."school_id" FROM "tests_class" INNER JOIN "tests_teacher" ON ("tests_class"."teacher_id" = "tests_teacher"."id") WHERE "tests_class"."subject_id" IN ('00000000000000000000000000000004') ORDER BY "tests_class"."id" ASC""", + """SELECT "tests_class"."id", "tests_class"."subject_id", "tests_class"."name", "tests_class"."teacher_id" FROM "tests_class" WHERE "tests_class"."subject_id" IN ('00000000000000000000000000000004') ORDER BY "tests_class"."id" ASC""", """SELECT "tests_subject"."id", "tests_subject"."name" FROM "tests_subject" WHERE "tests_subject"."id" = '00000000000000000000000000000004'""", + """SELECT "tests_teacher"."id", "tests_teacher"."name" FROM "tests_teacher" WHERE "tests_teacher"."id" IN ('00000000000000000000000000000002') ORDER BY "tests_teacher"."id" ASC""", ] ) @@ -207,8 +212,9 @@ def test_single_reverse_fk_on_fk(self): self.assertJsonEqual( sorted(query['sql'] for query in django_version_compat(capture.captured_queries)), [ - """SELECT "tests_school"."id", "tests_school"."name", "tests_school"."lea_id" FROM "tests_school" WHERE "tests_school"."lea_id" IN ('00000000000000000000000000000000') ORDER BY "tests_school"."id" ASC""", - """SELECT "tests_school"."id", "tests_school"."name", "tests_school"."lea_id", "tests_lea"."id", "tests_lea"."created", "tests_lea"."modified", "tests_lea"."name" FROM "tests_school" INNER JOIN "tests_lea" ON ("tests_school"."lea_id" = "tests_lea"."id") WHERE "tests_school"."id" = '00000000000000000000000000000001'""", + """SELECT "tests_lea"."id", "tests_lea"."name" FROM "tests_lea" WHERE "tests_lea"."id" IN ('00000000000000000000000000000000') ORDER BY "tests_lea"."id" ASC""", + """SELECT "tests_school"."id", "tests_school"."name", "tests_school"."lea_id" FROM "tests_school" WHERE "tests_school"."id" = '00000000000000000000000000000001'""", + """SELECT "tests_school"."id", "tests_school"."name", "tests_school"."lea_id" FROM "tests_school" WHERE "tests_school"."lea_id" IN ('00000000000000000000000000000000') ORDER BY "tests_school"."id" ASC""" ] ) @@ -239,7 +245,8 @@ def test_single_many_to_many_with_through(self): self.assertJsonEqual( sorted(query['sql'] for query in django_version_compat(capture.captured_queries)), [ - """SELECT "tests_assignmentstudent"."id", "tests_assignmentstudent"."is_complete", "tests_assignmentstudent"."assignment_id", "tests_assignmentstudent"."student_id", "tests_assignment"."id", "tests_assignment"."created", "tests_assignment"."modified", "tests_assignment"."name", "tests_assignment"."clasz_id" FROM "tests_assignmentstudent" INNER JOIN "tests_assignment" ON ("tests_assignmentstudent"."assignment_id" = "tests_assignment"."id") WHERE "tests_assignmentstudent"."student_id" IN ('00000000000000000000000000000015') ORDER BY "tests_assignmentstudent"."id" ASC""", + """SELECT "tests_assignment"."id", "tests_assignment"."name" FROM "tests_assignment" WHERE "tests_assignment"."id" IN ('00000000000000000000000000000020', '00000000000000000000000000000021') ORDER BY "tests_assignment"."id" ASC""", + """SELECT "tests_assignmentstudent"."id", "tests_assignmentstudent"."is_complete", "tests_assignmentstudent"."assignment_id", "tests_assignmentstudent"."student_id" FROM "tests_assignmentstudent" WHERE "tests_assignmentstudent"."student_id" IN ('00000000000000000000000000000015') ORDER BY "tests_assignmentstudent"."id" ASC""", """SELECT "tests_student"."id", "tests_student"."name" FROM "tests_student" WHERE "tests_student"."id" = '00000000000000000000000000000015'""", """SELECT ("tests_assignmentstudent"."student_id") AS "_prefetch_related_val_student_id", "tests_assignment"."id", "tests_assignment"."name" FROM "tests_assignment" INNER JOIN "tests_assignmentstudent" ON ("tests_assignment"."id" = "tests_assignmentstudent"."assignment_id") WHERE "tests_assignmentstudent"."student_id" IN ('00000000000000000000000000000015') ORDER BY "tests_assignment"."id" ASC""", ] @@ -274,7 +281,9 @@ def test_single_count_plugin(self): sorted(query['sql'] for query in django_version_compat(capture.captured_queries)), [ """SELECT "tests_assignment"."id", "tests_assignment"."name", "tests_assignment"."clasz_id" FROM "tests_assignment" WHERE "tests_assignment"."id" = '00000000000000000000000000000020'""", - """SELECT "tests_class"."id", "tests_class"."name", "tests_class"."teacher_id", COUNT(DISTINCT "tests_student_classes"."student_id") AS "student_count", "tests_teacher"."id", "tests_teacher"."created", "tests_teacher"."modified", "tests_teacher"."name", "tests_teacher"."school_id" FROM "tests_class" LEFT OUTER JOIN "tests_student_classes" ON ("tests_class"."id" = "tests_student_classes"."class_id") INNER JOIN "tests_teacher" ON ("tests_class"."teacher_id" = "tests_teacher"."id") WHERE "tests_class"."id" IN ('00000000000000000000000000000006') GROUP BY "tests_class"."id", "tests_class"."name", "tests_class"."teacher_id", "tests_teacher"."id", "tests_teacher"."created", "tests_teacher"."modified", "tests_teacher"."name", "tests_teacher"."school_id\"""", + """SELECT "tests_class"."id", "tests_class"."name", "tests_class"."teacher_id" FROM "tests_class" WHERE "tests_class"."id" IN ('00000000000000000000000000000006') ORDER BY "tests_class"."id" ASC""", + """SELECT "tests_class"."id", COUNT(DISTINCT "tests_student_classes"."student_id") AS "student_count" FROM "tests_class" LEFT OUTER JOIN "tests_student_classes" ON ("tests_class"."id" = "tests_student_classes"."class_id") WHERE "tests_class"."id" IN ('00000000000000000000000000000006') GROUP BY "tests_class"."id\"""", + """SELECT "tests_teacher"."id", "tests_teacher"."name" FROM "tests_teacher" WHERE "tests_teacher"."id" IN ('00000000000000000000000000000002') ORDER BY "tests_teacher"."id" ASC""", """SELECT ("tests_assignmentstudent"."assignment_id") AS "_prefetch_related_val_assignment_id", "tests_student"."id", "tests_student"."name", COUNT(DISTINCT "tests_student_classes"."class_id") AS "classes_count" FROM "tests_student" LEFT OUTER JOIN "tests_student_classes" ON ("tests_student"."id" = "tests_student_classes"."student_id") INNER JOIN "tests_assignmentstudent" ON ("tests_student"."id" = "tests_assignmentstudent"."student_id") WHERE "tests_assignmentstudent"."assignment_id" IN ('00000000000000000000000000000020') GROUP BY ("tests_assignmentstudent"."assignment_id"), "tests_student"."id", "tests_student"."name\"""", """SELECT ("tests_student_classes"."student_id") AS "_prefetch_related_val_student_id", "tests_class"."id" FROM "tests_class" INNER JOIN "tests_student_classes" ON ("tests_class"."id" = "tests_student_classes"."class_id") WHERE "tests_student_classes"."student_id" IN ('00000000000000000000000000000015') ORDER BY "tests_class"."id" ASC""", ] @@ -284,7 +293,9 @@ def test_single_count_plugin(self): sorted(query['sql'] for query in django_version_compat(capture.captured_queries)), [ """SELECT "tests_assignment"."id", "tests_assignment"."name", "tests_assignment"."clasz_id" FROM "tests_assignment" WHERE "tests_assignment"."id" = '00000000000000000000000000000020'""", - """SELECT "tests_class"."id", "tests_class"."name", "tests_class"."teacher_id", COUNT(DISTINCT "tests_student_classes"."student_id") AS "student_count", "tests_teacher"."id", "tests_teacher"."created", "tests_teacher"."modified", "tests_teacher"."name", "tests_teacher"."school_id" FROM "tests_class" LEFT OUTER JOIN "tests_student_classes" ON ("tests_class"."id" = "tests_student_classes"."class_id") INNER JOIN "tests_teacher" ON ("tests_class"."teacher_id" = "tests_teacher"."id") WHERE "tests_class"."id" IN ('00000000000000000000000000000006') GROUP BY "tests_class"."id", "tests_class"."name", "tests_class"."teacher_id", "tests_teacher"."id", "tests_teacher"."created", "tests_teacher"."modified", "tests_teacher"."name", "tests_teacher"."school_id" ORDER BY "tests_class"."id" ASC""", + """SELECT "tests_class"."id", "tests_class"."name", "tests_class"."teacher_id" FROM "tests_class" WHERE "tests_class"."id" IN ('00000000000000000000000000000006') ORDER BY "tests_class"."id" ASC""", + """SELECT "tests_class"."id", COUNT(DISTINCT "tests_student_classes"."student_id") AS "student_count" FROM "tests_class" LEFT OUTER JOIN "tests_student_classes" ON ("tests_class"."id" = "tests_student_classes"."class_id") WHERE "tests_class"."id" IN ('00000000000000000000000000000006') GROUP BY "tests_class"."id" ORDER BY "tests_class"."id" ASC""", + """SELECT "tests_teacher"."id", "tests_teacher"."name" FROM "tests_teacher" WHERE "tests_teacher"."id" IN ('00000000000000000000000000000002') ORDER BY "tests_teacher"."id" ASC""", """SELECT ("tests_assignmentstudent"."assignment_id") AS "_prefetch_related_val_assignment_id", "tests_student"."id", "tests_student"."name", COUNT(DISTINCT "tests_student_classes"."class_id") AS "classes_count" FROM "tests_student" LEFT OUTER JOIN "tests_student_classes" ON ("tests_student"."id" = "tests_student_classes"."student_id") INNER JOIN "tests_assignmentstudent" ON ("tests_student"."id" = "tests_assignmentstudent"."student_id") WHERE "tests_assignmentstudent"."assignment_id" IN ('00000000000000000000000000000020') GROUP BY ("tests_assignmentstudent"."assignment_id"), "tests_student"."id", "tests_student"."name" ORDER BY "tests_student"."id" ASC""", """SELECT ("tests_student_classes"."student_id") AS "_prefetch_related_val_student_id", "tests_class"."id" FROM "tests_class" INNER JOIN "tests_student_classes" ON ("tests_class"."id" = "tests_student_classes"."class_id") WHERE "tests_student_classes"."student_id" IN ('00000000000000000000000000000015') ORDER BY "tests_class"."id" ASC""", ] @@ -369,7 +380,7 @@ def test_view_must_have_serialization_spec(self): with self.assertRaises(ImproperlyConfigured) as cm: self.client.get(reverse('misconfigured')) - self.assertEqual(str(cm.exception), 'SerializationSpecMixin requires serialization_spec or get_serialization_spec') + self.assertEqual(str(cm.exception), 'SpecMixin requires spec or get_spec') class CollidingFieldsRegressionTestCase(SerializationSpecTestCase): diff --git a/tests/test_normalisation.py b/tests/test_normalisation.py deleted file mode 100644 index 3f08b0c..0000000 --- a/tests/test_normalisation.py +++ /dev/null @@ -1,88 +0,0 @@ -from django.test import TestCase - -from serialization_spec.serialization import normalise_spec - - -class NormalisationTestCase(TestCase): - - def test_base_case(self): - spec = [ - 'one', - {'two': [ - 'three', - ]}, - {'four': []}, - ] - - self.assertEqual(normalise_spec(spec), [ - 'one', - { - 'two': [ - 'three', - ], - 'four': [], - }, - ]) - - def test_merge_dupes_one_level(self): - spec = [ - 'one', - {'two': [ - 'three', - ]}, - 'one', - ] - - self.assertEqual(normalise_spec(spec), [ - 'one', - {'two': [ - 'three', - ]}, - ]) - - def test_merge_dupes_two_levels(self): - spec = [ - 'one', - {'two': [ - 'three', - ]}, - {'two': [ - 'four', - ]}, - ] - - self.assertEqual(normalise_spec(spec), [ - 'one', - {'two': [ - 'three', - 'four', - ]}, - ]) - - def test_merge_dupes_three_levels(self): - spec = [ - 'one', - {'two': [ - {'three': [ - 'five' - ]} - ]}, - {'two': [ - 'four', - {'three': [ - 'five', - 'six' - ]} - ]}, - ] - - self.assertEqual(normalise_spec(spec), [ - 'one', - {'two': [ - 'four', - {'three': [ - 'five', - 'six', - ]} - ]} - ]) diff --git a/tests/tests.py b/tests/tests.py index 1daac3d..f34f3b4 100644 --- a/tests/tests.py +++ b/tests/tests.py @@ -2,6 +2,7 @@ from .models import Teacher, Class from django.db.models.query import Q +from django_readers import pairs, specs from rest_framework import generics from unittest.mock import MagicMock from serialization_spec.serialization import SerializationSpecMixin, SerializationSpecPlugin, Filtered, Aliased @@ -65,7 +66,7 @@ def get_value(self, instance): {'school_name_upper': SchoolNameUpper()}, ] - with self.assertNumQueries(1): + with self.assertNumQueries(2): response = self.detail_view.retrieve(self.request) self.assertJsonEqual(response.data, { @@ -75,25 +76,25 @@ def get_value(self, instance): def test_merge_specs(self): class ClassNames(SerializationSpecPlugin): serialization_spec = [ - {'class_set': [ + specs.relationship('class_set', [ 'name', - ]} + ], to_attr="class_set_for_name") ] def get_value(self, instance): - return ', '.join(each.name for each in instance.class_set.all()) + return ', '.join(each.name for each in instance.class_set_for_name) class SubjectNames(SerializationSpecPlugin): serialization_spec = [ - {'class_set': [ + specs.relationship('class_set', [ {'subject': [ 'name', ]} - ]} + ], to_attr="class_set_for_subject_name") ] def get_value(self, instance): - return ', '.join(each.subject.name for each in instance.class_set.all()) + return ', '.join(each.subject.name for each in instance.class_set_for_subject_name) self.detail_view.serialization_spec = [ 'name', @@ -111,7 +112,7 @@ def get_value(self, instance): def test_reverse_fk_list_ids(self): self.detail_view.serialization_spec = [ - 'class_set' + {"class_set": pairs.pk_list('class_set')} ] response = self.detail_view.retrieve(self.request) @@ -125,7 +126,7 @@ class ClassDetailView(SerializationSpecMixin, generics.RetrieveAPIView): queryset = Class.objects.all() serialization_spec = [ - 'student_set' + {"student_set": pairs.pk_list('student_set')} ] detail_view = ClassDetailView( @@ -166,7 +167,7 @@ def get_value(self, instance): {'school_name_upper': SchoolNameUpper()}, ] - with self.assertNumQueries(2): + with self.assertNumQueries(3): response = self.detail_view.retrieve(self.request) self.assertJsonEqual(response.data, { @@ -183,9 +184,8 @@ def test_spec_with_filter(self): ]}, ] - with self.assertNumQueries(2): + with self.assertNumQueries(3): response = self.detail_view.retrieve(self.request) - self.assertJsonEqual(response.data, { "school": { "name": "Kitteh High", @@ -207,7 +207,7 @@ def test_spec_with_aliased_field(self): ]}, ] - with self.assertNumQueries(2): + with self.assertNumQueries(3): response = self.detail_view.retrieve(self.request) self.assertJsonEqual(response.data, { diff --git a/tests/views.py b/tests/views.py index bc0b16b..e76bd75 100644 --- a/tests/views.py +++ b/tests/views.py @@ -1,3 +1,4 @@ +from django_readers import pairs, specs from rest_framework import generics from serialization_spec.serialization import SerializationSpecMixin, SerializationSpecPlugin, Aliased from serialization_spec.plugins import CountOf @@ -138,23 +139,22 @@ class StudentWithAssignmentsDetailView(SerializationSpecMixin, generics.Retrieve class ClassName(SerializationSpecPlugin): serialization_spec = [ - {'clasz': [ + {'class_for_class_name': specs.relationship('clasz', [ 'name', {'teacher': [ 'name' ]}, - ]}, + ], to_attr="class_for_class_name")}, ] def get_value(self, instance): - return '%s - %s' % (instance.clasz.name, instance.clasz.teacher.name) + return '%s - %s' % (instance.class_for_class_name.name, instance.class_for_class_name.teacher.name) class AssignmentDetailView(SerializationSpecMixin, generics.RetrieveAPIView): queryset = Assignment.objects.all() lookup_field = 'id' - serialization_spec = [ 'id', 'name', @@ -162,7 +162,7 @@ class AssignmentDetailView(SerializationSpecMixin, generics.RetrieveAPIView): 'id', 'name', {'classes_count': CountOf('classes')}, - 'classes', + {'classes': pairs.pk_list('classes')}, ]}, {'class_name': ClassName()}, {'clasz': [ @@ -179,8 +179,8 @@ class StudentWithClassesAndAssignmentsDetailView(SerializationSpecMixin, generic serialization_spec = [ 'id', 'name', - 'assignments', - 'classes', + {'assignments': pairs.pk_list('assignments')}, + {'classes': pairs.pk_list('classes')}, ]