@@ -2419,28 +2419,46 @@ def push_type_map(self, type_map: Optional[Dict[Expression, Type]]) -> None:
2419
2419
2420
2420
TypeMap = Optional [Dict [Expression , Type ]]
2421
2421
2422
+ # An object that represents either a precise type or a type with an upper bound;
2423
+ # it is important for correct type inference with isinstance.
2424
+ TypeRange = NamedTuple (
2425
+ 'TypeRange' ,
2426
+ [
2427
+ ('item' , Type ),
2428
+ ('is_upper_bound' , bool ), # False => precise type
2429
+ ])
2430
+
2422
2431
2423
2432
def conditional_type_map (expr : Expression ,
2424
2433
current_type : Optional [Type ],
2425
- proposed_type : Optional [Type ],
2434
+ proposed_type_ranges : Optional [List [ TypeRange ] ],
2426
2435
) -> Tuple [TypeMap , TypeMap ]:
2427
2436
"""Takes in an expression, the current type of the expression, and a
2428
2437
proposed type of that expression.
2429
2438
2430
2439
Returns a 2-tuple: The first element is a map from the expression to
2431
2440
the proposed type, if the expression can be the proposed type. The
2432
2441
second element is a map from the expression to the type it would hold
2433
- if it was not the proposed type, if any."""
2434
- if proposed_type :
2442
+ if it was not the proposed type, if any. None means bot, {} means top"""
2443
+ if proposed_type_ranges :
2444
+ if len (proposed_type_ranges ) == 1 :
2445
+ proposed_type = proposed_type_ranges [0 ].item # Union with a single type breaks tests
2446
+ else :
2447
+ proposed_type = UnionType ([type_range .item for type_range in proposed_type_ranges ])
2435
2448
if current_type :
2436
- if is_proper_subtype (current_type , proposed_type ):
2437
- # Expression is always of type proposed_type
2449
+ if not any (type_range .is_upper_bound for type_range in proposed_type_ranges ) \
2450
+ and is_proper_subtype (current_type , proposed_type ):
2451
+ # Expression is always of one of the types in proposed_type_ranges
2438
2452
return {}, None
2439
2453
elif not is_overlapping_types (current_type , proposed_type ):
2440
- # Expression is never of type proposed_type
2454
+ # Expression is never of any type in proposed_type_ranges
2441
2455
return None , {}
2442
2456
else :
2443
- remaining_type = restrict_subtype_away (current_type , proposed_type )
2457
+ # we can only restrict when the type is precise, not bounded
2458
+ proposed_precise_type = UnionType ([type_range .item
2459
+ for type_range in proposed_type_ranges
2460
+ if not type_range .is_upper_bound ])
2461
+ remaining_type = restrict_subtype_away (current_type , proposed_precise_type )
2444
2462
return {expr : proposed_type }, {expr : remaining_type }
2445
2463
else :
2446
2464
return {expr : proposed_type }, {}
@@ -2611,8 +2629,8 @@ def find_isinstance_check(node: Expression,
2611
2629
expr = node .args [0 ]
2612
2630
if expr .literal == LITERAL_TYPE :
2613
2631
vartype = type_map [expr ]
2614
- type = get_isinstance_type (node .args [1 ], type_map )
2615
- return conditional_type_map (expr , vartype , type )
2632
+ types = get_isinstance_type (node .args [1 ], type_map )
2633
+ return conditional_type_map (expr , vartype , types )
2616
2634
elif refers_to_fullname (node .callee , 'builtins.callable' ):
2617
2635
expr = node .args [0 ]
2618
2636
if expr .literal == LITERAL_TYPE :
@@ -2630,7 +2648,8 @@ def find_isinstance_check(node: Expression,
2630
2648
# two elements in node.operands, and at least one of them
2631
2649
# should represent a None.
2632
2650
vartype = type_map [expr ]
2633
- if_vars , else_vars = conditional_type_map (expr , vartype , NoneTyp ())
2651
+ none_typ = [TypeRange (NoneTyp (), is_upper_bound = False )]
2652
+ if_vars , else_vars = conditional_type_map (expr , vartype , none_typ )
2634
2653
break
2635
2654
2636
2655
if is_not :
@@ -2692,33 +2711,31 @@ def flatten(t: Expression) -> List[Expression]:
2692
2711
return [t ]
2693
2712
2694
2713
2695
- def get_isinstance_type (expr : Expression , type_map : Dict [Expression , Type ]) -> Type :
2714
+ def get_isinstance_type (expr : Expression ,
2715
+ type_map : Dict [Expression , Type ]) -> Optional [List [TypeRange ]]:
2696
2716
type = type_map [expr ]
2697
2717
2698
2718
if isinstance (type , TupleType ):
2699
2719
all_types = type .items
2700
2720
else :
2701
2721
all_types = [type ]
2702
2722
2703
- types = [] # type: List[Type ]
2723
+ types = [] # type: List[TypeRange ]
2704
2724
2705
2725
for type in all_types :
2706
2726
if isinstance (type , FunctionLike ):
2707
2727
if type .is_type_obj ():
2708
2728
# Type variables may be present -- erase them, which is the best
2709
2729
# we can do (outside disallowing them here).
2710
2730
type = erase_typevars (type .items ()[0 ].ret_type )
2711
- types .append (type )
2731
+ types .append (TypeRange ( type , is_upper_bound = False ) )
2712
2732
elif isinstance (type , TypeType ):
2713
- types .append (type .item )
2733
+ types .append (TypeRange ( type .item , is_upper_bound = True ) )
2714
2734
else : # we didn't see an actual type, but rather a variable whose value is unknown to us
2715
2735
return None
2716
2736
2717
2737
assert len (types ) != 0
2718
- if len (types ) == 1 :
2719
- return types [0 ]
2720
- else :
2721
- return UnionType (types )
2738
+ return types
2722
2739
2723
2740
2724
2741
def expand_func (defn : FuncItem , map : Dict [TypeVarId , Type ]) -> FuncItem :
0 commit comments