Skip to content

Commit 5806bc7

Browse files
committed
Fix narrowing on match with function subject
Fixes #12998 mypy can't narrow match statements with functions subjects because the callexpr node is not a literal node. This adds a 'dummy' literal node that the match statement visitor can use to do the type narrowing. The python grammar describes the the match subject as a named expression so this uses that nameexpr node as it's literal.
1 parent e4c43cb commit 5806bc7

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

mypy/checker.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -5043,8 +5043,13 @@ def visit_continue_stmt(self, s: ContinueStmt) -> None:
50435043
return None
50445044

50455045
def visit_match_stmt(self, s: MatchStmt) -> None:
5046+
# Create a dummy subject expression to handle cases where a match
5047+
# statement's subject is not a literal value which prevent us from correctly
5048+
# narrowing types and checking exhaustivity
5049+
named_subject = NameExpr("match") if isinstance(s.subject, CallExpr) else s.subject
50465050
with self.binder.frame_context(can_skip=False, fall_through=0):
50475051
subject_type = get_proper_type(self.expr_checker.accept(s.subject))
5052+
self.store_type(named_subject, subject_type)
50485053

50495054
if isinstance(subject_type, DeletedType):
50505055
self.msg.deleted_as_rvalue(subject_type, s)
@@ -5061,7 +5066,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
50615066
# The second pass narrows down the types and type checks bodies.
50625067
for p, g, b in zip(s.patterns, s.guards, s.bodies):
50635068
current_subject_type = self.expr_checker.narrow_type_from_binder(
5064-
s.subject, subject_type
5069+
named_subject, subject_type
50655070
)
50665071
pattern_type = self.pattern_checker.accept(p, current_subject_type)
50675072
with self.binder.frame_context(can_skip=True, fall_through=2):
@@ -5072,7 +5077,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
50725077
else_map: TypeMap = {}
50735078
else:
50745079
pattern_map, else_map = conditional_types_to_typemaps(
5075-
s.subject, pattern_type.type, pattern_type.rest_type
5080+
named_subject, pattern_type.type, pattern_type.rest_type
50765081
)
50775082
self.remove_capture_conflicts(pattern_type.captures, inferred_types)
50785083
self.push_type_map(pattern_map)
@@ -5100,7 +5105,7 @@ def visit_match_stmt(self, s: MatchStmt) -> None:
51005105
and expr.fullname == case_target.fullname
51015106
):
51025107
continue
5103-
type_map[s.subject] = type_map[expr]
5108+
type_map[named_subject] = type_map[expr]
51045109

51055110
self.push_type_map(guard_map)
51065111
self.accept(b)

test-data/unit/check-python310.test

+12
Original file line numberDiff line numberDiff line change
@@ -1139,6 +1139,18 @@ match m:
11391139

11401140
reveal_type(a) # N: Revealed type is "builtins.str"
11411141

1142+
[case testMatchCapturePatternFromFunctionReturningUnion]
1143+
def func(arg: bool) -> str | int:
1144+
if arg:
1145+
return 1
1146+
return "a"
1147+
1148+
match func(True):
1149+
case str(a):
1150+
reveal_type(a) # N: Revealed type is "builtins.str"
1151+
case a:
1152+
reveal_type(a) # N: Revealed type is "builtins.int"
1153+
11421154
-- Guards --
11431155

11441156
[case testMatchSimplePatternGuard]

0 commit comments

Comments
 (0)