Skip to content

Commit 1d56a51

Browse files
committed
Fix InferExpectedTypeSuite.map/flatMap
1 parent a619128 commit 1d56a51

File tree

9 files changed

+97
-43
lines changed

9 files changed

+97
-43
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ package cc
55
import core.*
66
import Phases.*, DenotTransformers.*, SymDenotations.*
77
import Contexts.*, Names.*, Flags.*, Symbols.*, Decorators.*
8-
import Types.*, StdNames.*, Denotations.*
8+
import Types.*, StdNames.*, Denotations.*, NamerOps.linkConstructorParams
99
import config.Printers.{capt, recheckr, noPrinter}
1010
import config.{Config, Feature}
1111
import ast.{tpd, untpd, Trees}
@@ -1552,7 +1552,8 @@ class CheckCaptures extends Recheck, SymTransformer:
15521552
val checker = new TreeTraverser:
15531553
def traverse(tree: Tree)(using Context): Unit =
15541554
val lctx = tree match
1555-
case _: DefTree | _: TypeDef if tree.symbol.exists => ctx.withOwner(tree.symbol)
1555+
case _: DefDef => linkConstructorParams(tree.symbol)(using ctx.withOwner(tree.symbol))
1556+
case _: DefTree if tree.symbol.exists => ctx.withOwner(tree.symbol)
15561557
case _ => ctx
15571558
trace(i"post check $tree"):
15581559
traverseChildren(tree)(using lctx)

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,9 @@ object Contexts {
282282
/** AbstractFile with given path, memoized */
283283
def getFile(name: String): AbstractFile = getFile(name.toTermName)
284284

285-
private var related: SimpleIdentityMap[Phase | SourceFile, Context] | Null = null
285+
private var related: SimpleIdentityMap[Phase | SourceFile | GadtState, Context] | Null = null
286286

287-
private def lookup(key: Phase | SourceFile): Context | Null =
287+
private def lookup(key: Phase | SourceFile | GadtState): Context | Null =
288288
util.Stats.record("Context.related.lookup")
289289
if related == null then
290290
related = SimpleIdentityMap.empty
@@ -326,6 +326,16 @@ object Contexts {
326326
related = related.nn.updated(source, ctx2)
327327
ctx1
328328

329+
final def withGadtState(gadtState: GadtState): Context =
330+
if this.gadtState eq gadtState then
331+
this
332+
else
333+
var ctx1 = lookup(gadtState)
334+
if ctx1 == null then
335+
ctx1 = fresh.setGadtState(gadtState)
336+
related = related.nn.updated(gadtState, ctx1)
337+
ctx1
338+
329339
// `creationTrace`-related code. To enable, uncomment the code below and the
330340
// call to `setCreationTrace()` in this file.
331341
/*

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -904,8 +904,9 @@ class TreeUnpickler(reader: TastyReader,
904904

905905
def DefDef(paramss: List[ParamClause], tpt: Tree) =
906906
sym.setParamssFromDefs(paramss)
907+
val rhsCtx = linkConstructorParams(sym)(using localCtx)
907908
ta.assignType(
908-
untpd.DefDef(sym.name.asTermName, paramss, tpt, readRhs(using localCtx)),
909+
untpd.DefDef(sym.name.asTermName, paramss, tpt, readRhs(using rhsCtx)),
909910
sym)
910911

911912
def TypeDef(rhs: Tree) =
@@ -1127,7 +1128,7 @@ class TreeUnpickler(reader: TastyReader,
11271128
val mappedParents: LazyTreeList =
11281129
if parents.exists(_.isInstanceOf[InferredTypeTree]) then
11291130
// parents were not read fully, will need to be read again later on demand
1130-
new LazyReader(parentReader, localDummy, ctx.mode, ctx.source,
1131+
new LazyReader(parentReader, localDummy, ctx.mode, ctx.source, ctx.gadtState,
11311132
_.readParents(withArgs = true)
11321133
.map(_.changeOwner(localDummy, constr.symbol)))
11331134
else parents
@@ -1748,7 +1749,8 @@ class TreeUnpickler(reader: TastyReader,
17481749
goto(end)
17491750
val mode = ctx.mode
17501751
val source = ctx.source
1751-
owner => new LazyReader(localReader, owner, mode, source, op)
1752+
val gadtState = ctx.gadtState
1753+
owner => new LazyReader(localReader, owner, mode, source, gadtState, op)
17521754
}
17531755

17541756
// ------ Setting positions ------------------------------------------------
@@ -1810,15 +1812,16 @@ class TreeUnpickler(reader: TastyReader,
18101812
}
18111813

18121814
class LazyReader[T <: AnyRef](
1813-
reader: TreeReader, owner: Symbol, mode: Mode, source: SourceFile,
1815+
reader: TreeReader, owner: Symbol, mode: Mode, source: SourceFile, gadtState: GadtState,
18141816
op: TreeReader => Context ?=> T) extends Trees.Lazy[T] {
18151817
def complete(using Context): T = {
18161818
pickling.println(i"starting to read at ${reader.reader.currentAddr} with owner $owner")
18171819
atPhaseBeforeTransforms {
18181820
op(reader)(using ctx
18191821
.withOwner(owner)
18201822
.withModeBits(mode)
1821-
.withSource(source))
1823+
.withSource(source)
1824+
.withGadtState(gadtState))
18221825
}
18231826
}
18241827
}

compiler/src/dotty/tools/dotc/transform/PostTyper.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import typer.ErrorReporting.errorTree
1212
import Types.*, Contexts.*, Names.*, Flags.*, DenotTransformers.*, Phases.*
1313
import SymDenotations.*, StdNames.*, Annotations.*, Trees.*, Scopes.*
1414
import Decorators.*
15-
import Symbols.*, NameOps.*
15+
import Symbols.*, NameOps.*, NamerOps.linkConstructorParams
1616
import ContextFunctionResults.annotateContextResults
1717
import config.Printers.typr
1818
import config.Feature
@@ -441,7 +441,10 @@ class PostTyper extends MacroTransform with InfoTransformer { thisPhase =>
441441
Checking.checkPolyFunctionType(tree.tpt)
442442
annotateContextResults(tree)
443443
val tree1 = cpy.DefDef(tree)(tpt = makeOverrideTypeDeclared(tree.symbol, tree.tpt), rhs = normalizeErasedRhs(tree.rhs, tree.symbol))
444-
processValOrDefDef(superAcc.wrapDefDef(tree1)(super.transform(tree1).asInstanceOf[DefDef]))
444+
processValOrDefDef(superAcc.wrapDefDef(tree1):
445+
inContext(linkConstructorParams(tree1.symbol)):
446+
super.transform(tree1).asInstanceOf[DefDef]
447+
)
445448
case tree: TypeDef =>
446449
registerIfHasMacroAnnotations(tree)
447450
val sym = tree.symbol

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 51 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -240,25 +240,12 @@ object Inferencing {
240240
&& {
241241
var fail = false
242242
var skip = false
243-
val direction = instDirection(tvar.origin)
244-
if minimizeSelected then
245-
if direction <= 0 && tvar.hasLowerBound then
246-
skip = instantiate(tvar, fromBelow = true)
247-
else if direction >= 0 && tvar.hasUpperBound then
248-
skip = instantiate(tvar, fromBelow = false)
249-
// else hold off instantiating unbounded unconstrained variable
250-
else if direction != 0 then
251-
skip = instantiate(tvar, fromBelow = direction < 0)
252-
else if variance >= 0 && tvar.hasLowerBound then
253-
skip = instantiate(tvar, fromBelow = true)
254-
else if (variance > 0 || variance == 0 && !tvar.hasUpperBound)
255-
&& force.ifBottom == IfBottom.ok
256-
then // if variance == 0, prefer upper bound if one is given
257-
skip = instantiate(tvar, fromBelow = true)
258-
else if variance >= 0 && force.ifBottom == IfBottom.fail then
259-
fail = true
260-
else
261-
toMaximize = tvar :: toMaximize
243+
instDecision(tvar, variance, minimizeSelected, force.ifBottom) match
244+
case Decision.Min => skip = instantiate(tvar, fromBelow = true)
245+
case Decision.Max => skip = instantiate(tvar, fromBelow = false)
246+
case Decision.Skip => // hold off instantiating unbounded unconstrained variable
247+
case Decision.Fail => fail = true
248+
case Decision.ToMax => toMaximize ::= tvar
262249
!fail && (skip || foldOver(x, tvar))
263250
}
264251
case tp => foldOver(x, tp)
@@ -455,6 +442,41 @@ object Inferencing {
455442
approxAbove - approxBelow
456443
}
457444

445+
/** The instantiation decision for given poly param computed from the constraint. */
446+
enum Decision { case Min; case Max; case ToMax; case Skip; case Fail }
447+
private def instDecision(tvar: TypeVar, v: Int, minimizeSelected: Boolean, ifBottom: IfBottom)(using Context): Decision =
448+
import Decision.*
449+
val direction = instDirection(tvar.origin)
450+
val dec = if minimizeSelected then
451+
if direction <= 0 && tvar.hasLowerBound then Min
452+
else if direction >= 0 && tvar.hasUpperBound then Max
453+
else Skip
454+
else if direction != 0 then if direction < 0 then Min else Max
455+
else if tvar.hasLowerBound then if v >= 0 then Min else ToMax
456+
else ifBottom match
457+
// What's left are unconstrained tvars with at most a non-Any param upperbound:
458+
// * IfBottom.flip will always maximise to the param upperbound, for all variances
459+
// * IfBottom.fail will fail the IFD check, for covariant or invariant tvars, maximise contravariant tvars
460+
// * IfBottom.ok will minimise to Nothing covariant and unbounded invariant tvars, and max to Any the others
461+
case IfBottom.ok => if v > 0 || v == 0 && !tvar.hasUpperBound then Min else ToMax // prefer upper bound if one is given
462+
case IfBottom.fail => if v >= 0 then Fail else ToMax
463+
case ifBottom_flip => ToMax
464+
//println(i"instDecision($tvar, v=v, minimizedSelected=$minimizeSelected, $ifBottom) original=[$original] constrained=[$constrained] dir=$direction = $dec")
465+
dec
466+
467+
private def interpDecision(tvar: TypeVar, v: Int)(using Context): Decision =
468+
import Decision.*
469+
val dec = instDecision(tvar, v, minimizeSelected = false, IfBottom.fail) match
470+
case Min => Min
471+
case Fail => if v > 0 then Min else Max
472+
// like IfBottom.ok,
473+
// but only minimise unconstrained covariant tvars,
474+
// which means that unconstrained unbounded tvars
475+
// will be maximised to Any rather than minimised to Nothing
476+
case _ => Max
477+
//println(i"interpDecision($var, v=$v) = $dec")
478+
dec
479+
458480
/** Following type aliases and stripping refinements and annotations, if one arrives at a
459481
* class type reference where the class has a companion module, a reference to
460482
* that companion module. Otherwise NoType
@@ -651,12 +673,13 @@ trait Inferencing { this: Typer =>
651673

652674
val ownedVars = state.ownedVars
653675
if (ownedVars ne locked) && !ownedVars.isEmpty then
654-
val qualifying = ownedVars -- locked
676+
val qualifying = (ownedVars -- locked).toList
655677
if (!qualifying.isEmpty) {
656678
typr.println(i"interpolate $tree: ${tree.tpe.widen} in $state, pt = $pt, owned vars = ${state.ownedVars.toList}%, %, qualifying = ${qualifying.toList}%, %, previous = ${locked.toList}%, % / ${state.constraint}")
657679
val resultAlreadyConstrained =
658680
tree.isInstanceOf[Apply] || tree.tpe.isInstanceOf[MethodOrPoly]
659681
if (!resultAlreadyConstrained)
682+
trace(i"constrainResult($tree ${tree.symbol}, ${tree.tpe}, $pt)"):
660683
constrainResult(tree.symbol, tree.tpe, pt)
661684
// This is needed because it could establish singleton type upper bounds. See i2998.scala.
662685

@@ -687,6 +710,10 @@ trait Inferencing { this: Typer =>
687710

688711
def constraint = state.constraint
689712

713+
trace(i"interpolateTypeVars($tree: ${tree.tpe}, $pt, $qualifying)", typr, (_: Any) => i"$qualifying\n$constraint\n${ctx.gadt}") {
714+
//println(i"$constraint")
715+
//println(i"${ctx.gadt}")
716+
690717
/** Values of this type report type variables to instantiate with variance indication:
691718
* +1 variable appears covariantly, can be instantiated from lower bound
692719
* -1 variable appears contravariantly, can be instantiated from upper bound
@@ -782,12 +809,10 @@ trait Inferencing { this: Typer =>
782809
/** Try to instantiate `tvs`, return any suspended type variables */
783810
def tryInstantiate(tvs: ToInstantiate): ToInstantiate = tvs match
784811
case (hd @ (tvar, v)) :: tvs1 =>
785-
val fromBelow = v == 1 || (v == 0 && tvar.hasLowerBound)
786-
typr.println(
787-
i"interpolate${if v == 0 then " non-occurring" else ""} $tvar in $state in $tree: $tp, fromBelow = $fromBelow, $constraint")
788812
if tvar.isInstantiated then
789813
tryInstantiate(tvs1)
790814
else
815+
val fromBelow = interpDecision(tvar, v) == Decision.Min
791816
val suspend = tvs1.exists{ (following, _) =>
792817
if fromBelow
793818
then constraint.isLess(following.origin, tvar.origin)
@@ -797,13 +822,16 @@ trait Inferencing { this: Typer =>
797822
typr.println(i"suspended: $hd")
798823
hd :: tryInstantiate(tvs1)
799824
else
825+
typr.println(
826+
i"interpolate${if v == 0 then " non-occurring" else ""} $tvar in $state in $tree: $tp, fromBelow = $fromBelow, $constraint")
800827
tvar.instantiate(fromBelow)
801828
tryInstantiate(tvs1)
802829
case Nil => Nil
803830
if tvs.nonEmpty then doInstantiate(tryInstantiate(tvs))
804831
end doInstantiate
805832

806833
doInstantiate(filterByDeps(toInstantiate))
834+
}
807835
}
808836
end if
809837
tree

presentation-compiler/test/dotty/tools/pc/tests/InferExpectedTypeSuite.scala

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,6 @@ class InferExpectedTypeSuite extends BasePCSuite:
221221
|""".stripMargin
222222
)
223223

224-
@Ignore("Generic functions are not handled correctly.")
225224
@Test def flatmap =
226225
check(
227226
"""|val _ : List[Int] = List().flatMap(_ => @@)
@@ -230,7 +229,6 @@ class InferExpectedTypeSuite extends BasePCSuite:
230229
|""".stripMargin
231230
)
232231

233-
@Ignore("Generic functions are not handled correctly.")
234232
@Test def map =
235233
check(
236234
"""|val _ : List[Int] = List().map(_ => @@)
@@ -239,7 +237,6 @@ class InferExpectedTypeSuite extends BasePCSuite:
239237
|""".stripMargin
240238
)
241239

242-
@Ignore("Generic functions are not handled correctly.")
243240
@Test def `for-comprehension` =
244241
check(
245242
"""|val _ : List[Int] =

tests/neg-deep-subtype/i5877.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
object Main {
1+
object Main { // error // error
22
def main(a: Array[String]): Unit = {
33
println("you may not run `testHasThisType` - just check that it compiles")
44
// comment lines after "// this line of code makes" comments to make it compilable again
@@ -18,25 +18,25 @@ object Main {
1818

1919
// ---- ---- ---- ----
2020

21-
def testHasThisType(): Unit = {
21+
def testHasThisType(): Unit = { // error // error
2222
def testSelf[PThis <: HasThisType[_ <: PThis]](that: HasThisType[PThis]): Unit = {
2323
val thatSelf = that.self()
2424
// that.self().type <: that.This
2525
assert(implicitly[thatSelf.type <:< that.This] != null)
2626
}
2727
val that: HasThisType[_] = Foo() // null.asInstanceOf
28-
testSelf(that) // error: recursion limit exceeded
28+
testSelf(that) // error: recursion limit exceeded // error
2929
}
3030

3131

32-
def testHasThisType2(): Unit = {
32+
def testHasThisType2(): Unit = { // error // error
3333
def testSelf[PThis <: HasThisType[_ <: PThis]](that: PThis with HasThisType[PThis]): Unit = {
3434
// that.type <: that.This
3535
assert(implicitly[that.type <:< that.This] != null)
3636
}
3737
val that: HasThisType[_] = Foo() // null.asInstanceOf
3838
// this line of code makes Dotty compiler infinite recursion (stopped only by overflow) - comment it to make it compilable again
39-
testSelf(that) // error: recursion limit exceeded
39+
testSelf(that) // error: recursion limit exceeded // error
4040
}
4141

4242
// ---- ---- ---- ----

tests/neg/recursive-lower-constraint.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ class Bar extends Foo[Bar]
33

44
class A {
55
def foo[T <: Foo[T], U >: Foo[T] <: T](x: T): T = x
6-
foo(new Bar) // error // error
6+
foo(new Bar) // error
77
}

tests/pos/i21390.TrieMap.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Minimised from scala.collection.concurrent.LNode
2+
// Useful as a minimisation of how,
3+
// If we were to change the type interpolation
4+
// to minimise to the inferred "X" type,
5+
// then this is a minimisation of how the (ab)use of
6+
// GADT constraints to handle class type params
7+
// can fail PostTyper, -Ytest-pickler, and probably others.
8+
9+
import scala.language.experimental.captureChecking
10+
11+
class Foo[X](xs: List[X]):
12+
def this(a: X, b: X) = this(if (a == b) then a :: Nil else a :: b :: Nil)

0 commit comments

Comments
 (0)