Skip to content

Commit 4471c7e

Browse files
authored
Proposal: don't simplify unions in expand_type() (#14178)
Fixes #6730 Currently `expand_type()` is inherently recursive, going through `expand_type` -> `make_simplified_union` -> `is_proper_subtype` -> `map_instance_to_supertype` -> `expand_type`. TBH I never liked this, so I propose that we don't do this. One one hand, this is a significant change in semantics, but on the other hand: * This fixes a crash (actually a whole class of crashes) that can happen even without recursive aliases * This removes an ugly import and simplifies an import cycle in mypy code * This makes mypy 2% faster (measured on self-check) To make transition smoother, I propose to make trivial simplifications, like removing `<nothing>` (and `None` without strict optional), removing everything else if there is an `object` type, and remove strict duplicates. Notably, with these few things _all existing tests pass_ (and even without it, only half a dozen tests fail on `reveal_type()`).
1 parent 13bd201 commit 4471c7e

File tree

4 files changed

+89
-5
lines changed

4 files changed

+89
-5
lines changed

mypy/expandtype.py

+9-5
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434
UnionType,
3535
UnpackType,
3636
expand_param_spec,
37+
flatten_nested_unions,
3738
get_proper_type,
39+
remove_trivial,
3840
)
3941
from mypy.typevartuples import (
4042
find_unpack_in_list,
@@ -405,11 +407,13 @@ def visit_literal_type(self, t: LiteralType) -> Type:
405407
return t
406408

407409
def visit_union_type(self, t: UnionType) -> Type:
408-
# After substituting for type variables in t.items,
409-
# some of the resulting types might be subtypes of others.
410-
from mypy.typeops import make_simplified_union # asdf
411-
412-
return make_simplified_union(self.expand_types(t.items), t.line, t.column)
410+
expanded = self.expand_types(t.items)
411+
# After substituting for type variables in t.items, some resulting types
412+
# might be subtypes of others, however calling make_simplified_union()
413+
# can cause recursion, so we just remove strict duplicates.
414+
return UnionType.make_union(
415+
remove_trivial(flatten_nested_unions(expanded)), t.line, t.column
416+
)
413417

414418
def visit_partial_type(self, t: PartialType) -> Type:
415419
return t

mypy/types.py

+30
Original file line numberDiff line numberDiff line change
@@ -3487,3 +3487,33 @@ def store_argument_type(
34873487
if not isinstance(arg_type, ParamSpecType) and not typ.unpack_kwargs:
34883488
arg_type = named_type("builtins.dict", [named_type("builtins.str", []), arg_type])
34893489
defn.arguments[i].variable.type = arg_type
3490+
3491+
3492+
def remove_trivial(types: Iterable[Type]) -> list[Type]:
3493+
"""Make trivial simplifications on a list of types without calling is_subtype().
3494+
3495+
This makes following simplifications:
3496+
* Remove bottom types (taking into account strict optional setting)
3497+
* Remove everything else if there is an `object`
3498+
* Remove strict duplicate types
3499+
"""
3500+
removed_none = False
3501+
new_types = []
3502+
all_types = set()
3503+
for t in types:
3504+
p_t = get_proper_type(t)
3505+
if isinstance(p_t, UninhabitedType):
3506+
continue
3507+
if isinstance(p_t, NoneType) and not state.strict_optional:
3508+
removed_none = True
3509+
continue
3510+
if isinstance(p_t, Instance) and p_t.type.fullname == "builtins.object":
3511+
return [p_t]
3512+
if p_t not in all_types:
3513+
new_types.append(t)
3514+
all_types.add(p_t)
3515+
if new_types:
3516+
return new_types
3517+
if removed_none:
3518+
return [NoneType()]
3519+
return [UninhabitedType()]

test-data/unit/check-recursive-types.test

+34
Original file line numberDiff line numberDiff line change
@@ -837,3 +837,37 @@ def foo(x: T) -> C: ...
837837

838838
Nested = Union[C, Sequence[Nested]]
839839
x: Nested = foo(42)
840+
841+
[case testNoRecursiveExpandInstanceUnionCrash]
842+
from typing import List, Union
843+
844+
class Tag(List[Union[Tag, List[Tag]]]): ...
845+
Tag()
846+
847+
[case testNoRecursiveExpandInstanceUnionCrashGeneric]
848+
from typing import Generic, Iterable, TypeVar, Union
849+
850+
ValueT = TypeVar("ValueT")
851+
class Recursive(Iterable[Union[ValueT, Recursive[ValueT]]]):
852+
pass
853+
854+
class Base(Generic[ValueT]):
855+
def __init__(self, element: ValueT):
856+
pass
857+
class Sub(Base[Union[ValueT, Recursive[ValueT]]]):
858+
pass
859+
860+
x: Iterable[str]
861+
reveal_type(Sub) # N: Revealed type is "def [ValueT] (element: Union[ValueT`1, __main__.Recursive[ValueT`1]]) -> __main__.Sub[ValueT`1]"
862+
reveal_type(Sub(x)) # N: Revealed type is "__main__.Sub[typing.Iterable[builtins.str]]"
863+
864+
[case testNoRecursiveExpandInstanceUnionCrashInference]
865+
from typing import TypeVar, Union, Generic, List
866+
867+
T = TypeVar("T")
868+
InList = Union[T, InListRecurse[T]]
869+
class InListRecurse(Generic[T], List[InList[T]]): ...
870+
871+
def list_thing(transforming: InList[T]) -> T:
872+
...
873+
reveal_type(list_thing([5])) # N: Revealed type is "builtins.list[builtins.int]"

test-data/unit/pythoneval.test

+16
Original file line numberDiff line numberDiff line change
@@ -1735,3 +1735,19 @@ _testEnumNameWorkCorrectlyOn311.py:12: note: Revealed type is "Union[Literal[1]?
17351735
_testEnumNameWorkCorrectlyOn311.py:13: note: Revealed type is "Literal['X']?"
17361736
_testEnumNameWorkCorrectlyOn311.py:14: note: Revealed type is "builtins.int"
17371737
_testEnumNameWorkCorrectlyOn311.py:15: note: Revealed type is "builtins.int"
1738+
1739+
[case testTypedDictUnionGetFull]
1740+
from typing import Dict
1741+
from typing_extensions import TypedDict
1742+
1743+
class TD(TypedDict, total=False):
1744+
x: int
1745+
y: int
1746+
1747+
A = Dict[str, TD]
1748+
x: A
1749+
def foo(k: str) -> TD:
1750+
reveal_type(x.get(k, {}))
1751+
return x.get(k, {})
1752+
[out]
1753+
_testTypedDictUnionGetFull.py:11: note: Revealed type is "TypedDict('_testTypedDictUnionGetFull.TD', {'x'?: builtins.int, 'y'?: builtins.int})"

0 commit comments

Comments
 (0)