Skip to content

Commit 65a0d8c

Browse files
committed
compute mirror child types of a union
1 parent 16ed844 commit 65a0d8c

File tree

2 files changed

+71
-12
lines changed

2 files changed

+71
-12
lines changed

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import transform.SyntheticMembers._
1717
import util.Property
1818
import annotation.{tailrec, constructorOnly}
1919

20+
import scala.collection.mutable
21+
2022
/** Synthesize terms for special classes */
2123
class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
2224
import ast.tpd._
@@ -337,7 +339,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
337339
if acceptable(mirroredType) && cls.isGenericSum(if useCompanion then cls.linkedClass else ctx.owner) then
338340
val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString)))
339341

340-
def solve(sym: Symbol): Type = sym match
342+
def solve(target: Type)(sym: Symbol): Type = sym match
341343
case childClass: ClassSymbol =>
342344
assert(childClass.isOneOf(Case | Sealed))
343345
if childClass.is(Module) then
@@ -348,36 +350,50 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
348350
// Compute the the full child type by solving the subtype constraint
349351
// `C[X1, ..., Xn] <: P`, where
350352
//
351-
// - P is the current `mirroredType`
353+
// - P is the current `targetPart`
352354
// - C is the child class, with type parameters X1, ..., Xn
353355
//
354356
// Contravariant type parameters are minimized, all other type parameters are maximized.
355-
def instantiate(using Context) =
356-
val poly = constrained(info, untpd.EmptyTree)._1
357+
def instantiate(targetPart: Type)(using Context) =
358+
val poly = constrained(info)
357359
val resType = poly.finalResultType
358-
val target = mirroredType match
359-
case tp: HKTypeLambda => tp.resultType
360-
case tp => tp
361-
resType <:< target
360+
resType <:< targetPart // record constraints
362361
val tparams = poly.paramRefs
363362
val variances = childClass.typeParams.map(_.paramVarianceSign)
364363
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
365364
TypeComparer.instanceType(tparam, fromBelow = variance < 0))
366365
resType.substParams(poly, instanceTypes)
367-
instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass))
366+
367+
def instantiateAll(using Context): Type =
368+
369+
// instantiate for each part of a union type, compute lub of the results
370+
def loop(explore: List[Type], acc: mutable.ListBuffer[Type]): Type = explore match
371+
case OrType(tp1, tp2) :: rest => loop(tp1 :: tp2 :: rest, acc )
372+
case tp :: rest => loop(rest , acc += instantiate(tp))
373+
case _ => TypeComparer.lub(acc.toList)
374+
375+
def instantiateLub(tp1: Type, tp2: Type): Type =
376+
loop(tp1 :: tp2 :: Nil, new mutable.ListBuffer[Type])
377+
378+
target match
379+
case OrType(tp1, tp2) => instantiateLub(tp1, tp2)
380+
case _ => instantiate(target)
381+
382+
instantiateAll(using ctx.fresh.setExploreTyperState().setOwner(childClass))
368383
case _ =>
369384
childClass.typeRef
370385
case child => child.termRef
371386
end solve
372387

373388
val (monoType, elemsType) = mirroredType match
374389
case mirroredType: HKTypeLambda =>
390+
val target = mirroredType.resultType
375391
val elems = mirroredType.derivedLambdaType(
376-
resType = TypeOps.nestedPairs(cls.children.map(solve))
392+
resType = TypeOps.nestedPairs(cls.children.map(solve(target)))
377393
)
378394
(mkMirroredMonoType(mirroredType), elems)
379-
case _ =>
380-
val elems = TypeOps.nestedPairs(cls.children.map(solve))
395+
case target =>
396+
val elems = TypeOps.nestedPairs(cls.children.map(solve(target)))
381397
(mirroredType, elems)
382398

383399
val mirrorType =

tests/pos/i13493.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import deriving.Mirror
2+
3+
sealed trait Box[T]
4+
object Box
5+
6+
case class Child[T](t: T) extends Box[T]
7+
8+
object MirrorK1:
9+
type Of[F[_]] = Mirror { type MirroredType[A] = F[A] }
10+
11+
def testSums =
12+
13+
val foo = summon[Mirror.Of[Option[Int] | Option[String]]]
14+
summon[foo.MirroredElemTypes =:= (None.type, Some[Int] | Some[String])]
15+
16+
val bar = summon[Mirror.Of[Box[Int] | Box[String]]]
17+
summon[bar.MirroredElemTypes =:= ((Child[Int] | Child[String]) *: EmptyTuple)]
18+
19+
val qux = summon[Mirror.Of[Option[Int | String]]]
20+
summon[qux.MirroredElemTypes =:= (None.type, Some[Int | String])]
21+
22+
val bip = summon[Mirror.Of[Box[Int | String]]]
23+
summon[bip.MirroredElemTypes =:= (Child[Int | String] *: EmptyTuple)]
24+
25+
val bap = summon[MirrorK1.Of[[X] =>> Box[X] | Box[Int] | Box[String]]]
26+
summon[bap.MirroredElemTypes[Boolean] =:= ((Child[Boolean] | Child[Int] | Child[String]) *: EmptyTuple)]
27+
28+
29+
def testProducts =
30+
val foo = summon[Mirror.Of[Some[Int] | Some[String]]]
31+
summon[foo.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]
32+
33+
val bar = summon[Mirror.Of[Child[Int] | Child[String]]]
34+
summon[bar.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]
35+
36+
val qux = summon[Mirror.Of[Some[Int | String]]]
37+
summon[foo.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]
38+
39+
val bip = summon[Mirror.Of[Child[Int | String]]]
40+
summon[bip.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]
41+
42+
val bap = summon[MirrorK1.Of[[X] =>> Child[X] | Child[Int] | Child[String]]]
43+
summon[bap.MirroredElemTypes[Boolean] =:= ((Boolean | Int | String) *: EmptyTuple)]

0 commit comments

Comments
 (0)