@@ -5009,6 +5009,38 @@ def conditional_callable_type_map(
5009
5009
5010
5010
return None , {}
5011
5011
5012
+ def contains_operator_right_operand_type_map (self , item_type : Type , collection_type : Type ) -> tuple [Type , Type ]:
5013
+ """
5014
+ Deduces the type of the right operand of the `in` operator.
5015
+ For now, we only support narrowing unions of TypedDicts based on left operand being literal string(s).
5016
+ """
5017
+ if_types , else_types = [collection_type ], [collection_type ]
5018
+ item_strs = try_getting_str_literals_from_type (item_type )
5019
+ if item_strs :
5020
+ if_types , else_types = self ._contains_string_right_operand_type_map (set (item_strs ), collection_type )
5021
+ return UnionType .make_union (if_types ), UnionType .make_union (else_types )
5022
+
5023
+ def _contains_string_right_operand_type_map (self , item_strs : set [str ], t : Type ) -> tuple [list [Type ], list [Type ]]:
5024
+ t = get_proper_type (t )
5025
+ if_types : list [Type ] = []
5026
+ else_types : list [Type ] = []
5027
+ if isinstance (t , TypedDictType ):
5028
+ if item_strs <= t .items .keys ():
5029
+ if_types .append (t )
5030
+ elif item_strs .isdisjoint (t .items .keys ()):
5031
+ else_types .append (t )
5032
+ else :
5033
+ if_types .append (t )
5034
+ else_types .append (t )
5035
+ elif isinstance (t , UnionType ):
5036
+ for union_item in t .items :
5037
+ a , b = self ._contains_string_right_operand_type_map (item_strs , union_item )
5038
+ if_types .extend (a )
5039
+ else_types .extend (b )
5040
+ else :
5041
+ if_types = else_types = [t ]
5042
+ return if_types , else_types
5043
+
5012
5044
def _is_truthy_type (self , t : ProperType ) -> bool :
5013
5045
return (
5014
5046
(
@@ -5316,28 +5348,28 @@ def has_no_custom_eq_checks(t: Type) -> bool:
5316
5348
elif operator in {"in" , "not in" }:
5317
5349
assert len (expr_indices ) == 2
5318
5350
left_index , right_index = expr_indices
5319
- if left_index not in narrowable_operand_index_to_hash :
5320
- continue
5321
-
5322
5351
item_type = operand_types [left_index ]
5323
5352
collection_type = operand_types [right_index ]
5324
5353
5325
- # We only try and narrow away 'None' for now
5326
- if not is_optional (item_type ):
5327
- continue
5354
+ if_map , else_map = {}, {}
5355
+
5356
+ if left_index in narrowable_operand_index_to_hash :
5357
+ # We only try and narrow away 'None' for now
5358
+ if is_optional (item_type ):
5359
+ collection_item_type = get_proper_type (builtin_item_type (collection_type ))
5360
+ if (collection_item_type is not None
5361
+ and not is_optional (collection_item_type )
5362
+ and not (isinstance (collection_item_type , Instance ) and collection_item_type .type .fullname == "builtins.object" )
5363
+ and is_overlapping_erased_types (item_type , collection_item_type )
5364
+ ):
5365
+ if_map [operands [left_index ]] = remove_optional (item_type )
5366
+
5367
+ if right_index in narrowable_operand_index_to_hash :
5368
+ right_if_type , right_else_type = self .contains_operator_right_operand_type_map (item_type , collection_type )
5369
+ expr = operands [right_index ]
5370
+ if_map [expr ] = right_if_type
5371
+ else_map [expr ] = right_else_type
5328
5372
5329
- collection_item_type = get_proper_type (builtin_item_type (collection_type ))
5330
- if collection_item_type is None or is_optional (collection_item_type ):
5331
- continue
5332
- if (
5333
- isinstance (collection_item_type , Instance )
5334
- and collection_item_type .type .fullname == "builtins.object"
5335
- ):
5336
- continue
5337
- if is_overlapping_erased_types (item_type , collection_item_type ):
5338
- if_map , else_map = {operands [left_index ]: remove_optional (item_type )}, {}
5339
- else :
5340
- continue
5341
5373
else :
5342
5374
if_map = {}
5343
5375
else_map = {}
@@ -5390,6 +5422,10 @@ def has_no_custom_eq_checks(t: Type) -> bool:
5390
5422
or_conditional_maps (left_if_vars , right_if_vars ),
5391
5423
and_conditional_maps (left_else_vars , right_else_vars ),
5392
5424
)
5425
+ elif isinstance (node , OpExpr ) and node .op == "in" :
5426
+ left_if_vars , left_else_vars = self .find_isinstance_check (node .left )
5427
+ right_if_vars , right_else_vars = self .find_isinstance_check (node .right )
5428
+
5393
5429
elif isinstance (node , UnaryExpr ) and node .op == "not" :
5394
5430
left , right = self .find_isinstance_check (node .expr )
5395
5431
return right , left
0 commit comments