Skip to content

Commit 18e768c

Browse files
authored
Merge pull request #12875 from dotty-staging/shallow-capture-sets
Shallow capture sets
2 parents fe31d81 + 59151bd commit 18e768c

29 files changed

+346
-114
lines changed

compiler/src/dotty/tools/dotc/Run.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import Types._
99
import Scopes._
1010
import Names.Name
1111
import Denotations.Denotation
12-
import typer.Typer
12+
import typer.{Typer, RefineTypes}
1313
import typer.ImportInfo._
1414
import Decorators._
1515
import io.{AbstractFile, PlainFile, VirtualFile}
@@ -204,7 +204,7 @@ class Run(comp: Compiler, ictx: Context) extends ImplicitRunInfo with Constraint
204204
val profileBefore = profiler.beforePhase(phase)
205205
units = phase.runOn(units)
206206
profiler.afterPhase(phase, profileBefore)
207-
if (ctx.settings.Xprint.value.containsPhase(phase))
207+
if ctx.settings.Xprint.value.containsPhase(phase) && !phase.isInstanceOf[RefineTypes] then
208208
for (unit <- units)
209209
lastPrintedTree =
210210
printTree(lastPrintedTree)(using ctx.fresh.setPhase(phase.next).setCompilationUnit(unit))

compiler/src/dotty/tools/dotc/config/ScalaSettings.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ trait AllScalaSettings extends CommonScalaSettings { self: Settings.SettingGroup
203203
val YrequireTargetName: Setting[Boolean] = BooleanSetting("-Yrequire-targetName", "Warn if an operator is defined without a @targetName annotation")
204204
val YrefineTypes: Setting[Boolean] = BooleanSetting("-Yrefine-types", "Run experimental type refiner (test only)")
205205
val Ycc: Setting[Boolean] = BooleanSetting("-Ycc", "Check captured references")
206+
val YccNoAbbrev: Setting[Boolean] = BooleanSetting("-Ycc-no-abbrev", "Used in conjunction with -Ycc, suppress type abbreviations")
206207

207208
/** Area-specific debug output */
208209
val YexplainLowlevel: Setting[Boolean] = BooleanSetting("-Yexplain-lowlevel", "When explaining type errors, show types at a lower level.")

compiler/src/dotty/tools/dotc/core/CaptureSet.scala

Lines changed: 40 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,17 @@ case class CaptureSet private (elems: CaptureSet.Refs) extends Showable:
1717
def isEmpty: Boolean = elems.isEmpty
1818
def nonEmpty: Boolean = !isEmpty
1919

20-
private var myClosure: Refs | Null = null
21-
22-
def closure(using Context): Refs =
23-
if myClosure == null then
24-
var cl = elems
25-
var seen: Refs = SimpleIdentitySet.empty
26-
while
27-
val prev = cl
28-
for ref <- cl do
29-
if !seen.contains(ref) then
30-
seen += ref
31-
cl = cl ++ ref.captureSetOfInfo.elems
32-
prev ne cl
33-
do ()
34-
myClosure = cl
35-
myClosure
36-
3720
def ++ (that: CaptureSet): CaptureSet =
38-
CaptureSet(elems ++ that.elems)
21+
if this.isEmpty then that
22+
else if that.isEmpty then this
23+
else CaptureSet(elems ++ that.elems)
24+
25+
def + (ref: CaptureRef) =
26+
if elems.contains(ref) then this
27+
else CaptureSet(elems + ref)
28+
29+
def intersect (that: CaptureSet): CaptureSet =
30+
CaptureSet(this.elems.intersect(that.elems))
3931

4032
/** {x} <:< this where <:< is subcapturing */
4133
def accountsFor(x: CaptureRef)(using Context) =
@@ -45,6 +37,15 @@ case class CaptureSet private (elems: CaptureSet.Refs) extends Showable:
4537
def <:< (that: CaptureSet)(using Context): Boolean =
4638
elems.isEmpty || elems.forall(that.accountsFor)
4739

40+
def flatMap(f: CaptureRef => CaptureSet)(using Context): CaptureSet =
41+
(empty /: elems)((cs, ref) => cs ++ f(ref))
42+
43+
def substParams(tl: BindingType, to: List[Type])(using Context) =
44+
flatMap {
45+
case ref: ParamRef if ref.binder eq tl => to(ref.paramNum).captureSet
46+
case ref => ref.singletonCaptureSet
47+
}
48+
4849
override def toString = elems.toString
4950

5051
override def toText(printer: Printer): Text =
@@ -82,46 +83,26 @@ object CaptureSet:
8283
css.foldLeft(empty)(_ ++ _)
8384

8485
def ofType(tp: Type)(using Context): CaptureSet =
85-
val collect = new TypeAccumulator[Refs]:
86-
var localBinders: SimpleIdentitySet[BindingType] = SimpleIdentitySet.empty
87-
var seenLazyRefs: SimpleIdentitySet[LazyRef] = SimpleIdentitySet.empty
88-
def apply(elems: Refs, tp: Type): Refs = trace(i"capt $elems, $tp", capt, show = true) {
89-
tp match
90-
case tp: NamedType =>
91-
if variance < 0 then elems
92-
else elems ++ tp.captureSet.elems
93-
case tp: ParamRef =>
94-
if variance < 0 || localBinders.contains(tp.binder) then elems
95-
else elems ++ tp.captureSet.elems
96-
case tp: LambdaType =>
97-
localBinders += tp
98-
try apply(elems, tp.resultType)
99-
finally localBinders -= tp
100-
case AndType(tp1, tp2) =>
101-
val elems1 = apply(SimpleIdentitySet.empty, tp1)
102-
val elems2 = apply(SimpleIdentitySet.empty, tp2)
103-
elems ++ elems1.intersect(elems2)
104-
case CapturingType(parent, ref) =>
105-
val elems1 = apply(elems, parent)
106-
if variance >= 0 then elems1 + ref else elems1
107-
case TypeBounds(_, hi) =>
108-
apply(elems, hi)
109-
case tp: ClassInfo =>
110-
elems ++ ofClass(tp, Nil).elems
111-
case tp: LazyRef =>
112-
if seenLazyRefs.contains(tp)
113-
|| tp.evaluating // shapeless gets an assertion error without this test
114-
then elems
115-
else
116-
seenLazyRefs += tp
117-
foldOver(elems, tp)
118-
// case tp: MatchType =>
119-
// val normed = tp.tryNormalize
120-
// if normed.exists then apply(elems, normed) else foldOver(elems, tp)
121-
case _ =>
122-
foldOver(elems, tp)
123-
}
124-
125-
CaptureSet(collect(empty.elems, tp))
86+
def recur(tp: Type): CaptureSet = tp match
87+
case tp: CaptureRef =>
88+
tp.captureSet
89+
case CapturingType(parent, ref) =>
90+
recur(parent) + ref
91+
case AppliedType(tycon, args) =>
92+
val cs = recur(tycon)
93+
tycon.typeParams match
94+
case tparams @ (LambdaParam(tl, _) :: _) => cs.substParams(tl, args)
95+
case _ => cs
96+
case tp: TypeProxy =>
97+
recur(tp.underlying)
98+
case AndType(tp1, tp2) =>
99+
recur(tp1).intersect(recur(tp2))
100+
case OrType(tp1, tp2) =>
101+
recur(tp1) ++ recur(tp2)
102+
case tp: ClassInfo =>
103+
ofClass(tp, Nil)
104+
case _ =>
105+
empty
106+
recur(tp)
126107
.showing(i"capture set of $tp = $result", capt)
127108

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1565,7 +1565,7 @@ class Definitions {
15651565
* - the upper bound of a TypeParamRef in the current constraint
15661566
*/
15671567
def asContextFunctionType(tp: Type)(using Context): Type =
1568-
tp.stripTypeVar.dealias match
1568+
tp.stripped.dealias match
15691569
case tp1: TypeParamRef if ctx.typerState.constraint.contains(tp1) =>
15701570
asContextFunctionType(TypeComparer.bounds(tp1).hiBound)
15711571
case tp1 =>

compiler/src/dotty/tools/dotc/core/TypeApplications.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,6 @@ class TypeApplications(val self: Type) extends AnyVal {
310310
*/
311311
final def appliedTo(args: List[Type])(using Context): Type = {
312312
record("appliedTo")
313-
val typParams = self.typeParams
314313
val stripped = self.stripTypeVar
315314
val dealiased = stripped.safeDealias
316315
if (args.isEmpty || ctx.erasedTypes) self

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3620,7 +3620,7 @@ object Types {
36203620
case tp: TermParamRef if tp.binder eq thisLambdaType => TrueDeps
36213621
case tp: CapturingType =>
36223622
val status1 = compute(status, tp.parent, theAcc)
3623-
tp.ref match
3623+
tp.ref.stripTypeVar match
36243624
case tp: TermParamRef if tp.binder eq thisLambdaType => combine(status1, CaptureDeps)
36253625
case _ => status1
36263626
case _: ThisType | _: BoundType | NoPrefix => status
@@ -4505,9 +4505,10 @@ object Types {
45054505
* @param origin The parameter that's tracked by the type variable.
45064506
* @param creatorState The typer state in which the variable was created.
45074507
*/
4508-
final class TypeVar private(initOrigin: TypeParamRef, creatorState: TyperState, nestingLevel: Int) extends CachedProxyType with ValueType {
4508+
final class TypeVar private(initOrigin: TypeParamRef, creatorState: TyperState, nestingLevel: Int)
4509+
extends CachedProxyType, CaptureRef {
45094510

4510-
private var currentOrigin = initOrigin
4511+
private var currentOrigin = initOrigin
45114512

45124513
def origin: TypeParamRef = currentOrigin
45134514

@@ -4689,6 +4690,26 @@ object Types {
46894690
if (inst.exists) inst else origin
46904691
}
46914692

4693+
// Capture ref methods
4694+
4695+
def canBeTracked(using Context): Boolean = underlying match
4696+
case ref: CaptureRef => ref.canBeTracked
4697+
case _ => false
4698+
4699+
override def normalizedRef(using Context): CaptureRef = instanceOpt match
4700+
case ref: CaptureRef => ref
4701+
case _ => this
4702+
4703+
override def singletonCaptureSet(using Context) = instanceOpt match
4704+
case ref: CaptureRef => ref.singletonCaptureSet
4705+
case _ => super.singletonCaptureSet
4706+
4707+
override def captureSetOfInfo(using Context): CaptureSet = instanceOpt match
4708+
case ref: CaptureRef => ref.captureSetOfInfo
4709+
case tp => tp.captureSet
4710+
4711+
// Object members
4712+
46924713
override def computeHash(bs: Binders): Int = identityHash(bs)
46934714
override def equals(that: Any): Boolean = this.eq(that.asInstanceOf[AnyRef])
46944715

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -158,14 +158,29 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
158158
argStr ~ " " ~ arrow(isGiven) ~ " " ~ argText(args.last)
159159
}
160160

161-
def toTextDependentFunction(appType: MethodType): Text =
162-
"("
163-
~ keywordText("erased ").provided(appType.isErasedMethod)
164-
~ paramsText(appType)
165-
~ ") "
166-
~ arrow(appType.isImplicitMethod)
167-
~ " "
168-
~ toText(appType.resultType)
161+
def toTextMethodAsFunction(info: Type): Text = info match
162+
case info: MethodType =>
163+
changePrec(GlobalPrec) {
164+
"("
165+
~ keywordText("erased ").provided(info.isErasedMethod)
166+
~ ( if info.isParamDependent || info.isResultDependent
167+
then paramsText(info)
168+
else argsText(info.paramInfos)
169+
)
170+
~ ") "
171+
~ arrow(info.isImplicitMethod)
172+
~ " "
173+
~ toTextMethodAsFunction(info.resultType)
174+
}
175+
case info: PolyType =>
176+
changePrec(GlobalPrec) {
177+
"["
178+
~ paramsText(info)
179+
~ "] => "
180+
~ toTextMethodAsFunction(info.resultType)
181+
}
182+
case _ =>
183+
toText(info)
169184

170185
def isInfixType(tp: Type): Boolean = tp match
171186
case AppliedType(tycon, args) =>
@@ -229,8 +244,10 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
229244
if !printDebug && appliedText(tp.asInstanceOf[HKLambda].resType).isEmpty =>
230245
// don't eta contract if the application would be printed specially
231246
toText(tycon)
232-
case tp: RefinedType if defn.isFunctionType(tp) && !printDebug =>
233-
toTextDependentFunction(tp.refinedInfo.asInstanceOf[MethodType])
247+
case tp: RefinedType
248+
if (defn.isFunctionType(tp) || (tp.parent.typeSymbol eq defn.PolyFunctionClass))
249+
&& !printDebug =>
250+
toTextMethodAsFunction(tp.refinedInfo)
234251
case tp: TypeRef =>
235252
if (tp.symbol.isAnonymousClass && !showUniqueIds)
236253
toText(tp.info)
@@ -244,6 +261,10 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
244261
case ErasedValueType(tycon, underlying) =>
245262
"ErasedValueType(" ~ toText(tycon) ~ ", " ~ toText(underlying) ~ ")"
246263
case tp: ClassInfo =>
264+
if tp.cls.derivesFrom(defn.PolyFunctionClass) then
265+
tp.member(nme.apply).info match
266+
case info: PolyType => return toTextMethodAsFunction(info)
267+
case _ =>
247268
toTextParents(tp.parents) ~~ "{...}"
248269
case JavaArrayType(elemtp) =>
249270
toText(elemtp) ~ "[]"
@@ -506,13 +527,16 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
506527
case RefinedTypeTree(tpt, refines) =>
507528
toTextLocal(tpt) ~ " " ~ blockText(refines)
508529
case AppliedTypeTree(tpt, args) =>
509-
if (tpt.symbol == defn.orType && args.length == 2)
530+
if tpt.symbol == defn.orType && args.length == 2 then
510531
changePrec(OrTypePrec) { toText(args(0)) ~ " | " ~ atPrec(OrTypePrec + 1) { toText(args(1)) } }
511-
else if (tpt.symbol == defn.andType && args.length == 2)
532+
else if tpt.symbol == defn.andType && args.length == 2 then
512533
changePrec(AndTypePrec) { toText(args(0)) ~ " & " ~ atPrec(AndTypePrec + 1) { toText(args(1)) } }
534+
else if tpt.symbol == defn.Predef_retainsType && args.length == 2 then
535+
changePrec(InfixPrec) { toText(args(0)) ~ " retains " ~ toText(args(1)) }
513536
else if defn.isFunctionClass(tpt.symbol)
514537
&& tpt.isInstanceOf[TypeTree] && tree.hasType && !printDebug
515-
then changePrec(GlobalPrec) { toText(tree.typeOpt) }
538+
then
539+
changePrec(GlobalPrec) { toText(tree.typeOpt) }
516540
else args match
517541
case arg :: _ if arg.isTerm =>
518542
toTextLocal(tpt) ~ "(" ~ Text(args.map(argText), ", ") ~ ")"

compiler/src/dotty/tools/dotc/transform/TreeChecker.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -375,14 +375,14 @@ class TreeChecker extends Phase with SymTransformer {
375375
val tpe = tree.typeOpt
376376

377377
// Polymorphic apply methods stay structural until Erasure
378-
val isPolyFunctionApply = (tree.name eq nme.apply) && (tree.qualifier.typeOpt <:< defn.PolyFunctionType)
378+
val isPolyFunctionApply = (tree.name eq nme.apply) && tree.qualifier.typeOpt.derivesFrom(defn.PolyFunctionClass)
379379
// Outer selects are pickled specially so don't require a symbol
380380
val isOuterSelect = tree.name.is(OuterSelectName)
381381
val isPrimitiveArrayOp = ctx.erasedTypes && nme.isPrimitiveName(tree.name)
382382
if !(tree.isType || isPolyFunctionApply || isOuterSelect || isPrimitiveArrayOp) then
383383
val denot = tree.denot
384384
assert(denot.exists, i"Selection $tree with type $tpe does not have a denotation")
385-
assert(denot.symbol.exists, i"Denotation $denot of selection $tree with type $tpe does not have a symbol")
385+
assert(denot.symbol.exists, i"Denotation $denot of selection $tree with type $tpe does not have a symbol, ${tree.qualifier.typeOpt}")
386386

387387
val sym = tree.symbol
388388
val symIsFixed = tpe match {

0 commit comments

Comments
 (0)