From c910a96f5f6f4a9a1c4af38fe75a0ea2e5caab71 Mon Sep 17 00:00:00 2001 From: Paul Mitchell-Gears Date: Thu, 6 Feb 2020 15:04:59 +0000 Subject: [PATCH] Try using values_list() and manual stitching to create nested data outputs --- serialization_spec/serializer.py | 244 +++++++++++++++++++++++++++++++ tests/test_serialize.py | 169 +++++++++++++++++++++ 2 files changed, 413 insertions(+) create mode 100644 serialization_spec/serializer.py create mode 100644 tests/test_serialize.py diff --git a/serialization_spec/serializer.py b/serialization_spec/serializer.py new file mode 100644 index 0000000..4b1311e --- /dev/null +++ b/serialization_spec/serializer.py @@ -0,0 +1,244 @@ +from rest_framework.utils import model_meta +from django.db.models.fields.reverse_related import ManyToManyRel + + +def split(items, predicate): + return ( + [item for item in items if predicate(item)], + [item for item in items if not predicate(item)], + ) + + +def get_reverse_related_object_ids(own_ids, model, key): + relation = model_meta.get_field_info(model).reverse_relations[key] + reverse_fk = next( + rel.field.name + for rel in model._meta.related_objects + if rel.get_accessor_name() == key + ) + related_objects = relation.related_model.objects.filter(**{reverse_fk + '__in': own_ids}) + data_dict = {} + for own_id, other_id in related_objects.values_list(reverse_fk, 'id'): + data_dict[own_id] = data_dict.get(own_id, []) + [other_id] + return data_dict + + +def _get_m2m_related_object_ids(own_ids, model, key): + relation = model_meta.get_field_info(model).forward_relations[key] + m2m_field_name = relation.model_field.m2m_field_name() + m2m_reverse_field_name = relation.model_field.m2m_reverse_field_name() + related_objects = relation.model_field.remote_field.through.objects.filter(**{m2m_field_name + '__in': own_ids}) + return related_objects.values_list(m2m_field_name, m2m_reverse_field_name) + + +def get_m2m_related_object_ids(own_ids, model, key): + data_dict = {} + for own_id, other_id in _get_m2m_related_object_ids(own_ids, model, key): + data_dict[own_id] = data_dict.get(own_id, []) + [other_id] + return data_dict + + +def _get_reverse_m2m_related_object_ids(own_ids, model, key): + rel = model._meta.fields_map[key] + m2m_field_name = rel.field.m2m_field_name() + m2m_reverse_field_name = rel.field.m2m_reverse_field_name() + related_objects = rel.field.remote_field.through.objects.filter(**{m2m_reverse_field_name + '__in': own_ids}) + return related_objects.values_list(m2m_reverse_field_name, m2m_field_name) + + +def get_reverse_m2m_related_object_ids(own_ids, model, key): + data_dict = {} + for own_id, other_id in _get_reverse_m2m_related_object_ids(own_ids, model, key): + data_dict[own_id] = data_dict.get(own_id, []) + [other_id] + return data_dict + + +def get_reverse_related_objects(own_ids, model, key, spec): + relation = model_meta.get_field_info(model).reverse_relations[key] + reverse_fk = next( + rel.field.name + for rel in model._meta.related_objects + if rel.get_accessor_name() == key + ) + related_objects = relation.related_model.objects.filter(**{reverse_fk + '__in': own_ids}) + + data = serialize(related_objects, spec + [reverse_fk]) + data_dict = {} + for each in data: + data_dict[each[reverse_fk]] = data_dict.get(each[reverse_fk], []) + [each] + if reverse_fk not in spec: + for each in data_dict.values(): + for item in each: + del item[reverse_fk] + return data_dict + + +def get_m2m_related_objects(own_ids, model, key, spec): + m2m_related_object_ids = _get_m2m_related_object_ids(own_ids, model, key) + other_ids = [other_id for own_id, other_id in m2m_related_object_ids] + relation = model_meta.get_field_info(model).forward_relations[key] + related_objects = relation.related_model.objects.filter(id__in=other_ids) + + other_data = serialize(related_objects, spec + ['id']) + other_data_dict = {} + for each in other_data: + other_data_dict[each['id']] = each + if 'id' not in spec: + for each in other_data_dict.values(): + del each['id'] + data_dict = {} + for own_id, other_id in m2m_related_object_ids: + data_dict[own_id] = data_dict.get(own_id, []) + [other_data_dict[other_id]] + return data_dict + + +def get_reverse_m2m_related_objects(own_ids, model, key, spec): + m2m_related_object_ids = _get_reverse_m2m_related_object_ids(own_ids, model, key) + other_ids = [other_id for own_id, other_id in m2m_related_object_ids] + rel = model._meta.fields_map[key] + related_objects = rel.field.remote_field.related_model.objects.filter(id__in=other_ids) + + other_data = serialize(related_objects, spec + ['id']) + other_data_dict = {} + for each in other_data: + other_data_dict[each['id']] = each + if 'id' not in spec: + for each in other_data_dict.values(): + del each['id'] + data_dict = {} + for own_id, other_id in m2m_related_object_ids: + data_dict[own_id] = data_dict.get(own_id, []) + [other_data_dict[other_id]] + return data_dict + + +def get_forward_related_objects(fks, model, key, spec): + relation = model_meta.get_field_info(model).forward_relations[key] + related_objects = relation.related_model.objects.filter(id__in=fks) + data = serialize(related_objects, spec + ['id']) + data_dict = {each['id']: each for each in data} + if 'id' not in spec: + for each in data_dict.values(): + del each['id'] + return data_dict + + +def validate_serialization_spec(model, serialization_spec): + field_info = model_meta.get_field_info(model) + fields, tuples = split(serialization_spec, lambda each: isinstance(each, str)) + + def is_reverse(key): + return key in field_info.reverse_relations + + def is_forward(key): + return key in field_info.forward_relations + + def is_m2m(key): + return key in field_info.forward_relations and field_info.forward_relations[key].to_many + + def is_reverse_m2m(key): + return type(model._meta.fields_map.get(key)) is ManyToManyRel + + reverse_fks, fields = split(fields, is_reverse) + reverse_m2ms, reverse_fks = split(reverse_fks, is_reverse_m2m) + m2m_fields, fields = split(fields, is_m2m) + fks, fields = split(fields, is_forward) + + reverse_fk_objects, fk_objects = split(tuples, lambda each: is_reverse(each[0])) + reverse_m2m_objects, reverse_fk_objects = split(reverse_fk_objects, lambda each: is_reverse_m2m(each[0])) + m2m_objects, fk_objects = split(fk_objects, lambda each: is_m2m(each[0])) + + return ( + fields, + fks, + reverse_fks, + reverse_m2ms, + m2m_fields, + fk_objects, + reverse_fk_objects, + reverse_m2m_objects, + m2m_objects + ) + + +def serialize(queryset, serialization_spec): + model = queryset.model + + ( + fields, + fks, + reverse_fks, + reverse_m2ms, + m2m_fields, + fk_objects, + reverse_fk_objects, + reverse_m2m_objects, + m2m_objects + ) = validate_serialization_spec(model, serialization_spec) + + needs_id = reverse_fks or reverse_fk_objects or m2m_fields or m2m_objects or reverse_m2ms or reverse_m2m_objects + to_fetch = fields + fks + (['id'] if needs_id else []) + ([fk for fk, _ in fk_objects]) + data = list(queryset.values(*to_fetch)) + + if needs_id: + own_ids = [each['id'] for each in data] + + for extra_fields in [reverse_fks, m2m_fields, reverse_m2ms]: + if extra_fields: + for key in extra_fields: + if extra_fields is reverse_fks: + related_items = get_reverse_related_object_ids(own_ids, model, key) + elif extra_fields is m2m_fields: + related_items = get_m2m_related_object_ids(own_ids, model, key) + elif extra_fields is reverse_m2ms: + related_items = get_reverse_m2m_related_object_ids(own_ids, model, key) + + for each in data: + each[key] = related_items[each['id']] + + for extra_objects in [reverse_fk_objects, m2m_objects, reverse_m2m_objects]: + if extra_objects: + for key, spec in extra_objects: + if extra_objects is reverse_fk_objects: + related_items = get_reverse_related_objects(own_ids, model, key, spec) + elif extra_objects is reverse_m2m_objects: + related_items = get_reverse_m2m_related_objects(own_ids, model, key, spec) + else: + related_items = get_m2m_related_objects(own_ids, model, key, spec) + + for each in data: + each[key] = related_items[each['id']] + + if 'id' not in fields: + for each in data: + del each['id'] + + if fk_objects: + for key, spec in fk_objects: + fks = [each[key] for each in data] + related_objects = get_forward_related_objects(fks, model, key, spec) + + for each in data: + each[key] = related_objects[each[key]] + + return data + + +def jsonify(data): + if data is None or type(data) in [int, bool, str]: + return data + if isinstance(data, list): + return [jsonify(value) for value in data] + if isinstance(data, dict): + return {key: jsonify(value) for key, value in data.items()} + return str(data) + + +def serializej(qs, spec): + data = serialize(qs, spec) + return jsonify(data) + + +def serializep(qs, spec): + import json + json_data = serializej(qs, spec) + print(json.dumps(json_data, indent=2)) diff --git a/tests/test_serialize.py b/tests/test_serialize.py new file mode 100644 index 0000000..7b48973 --- /dev/null +++ b/tests/test_serialize.py @@ -0,0 +1,169 @@ +from django.test import TestCase + +from .models import LEA, School, Teacher, Subject, Class, Student, Assignment, AssignmentStudent + +from serialization_spec.serializer import serializej + + +class SerializationTestCase(TestCase): + def setUp(self): + self.lea = LEA.objects.create(name='Brighton & Hove') + self.school = School.objects.create(name='Kitteh High', lea=self.lea) + School.objects.create(name='Hove High', lea=self.lea) + self.teacher = Teacher.objects.create(name='Mr Cat', school=self.school) + Teacher.objects.create(name='Ms Dog', school=self.school) + self.french = Subject.objects.create(name='French') + self.math = Subject.objects.create(name='Math') + self.french_class = Class.objects.create(name='French A', subject=self.french, teacher=self.teacher) + self.math_class = Class.objects.create(name='Math B', subject=self.math, teacher=self.teacher) + students = [ + Student.objects.create(name='Student %d' % idx, school=self.school) + for idx in range(4) + ] + self.french_class.student_set.set(students[:2]) + self.math_class.student_set.set(students[2:]) + + self.student = students[0] + self.assignments = [] + for clasz in [self.french_class, self.math_class]: + is_math = clasz == self.math_class + assignment = Assignment.objects.create(clasz=clasz, name=clasz.name + ' Assignment') + AssignmentStudent.objects.create( + assignment=assignment, student=self.student, is_complete=is_math + ) + self.assignments.append(assignment) + + def test_fields(self): + serialization_spec = [ + 'id', + 'name', + ] + data = serializej(Class.objects.all(), serialization_spec) + self.assertEqual(data, [ + {'id': str(self.french_class.id), 'name': 'French A'}, + {'id': str(self.math_class.id), 'name': 'Math B'}, + ]) + + def test_fk(self): + serialization_spec = [ + 'name', + 'subject', + ] + data = serializej(Class.objects.all(), serialization_spec) + self.assertEqual(data, [ + {'name': 'French A', 'subject': str(self.french.id)}, + {'name': 'Math B', 'subject': str(self.math.id)}, + ]) + + def test_reverse_fk(self): + serialization_spec = [ + 'name', + 'assignment_set', + ] + data = serializej(Class.objects.all(), serialization_spec) + self.assertEqual(data, [ + {'name': 'French A', 'assignment_set': [str(self.assignments[0].id)]}, + {'name': 'Math B', 'assignment_set': [str(self.assignments[1].id)]}, + ]) + + def test_fk_with_spec(self): + serialization_spec = [ + 'name', + ('subject', [ + 'name' + ]), + ] + data = serializej(Class.objects.all(), serialization_spec) + self.assertEqual(data, [ + {'name': 'French A', 'subject': {'name': 'French'}}, + {'name': 'Math B', 'subject': {'name': 'Math'}}, + ]) + + def test_fk_with_nested_spec(self): + serialization_spec = [ + ('teacher', [ + 'name', + ('school', [ + 'name' + ]) + ]) + ] + data = serializej(Class.objects.all()[:1], serialization_spec) + self.assertEqual(data, [ + { + 'teacher': { + 'name': 'Mr Cat', + 'school': { + 'name': 'Kitteh High', + } + } + }, + ]) + + def xtest_reverse_fk_with_spec(self): + serialization_spec = [ + 'name', + ('teacher_set', [ + 'name', + ]) + ] + data = serializej(School.objects.all(), serialization_spec) + self.assertEqual(data, [{ + 'name': 'Kitteh High', + 'teacher_set': [ + {'name': 'Mr Cat'}, + {'name': 'Ms Dog'}, + ] + }]) + + def test_m2m(self): + serialization_spec = [ + 'name', + 'classes' + ] + data = serializej(Student.objects.all(), serialization_spec) + self.assertEqual(data, [ + {'name': 'Student 0', 'classes': [str(self.french_class.id)]}, + {'name': 'Student 1', 'classes': [str(self.french_class.id)]}, + {'name': 'Student 2', 'classes': [str(self.math_class.id)]}, + {'name': 'Student 3', 'classes': [str(self.math_class.id)]}, + ]) + + def test_m2m_with_spec(self): + serialization_spec = [ + 'name', + ('classes', [ + 'name' + ]) + ] + data = serializej(Student.objects.all(), serialization_spec) + self.assertEqual(data, [ + {'name': 'Student 0', 'classes': [{'name': 'French A'}]}, + {'name': 'Student 1', 'classes': [{'name': 'French A'}]}, + {'name': 'Student 2', 'classes': [{'name': 'Math B'}]}, + {'name': 'Student 3', 'classes': [{'name': 'Math B'}]}, + ]) + + def test_reverse_m2m(self): + serialization_spec = [ + 'name', + 'assignees', + ] + data = serializej(Assignment.objects.all(), serialization_spec) + self.assertEqual(data, [ + {'name': 'French A Assignment', 'assignees': [str(self.student.id)]}, + {'name': 'Math B Assignment', 'assignees': [str(self.student.id)]}, + ]) + + def test_reverse_m2m_with_spec(self): + serialization_spec = [ + 'name', + ('assignees', [ + 'name' + ]) + ] + data = serializej(Assignment.objects.all(), serialization_spec) + self.assertEqual(data, [ + {'assignees': [{'name': 'Student 0'}], 'name': 'French A Assignment'}, + {'assignees': [{'name': 'Student 0'}], 'name': 'Math B Assignment'}, + ])