@@ -17,6 +17,8 @@ import transform.SyntheticMembers._
17
17
import util .Property
18
18
import annotation .{tailrec , constructorOnly }
19
19
20
+ import scala .collection .mutable
21
+
20
22
/** Synthesize terms for special classes */
21
23
class Synthesizer (typer : Typer )(using @ constructorOnly c : Context ):
22
24
import ast .tpd ._
@@ -337,7 +339,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
337
339
if acceptable(mirroredType) && cls.isGenericSum(if useCompanion then cls.linkedClass else ctx.owner) then
338
340
val elemLabels = cls.children.map(c => ConstantType (Constant (c.name.toString)))
339
341
340
- def solve (sym : Symbol ): Type = sym match
342
+ def solve (target : Type )( sym : Symbol ): Type = sym match
341
343
case childClass : ClassSymbol =>
342
344
assert(childClass.isOneOf(Case | Sealed ))
343
345
if childClass.is(Module ) then
@@ -348,36 +350,50 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
348
350
// Compute the the full child type by solving the subtype constraint
349
351
// `C[X1, ..., Xn] <: P`, where
350
352
//
351
- // - P is the current `mirroredType `
353
+ // - P is the current `targetPart `
352
354
// - C is the child class, with type parameters X1, ..., Xn
353
355
//
354
356
// Contravariant type parameters are minimized, all other type parameters are maximized.
355
- def instantiate (using Context ) =
356
- val poly = constrained(info, untpd. EmptyTree )._1
357
+ def instantiate (targetPart : Type )( using Context ) =
358
+ val poly = constrained(info)
357
359
val resType = poly.finalResultType
358
- val target = mirroredType match
359
- case tp : HKTypeLambda => tp.resultType
360
- case tp => tp
361
- resType <:< target
360
+ resType <:< targetPart // record constraints
362
361
val tparams = poly.paramRefs
363
362
val variances = childClass.typeParams.map(_.paramVarianceSign)
364
363
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
365
364
TypeComparer .instanceType(tparam, fromBelow = variance < 0 ))
366
365
resType.substParams(poly, instanceTypes)
367
- instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass))
366
+
367
+ def instantiateAll (using Context ): Type =
368
+
369
+ // instantiate for each part of a union type, compute lub of the results
370
+ def loop (explore : List [Type ], acc : mutable.ListBuffer [Type ]): Type = explore match
371
+ case OrType (tp1, tp2) :: rest => loop(tp1 :: tp2 :: rest, acc )
372
+ case tp :: rest => loop(rest , acc += instantiate(tp))
373
+ case _ => TypeComparer .lub(acc.toList)
374
+
375
+ def instantiateLub (tp1 : Type , tp2 : Type ): Type =
376
+ loop(tp1 :: tp2 :: Nil , new mutable.ListBuffer [Type ])
377
+
378
+ target match
379
+ case OrType (tp1, tp2) => instantiateLub(tp1, tp2)
380
+ case _ => instantiate(target)
381
+
382
+ instantiateAll(using ctx.fresh.setExploreTyperState().setOwner(childClass))
368
383
case _ =>
369
384
childClass.typeRef
370
385
case child => child.termRef
371
386
end solve
372
387
373
388
val (monoType, elemsType) = mirroredType match
374
389
case mirroredType : HKTypeLambda =>
390
+ val target = mirroredType.resultType
375
391
val elems = mirroredType.derivedLambdaType(
376
- resType = TypeOps .nestedPairs(cls.children.map(solve))
392
+ resType = TypeOps .nestedPairs(cls.children.map(solve(target) ))
377
393
)
378
394
(mkMirroredMonoType(mirroredType), elems)
379
- case _ =>
380
- val elems = TypeOps .nestedPairs(cls.children.map(solve))
395
+ case target =>
396
+ val elems = TypeOps .nestedPairs(cls.children.map(solve(target) ))
381
397
(mirroredType, elems)
382
398
383
399
val mirrorType =
0 commit comments