Skip to content

Commit dd0503e

Browse files
authored
Don't consider a branch unreachable if there is possible promotion (#14077)
Fixes #14030 FWIW this looks like an acceptable compromise after discussions in the issue. Also it is easy to implement. Let's see what `mypy_primer` will show.
1 parent cf59b82 commit dd0503e

File tree

7 files changed

+180
-20
lines changed

7 files changed

+180
-20
lines changed

mypy/checker.py

+20-11
Original file line numberDiff line numberDiff line change
@@ -4824,7 +4824,7 @@ def make_fake_typeinfo(
48244824
return cdef, info
48254825

48264826
def intersect_instances(
4827-
self, instances: tuple[Instance, Instance], ctx: Context
4827+
self, instances: tuple[Instance, Instance], errors: list[tuple[str, str]]
48284828
) -> Instance | None:
48294829
"""Try creating an ad-hoc intersection of the given instances.
48304830
@@ -4851,6 +4851,17 @@ def intersect_instances(
48514851
curr_module = self.scope.stack[0]
48524852
assert isinstance(curr_module, MypyFile)
48534853

4854+
# First, retry narrowing while allowing promotions (they are disabled by default
4855+
# for isinstance() checks, etc). This way we will still type-check branches like
4856+
# x: complex = 1
4857+
# if isinstance(x, int):
4858+
# ...
4859+
left, right = instances
4860+
if is_proper_subtype(left, right, ignore_promotions=False):
4861+
return left
4862+
if is_proper_subtype(right, left, ignore_promotions=False):
4863+
return right
4864+
48544865
def _get_base_classes(instances_: tuple[Instance, Instance]) -> list[Instance]:
48554866
base_classes_ = []
48564867
for inst in instances_:
@@ -4891,17 +4902,10 @@ def _make_fake_typeinfo_and_full_name(
48914902
self.check_multiple_inheritance(info)
48924903
info.is_intersection = True
48934904
except MroError:
4894-
if self.should_report_unreachable_issues():
4895-
self.msg.impossible_intersection(
4896-
pretty_names_list, "inconsistent method resolution order", ctx
4897-
)
4905+
errors.append((pretty_names_list, "inconsistent method resolution order"))
48984906
return None
4899-
49004907
if local_errors.has_new_errors():
4901-
if self.should_report_unreachable_issues():
4902-
self.msg.impossible_intersection(
4903-
pretty_names_list, "incompatible method signatures", ctx
4904-
)
4908+
errors.append((pretty_names_list, "incompatible method signatures"))
49054909
return None
49064910

49074911
curr_module.names[full_name] = SymbolTableNode(GDEF, info)
@@ -6355,15 +6359,20 @@ def conditional_types_with_intersection(
63556359
possible_target_types.append(item)
63566360

63576361
out = []
6362+
errors: list[tuple[str, str]] = []
63586363
for v in possible_expr_types:
63596364
if not isinstance(v, Instance):
63606365
return yes_type, no_type
63616366
for t in possible_target_types:
6362-
intersection = self.intersect_instances((v, t), ctx)
6367+
intersection = self.intersect_instances((v, t), errors)
63636368
if intersection is None:
63646369
continue
63656370
out.append(intersection)
63666371
if len(out) == 0:
6372+
# Only report errors if no element in the union worked.
6373+
if self.should_report_unreachable_issues():
6374+
for types, reason in errors:
6375+
self.msg.impossible_intersection(types, reason, ctx)
63676376
return UninhabitedType(), expr_type
63686377
new_yes_type = make_simplified_union(out)
63696378
return new_yes_type, expr_type

mypy/join.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,11 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:
141141

142142

143143
def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
144-
"""Return a simple least upper bound given the declared type."""
145-
# TODO: check infinite recursion for aliases here?
144+
"""Return a simple least upper bound given the declared type.
145+
146+
This function should be only used by binder, and should not recurse.
147+
For all other uses, use `join_types()`.
148+
"""
146149
declaration = get_proper_type(declaration)
147150
s = get_proper_type(s)
148151
t = get_proper_type(t)
@@ -158,10 +161,10 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
158161
if isinstance(s, ErasedType):
159162
return t
160163

161-
if is_proper_subtype(s, t):
164+
if is_proper_subtype(s, t, ignore_promotions=True):
162165
return t
163166

164-
if is_proper_subtype(t, s):
167+
if is_proper_subtype(t, s, ignore_promotions=True):
165168
return s
166169

167170
if isinstance(declaration, UnionType):
@@ -176,6 +179,9 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
176179
# Meets/joins require callable type normalization.
177180
s, t = normalize_callables(s, t)
178181

182+
if isinstance(s, UnionType) and not isinstance(t, UnionType):
183+
s, t = t, s
184+
179185
value = t.accept(TypeJoinVisitor(s))
180186
if declaration is None or is_subtype(value, declaration):
181187
return value

mypy/meet.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,15 @@ def narrow_declared_type(declared: Type, narrowed: Type) -> Type:
124124
[
125125
narrow_declared_type(x, narrowed)
126126
for x in declared.relevant_items()
127-
if is_overlapping_types(x, narrowed, ignore_promotions=True)
127+
# This (ugly) special-casing is needed to support checking
128+
# branches like this:
129+
# x: Union[float, complex]
130+
# if isinstance(x, int):
131+
# ...
132+
if (
133+
is_overlapping_types(x, narrowed, ignore_promotions=True)
134+
or is_subtype(narrowed, x, ignore_promotions=False)
135+
)
128136
]
129137
)
130138
if is_enum_overlapping_union(declared, narrowed):

test-data/unit/check-classes.test

+1-1
Original file line numberDiff line numberDiff line change
@@ -7209,7 +7209,7 @@ from typing import Callable
72097209
class C:
72107210
x: Callable[[C], int] = lambda x: x.y.g() # E: "C" has no attribute "y"
72117211

7212-
[case testOpWithInheritedFromAny]
7212+
[case testOpWithInheritedFromAny-xfail]
72137213
from typing import Any
72147214
C: Any
72157215
class D(C):

test-data/unit/check-isinstance.test

+1-1
Original file line numberDiff line numberDiff line change
@@ -2392,7 +2392,7 @@ class B:
23922392
x1: Literal[1] = self.f()
23932393

23942394
def t2(self) -> None:
2395-
if isinstance(self, (A0, A1)): # E: Subclass of "B" and "A0" cannot exist: would have incompatible method signatures
2395+
if isinstance(self, (A0, A1)):
23962396
reveal_type(self) # N: Revealed type is "__main__.<subclass of "A1" and "B">1"
23972397
x0: Literal[0] = self.f() # E: Incompatible types in assignment (expression has type "Literal[1]", variable has type "Literal[0]")
23982398
x1: Literal[1] = self.f()

test-data/unit/check-type-promotion.test

+133
Original file line numberDiff line numberDiff line change
@@ -54,3 +54,136 @@ def f(x: Union[SupportsFloat, T]) -> Union[SupportsFloat, T]: pass
5454
f(0) # should not crash
5555
[builtins fixtures/primitives.pyi]
5656
[out]
57+
58+
[case testIntersectionUsingPromotion1]
59+
# flags: --warn-unreachable
60+
from typing import Union
61+
62+
x: complex = 1
63+
reveal_type(x) # N: Revealed type is "builtins.complex"
64+
if isinstance(x, int):
65+
reveal_type(x) # N: Revealed type is "builtins.int"
66+
else:
67+
reveal_type(x) # N: Revealed type is "builtins.complex"
68+
reveal_type(x) # N: Revealed type is "builtins.complex"
69+
70+
y: Union[int, float]
71+
if isinstance(y, float):
72+
reveal_type(y) # N: Revealed type is "builtins.float"
73+
else:
74+
reveal_type(y) # N: Revealed type is "builtins.int"
75+
76+
reveal_type(y) # N: Revealed type is "Union[builtins.int, builtins.float]"
77+
78+
if isinstance(y, int):
79+
reveal_type(y) # N: Revealed type is "builtins.int"
80+
else:
81+
reveal_type(y) # N: Revealed type is "builtins.float"
82+
[builtins fixtures/primitives.pyi]
83+
84+
[case testIntersectionUsingPromotion2]
85+
# flags: --warn-unreachable
86+
x: complex = 1
87+
reveal_type(x) # N: Revealed type is "builtins.complex"
88+
if isinstance(x, (int, float)):
89+
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.float]"
90+
else:
91+
reveal_type(x) # N: Revealed type is "builtins.complex"
92+
93+
# Note we make type precise, since type promotions are involved
94+
reveal_type(x) # N: Revealed type is "Union[builtins.complex, builtins.int, builtins.float]"
95+
[builtins fixtures/primitives.pyi]
96+
97+
[case testIntersectionUsingPromotion3]
98+
# flags: --warn-unreachable
99+
x: object
100+
if isinstance(x, int) and isinstance(x, complex):
101+
reveal_type(x) # N: Revealed type is "builtins.int"
102+
if isinstance(x, complex) and isinstance(x, int):
103+
reveal_type(x) # N: Revealed type is "builtins.int"
104+
[builtins fixtures/primitives.pyi]
105+
106+
[case testIntersectionUsingPromotion4]
107+
# flags: --warn-unreachable
108+
x: object
109+
if isinstance(x, int):
110+
if isinstance(x, complex):
111+
reveal_type(x) # N: Revealed type is "builtins.int"
112+
else:
113+
reveal_type(x) # N: Revealed type is "builtins.int"
114+
if isinstance(x, complex):
115+
if isinstance(x, int):
116+
reveal_type(x) # N: Revealed type is "builtins.int"
117+
else:
118+
reveal_type(x) # N: Revealed type is "builtins.complex"
119+
[builtins fixtures/primitives.pyi]
120+
121+
[case testIntersectionUsingPromotion5]
122+
# flags: --warn-unreachable
123+
from typing import Union
124+
125+
x: Union[float, complex]
126+
if isinstance(x, int):
127+
reveal_type(x) # N: Revealed type is "builtins.int"
128+
else:
129+
reveal_type(x) # N: Revealed type is "Union[builtins.float, builtins.complex]"
130+
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.complex]"
131+
[builtins fixtures/primitives.pyi]
132+
133+
[case testIntersectionUsingPromotion6]
134+
# flags: --warn-unreachable
135+
from typing import Union
136+
137+
x: Union[str, complex]
138+
if isinstance(x, int):
139+
reveal_type(x) # N: Revealed type is "builtins.int"
140+
else:
141+
reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.complex]"
142+
reveal_type(x) # N: Revealed type is "Union[builtins.str, builtins.int, builtins.complex]"
143+
[builtins fixtures/primitives.pyi]
144+
145+
[case testIntersectionUsingPromotion7]
146+
# flags: --warn-unreachable
147+
from typing import Union
148+
149+
x: Union[int, float, complex]
150+
if isinstance(x, int):
151+
reveal_type(x) # N: Revealed type is "builtins.int"
152+
else:
153+
reveal_type(x) # N: Revealed type is "Union[builtins.float, builtins.complex]"
154+
155+
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.complex]"
156+
157+
if isinstance(x, float):
158+
reveal_type(x) # N: Revealed type is "builtins.float"
159+
else:
160+
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.complex]"
161+
162+
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.complex]"
163+
164+
if isinstance(x, complex):
165+
reveal_type(x) # N: Revealed type is "builtins.complex"
166+
else:
167+
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.float]"
168+
169+
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.float, builtins.complex]"
170+
[builtins fixtures/primitives.pyi]
171+
172+
[case testIntersectionUsingPromotion8]
173+
# flags: --warn-unreachable
174+
from typing import Union
175+
176+
x: Union[int, float, complex]
177+
if isinstance(x, (int, float)):
178+
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.float]"
179+
else:
180+
reveal_type(x) # N: Revealed type is "builtins.complex"
181+
if isinstance(x, (int, complex)):
182+
reveal_type(x) # N: Revealed type is "Union[builtins.int, builtins.complex]"
183+
else:
184+
reveal_type(x) # N: Revealed type is "builtins.float"
185+
if isinstance(x, (float, complex)):
186+
reveal_type(x) # N: Revealed type is "Union[builtins.float, builtins.complex]"
187+
else:
188+
reveal_type(x) # N: Revealed type is "builtins.int"
189+
[builtins fixtures/primitives.pyi]

test-data/unit/fixtures/primitives.pyi

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# builtins stub with non-generic primitive types
2-
from typing import Generic, TypeVar, Sequence, Iterator, Mapping, Iterable, overload
2+
from typing import Generic, TypeVar, Sequence, Iterator, Mapping, Iterable, Tuple, Union
33

44
T = TypeVar('T')
55
V = TypeVar('V')
@@ -20,7 +20,9 @@ class int:
2020
def __rmul__(self, x: int) -> int: pass
2121
class float:
2222
def __float__(self) -> float: pass
23-
class complex: pass
23+
def __add__(self, x: float) -> float: pass
24+
class complex:
25+
def __add__(self, x: complex) -> complex: pass
2426
class bool(int): pass
2527
class str(Sequence[str]):
2628
def __add__(self, s: str) -> str: pass
@@ -63,3 +65,5 @@ class range(Sequence[int]):
6365
def __getitem__(self, i: int) -> int: pass
6466
def __iter__(self) -> Iterator[int]: pass
6567
def __contains__(self, other: object) -> bool: pass
68+
69+
def isinstance(x: object, t: Union[type, Tuple]) -> bool: pass

0 commit comments

Comments
 (0)