Skip to content

Commit aa8032d

Browse files
committed
Instantiate more type variables to hard unions
Fixes #14770
1 parent d670018 commit aa8032d

File tree

9 files changed

+130
-39
lines changed

9 files changed

+130
-39
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import config.Printers.typr
1212
import typer.ProtoTypes.{newTypeVar, representedParamRef}
1313
import UnificationDirection.*
1414
import NameKinds.AvoidNameKind
15+
import NullOpsDecorator.stripNull
1516

1617
/** Methods for adding constraints and solving them.
1718
*
@@ -518,8 +519,11 @@ trait ConstraintHandling {
518519
* 1. If `inst` is a singleton type, or a union containing some singleton types,
519520
* widen (all) the singleton type(s), provided the result is a subtype of `bound`.
520521
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
521-
* 2. If `inst` is a union type, approximate the union type from above by an intersection
522-
* of all common base types, provided the result is a subtype of `bound`.
522+
* 2a. If `inst` is a union type and `widenUnions` is true, approximate the union type
523+
* from above by an intersection of all common base types, provided the result
524+
* is a subtype of `bound`.
525+
* 2b. If `inst` is a union type and `widenUnions` is false, turn it into a hard
526+
* union type (except for unions | Null, which are kept in the state they were).
523527
* 3. Widen some irreducible applications of higher-kinded types to wildcard arguments
524528
* (see @widenIrreducible).
525529
* 4. Drop transparent traits from intersections (see @dropTransparentTraits).
@@ -532,10 +536,12 @@ trait ConstraintHandling {
532536
* At this point we also drop the @Repeated annotation to avoid inferring type arguments with it,
533537
* as those could leak the annotation to users (see run/inferred-repeated-result).
534538
*/
535-
def widenInferred(inst: Type, bound: Type)(using Context): Type =
539+
def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type =
536540
def widenOr(tp: Type) =
537-
val tpw = tp.widenUnion
538-
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
541+
if widenUnions then
542+
val tpw = tp.widenUnion
543+
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
544+
else tp.hardenUnions
539545

540546
def widenSingle(tp: Type) =
541547
val tpw = tp.widenSingletons
@@ -555,24 +561,41 @@ trait ConstraintHandling {
555561
wideInst.dropRepeatedAnnot
556562
end widenInferred
557563

564+
/** Convert all toplevel union types in `tp` to hard unions */
565+
extension (tp: Type) private def hardenUnions(using Context): Type = tp.widen match
566+
case tp: AndType =>
567+
tp.derivedAndType(tp.tp1.hardenUnions, tp.tp2.hardenUnions)
568+
case tp: RefinedType =>
569+
tp.derivedRefinedType(tp.parent.hardenUnions, tp.refinedName, tp.refinedInfo)
570+
case tp: RecType =>
571+
tp.rebind(tp.parent.hardenUnions)
572+
case tp: HKTypeLambda =>
573+
tp.derivedLambdaType(resType = tp.resType.hardenUnions)
574+
case tp: OrType =>
575+
val tp1 = tp.stripNull
576+
if tp1 ne tp then tp.derivedOrType(tp1.hardenUnions, defn.NullType)
577+
else tp.derivedOrType(tp.tp1.hardenUnions, tp.tp2.hardenUnions, soft = false)
578+
case _ =>
579+
tp
580+
558581
/** The instance type of `param` in the current constraint (which contains `param`).
559582
* If `fromBelow` is true, the instance type is the lub of the parameter's
560583
* lower bounds; otherwise it is the glb of its upper bounds. However,
561584
* a lower bound instantiation can be a singleton type only if the upper bound
562585
* is also a singleton type.
563586
*/
564-
def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type = {
587+
def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean)(using Context): Type = {
565588
val approx = approximation(param, fromBelow).simplified
566589
if fromBelow then
567-
val widened = widenInferred(approx, param)
590+
val widened = widenInferred(approx, param, widenUnions)
568591
// Widening can add extra constraints, in particular the widened type might
569592
// be a type variable which is now instantiated to `param`, and therefore
570593
// cannot be used as an instantiation of `param` without creating a loop.
571594
// If that happens, we run `instanceType` again to find a new instantation.
572595
// (we do not check for non-toplevel occurences: those should never occur
573596
// since `addOneBound` disallows recursive lower bounds).
574597
if constraint.occursAtToplevel(param, widened) then
575-
instanceType(param, fromBelow)
598+
instanceType(param, fromBelow, widenUnions)
576599
else
577600
widened
578601
else

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 45 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -484,31 +484,54 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
484484

485485
// If LHS is a hard union, constrain any type variables of the RHS with it as lower bound
486486
// before splitting the LHS into its constituents. That way, the RHS variables are
487-
// constraint by the hard union and can be instantiated to it. If we just split and add
487+
// constrained by the hard union and can be instantiated to it. If we just split and add
488488
// the two parts of the LHS separately to the constraint, the lower bound would become
489489
// a soft union.
490490
def constrainRHSVars(tp2: Type): Boolean = tp2.dealiasKeepRefiningAnnots match
491491
case tp2: TypeParamRef if constraint contains tp2 => compareTypeParamRef(tp2)
492492
case AndType(tp21, tp22) => constrainRHSVars(tp21) && constrainRHSVars(tp22)
493493
case _ => true
494494

495-
widenOK
496-
|| joinOK
497-
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
498-
|| containsAnd(tp1)
499-
&& !joined
500-
&& {
501-
joined = true
502-
try inFrozenGadt(recur(tp1.join, tp2))
503-
finally joined = false
504-
}
505-
// An & on the left side loses information. We compensate by also trying the join.
506-
// This is less ad-hoc than it looks since we produce joins in type inference,
507-
// and then need to check that they are indeed supertypes of the original types
508-
// under -Ycheck. Test case is i7965.scala.
509-
// On the other hand, we could get a combinatorial explosion by applying such joins
510-
// recursively, so we do it only once. See i14870.scala as a test case, which would
511-
// loop for a very long time without the recursion brake.
495+
/** Mark toplevel type vars in `tp2` as hard in the current typerState */
496+
def hardenTypeVars(tp2: Type): Unit = tp2.dealiasKeepRefiningAnnots match
497+
case tvar: TypeVar if constraint.contains(tvar.origin) =>
498+
state.hardVars += tvar
499+
case tp2: TypeParamRef if constraint.contains(tp2) =>
500+
hardenTypeVars(constraint.typeVarOfParam(tp2))
501+
case tp2: AndOrType =>
502+
hardenTypeVars(tp2.tp1)
503+
hardenTypeVars(tp2.tp2)
504+
case _ =>
505+
506+
val res = widenOK
507+
|| joinOK
508+
|| (tp1.isSoft || constrainRHSVars(tp2)) && recur(tp11, tp2) && recur(tp12, tp2)
509+
|| containsAnd(tp1)
510+
&& !joined
511+
&& {
512+
joined = true
513+
try inFrozenGadt(recur(tp1.join, tp2))
514+
finally joined = false
515+
}
516+
// An & on the left side loses information. We compensate by also trying the join.
517+
// This is less ad-hoc than it looks since we produce joins in type inference,
518+
// and then need to check that they are indeed supertypes of the original types
519+
// under -Ycheck. Test case is i7965.scala.
520+
// On the other hand, we could get a combinatorial explosion by applying such joins
521+
// recursively, so we do it only once. See i14870.scala as a test case, which would
522+
// loop for a very long time without the recursion brake.
523+
524+
if res && !tp1.isSoft then
525+
// We use a heuristic here where every toplevel type variable on the right hand side
526+
// is marked so that it converts all soft unions in its lower bound to hard unions
527+
// before it is instantiated. The reason is that the union might have come from
528+
// (decomposed and reconstituted) `tp1`. But of course there might be false positives
529+
// where we also treat unions that come from elsewhere as hard unions. Or the constraint
530+
// that created the union is ultimately thrown away, but the type variable will
531+
// stay marked. So it is a coarse measure to take. But it works in the obvious cases.
532+
hardenTypeVars(tp2)
533+
534+
res
512535

513536
case tp1: MatchType =>
514537
val reduced = tp1.reduced
@@ -2863,8 +2886,8 @@ object TypeComparer {
28632886
def subtypeCheckInProgress(using Context): Boolean =
28642887
comparing(_.subtypeCheckInProgress)
28652888

2866-
def instanceType(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
2867-
comparing(_.instanceType(param, fromBelow))
2889+
def instanceType(param: TypeParamRef, fromBelow: Boolean, widenUnions: Boolean)(using Context): Type =
2890+
comparing(_.instanceType(param, fromBelow, widenUnions))
28682891

28692892
def approximation(param: TypeParamRef, fromBelow: Boolean)(using Context): Type =
28702893
comparing(_.approximation(param, fromBelow))
@@ -2884,8 +2907,8 @@ object TypeComparer {
28842907
def addToConstraint(tl: TypeLambda, tvars: List[TypeVar])(using Context): Boolean =
28852908
comparing(_.addToConstraint(tl, tvars))
28862909

2887-
def widenInferred(inst: Type, bound: Type)(using Context): Type =
2888-
comparing(_.widenInferred(inst, bound))
2910+
def widenInferred(inst: Type, bound: Type, widenUnions: Boolean)(using Context): Type =
2911+
comparing(_.widenInferred(inst, bound, widenUnions))
28892912

28902913
def dropTransparentTraits(tp: Type, bound: Type)(using Context): Type =
28912914
comparing(_.dropTransparentTraits(tp, bound))

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,9 @@ object TypeOps:
516516
override def apply(tp: Type): Type = tp match
517517
case tp: TypeVar if mapCtx.typerState.constraint.contains(tp) =>
518518
val lo = TypeComparer.instanceType(
519-
tp.origin, fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound)(using mapCtx)
519+
tp.origin,
520+
fromBelow = variance > 0 || variance == 0 && tp.hasLowerBound,
521+
widenUnions = tp.widenUnions)(using mapCtx)
520522
val lo1 = apply(lo)
521523
if (lo1 ne lo) lo1 else tp
522524
case _ =>

compiler/src/dotty/tools/dotc/core/TyperState.scala

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,22 @@ object TyperState {
2323
.setReporter(new ConsoleReporter())
2424
.setCommittable(true)
2525

26-
opaque type Snapshot = (Constraint, TypeVars, TypeVars)
26+
opaque type Snapshot = (Constraint, TypeVars, TypeVars, TypeVars)
2727

2828
extension (ts: TyperState)
2929
def snapshot()(using Context): Snapshot =
3030
var previouslyInstantiated: TypeVars = SimpleIdentitySet.empty
3131
for tv <- ts.ownedVars do if tv.inst.exists then previouslyInstantiated += tv
32-
(ts.constraint, ts.ownedVars, previouslyInstantiated)
32+
(ts.constraint, ts.ownedVars, previouslyInstantiated, ts.hardVars)
3333

3434
def resetTo(state: Snapshot)(using Context): Unit =
35-
val (c, tvs, previouslyInstantiated) = state
35+
val (c, tvs, previouslyInstantiated, hvars) = state
3636
for tv <- tvs do
3737
if tv.inst.exists && !previouslyInstantiated.contains(tv) then
3838
tv.resetInst(ts)
3939
ts.ownedVars = tvs
4040
ts.constraint = c
41+
ts.hardVars = hvars
4142
}
4243

4344
class TyperState() {
@@ -89,6 +90,14 @@ class TyperState() {
8990
def ownedVars: TypeVars = myOwnedVars
9091
def ownedVars_=(vs: TypeVars): Unit = myOwnedVars = vs
9192

93+
/** The set of type variables `tv` such that, if `tv` is instantiated to
94+
* its lower bound, top-level soft unions in the instance type are converted
95+
* to hard unions instead of being widened in `widenOr`.
96+
*/
97+
private var myHardVars: TypeVars = _
98+
def hardVars: TypeVars = myHardVars
99+
def hardVars_=(tvs: TypeVars): Unit = myHardVars = tvs
100+
92101
/** Initializes all fields except reporter, isCommittable, which need to be
93102
* set separately.
94103
*/
@@ -99,16 +108,19 @@ class TyperState() {
99108
this.myConstraint = constraint
100109
this.previousConstraint = constraint
101110
this.myOwnedVars = SimpleIdentitySet.empty
111+
this.myHardVars = SimpleIdentitySet.empty
102112
this.isCommitted = false
103113
this
104114

105115
/** A fresh typer state with the same constraint as this one. */
106116
def fresh(reporter: Reporter = StoreReporter(this.reporter, fromTyperState = true),
107117
committable: Boolean = this.isCommittable): TyperState =
108118
util.Stats.record("TyperState.fresh")
109-
TyperState().init(this, this.constraint)
119+
val ts = TyperState().init(this, this.constraint)
110120
.setReporter(reporter)
111121
.setCommittable(committable)
122+
ts.hardVars = this.hardVars
123+
ts
112124

113125
/** The uninstantiated variables */
114126
def uninstVars: collection.Seq[TypeVar] = constraint.uninstVars
@@ -161,6 +173,7 @@ class TyperState() {
161173
constr.println(i"committing $this to $targetState, fromConstr = $constraint, toConstr = ${targetState.constraint}")
162174
if targetState.constraint eq previousConstraint then
163175
targetState.constraint = constraint
176+
targetState.hardVars = hardVars
164177
if !ownedVars.isEmpty then ownedVars.foreach(targetState.includeVar)
165178
else
166179
targetState.mergeConstraintWith(this)
@@ -213,6 +226,7 @@ class TyperState() {
213226
val otherLos = other.lower(p)
214227
val otherHis = other.upper(p)
215228
val otherEntry = other.entry(p)
229+
if that.hardVars.contains(tv) then this.myHardVars += tv
216230
( (otherLos eq constraint.lower(p)) || otherLos.forall(_ <:< p)) &&
217231
( (otherHis eq constraint.upper(p)) || otherHis.forall(p <:< _)) &&
218232
((otherEntry eq constraint.entry(p)) || otherEntry.match

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4574,12 +4574,16 @@ object Types {
45744574
* is also a singleton type.
45754575
*/
45764576
def instantiate(fromBelow: Boolean)(using Context): Type =
4577-
val tp = TypeComparer.instanceType(origin, fromBelow)
4577+
val tp = TypeComparer.instanceType(origin, fromBelow, widenUnions)
45784578
if myInst.exists then // The line above might have triggered instantiation of the current type variable
45794579
myInst
45804580
else
45814581
instantiateWith(tp)
45824582

4583+
/** Widen unions when instantiating this variable in the current context? */
4584+
def widenUnions(using Context): Boolean =
4585+
!ctx.typerState.hardVars.contains(this)
4586+
45834587
/** For uninstantiated type variables: the entry in the constraint (either bounds or
45844588
* provisional instance value)
45854589
*/

compiler/src/dotty/tools/dotc/typer/Namer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1884,7 +1884,7 @@ class Namer { typer: Typer =>
18841884
TypeOps.simplify(tp.widenTermRefExpr,
18851885
if defaultTp.exists then TypeOps.SimplifyKeepUnchecked() else null) match
18861886
case ctp: ConstantType if sym.isInlineVal => ctp
1887-
case tp => TypeComparer.widenInferred(tp, pt)
1887+
case tp => TypeComparer.widenInferred(tp, pt, widenUnions = true)
18881888

18891889
// Replace aliases to Unit by Unit itself. If we leave the alias in
18901890
// it would be erased to BoxedUnit.

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -514,7 +514,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
514514
val tparams = poly.paramRefs
515515
val variances = childClass.typeParams.map(_.paramVarianceSign)
516516
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
517-
TypeComparer.instanceType(tparam, fromBelow = variance < 0)
517+
TypeComparer.instanceType(tparam, fromBelow = variance < 0, widenUnions = true)
518518
)
519519
val instanceType = resType.substParams(poly, instanceTypes)
520520
// this is broken in tests/run/i13332intersection.scala,

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2842,7 +2842,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
28422842
if (ctx.mode.is(Mode.Pattern)) app1
28432843
else {
28442844
val elemTpes = elems.lazyZip(pts).map((elem, pt) =>
2845-
TypeComparer.widenInferred(elem.tpe, pt))
2845+
TypeComparer.widenInferred(elem.tpe, pt, widenUnions = true))
28462846
val resTpe = TypeOps.nestedPairs(elemTpes)
28472847
app1.cast(resTpe)
28482848
}

tests/pos/i14770.scala

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
type UndefOr[A] = A | Unit
2+
3+
extension [A](maybe: UndefOr[A])
4+
def foreach(f: A => Unit): Unit =
5+
maybe match
6+
case () => ()
7+
case a: A => f(a)
8+
9+
trait Foo
10+
trait Bar
11+
12+
object Baz:
13+
var booBap: Foo | Bar = _
14+
15+
def z: UndefOr[Foo | Bar] = ???
16+
17+
@main
18+
def main =
19+
z.foreach(x => Baz.booBap = x)
20+
21+
def test[A](v: A | Unit): A | Unit = v
22+
val x1 = test(5: Int | Unit)
23+
val x2 = test(5: String | Int | Unit)
24+
val _: Int | Unit = x1
25+
val _: String | Int | Unit = x2

0 commit comments

Comments
 (0)