Skip to content

Commit 0f76dfc

Browse files
Michael0x2ailevkivskyi
authored andcommitted
Refine how overload selection handles *args, **kwargs, and Any (#5166)
This pull request implements the changes discussed in #5124. Specifically... 1. When two overload alternatives match due to Any, we return the last matching return type if it's a supertype of all of the previous ones. If it's not a supertype, we give up and return 'Any' as before. 2. If a user calls an overload with a starred expression, we try matching alternatives with a starred arg or kwarg first, even if those alternatives do not appear first in the list. If none of the starred alternatives are a valid match, we fall back to checking the other remaining alternatives in order.
1 parent 402a066 commit 0f76dfc

File tree

2 files changed

+359
-31
lines changed

2 files changed

+359
-31
lines changed

mypy/checkexpr.py

+60-20
Original file line numberDiff line numberDiff line change
@@ -1195,18 +1195,49 @@ def plausible_overload_call_targets(self,
11951195
arg_kinds: List[int],
11961196
arg_names: Optional[Sequence[Optional[str]]],
11971197
overload: Overloaded) -> List[CallableType]:
1198-
"""Returns all overload call targets that having matching argument counts."""
1198+
"""Returns all overload call targets that having matching argument counts.
1199+
1200+
If the given args contains a star-arg (*arg or **kwarg argument), this method
1201+
will ensure all star-arg overloads appear at the start of the list, instead
1202+
of their usual location.
1203+
1204+
The only exception is if the starred argument is something like a Tuple or a
1205+
NamedTuple, which has a definitive "shape". If so, we don't move the corresponding
1206+
alternative to the front since we can infer a more precise match using the original
1207+
order."""
1208+
1209+
def has_shape(typ: Type) -> bool:
1210+
# TODO: Once https://github.com/python/mypy/issues/5198 is fixed,
1211+
# add 'isinstance(typ, TypedDictType)' somewhere below.
1212+
return (isinstance(typ, TupleType)
1213+
or (isinstance(typ, Instance) and typ.type.is_named_tuple))
1214+
11991215
matches = [] # type: List[CallableType]
1216+
star_matches = [] # type: List[CallableType]
1217+
1218+
args_have_var_arg = False
1219+
args_have_kw_arg = False
1220+
for kind, typ in zip(arg_kinds, arg_types):
1221+
if kind == ARG_STAR and not has_shape(typ):
1222+
args_have_var_arg = True
1223+
if kind == ARG_STAR2 and not has_shape(typ):
1224+
args_have_kw_arg = True
1225+
12001226
for typ in overload.items():
12011227
formal_to_actual = map_actuals_to_formals(arg_kinds, arg_names,
12021228
typ.arg_kinds, typ.arg_names,
12031229
lambda i: arg_types[i])
12041230

12051231
if self.check_argument_count(typ, arg_types, arg_kinds, arg_names,
12061232
formal_to_actual, None, None):
1207-
matches.append(typ)
1233+
if args_have_var_arg and typ.is_var_arg:
1234+
star_matches.append(typ)
1235+
elif args_have_kw_arg and typ.is_kw_arg:
1236+
star_matches.append(typ)
1237+
else:
1238+
matches.append(typ)
12081239

1209-
return matches
1240+
return star_matches + matches
12101241

12111242
def infer_overload_return_type(self,
12121243
plausible_targets: List[CallableType],
@@ -1273,15 +1304,20 @@ def infer_overload_return_type(self,
12731304
return None
12741305
elif any_causes_overload_ambiguity(matches, return_types, arg_types, arg_kinds, arg_names):
12751306
# An argument of type or containing the type 'Any' caused ambiguity.
1276-
# We infer a type of 'Any'
1277-
return self.check_call(callee=AnyType(TypeOfAny.special_form),
1278-
args=args,
1279-
arg_kinds=arg_kinds,
1280-
arg_names=arg_names,
1281-
context=context,
1282-
arg_messages=arg_messages,
1283-
callable_name=callable_name,
1284-
object_type=object_type)
1307+
if all(is_subtype(ret_type, return_types[-1]) for ret_type in return_types[:-1]):
1308+
# The last match is a supertype of all the previous ones, so it's safe
1309+
# to return that inferred type.
1310+
return return_types[-1], inferred_types[-1]
1311+
else:
1312+
# We give up and return 'Any'.
1313+
return self.check_call(callee=AnyType(TypeOfAny.special_form),
1314+
args=args,
1315+
arg_kinds=arg_kinds,
1316+
arg_names=arg_names,
1317+
context=context,
1318+
arg_messages=arg_messages,
1319+
callable_name=callable_name,
1320+
object_type=object_type)
12851321
else:
12861322
# Success! No ambiguity; return the first match.
12871323
return return_types[0], inferred_types[0]
@@ -3177,16 +3213,20 @@ def any_causes_overload_ambiguity(items: List[CallableType],
31773213
matching_formals_unfiltered = [(item_idx, lookup[arg_idx])
31783214
for item_idx, lookup in enumerate(actual_to_formal)
31793215
if lookup[arg_idx]]
3216+
3217+
matching_returns = []
31803218
matching_formals = []
31813219
for item_idx, formals in matching_formals_unfiltered:
3182-
if len(formals) > 1:
3183-
# An actual maps to multiple formals -- give up as too
3184-
# complex, just assume it overlaps.
3185-
return True
3186-
matching_formals.append((item_idx, items[item_idx].arg_types[formals[0]]))
3187-
if (not all_same_types(t for _, t in matching_formals) and
3188-
not all_same_types(items[idx].ret_type
3189-
for idx, _ in matching_formals)):
3220+
matched_callable = items[item_idx]
3221+
matching_returns.append(matched_callable.ret_type)
3222+
3223+
# Note: if an actual maps to multiple formals of differing types within
3224+
# a single callable, then we know at least one of those formals must be
3225+
# a different type then the formal(s) in some other callable.
3226+
# So it's safe to just append everything to the same list.
3227+
for formal in formals:
3228+
matching_formals.append(matched_callable.arg_types[formal])
3229+
if not all_same_types(matching_formals) and not all_same_types(matching_returns):
31903230
# Any maps to multiple different types, and the return types of these items differ.
31913231
return True
31923232
return False

0 commit comments

Comments
 (0)