Skip to content

Fall back to satisfiable constraints in unions #13467

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 41 additions & 13 deletions mypy/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
TypeQuery,
TypeType,
TypeVarId,
TypeVarLikeType,
TypeVarTupleType,
TypeVarType,
TypeVisitor,
Expand Down Expand Up @@ -73,10 +74,11 @@ class Constraint:
op = 0 # SUBTYPE_OF or SUPERTYPE_OF
target: Type

def __init__(self, type_var: TypeVarId, op: int, target: Type) -> None:
self.type_var = type_var
def __init__(self, type_var: TypeVarLikeType, op: int, target: Type) -> None:
self.type_var = type_var.id
self.op = op
self.target = target
self.origin_type_var = type_var

def __repr__(self) -> str:
op_str = "<:"
Expand Down Expand Up @@ -190,7 +192,7 @@ def _infer_constraints(template: Type, actual: Type, direction: int) -> list[Con
# T :> U2", but they are not equivalent to the constraint solver,
# which never introduces new Union types (it uses join() instead).
if isinstance(template, TypeVarType):
return [Constraint(template.id, direction, actual)]
return [Constraint(template, direction, actual)]

# Now handle the case of either template or actual being a Union.
# For a Union to be a subtype of another type, every item of the Union
Expand Down Expand Up @@ -286,7 +288,7 @@ def merge_with_any(constraint: Constraint) -> Constraint:
# TODO: if we will support multiple sources Any, use this here instead.
any_type = AnyType(TypeOfAny.implementation_artifact)
return Constraint(
constraint.type_var,
constraint.origin_type_var,
constraint.op,
UnionType.make_union([target, any_type], target.line, target.column),
)
Expand Down Expand Up @@ -345,11 +347,37 @@ def any_constraints(options: list[list[Constraint] | None], eager: bool) -> list
merged_option = None
merged_options.append(merged_option)
return any_constraints(list(merged_options), eager)

# If normal logic didn't work, try excluding trivially unsatisfiable constraint (due to
# upper bounds) from each option, and comparing them again.
filtered_options = [filter_satisfiable(o) for o in options]
if filtered_options != options:
return any_constraints(filtered_options, eager=eager)

# Otherwise, there are either no valid options or multiple, inconsistent valid
# options. Give up and deduce nothing.
return []


def filter_satisfiable(option: list[Constraint] | None) -> list[Constraint] | None:
"""Keep only constraints that can possibly be satisfied.

Currently, we filter out constraints where target is not a subtype of the upper bound.
Since those can be never satisfied. We may add more cases in future if it improves type
inference.
"""
if not option:
return option
satisfiable = []
for c in option:
# TODO: add similar logic for TypeVar values (also in various other places)?
if mypy.subtypes.is_subtype(c.target, c.origin_type_var.upper_bound):
satisfiable.append(c)
if not satisfiable:
return None
return satisfiable


def is_same_constraints(x: list[Constraint], y: list[Constraint]) -> bool:
for c1 in x:
if not any(is_same_constraint(c1, c2) for c2 in y):
Expand Down Expand Up @@ -560,9 +588,9 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
suffix.arg_kinds[len(prefix.arg_kinds) :],
suffix.arg_names[len(prefix.arg_names) :],
)
res.append(Constraint(mapped_arg.id, SUPERTYPE_OF, suffix))
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))
elif isinstance(suffix, ParamSpecType):
res.append(Constraint(mapped_arg.id, SUPERTYPE_OF, suffix))
res.append(Constraint(mapped_arg, SUPERTYPE_OF, suffix))
elif isinstance(tvar, TypeVarTupleType):
raise NotImplementedError

Expand All @@ -583,7 +611,7 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
if isinstance(template_unpack, TypeVarTupleType):
res.append(
Constraint(
template_unpack.id, SUPERTYPE_OF, TypeList(list(mapped_middle))
template_unpack, SUPERTYPE_OF, TypeList(list(mapped_middle))
)
)
elif (
Expand Down Expand Up @@ -644,9 +672,9 @@ def visit_instance(self, template: Instance) -> list[Constraint]:
suffix.arg_kinds[len(prefix.arg_kinds) :],
suffix.arg_names[len(prefix.arg_names) :],
)
res.append(Constraint(template_arg.id, SUPERTYPE_OF, suffix))
res.append(Constraint(template_arg, SUPERTYPE_OF, suffix))
elif isinstance(suffix, ParamSpecType):
res.append(Constraint(template_arg.id, SUPERTYPE_OF, suffix))
res.append(Constraint(template_arg, SUPERTYPE_OF, suffix))
return res
if (
template.type.is_protocol
Expand Down Expand Up @@ -763,7 +791,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
prefix_len = min(prefix_len, max_prefix_len)
res.append(
Constraint(
param_spec.id,
param_spec,
SUBTYPE_OF,
cactual.copy_modified(
arg_types=cactual.arg_types[prefix_len:],
Expand All @@ -774,7 +802,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
)
)
else:
res.append(Constraint(param_spec.id, SUBTYPE_OF, cactual_ps))
res.append(Constraint(param_spec, SUBTYPE_OF, cactual_ps))

# compare prefixes
cactual_prefix = cactual.copy_modified(
Expand Down Expand Up @@ -805,7 +833,7 @@ def visit_callable_type(self, template: CallableType) -> list[Constraint]:
else:
res = [
Constraint(
param_spec.id,
param_spec,
SUBTYPE_OF,
callable_with_ellipsis(any_type, any_type, template.fallback),
)
Expand Down Expand Up @@ -877,7 +905,7 @@ def visit_tuple_type(self, template: TupleType) -> list[Constraint]:
modified_actual = actual.copy_modified(items=list(actual_items))
return [
Constraint(
type_var=unpacked_type.id, op=self.direction, target=modified_actual
type_var=unpacked_type, op=self.direction, target=modified_actual
)
]

Expand Down
12 changes: 6 additions & 6 deletions mypy/test/testconstraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,21 +19,21 @@ def test_basic_type_variable(self) -> None:
fx = self.fx
for direction in [SUBTYPE_OF, SUPERTYPE_OF]:
assert infer_constraints(fx.gt, fx.ga, direction) == [
Constraint(type_var=fx.t.id, op=direction, target=fx.a)
Constraint(type_var=fx.t, op=direction, target=fx.a)
]

@pytest.mark.xfail
def test_basic_type_var_tuple_subtype(self) -> None:
fx = self.fx
assert infer_constraints(
Instance(fx.gvi, [UnpackType(fx.ts)]), Instance(fx.gvi, [fx.a, fx.b]), SUBTYPE_OF
) == [Constraint(type_var=fx.ts.id, op=SUBTYPE_OF, target=TypeList([fx.a, fx.b]))]
) == [Constraint(type_var=fx.ts, op=SUBTYPE_OF, target=TypeList([fx.a, fx.b]))]

def test_basic_type_var_tuple(self) -> None:
fx = self.fx
assert infer_constraints(
Instance(fx.gvi, [UnpackType(fx.ts)]), Instance(fx.gvi, [fx.a, fx.b]), SUPERTYPE_OF
) == [Constraint(type_var=fx.ts.id, op=SUPERTYPE_OF, target=TypeList([fx.a, fx.b]))]
) == [Constraint(type_var=fx.ts, op=SUPERTYPE_OF, target=TypeList([fx.a, fx.b]))]

def test_type_var_tuple_with_prefix_and_suffix(self) -> None:
fx = self.fx
Expand All @@ -44,7 +44,7 @@ def test_type_var_tuple_with_prefix_and_suffix(self) -> None:
SUPERTYPE_OF,
)
) == {
Constraint(type_var=fx.t.id, op=SUPERTYPE_OF, target=fx.a),
Constraint(type_var=fx.ts.id, op=SUPERTYPE_OF, target=TypeList([fx.b, fx.c])),
Constraint(type_var=fx.s.id, op=SUPERTYPE_OF, target=fx.d),
Constraint(type_var=fx.t, op=SUPERTYPE_OF, target=fx.a),
Constraint(type_var=fx.ts, op=SUPERTYPE_OF, target=TypeList([fx.b, fx.c])),
Constraint(type_var=fx.s, op=SUPERTYPE_OF, target=fx.d),
}
4 changes: 2 additions & 2 deletions mypy/test/testsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def assert_solve(
assert_equal(str(actual), str(res))

def supc(self, type_var: TypeVarType, bound: Type) -> Constraint:
return Constraint(type_var.id, SUPERTYPE_OF, bound)
return Constraint(type_var, SUPERTYPE_OF, bound)

def subc(self, type_var: TypeVarType, bound: Type) -> Constraint:
return Constraint(type_var.id, SUBTYPE_OF, bound)
return Constraint(type_var, SUBTYPE_OF, bound)
10 changes: 10 additions & 0 deletions test-data/unit/check-functions.test
Original file line number Diff line number Diff line change
Expand Up @@ -2674,3 +2674,13 @@ class A:
def h(self, *args, **kwargs) -> int: pass # OK
[builtins fixtures/property.pyi]
[out]

[case testSubtypingUnionGenericBounds]
from typing import Callable, TypeVar, Union, Sequence

TI = TypeVar("TI", bound=int)
TS = TypeVar("TS", bound=str)

f: Callable[[Sequence[TI]], None]
g: Callable[[Union[Sequence[TI], Sequence[TS]]], None]
f = g
17 changes: 17 additions & 0 deletions test-data/unit/check-overloading.test
Original file line number Diff line number Diff line change
Expand Up @@ -6493,3 +6493,20 @@ def foo(x: List[T]) -> str: ...
@overload
def foo(x: Sequence[int]) -> int: ...
[builtins fixtures/list.pyi]

[case testOverloadUnionGenericBounds]
from typing import overload, TypeVar, Sequence, Union

class Entity: ...
class Assoc: ...

E = TypeVar("E", bound=Entity)
A = TypeVar("A", bound=Assoc)

class Test:
@overload
def foo(self, arg: Sequence[E]) -> None: ...
@overload
def foo(self, arg: Sequence[A]) -> None: ...
def foo(self, arg: Union[Sequence[E], Sequence[A]]) -> None:
...