Skip to content

Adding Generic support for various classes #7625

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion rest_framework/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import warnings
from collections import OrderedDict
from collections.abc import Mapping
from typing import Generic, TypeVar

from django.conf import settings
from django.core.exceptions import ObjectDoesNotExist
Expand Down Expand Up @@ -308,8 +309,13 @@ class SkipField(Exception):
'not exist in the `error_messages` dictionary.'
)

_IN = TypeVar("_IN") # Instance Type
_VT = TypeVar("_VT") # Value Type
_DT = TypeVar("_DT") # Data Type
_RP = TypeVar("_RP") # Representation Type

class Field:

class Field(Generic[_VT, _DT, _RP, _IN]):
_creation_counter = 0

default_error_messages = {
Expand Down
6 changes: 5 additions & 1 deletion rest_framework/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
on the request, such as form content or json encoded data.
"""
import codecs
from typing import Generic, TypeVar
from urllib import parse

from django.conf import settings
Expand All @@ -21,8 +22,11 @@
from rest_framework.settings import api_settings
from rest_framework.utils import json

_Data = TypeVar("_Data")
_Files = TypeVar("_Files")

class DataAndFiles:

class DataAndFiles(Generic[_Data, _Files]):
def __init__(self, data, files):
self.data = data
self.files = files
Expand Down
9 changes: 7 additions & 2 deletions rest_framework/relations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import sys
from collections import OrderedDict
from typing import Any, Generic, TypeVar
from urllib import parse

from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist
from django.db.models import Manager
from django.db.models import Manager, Model
from django.db.models.query import QuerySet
from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve
from django.utils.encoding import smart_str, uri_to_iri
Expand Down Expand Up @@ -85,8 +86,12 @@ def __str__(self):
'html_cutoff', 'html_cutoff_text'
)

_MT = TypeVar("_MT", bound=Model)
_DT = TypeVar("_DT") # Data Type
_PT = TypeVar("_PT") # Primitive Type

class RelatedField(Field):

class RelatedField(Generic[_MT, _DT, _PT], Field[_MT, _DT, _PT, Any]):
queryset = None
html_cutoff = None
html_cutoff_text = None
Expand Down
32 changes: 21 additions & 11 deletions rest_framework/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
"""
import copy
import inspect
import sys
import traceback
from collections import OrderedDict, defaultdict
from collections.abc import Mapping
from typing import Any, Generic, TypeVar

from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured
from django.core.exceptions import ValidationError as DjangoValidationError
Expand Down Expand Up @@ -66,6 +68,13 @@
)
from rest_framework.relations import Hyperlink, PKOnlyObject # NOQA # isort:skip

if sys.version_info < (3, 7):
from typing import GenericMeta
else:
class GenericMeta(type):
pass


# We assume that 'validators' are intended for the child serializer,
# rather than the parent serializer.
LIST_SERIALIZER_KWARGS = (
Expand All @@ -79,8 +88,10 @@

# BaseSerializer
# --------------
_IN = TypeVar("_IN") # Instance Type


class BaseSerializer(Field):
class BaseSerializer(Generic[_IN], Field[Any, Any, Any, _IN]):
"""
The BaseSerializer class provides a minimal class which may be used
for writing custom serializer implementations.
Expand Down Expand Up @@ -121,10 +132,6 @@ def __new__(cls, *args, **kwargs):
return cls.many_init(*args, **kwargs)
return super().__new__(cls, *args, **kwargs)

# Allow type checkers to make serializers generic.
def __class_getitem__(cls, *args, **kwargs):
return cls

@classmethod
def many_init(cls, *args, **kwargs):
"""
Expand Down Expand Up @@ -268,7 +275,7 @@ def validated_data(self):
# Serializer & ListSerializer classes
# -----------------------------------

class SerializerMetaclass(type):
class SerializerMetaclass(GenericMeta):
"""
This metaclass sets a dictionary named `_declared_fields` on the class.

Expand Down Expand Up @@ -301,9 +308,9 @@ def visit(name):

return OrderedDict(base_fields + fields)

def __new__(cls, name, bases, attrs):
def __new__(cls, name, bases, attrs, *args, **kwargs):
attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs)
return super().__new__(cls, name, bases, attrs)
return super().__new__(cls, name, bases, attrs, *args, **kwargs)


def as_serializer_error(exc):
Expand Down Expand Up @@ -332,7 +339,7 @@ def as_serializer_error(exc):
}


class Serializer(BaseSerializer, metaclass=SerializerMetaclass):
class Serializer(BaseSerializer[_IN], metaclass=SerializerMetaclass):
default_error_messages = {
'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.')
}
Expand Down Expand Up @@ -562,7 +569,7 @@ def errors(self):
# There's some replication of `ListField` here,
# but that's probably better than obfuscating the call hierarchy.

class ListSerializer(BaseSerializer):
class ListSerializer(BaseSerializer[_IN]):
child = None
many = True

Expand Down Expand Up @@ -836,7 +843,10 @@ def raise_errors_on_nested_writes(method_name, serializer, validated_data):
)


class ModelSerializer(Serializer):
_MT = TypeVar("_MT", bound=models.Model) # Model Type


class ModelSerializer(Serializer[_MT]):
"""
A `ModelSerializer` is just a regular `Serializer`, except that:

Expand Down
7 changes: 6 additions & 1 deletion rest_framework/utils/field_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
keyword arguments that should be used for their equivalent serializer fields.
"""
import inspect
from typing import Generic, TypeVar

from django.core import validators
from django.db import models
Expand All @@ -16,7 +17,11 @@
)


class ClassLookupDict:
_K = TypeVar("_K", bound=type)
_V = TypeVar("_V")


class ClassLookupDict(Generic[_K, _V]):
"""
Takes a dictionary with classes as keys.
Lookups against this object will traverses the object's inheritance
Expand Down
20 changes: 19 additions & 1 deletion tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import rest_framework
from rest_framework import exceptions, serializers
from rest_framework.fields import (
BuiltinSignatureError, DjangoImageField, is_simple_callable
BuiltinSignatureError, DjangoImageField, Field, is_simple_callable
)

# Tests for helper functions.
Expand Down Expand Up @@ -2380,3 +2380,21 @@ def validate(self, obj):
),
]
}


class TestField:
def test_type_annotation(self):
assert Field[int, int, int, int] is not Field

def test_multiple_type_params_needed_when_hinting_class(self):
with pytest.raises(TypeError):
Field[int]

with pytest.raises(TypeError):
Field[int, int]

with pytest.raises(TypeError):
Field[int, int, int]

with pytest.raises(TypeError):
Field[int, int, int, int, int]
18 changes: 17 additions & 1 deletion tests/test_parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from rest_framework.exceptions import ParseError
from rest_framework.parsers import (
FileUploadParser, FormParser, JSONParser, MultiPartParser
DataAndFiles, FileUploadParser, FormParser, JSONParser, MultiPartParser
)
from rest_framework.request import Request
from rest_framework.test import APIRequestFactory
Expand Down Expand Up @@ -176,3 +176,19 @@ def test_request_read_before_parsing(self):
with pytest.raises(RawPostDataException):
request.POST
request.data


class TestDataAndFiles:
def test_type_annotation(self):
"""
This class inherits directly from Generic, so adding type hints to it should
yield a different class.
"""
assert DataAndFiles[int, int] is not DataAndFiles

def test_need_multiple_type_params_when_hinting_class(self):
with pytest.raises(TypeError):
DataAndFiles[int]

with pytest.raises(TypeError):
DataAndFiles[int, int, int]
15 changes: 15 additions & 0 deletions tests/test_relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,18 @@ def test_can_be_pickled(self):
upkled = pickle.loads(pickle.dumps(self.default_hyperlink))
assert upkled == self.default_hyperlink
assert upkled.name == self.default_hyperlink.name


class TestRelatedField:
def test_type_annotation(self):
assert relations.RelatedField[int, int, int] is not relations.RelatedField

def test_multiple_type_params_needed_when_hinting_class(self):
with pytest.raises(TypeError):
relations.RelatedField[int]

with pytest.raises(TypeError):
relations.RelatedField[int, int]

with pytest.raises(TypeError):
relations.RelatedField[int, int, int, int]
61 changes: 59 additions & 2 deletions tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,19 @@ class ExampleSerializer(serializers.Serializer):
sys.version_info < (3, 7),
reason="subscriptable classes requires Python 3.7 or higher",
)
def test_serializer_is_subscriptable(self):
assert serializers.Serializer is serializers.Serializer["foo"]
def test_type_annotation(self):
assert serializers.Serializer is not serializers.Serializer["foo"]

@pytest.mark.skipif(
sys.version_info > (3, 5),
reason="generic meta class behaviour changed from 3.5 to 3.7",
)
def test_type_annotation_pre_36(self):
"""
This class does NOT inherit directly from Generic, so adding type hints to it
should not yield a different class.
"""
assert serializers.Serializer[int] is not serializers.Serializer


class TestValidateMethod:
Expand Down Expand Up @@ -322,6 +333,28 @@ def test_validate_list(self):
{'id': 2, 'name': 'ann', 'domain': 'example.com'}
]

@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="generic meta class behaviour changed from 3.5 to 3.7",
)
def test_type_annotation(self):
"""
This class does NOT inherit directly from Generic, so adding type hints to it
should not yield a different class.
"""
assert serializers.BaseSerializer[int] is not serializers.BaseSerializer

@pytest.mark.skipif(
sys.version_info > (3, 5),
reason="generic meta class behaviour changed from 3.5 to 3.7",
)
def test_type_annotation_pre_36(self):
"""
This class does NOT inherit directly from Generic, so adding type hints to it
should not yield a different class.
"""
assert serializers.BaseSerializer[int] is not serializers.BaseSerializer


class TestStarredSource:
"""
Expand Down Expand Up @@ -740,3 +773,27 @@ class TestSerializer(A, B):
'f4': serializers.CharField,
'f5': serializers.CharField,
}


class TestModelSerializer:
@pytest.mark.skipif(
sys.version_info < (3, 7),
reason="generic meta class behaviour changed from 3.5 to 3.7",
)
def test_type_annotation(self):
"""
This class does NOT inherit directly from Generic, so adding type hints to it
should not yield a different class.
"""
assert serializers.ModelSerializer[int] is not serializers.ModelSerializer

@pytest.mark.skipif(
sys.version_info > (3, 5),
reason="generic meta class behaviour changed from 3.5 to 3.7",
)
def test_type_annotation_pre_36(self):
"""
This class does NOT inherit directly from Generic, so adding type hints to it
should not yield a different class.
"""
assert serializers.ModelSerializer[int] is not serializers.ModelSerializer
13 changes: 12 additions & 1 deletion tests/test_serializer_lists.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,18 @@ def test_validate_html_input(self):
reason="subscriptable classes requires Python 3.7 or higher",
)
def test_list_serializer_is_subscriptable(self):
assert serializers.ListSerializer is serializers.ListSerializer["foo"]
assert serializers.ListSerializer is not serializers.ListSerializer["foo"]

@pytest.mark.skipif(
sys.version_info > (3, 5),
reason="generic meta class behaviour changed from 3.5 to 3.7",
)
def test_type_annotation_pre_36(self):
"""
This class does NOT inherit directly from Generic, so adding type hints to it
should not yield a different class.
"""
assert serializers.ListSerializer[int] is not serializers.ListSerializer


class TestListSerializerContainingNestedSerializer:
Expand Down
14 changes: 14 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest import mock

import pytest
from django.test import TestCase, override_settings
from django.urls import path

Expand All @@ -8,6 +9,7 @@
from rest_framework.serializers import ModelSerializer
from rest_framework.utils import json
from rest_framework.utils.breadcrumbs import get_breadcrumbs
from rest_framework.utils.field_mapping import ClassLookupDict
from rest_framework.utils.formatting import lazy_format
from rest_framework.utils.urls import remove_query_param, replace_query_param
from rest_framework.views import APIView
Expand Down Expand Up @@ -267,3 +269,15 @@ def test_it_formats_lazily(self):
assert message.format.call_count == 1
str(formatted)
assert message.format.call_count == 1


class ClassLookupDictTests(TestCase):
def test_type_annotation(self):
assert ClassLookupDict[int, int] is not ClassLookupDict

def test_need_multiple_type_params_when_hinting_class(self):
with pytest.raises(TypeError):
ClassLookupDict[None]

with pytest.raises(TypeError):
ClassLookupDict[None, None, None]