Skip to content

Commit 227ab22

Browse files
committed
compute mirror child types of a union
1 parent d484926 commit 227ab22

File tree

2 files changed

+71
-12
lines changed

2 files changed

+71
-12
lines changed

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ import transform.SyntheticMembers._
1717
import util.Property
1818
import annotation.{tailrec, constructorOnly}
1919

20+
import scala.collection.mutable
21+
2022
/** Synthesize terms for special classes */
2123
class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
2224
import ast.tpd._
@@ -339,7 +341,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
339341
if acceptable(mirroredType) && cls.isGenericSum(if useCompanion then cls.linkedClass else ctx.owner) then
340342
val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString)))
341343

342-
def solve(sym: Symbol): Type = sym match
344+
def solve(target: Type)(sym: Symbol): Type = sym match
343345
case childClass: ClassSymbol =>
344346
assert(childClass.isOneOf(Case | Sealed))
345347
if childClass.is(Module) then
@@ -350,36 +352,50 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
350352
// Compute the the full child type by solving the subtype constraint
351353
// `C[X1, ..., Xn] <: P`, where
352354
//
353-
// - P is the current `mirroredType`
355+
// - P is the current `targetPart`
354356
// - C is the child class, with type parameters X1, ..., Xn
355357
//
356358
// Contravariant type parameters are minimized, all other type parameters are maximized.
357-
def instantiate(using Context) =
358-
val poly = constrained(info, untpd.EmptyTree)._1
359+
def instantiate(targetPart: Type)(using Context) =
360+
val poly = constrained(info)
359361
val resType = poly.finalResultType
360-
val target = mirroredType match
361-
case tp: HKTypeLambda => tp.resultType
362-
case tp => tp
363-
resType <:< target
362+
resType <:< targetPart // record constraints
364363
val tparams = poly.paramRefs
365364
val variances = childClass.typeParams.map(_.paramVarianceSign)
366365
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
367366
TypeComparer.instanceType(tparam, fromBelow = variance < 0))
368367
resType.substParams(poly, instanceTypes)
369-
instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass))
368+
369+
def instantiateAll(using Context): Type =
370+
371+
// instantiate for each part of a union type, compute lub of the results
372+
def loop(explore: List[Type], acc: mutable.ListBuffer[Type]): Type = explore match
373+
case OrType(tp1, tp2) :: rest => loop(tp1 :: tp2 :: rest, acc )
374+
case tp :: rest => loop(rest , acc += instantiate(tp))
375+
case _ => TypeComparer.lub(acc.toList)
376+
377+
def instantiateLub(tp1: Type, tp2: Type): Type =
378+
loop(tp1 :: tp2 :: Nil, new mutable.ListBuffer[Type])
379+
380+
target match
381+
case OrType(tp1, tp2) => instantiateLub(tp1, tp2)
382+
case _ => instantiate(target)
383+
384+
instantiateAll(using ctx.fresh.setExploreTyperState().setOwner(childClass))
370385
case _ =>
371386
childClass.typeRef
372387
case child => child.termRef
373388
end solve
374389

375390
val (monoType, elemsType) = mirroredType match
376391
case mirroredType: HKTypeLambda =>
392+
val target = mirroredType.resultType
377393
val elems = mirroredType.derivedLambdaType(
378-
resType = TypeOps.nestedPairs(cls.children.map(solve))
394+
resType = TypeOps.nestedPairs(cls.children.map(solve(target)))
379395
)
380396
(mkMirroredMonoType(mirroredType), elems)
381-
case _ =>
382-
val elems = TypeOps.nestedPairs(cls.children.map(solve))
397+
case target =>
398+
val elems = TypeOps.nestedPairs(cls.children.map(solve(target)))
383399
(mirroredType, elems)
384400

385401
val mirrorType =

tests/pos/i13493.scala

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import deriving.Mirror
2+
3+
sealed trait Box[T]
4+
object Box
5+
6+
case class Child[T](t: T) extends Box[T]
7+
8+
object MirrorK1:
9+
type Of[F[_]] = Mirror { type MirroredType[A] = F[A] }
10+
11+
def testSums =
12+
13+
val foo = summon[Mirror.Of[Option[Int] | Option[String]]]
14+
summon[foo.MirroredElemTypes =:= (None.type, Some[Int] | Some[String])]
15+
16+
val bar = summon[Mirror.Of[Box[Int] | Box[String]]]
17+
summon[bar.MirroredElemTypes =:= ((Child[Int] | Child[String]) *: EmptyTuple)]
18+
19+
val qux = summon[Mirror.Of[Option[Int | String]]]
20+
summon[qux.MirroredElemTypes =:= (None.type, Some[Int | String])]
21+
22+
val bip = summon[Mirror.Of[Box[Int | String]]]
23+
summon[bip.MirroredElemTypes =:= (Child[Int | String] *: EmptyTuple)]
24+
25+
val bap = summon[MirrorK1.Of[[X] =>> Box[X] | Box[Int] | Box[String]]]
26+
summon[bap.MirroredElemTypes[Boolean] =:= ((Child[Boolean] | Child[Int] | Child[String]) *: EmptyTuple)]
27+
28+
29+
def testProducts =
30+
val foo = summon[Mirror.Of[Some[Int] | Some[String]]]
31+
summon[foo.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]
32+
33+
val bar = summon[Mirror.Of[Child[Int] | Child[String]]]
34+
summon[bar.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]
35+
36+
val qux = summon[Mirror.Of[Some[Int | String]]]
37+
summon[foo.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]
38+
39+
val bip = summon[Mirror.Of[Child[Int | String]]]
40+
summon[bip.MirroredElemTypes =:= ((Int | String) *: EmptyTuple)]
41+
42+
val bap = summon[MirrorK1.Of[[X] =>> Child[X] | Child[Int] | Child[String]]]
43+
summon[bap.MirroredElemTypes[Boolean] =:= ((Boolean | Int | String) *: EmptyTuple)]

0 commit comments

Comments
 (0)