From c0a90e97859ef5e8b6180cffde56967be2ec48ff Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 10 Dec 2016 10:45:02 +0100 Subject: [PATCH 1/7] Runtime implementation of TypedDict extension --- extensions/__init__.py | 1 + extensions/mypy_extensions.py | 64 ++++++++++++++++++--- mypy/test/testextensions.py | 102 ++++++++++++++++++++++++++++++++++ runtests.py | 2 +- 4 files changed, 161 insertions(+), 8 deletions(-) create mode 100644 extensions/__init__.py create mode 100644 mypy/test/testextensions.py diff --git a/extensions/__init__.py b/extensions/__init__.py new file mode 100644 index 000000000000..88a19e400e5b --- /dev/null +++ b/extensions/__init__.py @@ -0,0 +1 @@ +# This page intentioanlly left blank. diff --git a/extensions/mypy_extensions.py b/extensions/mypy_extensions.py index db66f586f051..7323d3796681 100644 --- a/extensions/mypy_extensions.py +++ b/extensions/mypy_extensions.py @@ -8,15 +8,65 @@ # NOTE: This module must support Python 2.7 in addition to Python 3.x -def TypedDict(typename, fields): - """TypedDict creates a dictionary type that expects all of its +import sys +from typing import _type_check # type: ignore + + +def _check_fails(cls, other): + if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']: + raise TypeError('TypedDict does not support instance and class checks') + +class _TypedDictMeta(type): + def __new__(cls, name, bases, ns): + tp_dict = super(_TypedDictMeta, cls).__new__(cls, name, (dict,), ns) + try: + tp_dict.__module__ = sys._getframe(2).f_globals.get('__name__', '__main__') + except (AttributeError, ValueError): + pass + anns = ns.get('__annotations__', {}) + msg = "TypedDict('Name', {f0: t0, f1: t1, ...}); each t must be a type" + anns = {n: _type_check(tp, msg) for n, tp in anns.items()} + for base in bases: + anns.update(base.__dict__.get('__annotations__', {})) + tp_dict.__annotations__ = anns + return tp_dict + + __instancecheck__ = __subclasscheck__ = _check_fails + + +class _TypedDict(object): + """A simple typed name space. At runtime it is equivalent to a plain dict. + + TypedDict creates a dictionary type that expects all of its instances to have a certain set of keys, with each key associated with a value of a consistent type. This expectation is not checked at runtime but is only enforced by typecheckers. + Usage:: + + Point2D = TypedDict('Point2D', {'x': int, 'y': int, 'label': str}) + a: Point2D = {'x': 1, 'y': 2, 'label': 'good'} # OK + b: Point2D = {'z': 3, 'label': 'bad'} # Fails type check + assert Point2D(x=1, y=2, label='first') == dict(x=1, y=2, label='first') + + The type info could be accessed via Point2D.__annotations__. TypedDict + supports two additional equivalent forms:: + + Point2D = TypedDict('Point2D', x=int, y=int, label=str) + + class Point2D(TypedDict): + x: int + y: int + label: str + + The latter syntax is only supported in Python 3.6+ """ - def new_dict(*args, **kwargs): - return dict(*args, **kwargs) + def __new__(cls, _typename, fields=None, **kwargs): + if fields is None: + fields = kwargs + elif kwargs: + raise TypeError("Either list of fields or keywords" + " can be provided to TypedDict, not both") + return cls.__class__(_typename, (), {'__annotations__': dict(fields)}) + - new_dict.__name__ = typename - new_dict.__supertype__ = dict - return new_dict +TypedDict = _TypedDictMeta('TypedDict', _TypedDict.__bases__, dict(_TypedDict.__dict__)) diff --git a/mypy/test/testextensions.py b/mypy/test/testextensions.py new file mode 100644 index 000000000000..5071cbac2f6c --- /dev/null +++ b/mypy/test/testextensions.py @@ -0,0 +1,102 @@ +import sys +import pickle +from unittest import TestCase, main, skipUnless, SkipTest +from extensions.mypy_extensions import TypedDict + + +class BaseTestCase(TestCase): + + def assertIsSubclass(self, cls, class_or_tuple, msg=None): + if not issubclass(cls, class_or_tuple): + message = '%r is not a subclass of %r' % (cls, class_or_tuple) + if msg is not None: + message += ' : %s' % msg + raise self.failureException(message) + + def assertNotIsSubclass(self, cls, class_or_tuple, msg=None): + if issubclass(cls, class_or_tuple): + message = '%r is a subclass of %r' % (cls, class_or_tuple) + if msg is not None: + message += ' : %s' % msg + raise self.failureException(message) + + +PY36 = sys.version_info[:2] >= (3, 6) + +PY36_TESTS = """ +Label = TypedDict('Label', [('label', str)]) + +class Point2D(TypedDict): + x: int + y: int + +class LabelPoint2D(Point2D, Label): ... +""" + +if PY36: + exec(PY36_TESTS) + + +class TypedDictTests(BaseTestCase): + + def test_basics_iterable_syntax(self): + # Check that two iterables allowed + Emp = TypedDict('Emp', [('name', str), ('id', int)]) + Emp = TypedDict('Emp', {'name': str, 'id': int}) + self.assertIsSubclass(Emp, dict) + jim = Emp(name='Jim', id=1) + self.assertIsInstance(jim, Emp) + self.assertIsInstance(jim, dict) + self.assertEqual(jim['name'], 'Jim') + self.assertEqual(jim['id'], 1) + self.assertEqual(Emp.__name__, 'Emp') + self.assertEqual(Emp.__bases__, (dict,)) + self.assertEqual(Emp.__annotations__, {'name': str, 'id': int}) + + def test_basics_keywords_syntax(self): + Emp = TypedDict('Emp', name=str, id=int) + self.assertIsSubclass(Emp, dict) + jim = Emp(name='Jim', id=1) + self.assertIsInstance(jim, Emp) + self.assertIsInstance(jim, dict) + self.assertEqual(jim['name'], 'Jim') + self.assertEqual(jim['id'], 1) + self.assertEqual(Emp.__name__, 'Emp') + self.assertEqual(Emp.__bases__, (dict,)) + self.assertEqual(Emp.__annotations__, {'name': str, 'id': int}) + + def test_typeddict_errors(self): + Emp = TypedDict('Emp', {'name': str, 'id': int}) + with self.assertRaises(TypeError): + isinstance({}, Emp) + with self.assertRaises(TypeError): + issubclass(dict, Emp) + with self.assertRaises(TypeError): + TypedDict('Hi', x=1) + with self.assertRaises(TypeError): + TypedDict('Hi', [('x', int), ('y', 1)]) + with self.assertRaises(TypeError): + TypedDict('Hi', [('x', int)], y=int) + + @skipUnless(PY36, 'Python 3.6 required') + def test_class_syntax_usage(self): + self.assertEqual(LabelPoint2D.__annotations__, {'x': int, 'y': int, 'label': str}) # noqa + self.assertEqual(LabelPoint2D.__bases__, (dict,)) # noqa + not_origin = Point2D(x=0, y=1) # noqa + self.assertEqual(not_origin['x'], 0) + self.assertEqual(not_origin['y'], 1) + other = LabelPoint2D(x=0, y=1, label='hi') # noqa + self.assertEqual(other['label'], 'hi') + + def test_pickle(self): + global EmpD # pickle wants to reference the class by name + EmpD = TypedDict('EmpD', name=str, id=int) + jane = EmpD({'name': 'jane', 'id': 37}) + for proto in range(pickle.HIGHEST_PROTOCOL + 1): + z = pickle.dumps(jane, proto) + jane2 = pickle.loads(z) + self.assertEqual(jane2, jane) + self.assertEqual(jane2, {'name': 'jane', 'id': 37}) + +if __name__ == '__main__': + main() diff --git a/runtests.py b/runtests.py index caf27eea4fed..5c67f37b7b7b 100755 --- a/runtests.py +++ b/runtests.py @@ -207,7 +207,7 @@ def add_imports(driver: Driver) -> None: PYTEST_FILES = ['mypy/test/{}.py'.format(name) for name in [ - 'testcheck', + 'testcheck', 'testextensions', ]] From 04ee8e60a4d329de557437bb3b68c11b744ffeb8 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sat, 10 Dec 2016 11:22:34 +0100 Subject: [PATCH 2/7] Make flake8 happy --- mypy/test/testextensions.py | 1 + 1 file changed, 1 insertion(+) diff --git a/mypy/test/testextensions.py b/mypy/test/testextensions.py index 5071cbac2f6c..58470baac839 100644 --- a/mypy/test/testextensions.py +++ b/mypy/test/testextensions.py @@ -98,5 +98,6 @@ def test_pickle(self): self.assertEqual(jane2, jane) self.assertEqual(jane2, {'name': 'jane', 'id': 37}) + if __name__ == '__main__': main() From 742bc14daada9f9c58bd229004a4572c9ddc8909 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Sun, 11 Dec 2016 14:10:03 +0100 Subject: [PATCH 3/7] Response to comments --- extensions/mypy_extensions.py | 24 +++++++++++++++++------- mypy/test/testextensions.py | 10 ++++++---- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/extensions/mypy_extensions.py b/extensions/mypy_extensions.py index 7323d3796681..e740d3442e5c 100644 --- a/extensions/mypy_extensions.py +++ b/extensions/mypy_extensions.py @@ -7,14 +7,21 @@ # NOTE: This module must support Python 2.7 in addition to Python 3.x - import sys +import types from typing import _type_check # type: ignore +def _dict_new(inst, cls, *args, **kwargs): + return dict(*args, **kwargs) + def _check_fails(cls, other): - if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']: - raise TypeError('TypedDict does not support instance and class checks') + try: + if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']: + raise TypeError('TypedDict does not support instance and class checks') + except (AttributeError, ValueError): + pass + return False class _TypedDictMeta(type): def __new__(cls, name, bases, ns): @@ -29,6 +36,8 @@ def __new__(cls, name, bases, ns): for base in bases: anns.update(base.__dict__.get('__annotations__', {})) tp_dict.__annotations__ = anns + if tp_dict.__name__ != 'TypedDict': + tp_dict.__new__ = types.MethodType(_dict_new, tp_dict) return tp_dict __instancecheck__ = __subclasscheck__ = _check_fails @@ -60,13 +69,14 @@ class Point2D(TypedDict): The latter syntax is only supported in Python 3.6+ """ - def __new__(cls, _typename, fields=None, **kwargs): - if fields is None: - fields = kwargs + def __new__(cls, _typename, _fields=None, **kwargs): + if _fields is None: + _fields = kwargs elif kwargs: raise TypeError("Either list of fields or keywords" " can be provided to TypedDict, not both") - return cls.__class__(_typename, (), {'__annotations__': dict(fields)}) + return cls.__class__(_typename, (), {'__annotations__': dict(_fields)}) TypedDict = _TypedDictMeta('TypedDict', _TypedDict.__bases__, dict(_TypedDict.__dict__)) +TypedDict.__module__ = __name__ diff --git a/mypy/test/testextensions.py b/mypy/test/testextensions.py index 58470baac839..a468410fccd2 100644 --- a/mypy/test/testextensions.py +++ b/mypy/test/testextensions.py @@ -40,16 +40,14 @@ class LabelPoint2D(Point2D, Label): ... class TypedDictTests(BaseTestCase): def test_basics_iterable_syntax(self): - # Check that two iterables allowed - Emp = TypedDict('Emp', [('name', str), ('id', int)]) Emp = TypedDict('Emp', {'name': str, 'id': int}) self.assertIsSubclass(Emp, dict) jim = Emp(name='Jim', id=1) - self.assertIsInstance(jim, Emp) self.assertIsInstance(jim, dict) self.assertEqual(jim['name'], 'Jim') self.assertEqual(jim['id'], 1) self.assertEqual(Emp.__name__, 'Emp') + self.assertEqual(Emp.__module__, 'mypy.test.testextensions') self.assertEqual(Emp.__bases__, (dict,)) self.assertEqual(Emp.__annotations__, {'name': str, 'id': int}) @@ -57,18 +55,22 @@ def test_basics_keywords_syntax(self): Emp = TypedDict('Emp', name=str, id=int) self.assertIsSubclass(Emp, dict) jim = Emp(name='Jim', id=1) - self.assertIsInstance(jim, Emp) self.assertIsInstance(jim, dict) self.assertEqual(jim['name'], 'Jim') self.assertEqual(jim['id'], 1) self.assertEqual(Emp.__name__, 'Emp') + self.assertEqual(Emp.__module__, 'mypy.test.testextensions') self.assertEqual(Emp.__bases__, (dict,)) self.assertEqual(Emp.__annotations__, {'name': str, 'id': int}) def test_typeddict_errors(self): Emp = TypedDict('Emp', {'name': str, 'id': int}) + self.assertEqual(TypedDict.__module__, 'extensions.mypy_extensions') + jim = Emp(name='Jim', id=1) with self.assertRaises(TypeError): isinstance({}, Emp) + with self.assertRaises(TypeError): + isinstance(jim, Emp) with self.assertRaises(TypeError): issubclass(dict, Emp) with self.assertRaises(TypeError): From b5fd6f3f0350bfdfc4d6205e0d0171c400113bad Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 12 Dec 2016 15:27:08 +0100 Subject: [PATCH 4/7] Second review --- extensions/__init__.py | 2 +- mypy/test/testextensions.py | 16 +++++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/extensions/__init__.py b/extensions/__init__.py index 88a19e400e5b..a45992386a98 100644 --- a/extensions/__init__.py +++ b/extensions/__init__.py @@ -1 +1 @@ -# This page intentioanlly left blank. +# This page intentionally left blank. diff --git a/mypy/test/testextensions.py b/mypy/test/testextensions.py index a468410fccd2..03186be0aeae 100644 --- a/mypy/test/testextensions.py +++ b/mypy/test/testextensions.py @@ -1,6 +1,11 @@ import sys import pickle -from unittest import TestCase, main, skipUnless, SkipTest +import typing +try: + import collections.abc as collections_abc +except ImportError: + import collections as collections_abc # type: ignore # PY32 and earlier +from unittest import TestCase, main, skipUnless from extensions.mypy_extensions import TypedDict @@ -42,8 +47,10 @@ class TypedDictTests(BaseTestCase): def test_basics_iterable_syntax(self): Emp = TypedDict('Emp', {'name': str, 'id': int}) self.assertIsSubclass(Emp, dict) + self.assertIsSubclass(Emp, typing.MutableMapping) + self.assertNotIsSubclass(Emp, collections_abc.Sequence) jim = Emp(name='Jim', id=1) - self.assertIsInstance(jim, dict) + self.assertIs(type(jim), dict) self.assertEqual(jim['name'], 'Jim') self.assertEqual(jim['id'], 1) self.assertEqual(Emp.__name__, 'Emp') @@ -54,8 +61,10 @@ def test_basics_iterable_syntax(self): def test_basics_keywords_syntax(self): Emp = TypedDict('Emp', name=str, id=int) self.assertIsSubclass(Emp, dict) + self.assertIsSubclass(Emp, typing.MutableMapping) + self.assertNotIsSubclass(Emp, collections_abc.Sequence) jim = Emp(name='Jim', id=1) - self.assertIsInstance(jim, dict) + self.assertIs(type(jim), dict) self.assertEqual(jim['name'], 'Jim') self.assertEqual(jim['id'], 1) self.assertEqual(Emp.__name__, 'Emp') @@ -84,6 +93,7 @@ def test_typeddict_errors(self): def test_class_syntax_usage(self): self.assertEqual(LabelPoint2D.__annotations__, {'x': int, 'y': int, 'label': str}) # noqa self.assertEqual(LabelPoint2D.__bases__, (dict,)) # noqa + self.assertNotIsSubclass(LabelPoint2D, typing.Sequence) # noqa not_origin = Point2D(x=0, y=1) # noqa self.assertEqual(not_origin['x'], 0) self.assertEqual(not_origin['y'], 1) From b56d81c93a7bdd8cc11d2f5e88a1fb1eeea700cf Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 14 Dec 2016 00:17:04 +0100 Subject: [PATCH 5/7] REsponse to comments first part --- extensions/__init__.py | 1 - extensions/mypy_extensions.py | 36 ++++++++++++++++++----------------- mypy/test/testextensions.py | 4 ++-- 3 files changed, 21 insertions(+), 20 deletions(-) delete mode 100644 extensions/__init__.py diff --git a/extensions/__init__.py b/extensions/__init__.py deleted file mode 100644 index a45992386a98..000000000000 --- a/extensions/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# This page intentionally left blank. diff --git a/extensions/mypy_extensions.py b/extensions/mypy_extensions.py index e740d3442e5c..8e0c3765c2d8 100644 --- a/extensions/mypy_extensions.py +++ b/extensions/mypy_extensions.py @@ -12,9 +12,6 @@ from typing import _type_check # type: ignore -def _dict_new(inst, cls, *args, **kwargs): - return dict(*args, **kwargs) - def _check_fails(cls, other): try: if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']: @@ -23,6 +20,17 @@ def _check_fails(cls, other): pass return False +def _dict_new(inst, cls, *args, **kwargs): + return dict(*args, **kwargs) + +def _typeddict_new(inst, cls, _typename, _fields=None, **kwargs): + if _fields is None: + _fields = kwargs + elif kwargs: + raise TypeError("TypedDict takes either a dict or " + "keyword arguments, but not both") + return _TypedDictMeta(_typename, (), {'__annotations__': dict(_fields)}) + class _TypedDictMeta(type): def __new__(cls, name, bases, ns): tp_dict = super(_TypedDictMeta, cls).__new__(cls, name, (dict,), ns) @@ -36,14 +44,18 @@ def __new__(cls, name, bases, ns): for base in bases: anns.update(base.__dict__.get('__annotations__', {})) tp_dict.__annotations__ = anns - if tp_dict.__name__ != 'TypedDict': + if name == 'TypedDict': + tp_dict.__new__ = types.MethodType(_typeddict_new, tp_dict) + else: tp_dict.__new__ = types.MethodType(_dict_new, tp_dict) return tp_dict __instancecheck__ = __subclasscheck__ = _check_fails -class _TypedDict(object): +TypedDict = _TypedDictMeta('TypedDict', (dict,), {}) +TypedDict.__module__ = __name__ +TypedDict.__doc__ = \ """A simple typed name space. At runtime it is equivalent to a plain dict. TypedDict creates a dictionary type that expects all of its @@ -67,16 +79,6 @@ class Point2D(TypedDict): y: int label: str - The latter syntax is only supported in Python 3.6+ + The latter syntax is only supported in Python 3.6+, while two other + syntax forms work for Python 2.7 and 3.2+ """ - def __new__(cls, _typename, _fields=None, **kwargs): - if _fields is None: - _fields = kwargs - elif kwargs: - raise TypeError("Either list of fields or keywords" - " can be provided to TypedDict, not both") - return cls.__class__(_typename, (), {'__annotations__': dict(_fields)}) - - -TypedDict = _TypedDictMeta('TypedDict', _TypedDict.__bases__, dict(_TypedDict.__dict__)) -TypedDict.__module__ = __name__ diff --git a/mypy/test/testextensions.py b/mypy/test/testextensions.py index 03186be0aeae..a9a526d4f025 100644 --- a/mypy/test/testextensions.py +++ b/mypy/test/testextensions.py @@ -6,7 +6,7 @@ except ImportError: import collections as collections_abc # type: ignore # PY32 and earlier from unittest import TestCase, main, skipUnless -from extensions.mypy_extensions import TypedDict +from mypy_extensions import TypedDict class BaseTestCase(TestCase): @@ -90,7 +90,7 @@ def test_typeddict_errors(self): TypedDict('Hi', [('x', int)], y=int) @skipUnless(PY36, 'Python 3.6 required') - def test_class_syntax_usage(self): + def test_py36_class_syntax_usage(self): self.assertEqual(LabelPoint2D.__annotations__, {'x': int, 'y': int, 'label': str}) # noqa self.assertEqual(LabelPoint2D.__bases__, (dict,)) # noqa self.assertNotIsSubclass(LabelPoint2D, typing.Sequence) # noqa From 7b74d963a64048c05e1ecf1387a05f33aa3bcd00 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 14 Dec 2016 00:54:43 +0100 Subject: [PATCH 6/7] Response to comments, part two --- mypy/semanal.py | 5 +++-- mypy/test/testextensions.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index e8609fdadb1b..f332a88dc9c3 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1870,8 +1870,9 @@ def process_typeddict_definition(self, s: AssignmentStmt) -> None: return # Yes, it's a valid TypedDict definition. Add it to the symbol table. node = self.lookup(name, s) - node.kind = GDEF # TODO locally defined TypedDict - node.node = typed_dict + if node: + node.kind = GDEF # TODO locally defined TypedDict + node.node = typed_dict def check_typeddict(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: """Check if a call defines a TypedDict. diff --git a/mypy/test/testextensions.py b/mypy/test/testextensions.py index a9a526d4f025..032dd45ba1db 100644 --- a/mypy/test/testextensions.py +++ b/mypy/test/testextensions.py @@ -6,6 +6,7 @@ except ImportError: import collections as collections_abc # type: ignore # PY32 and earlier from unittest import TestCase, main, skipUnless +sys.path[0:0] = ['extensions'] from mypy_extensions import TypedDict @@ -63,7 +64,7 @@ def test_basics_keywords_syntax(self): self.assertIsSubclass(Emp, dict) self.assertIsSubclass(Emp, typing.MutableMapping) self.assertNotIsSubclass(Emp, collections_abc.Sequence) - jim = Emp(name='Jim', id=1) + jim = Emp(name='Jim', id=1) # type: ignore # mypy doesn't support keyword syntax yet self.assertIs(type(jim), dict) self.assertEqual(jim['name'], 'Jim') self.assertEqual(jim['id'], 1) @@ -74,7 +75,7 @@ def test_basics_keywords_syntax(self): def test_typeddict_errors(self): Emp = TypedDict('Emp', {'name': str, 'id': int}) - self.assertEqual(TypedDict.__module__, 'extensions.mypy_extensions') + self.assertEqual(TypedDict.__module__, 'mypy_extensions') jim = Emp(name='Jim', id=1) with self.assertRaises(TypeError): isinstance({}, Emp) From 557c5a3680d40bfd7f221298c3f5a5e0f7296340 Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Wed, 14 Dec 2016 16:27:56 +0100 Subject: [PATCH 7/7] Response to comments: another round --- extensions/mypy_extensions.py | 24 +++++++++++++++--------- mypy/test/testextensions.py | 3 +++ 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/extensions/mypy_extensions.py b/extensions/mypy_extensions.py index 8e0c3765c2d8..248da3de4aae 100644 --- a/extensions/mypy_extensions.py +++ b/extensions/mypy_extensions.py @@ -8,33 +8,43 @@ # NOTE: This module must support Python 2.7 in addition to Python 3.x import sys -import types +# _type_check is NOT a part of public typing API, it is used here only to mimic +# the (convenient) behavior of types provided by typing module. from typing import _type_check # type: ignore def _check_fails(cls, other): try: if sys._getframe(1).f_globals['__name__'] not in ['abc', 'functools']: + # Typed dicts are only for static structural subtyping. raise TypeError('TypedDict does not support instance and class checks') except (AttributeError, ValueError): pass return False -def _dict_new(inst, cls, *args, **kwargs): +def _dict_new(cls, *args, **kwargs): return dict(*args, **kwargs) -def _typeddict_new(inst, cls, _typename, _fields=None, **kwargs): +def _typeddict_new(cls, _typename, _fields=None, **kwargs): if _fields is None: _fields = kwargs elif kwargs: - raise TypeError("TypedDict takes either a dict or " - "keyword arguments, but not both") + raise TypeError("TypedDict takes either a dict or keyword arguments," + " but not both") return _TypedDictMeta(_typename, (), {'__annotations__': dict(_fields)}) class _TypedDictMeta(type): def __new__(cls, name, bases, ns): + # Create new typed dict class object. + # This method is called directly when TypedDict is subclassed, + # or via _typeddict_new when TypedDict is instantiated. This way + # TypedDict supports all three syntaxes described in its docstring. + # Subclasses and instanes of TypedDict return actual dictionaries + # via _dict_new. + ns['__new__'] = _typeddict_new if name == 'TypedDict' else _dict_new tp_dict = super(_TypedDictMeta, cls).__new__(cls, name, (dict,), ns) try: + # Setting correct module is necessary to make typed dict classes pickleable. tp_dict.__module__ = sys._getframe(2).f_globals.get('__name__', '__main__') except (AttributeError, ValueError): pass @@ -44,10 +54,6 @@ def __new__(cls, name, bases, ns): for base in bases: anns.update(base.__dict__.get('__annotations__', {})) tp_dict.__annotations__ = anns - if name == 'TypedDict': - tp_dict.__new__ = types.MethodType(_typeddict_new, tp_dict) - else: - tp_dict.__new__ = types.MethodType(_dict_new, tp_dict) return tp_dict __instancecheck__ = __subclasscheck__ = _check_fails diff --git a/mypy/test/testextensions.py b/mypy/test/testextensions.py index 032dd45ba1db..eca45d7e54dd 100644 --- a/mypy/test/testextensions.py +++ b/mypy/test/testextensions.py @@ -110,6 +110,9 @@ def test_pickle(self): jane2 = pickle.loads(z) self.assertEqual(jane2, jane) self.assertEqual(jane2, {'name': 'jane', 'id': 37}) + ZZ = pickle.dumps(EmpD, proto) + EmpDnew = pickle.loads(ZZ) + self.assertEqual(EmpDnew({'name': 'jane', 'id': 37}), jane) if __name__ == '__main__':