diff --git a/rest_framework/fields.py b/rest_framework/fields.py index fdfba13f26..3b41e4015b 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -8,6 +8,7 @@ import warnings from collections import OrderedDict from collections.abc import Mapping +from typing import Generic, TypeVar from django.conf import settings from django.core.exceptions import ObjectDoesNotExist @@ -308,8 +309,13 @@ class SkipField(Exception): 'not exist in the `error_messages` dictionary.' ) +_IN = TypeVar("_IN") # Instance Type +_VT = TypeVar("_VT") # Value Type +_DT = TypeVar("_DT") # Data Type +_RP = TypeVar("_RP") # Representation Type -class Field: + +class Field(Generic[_VT, _DT, _RP, _IN]): _creation_counter = 0 default_error_messages = { diff --git a/rest_framework/parsers.py b/rest_framework/parsers.py index fc4eb14283..8b0fc73810 100644 --- a/rest_framework/parsers.py +++ b/rest_framework/parsers.py @@ -5,6 +5,7 @@ on the request, such as form content or json encoded data. """ import codecs +from typing import Generic, TypeVar from urllib import parse from django.conf import settings @@ -21,8 +22,11 @@ from rest_framework.settings import api_settings from rest_framework.utils import json +_Data = TypeVar("_Data") +_Files = TypeVar("_Files") -class DataAndFiles: + +class DataAndFiles(Generic[_Data, _Files]): def __init__(self, data, files): self.data = data self.files = files diff --git a/rest_framework/relations.py b/rest_framework/relations.py index eaf27e1d96..3a535f34c6 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -1,9 +1,10 @@ import sys from collections import OrderedDict +from typing import Any, Generic, TypeVar from urllib import parse from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist -from django.db.models import Manager +from django.db.models import Manager, Model from django.db.models.query import QuerySet from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve from django.utils.encoding import smart_str, uri_to_iri @@ -85,8 +86,12 @@ def __str__(self): 'html_cutoff', 'html_cutoff_text' ) +_MT = TypeVar("_MT", bound=Model) +_DT = TypeVar("_DT") # Data Type +_PT = TypeVar("_PT") # Primitive Type -class RelatedField(Field): + +class RelatedField(Generic[_MT, _DT, _PT], Field[_MT, _DT, _PT, Any]): queryset = None html_cutoff = None html_cutoff_text = None diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 49eec82591..35d52610f9 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -12,9 +12,11 @@ """ import copy import inspect +import sys import traceback from collections import OrderedDict, defaultdict from collections.abc import Mapping +from typing import Any, Generic, TypeVar from django.core.exceptions import FieldDoesNotExist, ImproperlyConfigured from django.core.exceptions import ValidationError as DjangoValidationError @@ -66,6 +68,13 @@ ) from rest_framework.relations import Hyperlink, PKOnlyObject # NOQA # isort:skip +if sys.version_info < (3, 7): + from typing import GenericMeta +else: + class GenericMeta(type): + pass + + # We assume that 'validators' are intended for the child serializer, # rather than the parent serializer. LIST_SERIALIZER_KWARGS = ( @@ -79,8 +88,10 @@ # BaseSerializer # -------------- +_IN = TypeVar("_IN") # Instance Type + -class BaseSerializer(Field): +class BaseSerializer(Generic[_IN], Field[Any, Any, Any, _IN]): """ The BaseSerializer class provides a minimal class which may be used for writing custom serializer implementations. @@ -121,10 +132,6 @@ def __new__(cls, *args, **kwargs): return cls.many_init(*args, **kwargs) return super().__new__(cls, *args, **kwargs) - # Allow type checkers to make serializers generic. - def __class_getitem__(cls, *args, **kwargs): - return cls - @classmethod def many_init(cls, *args, **kwargs): """ @@ -268,7 +275,7 @@ def validated_data(self): # Serializer & ListSerializer classes # ----------------------------------- -class SerializerMetaclass(type): +class SerializerMetaclass(GenericMeta): """ This metaclass sets a dictionary named `_declared_fields` on the class. @@ -301,9 +308,9 @@ def visit(name): return OrderedDict(base_fields + fields) - def __new__(cls, name, bases, attrs): + def __new__(cls, name, bases, attrs, *args, **kwargs): attrs['_declared_fields'] = cls._get_declared_fields(bases, attrs) - return super().__new__(cls, name, bases, attrs) + return super().__new__(cls, name, bases, attrs, *args, **kwargs) def as_serializer_error(exc): @@ -332,7 +339,7 @@ def as_serializer_error(exc): } -class Serializer(BaseSerializer, metaclass=SerializerMetaclass): +class Serializer(BaseSerializer[_IN], metaclass=SerializerMetaclass): default_error_messages = { 'invalid': _('Invalid data. Expected a dictionary, but got {datatype}.') } @@ -562,7 +569,7 @@ def errors(self): # There's some replication of `ListField` here, # but that's probably better than obfuscating the call hierarchy. -class ListSerializer(BaseSerializer): +class ListSerializer(BaseSerializer[_IN]): child = None many = True @@ -836,7 +843,10 @@ def raise_errors_on_nested_writes(method_name, serializer, validated_data): ) -class ModelSerializer(Serializer): +_MT = TypeVar("_MT", bound=models.Model) # Model Type + + +class ModelSerializer(Serializer[_MT]): """ A `ModelSerializer` is just a regular `Serializer`, except that: diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index 4f8a4f1926..79154c19d3 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -3,6 +3,7 @@ keyword arguments that should be used for their equivalent serializer fields. """ import inspect +from typing import Generic, TypeVar from django.core import validators from django.db import models @@ -16,7 +17,11 @@ ) -class ClassLookupDict: +_K = TypeVar("_K", bound=type) +_V = TypeVar("_V") + + +class ClassLookupDict(Generic[_K, _V]): """ Takes a dictionary with classes as keys. Lookups against this object will traverses the object's inheritance diff --git a/tests/test_fields.py b/tests/test_fields.py index fdd570d8a6..c8d2f87a28 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -14,7 +14,7 @@ import rest_framework from rest_framework import exceptions, serializers from rest_framework.fields import ( - BuiltinSignatureError, DjangoImageField, is_simple_callable + BuiltinSignatureError, DjangoImageField, Field, is_simple_callable ) # Tests for helper functions. @@ -2380,3 +2380,21 @@ def validate(self, obj): ), ] } + + +class TestField: + def test_type_annotation(self): + assert Field[int, int, int, int] is not Field + + def test_multiple_type_params_needed_when_hinting_class(self): + with pytest.raises(TypeError): + Field[int] + + with pytest.raises(TypeError): + Field[int, int] + + with pytest.raises(TypeError): + Field[int, int, int] + + with pytest.raises(TypeError): + Field[int, int, int, int, int] diff --git a/tests/test_parsers.py b/tests/test_parsers.py index dcd62fac9d..0e5a915f5d 100644 --- a/tests/test_parsers.py +++ b/tests/test_parsers.py @@ -11,7 +11,7 @@ from rest_framework.exceptions import ParseError from rest_framework.parsers import ( - FileUploadParser, FormParser, JSONParser, MultiPartParser + DataAndFiles, FileUploadParser, FormParser, JSONParser, MultiPartParser ) from rest_framework.request import Request from rest_framework.test import APIRequestFactory @@ -176,3 +176,19 @@ def test_request_read_before_parsing(self): with pytest.raises(RawPostDataException): request.POST request.data + + +class TestDataAndFiles: + def test_type_annotation(self): + """ + This class inherits directly from Generic, so adding type hints to it should + yield a different class. + """ + assert DataAndFiles[int, int] is not DataAndFiles + + def test_need_multiple_type_params_when_hinting_class(self): + with pytest.raises(TypeError): + DataAndFiles[int] + + with pytest.raises(TypeError): + DataAndFiles[int, int, int] diff --git a/tests/test_relations.py b/tests/test_relations.py index 92aeecf6c4..f02b637ebe 100644 --- a/tests/test_relations.py +++ b/tests/test_relations.py @@ -376,3 +376,18 @@ def test_can_be_pickled(self): upkled = pickle.loads(pickle.dumps(self.default_hyperlink)) assert upkled == self.default_hyperlink assert upkled.name == self.default_hyperlink.name + + +class TestRelatedField: + def test_type_annotation(self): + assert relations.RelatedField[int, int, int] is not relations.RelatedField + + def test_multiple_type_params_needed_when_hinting_class(self): + with pytest.raises(TypeError): + relations.RelatedField[int] + + with pytest.raises(TypeError): + relations.RelatedField[int, int] + + with pytest.raises(TypeError): + relations.RelatedField[int, int, int, int] diff --git a/tests/test_serializer.py b/tests/test_serializer.py index afefd70e1c..cc40407e4b 100644 --- a/tests/test_serializer.py +++ b/tests/test_serializer.py @@ -209,8 +209,19 @@ class ExampleSerializer(serializers.Serializer): sys.version_info < (3, 7), reason="subscriptable classes requires Python 3.7 or higher", ) - def test_serializer_is_subscriptable(self): - assert serializers.Serializer is serializers.Serializer["foo"] + def test_type_annotation(self): + assert serializers.Serializer is not serializers.Serializer["foo"] + + @pytest.mark.skipif( + sys.version_info > (3, 5), + reason="generic meta class behaviour changed from 3.5 to 3.7", + ) + def test_type_annotation_pre_36(self): + """ + This class does NOT inherit directly from Generic, so adding type hints to it + should not yield a different class. + """ + assert serializers.Serializer[int] is not serializers.Serializer class TestValidateMethod: @@ -322,6 +333,28 @@ def test_validate_list(self): {'id': 2, 'name': 'ann', 'domain': 'example.com'} ] + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="generic meta class behaviour changed from 3.5 to 3.7", + ) + def test_type_annotation(self): + """ + This class does NOT inherit directly from Generic, so adding type hints to it + should not yield a different class. + """ + assert serializers.BaseSerializer[int] is not serializers.BaseSerializer + + @pytest.mark.skipif( + sys.version_info > (3, 5), + reason="generic meta class behaviour changed from 3.5 to 3.7", + ) + def test_type_annotation_pre_36(self): + """ + This class does NOT inherit directly from Generic, so adding type hints to it + should not yield a different class. + """ + assert serializers.BaseSerializer[int] is not serializers.BaseSerializer + class TestStarredSource: """ @@ -740,3 +773,27 @@ class TestSerializer(A, B): 'f4': serializers.CharField, 'f5': serializers.CharField, } + + +class TestModelSerializer: + @pytest.mark.skipif( + sys.version_info < (3, 7), + reason="generic meta class behaviour changed from 3.5 to 3.7", + ) + def test_type_annotation(self): + """ + This class does NOT inherit directly from Generic, so adding type hints to it + should not yield a different class. + """ + assert serializers.ModelSerializer[int] is not serializers.ModelSerializer + + @pytest.mark.skipif( + sys.version_info > (3, 5), + reason="generic meta class behaviour changed from 3.5 to 3.7", + ) + def test_type_annotation_pre_36(self): + """ + This class does NOT inherit directly from Generic, so adding type hints to it + should not yield a different class. + """ + assert serializers.ModelSerializer[int] is not serializers.ModelSerializer diff --git a/tests/test_serializer_lists.py b/tests/test_serializer_lists.py index f35c4fcc9e..5ad7eeb28c 100644 --- a/tests/test_serializer_lists.py +++ b/tests/test_serializer_lists.py @@ -62,7 +62,18 @@ def test_validate_html_input(self): reason="subscriptable classes requires Python 3.7 or higher", ) def test_list_serializer_is_subscriptable(self): - assert serializers.ListSerializer is serializers.ListSerializer["foo"] + assert serializers.ListSerializer is not serializers.ListSerializer["foo"] + + @pytest.mark.skipif( + sys.version_info > (3, 5), + reason="generic meta class behaviour changed from 3.5 to 3.7", + ) + def test_type_annotation_pre_36(self): + """ + This class does NOT inherit directly from Generic, so adding type hints to it + should not yield a different class. + """ + assert serializers.ListSerializer[int] is not serializers.ListSerializer class TestListSerializerContainingNestedSerializer: diff --git a/tests/test_utils.py b/tests/test_utils.py index c72f680fe4..a8371d8b6b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,6 @@ from unittest import mock +import pytest from django.test import TestCase, override_settings from django.urls import path @@ -8,6 +9,7 @@ from rest_framework.serializers import ModelSerializer from rest_framework.utils import json from rest_framework.utils.breadcrumbs import get_breadcrumbs +from rest_framework.utils.field_mapping import ClassLookupDict from rest_framework.utils.formatting import lazy_format from rest_framework.utils.urls import remove_query_param, replace_query_param from rest_framework.views import APIView @@ -267,3 +269,15 @@ def test_it_formats_lazily(self): assert message.format.call_count == 1 str(formatted) assert message.format.call_count == 1 + + +class ClassLookupDictTests(TestCase): + def test_type_annotation(self): + assert ClassLookupDict[int, int] is not ClassLookupDict + + def test_need_multiple_type_params_when_hinting_class(self): + with pytest.raises(TypeError): + ClassLookupDict[None] + + with pytest.raises(TypeError): + ClassLookupDict[None, None, None]