@@ -17,6 +17,8 @@ import transform.SyntheticMembers._
17
17
import util .Property
18
18
import annotation .{tailrec , constructorOnly }
19
19
20
+ import scala .collection .mutable
21
+
20
22
/** Synthesize terms for special classes */
21
23
class Synthesizer (typer : Typer )(using @ constructorOnly c : Context ):
22
24
import ast .tpd ._
@@ -339,7 +341,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
339
341
if acceptable(mirroredType) && cls.isGenericSum(if useCompanion then cls.linkedClass else ctx.owner) then
340
342
val elemLabels = cls.children.map(c => ConstantType (Constant (c.name.toString)))
341
343
342
- def solve (sym : Symbol ): Type = sym match
344
+ def solve (target : Type )( sym : Symbol ): Type = sym match
343
345
case childClass : ClassSymbol =>
344
346
assert(childClass.isOneOf(Case | Sealed ))
345
347
if childClass.is(Module ) then
@@ -350,36 +352,50 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
350
352
// Compute the the full child type by solving the subtype constraint
351
353
// `C[X1, ..., Xn] <: P`, where
352
354
//
353
- // - P is the current `mirroredType `
355
+ // - P is the current `targetPart `
354
356
// - C is the child class, with type parameters X1, ..., Xn
355
357
//
356
358
// Contravariant type parameters are minimized, all other type parameters are maximized.
357
- def instantiate (using Context ) =
358
- val poly = constrained(info, untpd. EmptyTree )._1
359
+ def instantiate (targetPart : Type )( using Context ) =
360
+ val poly = constrained(info)
359
361
val resType = poly.finalResultType
360
- val target = mirroredType match
361
- case tp : HKTypeLambda => tp.resultType
362
- case tp => tp
363
- resType <:< target
362
+ resType <:< targetPart // record constraints
364
363
val tparams = poly.paramRefs
365
364
val variances = childClass.typeParams.map(_.paramVarianceSign)
366
365
val instanceTypes = tparams.lazyZip(variances).map((tparam, variance) =>
367
366
TypeComparer .instanceType(tparam, fromBelow = variance < 0 ))
368
367
resType.substParams(poly, instanceTypes)
369
- instantiate(using ctx.fresh.setExploreTyperState().setOwner(childClass))
368
+
369
+ def instantiateAll (using Context ): Type =
370
+
371
+ // instantiate for each part of a union type, compute lub of the results
372
+ def loop (explore : List [Type ], acc : mutable.ListBuffer [Type ]): Type = explore match
373
+ case OrType (tp1, tp2) :: rest => loop(tp1 :: tp2 :: rest, acc )
374
+ case tp :: rest => loop(rest , acc += instantiate(tp))
375
+ case _ => TypeComparer .lub(acc.toList)
376
+
377
+ def instantiateLub (tp1 : Type , tp2 : Type ): Type =
378
+ loop(tp1 :: tp2 :: Nil , new mutable.ListBuffer [Type ])
379
+
380
+ target match
381
+ case OrType (tp1, tp2) => instantiateLub(tp1, tp2)
382
+ case _ => instantiate(target)
383
+
384
+ instantiateAll(using ctx.fresh.setExploreTyperState().setOwner(childClass))
370
385
case _ =>
371
386
childClass.typeRef
372
387
case child => child.termRef
373
388
end solve
374
389
375
390
val (monoType, elemsType) = mirroredType match
376
391
case mirroredType : HKTypeLambda =>
392
+ val target = mirroredType.resultType
377
393
val elems = mirroredType.derivedLambdaType(
378
- resType = TypeOps .nestedPairs(cls.children.map(solve))
394
+ resType = TypeOps .nestedPairs(cls.children.map(solve(target) ))
379
395
)
380
396
(mkMirroredMonoType(mirroredType), elems)
381
- case _ =>
382
- val elems = TypeOps .nestedPairs(cls.children.map(solve))
397
+ case target =>
398
+ val elems = TypeOps .nestedPairs(cls.children.map(solve(target) ))
383
399
(mirroredType, elems)
384
400
385
401
val mirrorType =
0 commit comments