@@ -1195,18 +1195,49 @@ def plausible_overload_call_targets(self,
1195
1195
arg_kinds : List [int ],
1196
1196
arg_names : Optional [Sequence [Optional [str ]]],
1197
1197
overload : Overloaded ) -> List [CallableType ]:
1198
- """Returns all overload call targets that having matching argument counts."""
1198
+ """Returns all overload call targets that having matching argument counts.
1199
+
1200
+ If the given args contains a star-arg (*arg or **kwarg argument), this method
1201
+ will ensure all star-arg overloads appear at the start of the list, instead
1202
+ of their usual location.
1203
+
1204
+ The only exception is if the starred argument is something like a Tuple or a
1205
+ NamedTuple, which has a definitive "shape". If so, we don't move the corresponding
1206
+ alternative to the front since we can infer a more precise match using the original
1207
+ order."""
1208
+
1209
+ def has_shape (typ : Type ) -> bool :
1210
+ # TODO: Once https://github.com/python/mypy/issues/5198 is fixed,
1211
+ # add 'isinstance(typ, TypedDictType)' somewhere below.
1212
+ return (isinstance (typ , TupleType )
1213
+ or (isinstance (typ , Instance ) and typ .type .is_named_tuple ))
1214
+
1199
1215
matches = [] # type: List[CallableType]
1216
+ star_matches = [] # type: List[CallableType]
1217
+
1218
+ args_have_var_arg = False
1219
+ args_have_kw_arg = False
1220
+ for kind , typ in zip (arg_kinds , arg_types ):
1221
+ if kind == ARG_STAR and not has_shape (typ ):
1222
+ args_have_var_arg = True
1223
+ if kind == ARG_STAR2 and not has_shape (typ ):
1224
+ args_have_kw_arg = True
1225
+
1200
1226
for typ in overload .items ():
1201
1227
formal_to_actual = map_actuals_to_formals (arg_kinds , arg_names ,
1202
1228
typ .arg_kinds , typ .arg_names ,
1203
1229
lambda i : arg_types [i ])
1204
1230
1205
1231
if self .check_argument_count (typ , arg_types , arg_kinds , arg_names ,
1206
1232
formal_to_actual , None , None ):
1207
- matches .append (typ )
1233
+ if args_have_var_arg and typ .is_var_arg :
1234
+ star_matches .append (typ )
1235
+ elif args_have_kw_arg and typ .is_kw_arg :
1236
+ star_matches .append (typ )
1237
+ else :
1238
+ matches .append (typ )
1208
1239
1209
- return matches
1240
+ return star_matches + matches
1210
1241
1211
1242
def infer_overload_return_type (self ,
1212
1243
plausible_targets : List [CallableType ],
@@ -1273,15 +1304,20 @@ def infer_overload_return_type(self,
1273
1304
return None
1274
1305
elif any_causes_overload_ambiguity (matches , return_types , arg_types , arg_kinds , arg_names ):
1275
1306
# An argument of type or containing the type 'Any' caused ambiguity.
1276
- # We infer a type of 'Any'
1277
- return self .check_call (callee = AnyType (TypeOfAny .special_form ),
1278
- args = args ,
1279
- arg_kinds = arg_kinds ,
1280
- arg_names = arg_names ,
1281
- context = context ,
1282
- arg_messages = arg_messages ,
1283
- callable_name = callable_name ,
1284
- object_type = object_type )
1307
+ if all (is_subtype (ret_type , return_types [- 1 ]) for ret_type in return_types [:- 1 ]):
1308
+ # The last match is a supertype of all the previous ones, so it's safe
1309
+ # to return that inferred type.
1310
+ return return_types [- 1 ], inferred_types [- 1 ]
1311
+ else :
1312
+ # We give up and return 'Any'.
1313
+ return self .check_call (callee = AnyType (TypeOfAny .special_form ),
1314
+ args = args ,
1315
+ arg_kinds = arg_kinds ,
1316
+ arg_names = arg_names ,
1317
+ context = context ,
1318
+ arg_messages = arg_messages ,
1319
+ callable_name = callable_name ,
1320
+ object_type = object_type )
1285
1321
else :
1286
1322
# Success! No ambiguity; return the first match.
1287
1323
return return_types [0 ], inferred_types [0 ]
@@ -3177,16 +3213,20 @@ def any_causes_overload_ambiguity(items: List[CallableType],
3177
3213
matching_formals_unfiltered = [(item_idx , lookup [arg_idx ])
3178
3214
for item_idx , lookup in enumerate (actual_to_formal )
3179
3215
if lookup [arg_idx ]]
3216
+
3217
+ matching_returns = []
3180
3218
matching_formals = []
3181
3219
for item_idx , formals in matching_formals_unfiltered :
3182
- if len (formals ) > 1 :
3183
- # An actual maps to multiple formals -- give up as too
3184
- # complex, just assume it overlaps.
3185
- return True
3186
- matching_formals .append ((item_idx , items [item_idx ].arg_types [formals [0 ]]))
3187
- if (not all_same_types (t for _ , t in matching_formals ) and
3188
- not all_same_types (items [idx ].ret_type
3189
- for idx , _ in matching_formals )):
3220
+ matched_callable = items [item_idx ]
3221
+ matching_returns .append (matched_callable .ret_type )
3222
+
3223
+ # Note: if an actual maps to multiple formals of differing types within
3224
+ # a single callable, then we know at least one of those formals must be
3225
+ # a different type then the formal(s) in some other callable.
3226
+ # So it's safe to just append everything to the same list.
3227
+ for formal in formals :
3228
+ matching_formals .append (matched_callable .arg_types [formal ])
3229
+ if not all_same_types (matching_formals ) and not all_same_types (matching_returns ):
3190
3230
# Any maps to multiple different types, and the return types of these items differ.
3191
3231
return True
3192
3232
return False
0 commit comments