Skip to content

Commit 9ab1a70

Browse files
committed
Add syntax cap[qual] for outer capture roots
1 parent 7913391 commit 9ab1a70

File tree

11 files changed

+114
-54
lines changed

11 files changed

+114
-54
lines changed

compiler/src/dotty/tools/dotc/ast/TreeInfo.scala

+11
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,17 @@ trait TreeInfo[T <: Untyped] { self: Trees.Instance[T] =>
376376
case _ =>
377377
tree.tpe.isInstanceOf[ThisType]
378378
}
379+
380+
/** Under capture checking, an extractor for qualified roots `cap[Q]`.
381+
*/
382+
object QualifiedRoot:
383+
384+
def unapply(tree: Apply)(using Context): Option[String] = tree match
385+
case Apply(fn, Literal(lit) :: Nil) if fn.symbol == defn.Caps_capIn =>
386+
Some(lit.value.asInstanceOf[String])
387+
case _ =>
388+
None
389+
end QualifiedRoot
379390
}
380391

381392
trait UntypedTreeInfo extends TreeInfo[Untyped] { self: Trees.Instance[Untyped] =>

compiler/src/dotty/tools/dotc/ast/untpd.scala

+7-1
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,10 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
149149
case Floating
150150
}
151151

152-
/** {x1, ..., xN} T (only relevant under captureChecking) */
152+
/** {x1, ..., xN} T (only relevant under captureChecking)
153+
* Created when parsing function types so that capture set and result type
154+
* is combined in a single node.
155+
*/
153156
case class CapturesAndResult(refs: List[Tree], parent: Tree)(implicit @constructorOnly src: SourceFile) extends TypTree
154157

155158
/** A type tree appearing somewhere in the untyped DefDef of a lambda, it will be typed using `tpFun`.
@@ -512,6 +515,9 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
512515
def captureRoot(using Context): Select =
513516
Select(scalaDot(nme.caps), nme.CAPTURE_ROOT)
514517

518+
def captureRootIn(using Context): Select =
519+
Select(scalaDot(nme.caps), nme.capIn)
520+
515521
def makeRetaining(parent: Tree, refs: List[Tree], annotName: TypeName)(using Context): Annotated =
516522
Annotated(parent, New(scalaAnnotationDot(annotName), List(refs)))
517523

compiler/src/dotty/tools/dotc/cc/CaptureOps.scala

+27-36
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,14 @@ end mapRoots
103103
extension (tree: Tree)
104104

105105
/** Map tree with CaptureRef type to its type, throw IllegalCaptureRef otherwise */
106-
def toCaptureRef(using Context): CaptureRef = tree.tpe match
107-
case ref: CaptureRef => ref
108-
case tpe => throw IllegalCaptureRef(tpe)
106+
def toCaptureRef(using Context): CaptureRef = tree match
107+
case QualifiedRoot(outer) =>
108+
ctx.owner.levelOwnerNamed(outer)
109+
.orElse(defn.captureRoot) // non-existing outer roots are reported in Setup's checkQualifiedRoots
110+
.localRoot.termRef
111+
case _ => tree.tpe match
112+
case ref: CaptureRef => ref
113+
case tpe => throw IllegalCaptureRef(tpe) // if this was compiled from cc syntax, problem should have been reported at Typer
109114

110115
/** Convert a @retains or @retainsByName annotation tree to the capture set it represents.
111116
* For efficience, the result is cached as an Attachment on the tree.
@@ -266,39 +271,6 @@ extension (tp: Type)
266271
tp.tp1.isAlwaysPure && tp.tp2.isAlwaysPure
267272
case _ =>
268273
false
269-
/*!!!
270-
def capturedLocalRoot(using Context): Symbol =
271-
tp.captureSet.elems.toList
272-
.filter(_.isLocalRootCapability)
273-
.map(_.termSymbol)
274-
.maxByOption(_.ccNestingLevel)
275-
.getOrElse(NoSymbol)
276-
277-
/** Remap roots defined in `cls` to the ... */
278-
def remapRoots(pre: Type, cls: Symbol)(using Context): Type =
279-
if cls.isStaticOwner then tp
280-
else
281-
val from =
282-
if cls.source == ctx.compilationUnit.source then cls.localRoot
283-
else defn.captureRoot
284-
mapRoots(from, capturedLocalRoot)(tp)
285-
286-
287-
def containsRoot(root: Symbol)(using Context): Boolean =
288-
val search = new TypeAccumulator[Boolean]:
289-
def apply(x: Boolean, t: Type): Boolean =
290-
if x then true
291-
else t.dealias match
292-
case t1: TermRef if t1.symbol == root => true
293-
case t1: TypeRef if t1.classSymbol.hasAnnotation(defn.CapabilityAnnot) => true
294-
case t1: MethodType =>
295-
!foldOver(x, t1.paramInfos) && this(x, t1.resType)
296-
case t1 @ AppliedType(tycon, args) if defn.isFunctionSymbol(tycon.typeSymbol) =>
297-
val (inits, last :: Nil) = args.splitAt(args.length - 1): @unchecked
298-
!foldOver(x, inits) && this(x, last)
299-
case t1 => foldOver(x, t1)
300-
search(false, tp)
301-
*/
302274

303275
extension (cls: ClassSymbol)
304276

@@ -405,6 +377,7 @@ extension (sym: Symbol)
405377
case psyms :: _ => psyms.find(_.info.typeSymbol == defn.Caps_Cap).getOrElse(NoSymbol)
406378
case _ => NoSymbol
407379

380+
/** The local root corresponding to sym's level owner */
408381
def localRoot(using Context): Symbol =
409382
val owner = sym.levelOwner
410383
assert(owner.exists)
@@ -415,6 +388,24 @@ extension (sym: Symbol)
415388
else newRoot
416389
ccState.localRoots.getOrElseUpdate(owner, lclRoot)
417390

391+
/** The level owner enclosing `sym` which has the given name, or NoSymbol if none exists.
392+
* If name refers to a val that has a closure as rhs, we return the closure as level
393+
* owner.
394+
*/
395+
def levelOwnerNamed(name: String)(using Context): Symbol =
396+
def recur(owner: Symbol, prev: Symbol): Symbol =
397+
if owner.name.toString == name then
398+
if owner.isLevelOwner then owner
399+
else if owner.isTerm && !owner.isOneOf(Method | Module) && prev.exists then prev
400+
else NoSymbol
401+
else if owner == defn.RootClass then
402+
NoSymbol
403+
else
404+
val prev1 = if owner.isAnonymousFunction && owner.isLevelOwner then owner else NoSymbol
405+
recur(owner.owner, prev1)
406+
recur(sym, NoSymbol)
407+
.showing(i"find outer $sym [ $name ] = $result", capt)
408+
418409
def maxNested(other: Symbol)(using Context): Symbol =
419410
if sym.ccNestingLevel < other.ccNestingLevel then other else sym
420411
/* does not work yet, we do mix sets with different levels, for instance in cc-this.scala.

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

+22-12
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,15 @@ object CheckCaptures:
138138
report.error(em"Singleton type $parent cannot have capture set", parent.srcPos)
139139
case _ =>
140140
for elem <- retainedElems(ann) do
141-
elem.tpe match
142-
case ref: CaptureRef =>
143-
if !ref.isTrackableRef then
144-
report.error(em"$elem cannot be tracked since it is not a parameter or local value", elem.srcPos)
145-
case tpe =>
146-
report.error(em"$elem: $tpe is not a legal element of a capture set", elem.srcPos)
141+
elem match
142+
case QualifiedRoot(outer) =>
143+
// Will be checked by Setup's checkOuterRoots
144+
case _ => elem.tpe match
145+
case ref: CaptureRef =>
146+
if !ref.isTrackableRef then
147+
report.error(em"$elem cannot be tracked since it is not a parameter or local value", elem.srcPos)
148+
case tpe =>
149+
report.error(em"$elem: $tpe is not a legal element of a capture set", elem.srcPos)
147150

148151
/** If `tp` is a capturing type, check that all references it mentions have non-empty
149152
* capture sets. Also: warn about redundant capture annotations.
@@ -155,7 +158,7 @@ object CheckCaptures:
155158
if ref.captureSetOfInfo.elems.isEmpty then
156159
report.error(em"$ref cannot be tracked since its capture set is empty", pos)
157160
else if parent.captureSet.accountsFor(ref) then
158-
report.warning(em"redundant capture: $parent already accounts for $ref", pos)
161+
report.warning(em"redundant capture: $parent already accounts for $ref in $tp", pos)
159162
case _ =>
160163

161164
/** Warn if `ann`, which is the tree of a @retains annotation, defines some elements that
@@ -166,11 +169,15 @@ object CheckCaptures:
166169
def warnIfRedundantCaptureSet(ann: Tree, tpt: Tree)(using Context): Unit =
167170
var retained = retainedElems(ann).toArray
168171
for i <- 0 until retained.length do
169-
val ref = retained(i).toCaptureRef
172+
val refTree = retained(i)
173+
val ref = refTree.toCaptureRef
170174
val others = for j <- 0 until retained.length if j != i yield retained(j).toCaptureRef
171175
val remaining = CaptureSet(others*)
172176
if remaining.accountsFor(ref) then
173-
val srcTree = if ann.span.exists then ann else tpt
177+
val srcTree =
178+
if refTree.span.exists then refTree
179+
else if ann.span.exists then ann
180+
else tpt
174181
report.warning(em"redundant capture: $remaining already accounts for $ref", srcTree.srcPos)
175182

176183
/** Attachment key for bodies of closures, provided they are values */
@@ -1192,9 +1199,12 @@ class CheckCaptures extends Recheck, SymTransformer:
11921199
def postCheck(unit: tpd.Tree)(using Context): Unit =
11931200
val checker = new TreeTraverser:
11941201
def traverse(tree: Tree)(using Context): Unit =
1195-
traverseChildren(tree)
1202+
val lctx = tree match
1203+
case _: DefTree | _: TypeDef if tree.symbol.exists => ctx.withOwner(tree.symbol)
1204+
case _ => ctx
1205+
traverseChildren(tree)(using lctx)
11961206
check(tree)
1197-
def check(tree: Tree) = tree match
1207+
def check(tree: Tree)(using Context) = tree match
11981208
case _: InferredTypeTree =>
11991209
case tree: TypeTree if !tree.span.isZeroExtent =>
12001210
tree.knownType.foreachPart { tp =>
@@ -1253,7 +1263,7 @@ class CheckCaptures extends Recheck, SymTransformer:
12531263
case _ =>
12541264
end check
12551265
end checker
1256-
checker.traverse(unit)
1266+
checker.traverse(unit)(using ctx.withOwner(defn.RootClass))
12571267
if !ctx.reporter.errorsReported then
12581268
// We dont report errors here if previous errors were reported, because other
12591269
// errors often result in bad applied types, but flagging these bad types gives

compiler/src/dotty/tools/dotc/cc/Setup.scala

+8-2
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,11 @@ extends tpd.TreeTraverser:
217217
then CapturingType(tp, CaptureSet.universal, boxed = false)
218218
else tp
219219

220+
private def checkQualifiedRoots(tree: Tree)(using Context): Unit =
221+
for case elem @ QualifiedRoot(outer) <- retainedElems(tree) do
222+
if !ctx.owner.levelOwnerNamed(outer).exists then
223+
report.error(em"`$outer` does not name an outer definition that represents a capture level", elem.srcPos)
224+
220225
private def expandAliases(using Context) = new TypeMap with FollowAliases:
221226
override def toString = "expand aliases"
222227
def apply(t: Type) =
@@ -226,12 +231,13 @@ extends tpd.TreeTraverser:
226231
if t2 ne t then return t2
227232
t match
228233
case t @ AnnotatedType(t1, ann) =>
229-
val t2 =
234+
checkQualifiedRoots(ann.tree)
235+
val t3 =
230236
if ann.symbol == defn.RetainsAnnot && isCapabilityClassRef(t1) then t1
231237
else this(t1)
232238
// Don't map capture sets, since that would implicitly normalize sets that
233239
// are not well-formed.
234-
t.derivedAnnotatedType(t2, ann)
240+
t.derivedAnnotatedType(t3, ann)
235241
case _ =>
236242
mapOverFollowingAliases(t)
237243

compiler/src/dotty/tools/dotc/core/Definitions.scala

+1
Original file line numberDiff line numberDiff line change
@@ -972,6 +972,7 @@ class Definitions {
972972
@tu lazy val CapsModule: Symbol = requiredModule("scala.caps")
973973
@tu lazy val captureRoot: TermSymbol = CapsModule.requiredValue("cap")
974974
@tu lazy val Caps_Cap: TypeSymbol = CapsModule.requiredType("Cap")
975+
@tu lazy val Caps_capIn: TermSymbol = CapsModule.requiredMethod("capIn")
975976
@tu lazy val CapsUnsafeModule: Symbol = requiredModule("scala.caps.unsafe")
976977
@tu lazy val Caps_unsafeAssumePure: Symbol = CapsUnsafeModule.requiredMethod("unsafeAssumePure")
977978
@tu lazy val Caps_unsafeBox: Symbol = CapsUnsafeModule.requiredMethod("unsafeBox")

compiler/src/dotty/tools/dotc/core/StdNames.scala

+1
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,7 @@ object StdNames {
434434
val bytes: N = "bytes"
435435
val canEqual_ : N = "canEqual"
436436
val canEqualAny : N = "canEqualAny"
437+
val capIn: N = "capIn"
437438
val caps: N = "caps"
438439
val captureChecking: N = "captureChecking"
439440
val checkInitialized: N = "checkInitialized"

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+13-3
Original file line numberDiff line numberDiff line change
@@ -1423,13 +1423,23 @@ object Parsers {
14231423
case _ => None
14241424
}
14251425

1426-
/** CaptureRef ::= ident | `this`
1426+
/** CaptureRef ::= ident | `this` | `cap` [`[` ident `]`]
14271427
*/
14281428
def captureRef(): Tree =
14291429
if in.token == THIS then simpleRef()
14301430
else termIdent() match
1431-
case Ident(nme.CAPTURE_ROOT) => captureRoot
1432-
case id => id
1431+
case id @ Ident(nme.CAPTURE_ROOT) =>
1432+
if in.token == LBRACKET then
1433+
val ref = atSpan(id.span.start)(captureRootIn)
1434+
val qual =
1435+
inBrackets:
1436+
atSpan(in.offset):
1437+
Literal(Constant(ident().toString))
1438+
atSpan(id.span.start)(Apply(ref, qual :: Nil))
1439+
else
1440+
atSpan(id.span.start)(captureRoot)
1441+
case id =>
1442+
id
14331443

14341444
/** CaptureSet ::= `{` CaptureRef {`,` CaptureRef} `}` -- under captureChecking
14351445
*/

library/src/scala/caps.scala

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ import annotation.experimental
1515

1616
given Cap = cap
1717

18+
def capIn(scope: String): Cap = ()
19+
1820
object unsafe:
1921

2022
extension [T](x: T)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
class C:
2+
def x: C^{cap[d]} = ??? // error
3+
4+
def y: C^{cap[C]} = ??? // ok
5+
private val z = (x: Int) => (c: C^{cap[z]}) => x // ok
6+
7+
private val z2 = identity((x: Int) => (c: C^{cap[z2]}) => x) // error

tests/pos-custom-args/captures/pairs.scala

+15
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,18 @@ object Monomorphic:
3131
val x1c: Cap ->{c} Unit = x1
3232
val y1 = p.snd
3333
val y1c: Cap ->{d} Unit = y1
34+
35+
object Monomorphic2:
36+
37+
class Pair(x: Cap => Unit, y: Cap => Unit):
38+
def fst: Cap^{cap[Pair]} ->{x} Unit = x
39+
def snd: Cap^{cap[Pair]} ->{y} Unit = y
40+
41+
def test(c: Cap, d: Cap) =
42+
def f(x: Cap): Unit = if c == x then ()
43+
def g(x: Cap): Unit = if d == x then ()
44+
val p = Pair(f, g)
45+
val x1 = p.fst
46+
val x1c: Cap ->{c} Unit = x1
47+
val y1 = p.snd
48+
val y1c: Cap ->{d} Unit = y1

0 commit comments

Comments
 (0)