From cea9c5410295f7a9d7c489c82297b7094cd5923a Mon Sep 17 00:00:00 2001 From: Richard Si Date: Mon, 20 Feb 2023 18:14:17 -0500 Subject: [PATCH 1/5] [mypyc] Support iterating over a TypedDict An optimization to make iterating over dict.keys(), dict.values() and dict.items() faster caused mypyc to crash while compiling a TypedDict. This commit fixes `Builder.get_dict_base_type` to properly handle TypedDictType. --- mypyc/irbuild/builder.py | 9 ++- mypyc/test-data/irbuild-dict.test | 62 +++++++++++++++++++ test-data/unit/lib-stub/typing_extensions.pyi | 3 + 3 files changed, 72 insertions(+), 2 deletions(-) diff --git a/mypyc/irbuild/builder.py b/mypyc/irbuild/builder.py index f37fae608083..e001808a5c78 100644 --- a/mypyc/irbuild/builder.py +++ b/mypyc/irbuild/builder.py @@ -51,6 +51,7 @@ ProperType, TupleType, Type, + TypedDictType, TypeOfAny, UninhabitedType, UnionType, @@ -892,8 +893,12 @@ def get_dict_base_type(self, expr: Expression) -> list[Instance]: dict_types = [] for t in types: - assert isinstance(t, Instance), t - dict_base = next(base for base in t.type.mro if base.fullname == "builtins.dict") + if isinstance(t, TypedDictType): + t = t.fallback + dict_base = next(base for base in t.type.mro if base.fullname == "typing.Mapping") + else: + assert isinstance(t, Instance), t + dict_base = next(base for base in t.type.mro if base.fullname == "builtins.dict") dict_types.append(map_instance_to_supertype(t, dict_base)) return dict_types diff --git a/mypyc/test-data/irbuild-dict.test b/mypyc/test-data/irbuild-dict.test index 99643b9451f0..a51a5832a5fd 100644 --- a/mypyc/test-data/irbuild-dict.test +++ b/mypyc/test-data/irbuild-dict.test @@ -219,6 +219,12 @@ L0: [case testDictIterationMethods] from typing import Dict, Union +from typing_extensions import TypedDict + +class Person(TypedDict): + name: str + age: int + def print_dict_methods(d1: Dict[int, int], d2: Dict[int, int]) -> None: for v in d1.values(): if v in d2: @@ -229,6 +235,10 @@ def union_of_dicts(d: Union[Dict[str, int], Dict[str, str]]) -> None: new = {} for k, v in d.items(): new[k] = int(v) +def typeddict(d: Person) -> None: + for k, v in d.items(): + if k == "name": + name = v [out] def print_dict_methods(d1, d2): d1, d2 :: dict @@ -370,6 +380,58 @@ L4: r19 = CPy_NoErrOccured() L5: return 1 +def typeddict(d): + d :: dict + r0 :: short_int + r1 :: native_int + r2 :: short_int + r3 :: object + r4 :: tuple[bool, short_int, object, object] + r5 :: short_int + r6 :: bool + r7, r8 :: object + r9 :: str + k, v :: object + r10 :: str + r11 :: object + r12 :: int32 + r13 :: bit + r14 :: bool + name :: object + r15, r16 :: bit +L0: + r0 = 0 + r1 = PyDict_Size(d) + r2 = r1 << 1 + r3 = CPyDict_GetItemsIter(d) +L1: + r4 = CPyDict_NextItem(r3, r0) + r5 = r4[1] + r0 = r5 + r6 = r4[0] + if r6 goto L2 else goto L6 :: bool +L2: + r7 = r4[2] + r8 = r4[3] + r9 = cast(str, r7) + k = r9 + v = r8 + r10 = 'name' + r11 = PyObject_RichCompare(k, r10, 2) + r12 = PyObject_IsTrue(r11) + r13 = r12 >= 0 :: signed + r14 = truncate r12: int32 to builtins.bool + if r14 goto L3 else goto L4 :: bool +L3: + name = v +L4: +L5: + r15 = CPyDict_CheckSize(d, r2) + goto L1 +L6: + r16 = CPy_NoErrOccured() +L7: + return 1 [case testDictLoadAddress] def f() -> None: diff --git a/test-data/unit/lib-stub/typing_extensions.pyi b/test-data/unit/lib-stub/typing_extensions.pyi index 22b895971521..5dbde18bab63 100644 --- a/test-data/unit/lib-stub/typing_extensions.pyi +++ b/test-data/unit/lib-stub/typing_extensions.pyi @@ -50,6 +50,9 @@ class _TypedDict(Mapping[str, object]): # Mypy expects that 'default' has a type variable type. def pop(self, k: NoReturn, default: _T = ...) -> object: ... def update(self: _T, __m: _T) -> None: ... + def items(self) -> dict_items[str, object]: ... + def keys(self) -> dict_keys[str, object]: ... + def values(self) -> dict_values[str, object]: ... if sys.version_info < (3, 0): def has_key(self, k: str) -> bool: ... def __delitem__(self, k: NoReturn) -> None: ... From c5fe32b8e80bf27c8617c021bc5a7ded7caab7dc Mon Sep 17 00:00:00 2001 From: Richard Si Date: Thu, 23 Feb 2023 15:22:12 -0500 Subject: [PATCH 2/5] Use less specific return types in typing_extensions stub --- test-data/unit/lib-stub/typing_extensions.pyi | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test-data/unit/lib-stub/typing_extensions.pyi b/test-data/unit/lib-stub/typing_extensions.pyi index 5dbde18bab63..31ed678b0808 100644 --- a/test-data/unit/lib-stub/typing_extensions.pyi +++ b/test-data/unit/lib-stub/typing_extensions.pyi @@ -1,5 +1,5 @@ import typing -from typing import Any, Mapping, Iterator, NoReturn as NoReturn, Dict, Type +from typing import Any, Mapping, Iterable, Iterator, NoReturn as NoReturn, Dict, Tuple, Type from typing import TYPE_CHECKING as TYPE_CHECKING from typing import NewType as NewType, overload as overload @@ -50,9 +50,9 @@ class _TypedDict(Mapping[str, object]): # Mypy expects that 'default' has a type variable type. def pop(self, k: NoReturn, default: _T = ...) -> object: ... def update(self: _T, __m: _T) -> None: ... - def items(self) -> dict_items[str, object]: ... - def keys(self) -> dict_keys[str, object]: ... - def values(self) -> dict_values[str, object]: ... + def items(self) -> Iterable[str, object]: ... + def keys(self) -> Iterable[str, object]: ... + def values(self) -> Iterable[Tuple[str, object]]: ... if sys.version_info < (3, 0): def has_key(self, k: str) -> bool: ... def __delitem__(self, k: NoReturn) -> None: ... From ab1a358e5ebc5fa6a7f11dd980340421b16766a6 Mon Sep 17 00:00:00 2001 From: Richard Si Date: Fri, 3 Mar 2023 20:17:09 -0500 Subject: [PATCH 3/5] Actually fix the stubs --- test-data/unit/lib-stub/typing_extensions.pyi | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test-data/unit/lib-stub/typing_extensions.pyi b/test-data/unit/lib-stub/typing_extensions.pyi index 31ed678b0808..759f956d314b 100644 --- a/test-data/unit/lib-stub/typing_extensions.pyi +++ b/test-data/unit/lib-stub/typing_extensions.pyi @@ -50,9 +50,9 @@ class _TypedDict(Mapping[str, object]): # Mypy expects that 'default' has a type variable type. def pop(self, k: NoReturn, default: _T = ...) -> object: ... def update(self: _T, __m: _T) -> None: ... - def items(self) -> Iterable[str, object]: ... - def keys(self) -> Iterable[str, object]: ... - def values(self) -> Iterable[Tuple[str, object]]: ... + def items(self) -> Iterable[Tuple[str, object]]: ... + def keys(self) -> Iterable[str]: ... + def values(self) -> Iterable[object]: ... if sys.version_info < (3, 0): def has_key(self, k: str) -> bool: ... def __delitem__(self, k: NoReturn) -> None: ... From c55c24d1bc3eb086a80e8c2ce0f8c6d5375b422e Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sat, 4 Mar 2023 12:03:45 -0500 Subject: [PATCH 4/5] Update the irbuild test, ugh --- mypyc/test-data/irbuild-dict.test | 41 ++++++++++++++++++------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/mypyc/test-data/irbuild-dict.test b/mypyc/test-data/irbuild-dict.test index a51a5832a5fd..d1fc4f956ce7 100644 --- a/mypyc/test-data/irbuild-dict.test +++ b/mypyc/test-data/irbuild-dict.test @@ -390,15 +390,15 @@ def typeddict(d): r5 :: short_int r6 :: bool r7, r8 :: object - r9 :: str - k, v :: object + r9, k :: str + v :: object r10 :: str - r11 :: object - r12 :: int32 - r13 :: bit - r14 :: bool + r11 :: int32 + r12 :: bit + r13 :: object + r14, r15, r16 :: bit name :: object - r15, r16 :: bit + r17, r18 :: bit L0: r0 = 0 r1 = PyDict_Size(d) @@ -409,7 +409,7 @@ L1: r5 = r4[1] r0 = r5 r6 = r4[0] - if r6 goto L2 else goto L6 :: bool + if r6 goto L2 else goto L9 :: bool L2: r7 = r4[2] r8 = r4[3] @@ -417,20 +417,27 @@ L2: k = r9 v = r8 r10 = 'name' - r11 = PyObject_RichCompare(k, r10, 2) - r12 = PyObject_IsTrue(r11) - r13 = r12 >= 0 :: signed - r14 = truncate r12: int32 to builtins.bool - if r14 goto L3 else goto L4 :: bool + r11 = PyUnicode_Compare(k, r10) + r12 = r11 == -1 + if r12 goto L3 else goto L5 :: bool L3: - name = v + r13 = PyErr_Occurred() + r14 = r13 != 0 + if r14 goto L4 else goto L5 :: bool L4: + r15 = CPy_KeepPropagating() L5: - r15 = CPyDict_CheckSize(d, r2) - goto L1 + r16 = r11 == 0 + if r16 goto L6 else goto L7 :: bool L6: - r16 = CPy_NoErrOccured() + name = v L7: +L8: + r17 = CPyDict_CheckSize(d, r2) + goto L1 +L9: + r18 = CPy_NoErrOccured() +L10: return 1 [case testDictLoadAddress] From 85e46ab136869d98f6d555b98b4dc449f7e174a5 Mon Sep 17 00:00:00 2001 From: Richard Si Date: Sat, 18 Mar 2023 19:38:36 -0400 Subject: [PATCH 5/5] Add run test --- mypyc/test-data/run-dicts.test | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/mypyc/test-data/run-dicts.test b/mypyc/test-data/run-dicts.test index 41675e7fcc91..58b862e3f303 100644 --- a/mypyc/test-data/run-dicts.test +++ b/mypyc/test-data/run-dicts.test @@ -95,7 +95,13 @@ assert get_content_set(od) == ({1, 3}, {2, 4}, {(1, 2), (3, 4)}) [typing fixtures/typing-full.pyi] [case testDictIterationMethodsRun] -from typing import Dict +from typing import Dict, Union +from typing_extensions import TypedDict + +class ExtensionDict(TypedDict): + python: str + c: str + def print_dict_methods(d1: Dict[int, int], d2: Dict[int, int], d3: Dict[int, int]) -> None: @@ -107,13 +113,27 @@ def print_dict_methods(d1: Dict[int, int], for v in d3.values(): print(v) +def print_dict_methods_special(d1: Union[Dict[int, int], Dict[str, str]], + d2: ExtensionDict) -> None: + for k in d1.keys(): + print(k) + for k, v in d1.items(): + print(k) + print(v) + for v2 in d2.values(): + print(v2) + for k2, v2 in d2.items(): + print(k2) + print(v2) + + def clear_during_iter(d: Dict[int, int]) -> None: for k in d: d.clear() class Custom(Dict[int, int]): pass [file driver.py] -from native import print_dict_methods, Custom, clear_during_iter +from native import print_dict_methods, print_dict_methods_special, Custom, clear_during_iter from collections import OrderedDict print_dict_methods({}, {}, {}) print_dict_methods({1: 2}, {3: 4, 5: 6}, {7: 8}) @@ -124,6 +144,7 @@ print('==') d = OrderedDict([(1, 2), (3, 4)]) print_dict_methods(d, d, d) print('==') +print_dict_methods_special({1: 2}, {"python": ".py", "c": ".c"}) d.move_to_end(1) print_dict_methods(d, d, d) clear_during_iter({}) # OK @@ -185,6 +206,15 @@ else: 2 4 == +1 +1 +2 +.py +.c +python +.py +c +.c 3 1 3