Skip to content

Commit 69eaf88

Browse files
authored
Support indexing unions containing tuples (#6475)
Fixes #4286 Fixes #2509 The fix is straightforward, map the indexing through union types. I decided to use the existing error message mentioning the full type. IMO `Value of type "Union[int, List[int]]" is not indexable` is better than `Value of type "int" is not idexable`.
1 parent 72a6b1b commit 69eaf88

File tree

4 files changed

+56
-11
lines changed

4 files changed

+56
-11
lines changed

mypy/checkexpr.py

+24-6
Original file line numberDiff line numberDiff line change
@@ -2012,14 +2012,18 @@ def check_method_call_by_name(self,
20122012
arg_kinds: List[int],
20132013
context: Context,
20142014
local_errors: Optional[MessageBuilder] = None,
2015+
original_type: Optional[Type] = None
20152016
) -> Tuple[Type, Type]:
20162017
"""Type check a call to a named method on an object.
20172018
2018-
Return tuple (result type, inferred method type).
2019+
Return tuple (result type, inferred method type). The 'original_type'
2020+
is used for error messages.
20192021
"""
20202022
local_errors = local_errors or self.msg
2023+
original_type = original_type or base_type
20212024
method_type = analyze_member_access(method, base_type, context, False, False, True,
2022-
local_errors, original_type=base_type, chk=self.chk,
2025+
local_errors, original_type=original_type,
2026+
chk=self.chk,
20232027
in_literal_context=self.is_literal_context())
20242028
return self.check_method_call(
20252029
method, base_type, method_type, args, arg_kinds, context, local_errors)
@@ -2441,18 +2445,31 @@ def visit_index_expr_helper(self, e: IndexExpr) -> Type:
24412445
# It's actually a type application.
24422446
return self.accept(e.analyzed)
24432447
left_type = self.accept(e.base)
2444-
if isinstance(left_type, TupleType) and self.chk.in_checked_function():
2448+
return self.visit_index_with_type(left_type, e)
2449+
2450+
def visit_index_with_type(self, left_type: Type, e: IndexExpr,
2451+
original_type: Optional[Type] = None) -> Type:
2452+
"""Analyze type of an index expression for a given type of base expression.
2453+
2454+
The 'original_type' is used for error messages (currently used for union types).
2455+
"""
2456+
index = e.index
2457+
if isinstance(left_type, UnionType):
2458+
original_type = original_type or left_type
2459+
return UnionType.make_simplified_union([self.visit_index_with_type(typ, e,
2460+
original_type)
2461+
for typ in left_type.relevant_items()])
2462+
elif isinstance(left_type, TupleType) and self.chk.in_checked_function():
24452463
# Special case for tuples. They return a more specific type when
24462464
# indexed by an integer literal.
2447-
index = e.index
24482465
if isinstance(index, SliceExpr):
24492466
return self.visit_tuple_slice_helper(left_type, index)
24502467

24512468
n = self._get_value(index)
24522469
if n is not None:
24532470
if n < 0:
24542471
n += len(left_type.items)
2455-
if n >= 0 and n < len(left_type.items):
2472+
if 0 <= n < len(left_type.items):
24562473
return left_type.items[n]
24572474
else:
24582475
self.chk.fail(message_registry.TUPLE_INDEX_OUT_OF_RANGE, e)
@@ -2466,7 +2483,8 @@ def visit_index_expr_helper(self, e: IndexExpr) -> Type:
24662483
return self.visit_enum_index_expr(left_type.type_object(), e.index, e)
24672484
else:
24682485
result, method_type = self.check_method_call_by_name(
2469-
'__getitem__', left_type, [e.index], [ARG_POS], e)
2486+
'__getitem__', left_type, [e.index], [ARG_POS], e,
2487+
original_type=original_type)
24702488
e.method_type = method_type
24712489
return result
24722490

test-data/unit/check-classes.test

+1-3
Original file line numberDiff line numberDiff line change
@@ -4548,9 +4548,7 @@ class Bar(TypedDict):
45484548

45494549
def foo(node: NodeType) -> int:
45504550
x = node
4551-
# TODO: This is incorrect (https://github.com/python/mypy/issues/5930), but ensure that it
4552-
# doesn't crash at least
4553-
return x['x'] # E: Incompatible return value type (got "object", expected "int")
4551+
return x['x']
45544552
[builtins fixtures/isinstancelist.pyi]
45554553
[out]
45564554

test-data/unit/check-tuples.test

+24
Original file line numberDiff line numberDiff line change
@@ -1209,3 +1209,27 @@ else:
12091209
reveal_type(x) # E: Revealed type is 'Union[Tuple[__main__.B], None]'
12101210

12111211
[builtins fixtures/tuple.pyi]
1212+
1213+
[case testUnionOfTupleIndex]
1214+
from typing import Union, Tuple
1215+
1216+
tup: Union[Tuple[int, str], Tuple[int, int, str]]
1217+
reveal_type(tup[0]) # E: Revealed type is 'builtins.int'
1218+
reveal_type(tup[1]) # E: Revealed type is 'Union[builtins.str, builtins.int]'
1219+
reveal_type(tup[2]) # E: Revealed type is 'Union[Any, builtins.str]' \
1220+
# E: Tuple index out of range
1221+
reveal_type(tup[:]) # E: Revealed type is 'Union[Tuple[builtins.int, builtins.str], Tuple[builtins.int, builtins.int, builtins.str]]'
1222+
1223+
[builtins fixtures/tuple.pyi]
1224+
1225+
[case testUnionOfTupleIndexMixed]
1226+
from typing import Union, Tuple, List
1227+
1228+
tup: Union[Tuple[int, str], List[int]]
1229+
reveal_type(tup[0]) # E: Revealed type is 'builtins.int'
1230+
reveal_type(tup[1]) # E: Revealed type is 'Union[builtins.str, builtins.int*]'
1231+
reveal_type(tup[2]) # E: Revealed type is 'Union[Any, builtins.int*]' \
1232+
# E: Tuple index out of range
1233+
reveal_type(tup[:]) # E: Revealed type is 'Union[Tuple[builtins.int, builtins.str], builtins.list[builtins.int*]]'
1234+
1235+
[builtins fixtures/tuple.pyi]

test-data/unit/fixtures/tuple.pyi

+7-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Builtins stub used in tuple-related test cases.
22

3-
from typing import Iterable, Iterator, TypeVar, Generic, Sequence, Any
3+
from typing import Iterable, Iterator, TypeVar, Generic, Sequence, Any, overload
44

55
Tco = TypeVar('Tco', covariant=True)
66

@@ -28,7 +28,12 @@ class unicode: pass
2828

2929
T = TypeVar('T')
3030

31-
class list(Sequence[T], Generic[T]): pass
31+
class list(Sequence[T], Generic[T]):
32+
@overload
33+
def __getitem__(self, i: int) -> T: ...
34+
@overload
35+
def __getitem__(self, s: slice) -> list[T]: ...
36+
3237
def isinstance(x: object, t: type) -> bool: pass
3338

3439
def sum(iterable: Iterable[T], start: T = None) -> T: pass

0 commit comments

Comments
 (0)