Skip to content

Commit c6f793e

Browse files
committed
'in' can narrow TypedDict unions
1 parent d528bf2 commit c6f793e

File tree

2 files changed

+95
-18
lines changed

2 files changed

+95
-18
lines changed

mypy/checker.py

Lines changed: 54 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5009,6 +5009,38 @@ def conditional_callable_type_map(
50095009

50105010
return None, {}
50115011

5012+
def contains_operator_right_operand_type_map(self, item_type: Type, collection_type: Type) -> tuple[Type, Type]:
5013+
"""
5014+
Deduces the type of the right operand of the `in` operator.
5015+
For now, we only support narrowing unions of TypedDicts based on left operand being literal string(s).
5016+
"""
5017+
if_types, else_types = [collection_type], [collection_type]
5018+
item_strs = try_getting_str_literals_from_type(item_type)
5019+
if item_strs:
5020+
if_types, else_types = self._contains_string_right_operand_type_map(set(item_strs), collection_type)
5021+
return UnionType.make_union(if_types), UnionType.make_union(else_types)
5022+
5023+
def _contains_string_right_operand_type_map(self, item_strs: set[str], t: Type) -> tuple[list[Type], list[Type]]:
5024+
t = get_proper_type(t)
5025+
if_types: list[Type] = []
5026+
else_types: list[Type] = []
5027+
if isinstance(t, TypedDictType):
5028+
if item_strs <= t.items.keys():
5029+
if_types.append(t)
5030+
elif item_strs.isdisjoint(t.items.keys()):
5031+
else_types.append(t)
5032+
else:
5033+
if_types.append(t)
5034+
else_types.append(t)
5035+
elif isinstance(t, UnionType):
5036+
for union_item in t.items:
5037+
a, b = self._contains_string_right_operand_type_map(item_strs, union_item)
5038+
if_types.extend(a)
5039+
else_types.extend(b)
5040+
else:
5041+
if_types = else_types = [t]
5042+
return if_types, else_types
5043+
50125044
def _is_truthy_type(self, t: ProperType) -> bool:
50135045
return (
50145046
(
@@ -5316,28 +5348,28 @@ def has_no_custom_eq_checks(t: Type) -> bool:
53165348
elif operator in {"in", "not in"}:
53175349
assert len(expr_indices) == 2
53185350
left_index, right_index = expr_indices
5319-
if left_index not in narrowable_operand_index_to_hash:
5320-
continue
5321-
53225351
item_type = operand_types[left_index]
53235352
collection_type = operand_types[right_index]
53245353

5325-
# We only try and narrow away 'None' for now
5326-
if not is_optional(item_type):
5327-
continue
5354+
if_map, else_map = {}, {}
5355+
5356+
if left_index in narrowable_operand_index_to_hash:
5357+
# We only try and narrow away 'None' for now
5358+
if is_optional(item_type):
5359+
collection_item_type = get_proper_type(builtin_item_type(collection_type))
5360+
if (collection_item_type is not None
5361+
and not is_optional(collection_item_type)
5362+
and not (isinstance(collection_item_type, Instance) and collection_item_type.type.fullname == "builtins.object")
5363+
and is_overlapping_erased_types(item_type, collection_item_type)
5364+
):
5365+
if_map[operands[left_index]] = remove_optional(item_type)
5366+
5367+
if right_index in narrowable_operand_index_to_hash:
5368+
right_if_type, right_else_type = self.contains_operator_right_operand_type_map(item_type, collection_type)
5369+
expr = operands[right_index]
5370+
if_map[expr] = right_if_type
5371+
else_map[expr] = right_else_type
53285372

5329-
collection_item_type = get_proper_type(builtin_item_type(collection_type))
5330-
if collection_item_type is None or is_optional(collection_item_type):
5331-
continue
5332-
if (
5333-
isinstance(collection_item_type, Instance)
5334-
and collection_item_type.type.fullname == "builtins.object"
5335-
):
5336-
continue
5337-
if is_overlapping_erased_types(item_type, collection_item_type):
5338-
if_map, else_map = {operands[left_index]: remove_optional(item_type)}, {}
5339-
else:
5340-
continue
53415373
else:
53425374
if_map = {}
53435375
else_map = {}
@@ -5390,6 +5422,10 @@ def has_no_custom_eq_checks(t: Type) -> bool:
53905422
or_conditional_maps(left_if_vars, right_if_vars),
53915423
and_conditional_maps(left_else_vars, right_else_vars),
53925424
)
5425+
elif isinstance(node, OpExpr) and node.op == "in":
5426+
left_if_vars, left_else_vars = self.find_isinstance_check(node.left)
5427+
right_if_vars, right_else_vars = self.find_isinstance_check(node.right)
5428+
53935429
elif isinstance(node, UnaryExpr) and node.op == "not":
53945430
left, right = self.find_isinstance_check(node.expr)
53955431
return right, left

test-data/unit/check-typeddict.test

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2012,6 +2012,47 @@ v = {bad2: 2} # E: Extra key "bad" for TypedDict "Value"
20122012
[builtins fixtures/dict.pyi]
20132013
[typing fixtures/typing-typeddict.pyi]
20142014

2015+
[case testFinalTypedDictTagged]
2016+
from __future__ import annotations
2017+
from typing import Literal, TypedDict
2018+
from typing_extensions import final
2019+
2020+
@final
2021+
class D1(TypedDict):
2022+
foo: int
2023+
2024+
2025+
@final
2026+
class D2(TypedDict):
2027+
bar: int
2028+
2029+
d: D1 | D2
2030+
val: int
2031+
2032+
val = d['foo'] # E: TypedDict "D2" has no key "foo"
2033+
if 'foo' in d:
2034+
val = d['foo']
2035+
else:
2036+
val = d['bar']
2037+
2038+
foo_or_bar: Literal['foo', 'bar']
2039+
if foo_or_bar in d:
2040+
val = d['foo'] # E: TypedDict "D2" has no key "foo"
2041+
val = d['bar'] # E: TypedDict "D1" has no key "bar"
2042+
else:
2043+
val = d['foo'] # E: TypedDict "D2" has no key "foo"
2044+
val = d['bar'] # E: TypedDict "D1" has no key "bar"
2045+
2046+
foo_or_invalid: Literal['foo', 'invalid']
2047+
if foo_or_invalid in d:
2048+
val = d['foo']
2049+
else:
2050+
val = d['foo'] # E: TypedDict "D2" has no key "foo"
2051+
val = d['bar'] # E: TypedDict "D1" has no key "bar"
2052+
2053+
[builtins fixtures/dict.pyi]
2054+
[typing fixtures/typing-typeddict.pyi]
2055+
20152056
[case testCannotSubclassFinalTypedDict]
20162057
from typing import TypedDict
20172058
from typing_extensions import final

0 commit comments

Comments
 (0)