Skip to content

Commit 9f6b373

Browse files
Consider variance when joining generic types (#9994)
Co-authored-by: Shantanu <[email protected]>
1 parent 55bd489 commit 9f6b373

10 files changed

+196
-90
lines changed

mypy/join.py

+100-58
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Calculation of the least upper bound types (joins)."""
22

33
from mypy.backports import OrderedDict
4-
from typing import List, Optional
4+
from typing import List, Optional, Tuple
55

66
from mypy.types import (
77
Type, AnyType, NoneType, TypeVisitor, Instance, UnboundType, TypeVarType, CallableType,
@@ -14,9 +14,96 @@
1414
is_subtype, is_equivalent, is_subtype_ignoring_tvars, is_proper_subtype,
1515
is_protocol_implementation, find_member
1616
)
17-
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT
17+
from mypy.nodes import ARG_NAMED, ARG_NAMED_OPT, INVARIANT, COVARIANT, CONTRAVARIANT
1818
import mypy.typeops
1919
from mypy import state
20+
from mypy import meet
21+
22+
23+
class InstanceJoiner:
24+
def __init__(self) -> None:
25+
self.seen_instances = [] # type: List[Tuple[Instance, Instance]]
26+
27+
def join_instances(self, t: Instance, s: Instance) -> ProperType:
28+
if (t, s) in self.seen_instances or (s, t) in self.seen_instances:
29+
return object_from_instance(t)
30+
31+
self.seen_instances.append((t, s))
32+
33+
"""Calculate the join of two instance types."""
34+
if t.type == s.type:
35+
# Simplest case: join two types with the same base type (but
36+
# potentially different arguments).
37+
38+
# Combine type arguments.
39+
args = [] # type: List[Type]
40+
# N.B: We use zip instead of indexing because the lengths might have
41+
# mismatches during daemon reprocessing.
42+
for ta, sa, type_var in zip(t.args, s.args, t.type.defn.type_vars):
43+
ta_proper = get_proper_type(ta)
44+
sa_proper = get_proper_type(sa)
45+
new_type = None # type: Optional[Type]
46+
if isinstance(ta_proper, AnyType):
47+
new_type = AnyType(TypeOfAny.from_another_any, ta_proper)
48+
elif isinstance(sa_proper, AnyType):
49+
new_type = AnyType(TypeOfAny.from_another_any, sa_proper)
50+
elif type_var.variance == COVARIANT:
51+
new_type = join_types(ta, sa, self)
52+
if len(type_var.values) != 0 and new_type not in type_var.values:
53+
self.seen_instances.pop()
54+
return object_from_instance(t)
55+
if not is_subtype(new_type, type_var.upper_bound):
56+
self.seen_instances.pop()
57+
return object_from_instance(t)
58+
elif type_var.variance == CONTRAVARIANT:
59+
new_type = meet.meet_types(ta, sa)
60+
if len(type_var.values) != 0 and new_type not in type_var.values:
61+
self.seen_instances.pop()
62+
return object_from_instance(t)
63+
# No need to check subtype, as ta and sa already have to be subtypes of
64+
# upper_bound
65+
elif type_var.variance == INVARIANT:
66+
new_type = join_types(ta, sa)
67+
if not is_equivalent(ta, sa):
68+
self.seen_instances.pop()
69+
return object_from_instance(t)
70+
assert new_type is not None
71+
args.append(new_type)
72+
result = Instance(t.type, args) # type: ProperType
73+
elif t.type.bases and is_subtype_ignoring_tvars(t, s):
74+
result = self.join_instances_via_supertype(t, s)
75+
else:
76+
# Now t is not a subtype of s, and t != s. Now s could be a subtype
77+
# of t; alternatively, we need to find a common supertype. This works
78+
# in of the both cases.
79+
result = self.join_instances_via_supertype(s, t)
80+
81+
self.seen_instances.pop()
82+
return result
83+
84+
def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:
85+
# Give preference to joins via duck typing relationship, so that
86+
# join(int, float) == float, for example.
87+
if t.type._promote and is_subtype(t.type._promote, s):
88+
return join_types(t.type._promote, s, self)
89+
elif s.type._promote and is_subtype(s.type._promote, t):
90+
return join_types(t, s.type._promote, self)
91+
# Compute the "best" supertype of t when joined with s.
92+
# The definition of "best" may evolve; for now it is the one with
93+
# the longest MRO. Ties are broken by using the earlier base.
94+
best = None # type: Optional[ProperType]
95+
for base in t.type.bases:
96+
mapped = map_instance_to_supertype(t, base.type)
97+
res = self.join_instances(mapped, s)
98+
if best is None or is_better(res, best):
99+
best = res
100+
assert best is not None
101+
promote = get_proper_type(t.type._promote)
102+
if isinstance(promote, Instance):
103+
res = self.join_instances(promote, s)
104+
if is_better(res, best):
105+
best = res
106+
return best
20107

21108

22109
def join_simple(declaration: Optional[Type], s: Type, t: Type) -> ProperType:
@@ -69,7 +156,7 @@ def trivial_join(s: Type, t: Type) -> ProperType:
69156
return object_or_any_from_type(get_proper_type(t))
70157

71158

72-
def join_types(s: Type, t: Type) -> ProperType:
159+
def join_types(s: Type, t: Type, instance_joiner: Optional[InstanceJoiner] = None) -> ProperType:
73160
"""Return the least upper bound of s and t.
74161
75162
For example, the join of 'int' and 'object' is 'object'.
@@ -110,7 +197,7 @@ def join_types(s: Type, t: Type) -> ProperType:
110197
return AnyType(TypeOfAny.from_error)
111198

112199
# Use a visitor to handle non-trivial cases.
113-
return t.accept(TypeJoinVisitor(s))
200+
return t.accept(TypeJoinVisitor(s, instance_joiner))
114201

115202

116203
class TypeJoinVisitor(TypeVisitor[ProperType]):
@@ -120,8 +207,9 @@ class TypeJoinVisitor(TypeVisitor[ProperType]):
120207
s: The other (left) type operand.
121208
"""
122209

123-
def __init__(self, s: ProperType) -> None:
210+
def __init__(self, s: ProperType, instance_joiner: Optional[InstanceJoiner] = None) -> None:
124211
self.s = s
212+
self.instance_joiner = instance_joiner
125213

126214
def visit_unbound_type(self, t: UnboundType) -> ProperType:
127215
return AnyType(TypeOfAny.special_form)
@@ -163,7 +251,9 @@ def visit_type_var(self, t: TypeVarType) -> ProperType:
163251

164252
def visit_instance(self, t: Instance) -> ProperType:
165253
if isinstance(self.s, Instance):
166-
nominal = join_instances(t, self.s)
254+
if self.instance_joiner is None:
255+
self.instance_joiner = InstanceJoiner()
256+
nominal = self.instance_joiner.join_instances(t, self.s)
167257
structural = None # type: Optional[Instance]
168258
if t.type.is_protocol and is_protocol_implementation(self.s, t):
169259
structural = t
@@ -282,8 +372,10 @@ def visit_tuple_type(self, t: TupleType) -> ProperType:
282372
# * Joining with any Sequence also returns a Sequence:
283373
# Tuple[int, bool] + List[bool] becomes Sequence[int]
284374
if isinstance(self.s, TupleType) and self.s.length() == t.length():
285-
fallback = join_instances(mypy.typeops.tuple_fallback(self.s),
286-
mypy.typeops.tuple_fallback(t))
375+
if self.instance_joiner is None:
376+
self.instance_joiner = InstanceJoiner()
377+
fallback = self.instance_joiner.join_instances(mypy.typeops.tuple_fallback(self.s),
378+
mypy.typeops.tuple_fallback(t))
287379
assert isinstance(fallback, Instance)
288380
if self.s.length() == t.length():
289381
items = [] # type: List[Type]
@@ -364,56 +456,6 @@ def default(self, typ: Type) -> ProperType:
364456
return AnyType(TypeOfAny.special_form)
365457

366458

367-
def join_instances(t: Instance, s: Instance) -> ProperType:
368-
"""Calculate the join of two instance types."""
369-
if t.type == s.type:
370-
# Simplest case: join two types with the same base type (but
371-
# potentially different arguments).
372-
if is_subtype(t, s) or is_subtype(s, t):
373-
# Compatible; combine type arguments.
374-
args = [] # type: List[Type]
375-
# N.B: We use zip instead of indexing because the lengths might have
376-
# mismatches during daemon reprocessing.
377-
for ta, sa in zip(t.args, s.args):
378-
args.append(join_types(ta, sa))
379-
return Instance(t.type, args)
380-
else:
381-
# Incompatible; return trivial result object.
382-
return object_from_instance(t)
383-
elif t.type.bases and is_subtype_ignoring_tvars(t, s):
384-
return join_instances_via_supertype(t, s)
385-
else:
386-
# Now t is not a subtype of s, and t != s. Now s could be a subtype
387-
# of t; alternatively, we need to find a common supertype. This works
388-
# in of the both cases.
389-
return join_instances_via_supertype(s, t)
390-
391-
392-
def join_instances_via_supertype(t: Instance, s: Instance) -> ProperType:
393-
# Give preference to joins via duck typing relationship, so that
394-
# join(int, float) == float, for example.
395-
if t.type._promote and is_subtype(t.type._promote, s):
396-
return join_types(t.type._promote, s)
397-
elif s.type._promote and is_subtype(s.type._promote, t):
398-
return join_types(t, s.type._promote)
399-
# Compute the "best" supertype of t when joined with s.
400-
# The definition of "best" may evolve; for now it is the one with
401-
# the longest MRO. Ties are broken by using the earlier base.
402-
best = None # type: Optional[ProperType]
403-
for base in t.type.bases:
404-
mapped = map_instance_to_supertype(t, base.type)
405-
res = join_instances(mapped, s)
406-
if best is None or is_better(res, best):
407-
best = res
408-
assert best is not None
409-
promote = get_proper_type(t.type._promote)
410-
if isinstance(promote, Instance):
411-
res = join_instances(promote, s)
412-
if is_better(res, best):
413-
best = res
414-
return best
415-
416-
417459
def is_better(t: Type, s: Type) -> bool:
418460
# Given two possible results from join_instances_via_supertype(),
419461
# indicate whether t is the better one.

mypy/meet.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
from mypy.backports import OrderedDict
22
from typing import List, Optional, Tuple, Callable
33

4-
from mypy.join import (
5-
is_similar_callables, combine_similar_callables, join_type_list, unpack_callback_protocol
6-
)
74
from mypy.types import (
85
Type, AnyType, TypeVisitor, UnboundType, NoneType, TypeVarType, Instance, CallableType,
96
TupleType, TypedDictType, ErasedType, UnionType, PartialType, DeletedType,
@@ -15,6 +12,7 @@
1512
from mypy.maptype import map_instance_to_supertype
1613
from mypy.typeops import tuple_fallback, make_simplified_union, is_recursive_pair
1714
from mypy import state
15+
from mypy import join
1816

1917
# TODO Describe this module.
2018

@@ -518,7 +516,7 @@ def visit_instance(self, t: Instance) -> ProperType:
518516
else:
519517
return NoneType()
520518
elif isinstance(self.s, FunctionLike) and t.type.is_protocol:
521-
call = unpack_callback_protocol(t)
519+
call = join.unpack_callback_protocol(t)
522520
if call:
523521
return meet_types(call, self.s)
524522
elif isinstance(self.s, FunctionLike) and self.s.is_type_obj() and t.type.is_metaclass():
@@ -536,9 +534,9 @@ def visit_instance(self, t: Instance) -> ProperType:
536534
return self.default(self.s)
537535

538536
def visit_callable_type(self, t: CallableType) -> ProperType:
539-
if isinstance(self.s, CallableType) and is_similar_callables(t, self.s):
537+
if isinstance(self.s, CallableType) and join.is_similar_callables(t, self.s):
540538
if is_equivalent(t, self.s):
541-
return combine_similar_callables(t, self.s)
539+
return join.combine_similar_callables(t, self.s)
542540
result = meet_similar_callables(t, self.s)
543541
# We set the from_type_type flag to suppress error when a collection of
544542
# concrete class objects gets inferred as their common abstract superclass.
@@ -556,7 +554,7 @@ def visit_callable_type(self, t: CallableType) -> ProperType:
556554
return TypeType.make_normalized(res)
557555
return self.default(self.s)
558556
elif isinstance(self.s, Instance) and self.s.type.is_protocol:
559-
call = unpack_callback_protocol(self.s)
557+
call = join.unpack_callback_protocol(self.s)
560558
if call:
561559
return meet_types(t, call)
562560
return self.default(self.s)
@@ -575,7 +573,7 @@ def visit_overloaded(self, t: Overloaded) -> ProperType:
575573
else:
576574
return meet_types(t.fallback, s.fallback)
577575
elif isinstance(self.s, Instance) and self.s.type.is_protocol:
578-
call = unpack_callback_protocol(self.s)
576+
call = join.unpack_callback_protocol(self.s)
579577
if call:
580578
return meet_types(t, call)
581579
return meet_types(t.fallback, s)
@@ -611,7 +609,7 @@ def visit_typeddict_type(self, t: TypedDictType) -> ProperType:
611609
assert t_item_type is not None
612610
item_list.append((item_name, t_item_type))
613611
items = OrderedDict(item_list)
614-
mapping_value_type = join_type_list(list(items.values()))
612+
mapping_value_type = join.join_type_list(list(items.values()))
615613
fallback = self.s.create_anonymous_fallback(value_type=mapping_value_type)
616614
required_keys = t.required_keys | self.s.required_keys
617615
return TypedDictType(items, required_keys, fallback)

mypy/test/testtypes.py

+32-15
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,9 @@ def callable(self, vars: List[str], *a: Type) -> CallableType:
510510

511511
class JoinSuite(Suite):
512512
def setUp(self) -> None:
513-
self.fx = TypeFixture()
513+
self.fx = TypeFixture(INVARIANT)
514+
self.fx_co = TypeFixture(COVARIANT)
515+
self.fx_contra = TypeFixture(CONTRAVARIANT)
514516

515517
def test_trivial_cases(self) -> None:
516518
for simple in self.fx.a, self.fx.o, self.fx.b:
@@ -630,35 +632,50 @@ def test_other_mixed_types(self) -> None:
630632
self.assert_join(t1, t2, self.fx.o)
631633

632634
def test_simple_generics(self) -> None:
633-
self.assert_join(self.fx.ga, self.fx.ga, self.fx.ga)
634-
self.assert_join(self.fx.ga, self.fx.gb, self.fx.ga)
635-
self.assert_join(self.fx.ga, self.fx.gd, self.fx.o)
636-
self.assert_join(self.fx.ga, self.fx.g2a, self.fx.o)
637-
638635
self.assert_join(self.fx.ga, self.fx.nonet, self.fx.ga)
639636
self.assert_join(self.fx.ga, self.fx.anyt, self.fx.anyt)
640637

641638
for t in [self.fx.a, self.fx.o, self.fx.t, self.tuple(),
642639
self.callable(self.fx.a, self.fx.b)]:
643640
self.assert_join(t, self.fx.ga, self.fx.o)
644641

642+
def test_generics_invariant(self) -> None:
643+
self.assert_join(self.fx.ga, self.fx.ga, self.fx.ga)
644+
self.assert_join(self.fx.ga, self.fx.gb, self.fx.o)
645+
self.assert_join(self.fx.ga, self.fx.gd, self.fx.o)
646+
self.assert_join(self.fx.ga, self.fx.g2a, self.fx.o)
647+
648+
def test_generics_covariant(self) -> None:
649+
self.assert_join(self.fx_co.ga, self.fx_co.ga, self.fx_co.ga)
650+
self.assert_join(self.fx_co.ga, self.fx_co.gb, self.fx_co.ga)
651+
self.assert_join(self.fx_co.ga, self.fx_co.gd, self.fx_co.go)
652+
self.assert_join(self.fx_co.ga, self.fx_co.g2a, self.fx_co.o)
653+
654+
def test_generics_contravariant(self) -> None:
655+
self.assert_join(self.fx_contra.ga, self.fx_contra.ga, self.fx_contra.ga)
656+
self.assert_join(self.fx_contra.ga, self.fx_contra.gb, self.fx_contra.gb)
657+
self.assert_join(self.fx_contra.ga, self.fx_contra.gd, self.fx_contra.gn)
658+
self.assert_join(self.fx_contra.ga, self.fx_contra.g2a, self.fx_contra.o)
659+
645660
def test_generics_with_multiple_args(self) -> None:
646-
self.assert_join(self.fx.hab, self.fx.hab, self.fx.hab)
647-
self.assert_join(self.fx.hab, self.fx.hbb, self.fx.hab)
648-
self.assert_join(self.fx.had, self.fx.haa, self.fx.o)
661+
self.assert_join(self.fx_co.hab, self.fx_co.hab, self.fx_co.hab)
662+
self.assert_join(self.fx_co.hab, self.fx_co.hbb, self.fx_co.hab)
663+
self.assert_join(self.fx_co.had, self.fx_co.haa, self.fx_co.hao)
649664

650665
def test_generics_with_inheritance(self) -> None:
651-
self.assert_join(self.fx.gsab, self.fx.gb, self.fx.gb)
652-
self.assert_join(self.fx.gsba, self.fx.gb, self.fx.ga)
653-
self.assert_join(self.fx.gsab, self.fx.gd, self.fx.o)
666+
self.assert_join(self.fx_co.gsab, self.fx_co.gb, self.fx_co.gb)
667+
self.assert_join(self.fx_co.gsba, self.fx_co.gb, self.fx_co.ga)
668+
self.assert_join(self.fx_co.gsab, self.fx_co.gd, self.fx_co.go)
654669

655670
def test_generics_with_inheritance_and_shared_supertype(self) -> None:
656-
self.assert_join(self.fx.gsba, self.fx.gs2a, self.fx.ga)
657-
self.assert_join(self.fx.gsab, self.fx.gs2a, self.fx.ga)
658-
self.assert_join(self.fx.gsab, self.fx.gs2d, self.fx.o)
671+
self.assert_join(self.fx_co.gsba, self.fx_co.gs2a, self.fx_co.ga)
672+
self.assert_join(self.fx_co.gsab, self.fx_co.gs2a, self.fx_co.ga)
673+
self.assert_join(self.fx_co.gsab, self.fx_co.gs2d, self.fx_co.go)
659674

660675
def test_generic_types_and_any(self) -> None:
661676
self.assert_join(self.fx.gdyn, self.fx.ga, self.fx.gdyn)
677+
self.assert_join(self.fx_co.gdyn, self.fx_co.ga, self.fx_co.gdyn)
678+
self.assert_join(self.fx_contra.gdyn, self.fx_contra.ga, self.fx_contra.gdyn)
662679

663680
def test_callables_with_any(self) -> None:
664681
self.assert_join(self.callable(self.fx.a, self.fx.a, self.fx.anyt,

mypy/test/typefixture.py

+2
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type,
129129
self.gtf2 = Instance(self.gi, [self.tf2]) # G[T`-2]
130130
self.gs = Instance(self.gi, [self.s]) # G[S]
131131
self.gdyn = Instance(self.gi, [self.anyt]) # G[Any]
132+
self.gn = Instance(self.gi, [NoneType()]) # G[None]
132133

133134
self.g2a = Instance(self.g2i, [self.a]) # G2[A]
134135

@@ -145,6 +146,7 @@ def make_type_var(name: str, id: int, values: List[Type], upper_bound: Type,
145146
self.hbb = Instance(self.hi, [self.b, self.b]) # H[B, B]
146147
self.hts = Instance(self.hi, [self.t, self.s]) # H[T, S]
147148
self.had = Instance(self.hi, [self.a, self.d]) # H[A, D]
149+
self.hao = Instance(self.hi, [self.a, self.o]) # H[A, object]
148150

149151
self.lsta = Instance(self.std_listi, [self.a]) # List[A]
150152
self.lstb = Instance(self.std_listi, [self.b]) # List[B]

test-data/unit/check-classes.test

+1-1
Original file line numberDiff line numberDiff line change
@@ -5436,7 +5436,7 @@ main:8:6: error: Type argument "builtins.int" of "G" must be a subtype of "built
54365436
[case testExtremeForwardReferencing]
54375437
from typing import TypeVar, Generic
54385438

5439-
T = TypeVar('T')
5439+
T = TypeVar('T', covariant=True)
54405440
class B(Generic[T]): ...
54415441

54425442
y: A

0 commit comments

Comments
 (0)