Skip to content

Commit c4aefa1

Browse files
committed
Change closure handling
Constrain closure parameters and result from expected type before rechecking the closure's body. This gives more precise types and avoids the spurious duplication of some variables. It also avoids the unmotivated special case that we needed before to make tests pass.
1 parent 6339276 commit c4aefa1

File tree

7 files changed

+106
-105
lines changed

7 files changed

+106
-105
lines changed

compiler/src/dotty/tools/dotc/cc/CheckCaptures.scala

Lines changed: 27 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -408,10 +408,16 @@ class CheckCaptures extends Recheck, SymTransformer:
408408
else if meth == defn.Caps_unsafeUnbox then
409409
mapArgUsing(_.forceBoxStatus(false))
410410
else if meth == defn.Caps_unsafeBoxFunArg then
411-
mapArgUsing:
411+
def forceBox(tp: Type): Type = tp match
412412
case defn.FunctionOf(paramtpe :: Nil, restpe, isContextual) =>
413413
defn.FunctionOf(paramtpe.forceBoxStatus(true) :: Nil, restpe, isContextual)
414-
414+
case tp @ RefinedType(parent, rname, rinfo: MethodType) =>
415+
tp.derivedRefinedType(parent, rname,
416+
rinfo.derivedLambdaType(
417+
paramInfos = rinfo.paramInfos.map(_.forceBoxStatus(true))))
418+
case tp @ CapturingType(parent, refs) =>
419+
tp.derivedCapturingType(forceBox(parent), refs)
420+
mapArgUsing(forceBox)
415421
else
416422
super.recheckApply(tree, pt) match
417423
case appType @ CapturingType(appType1, refs) =>
@@ -485,63 +491,28 @@ class CheckCaptures extends Recheck, SymTransformer:
485491
else ownType
486492
end instantiate
487493

488-
override def recheckClosure(tree: Closure, pt: Type)(using Context): Type =
494+
override def recheckClosure(tree: Closure, pt: Type, forceDependent: Boolean)(using Context): Type =
489495
val cs = capturedVars(tree.meth.symbol)
490496
capt.println(i"typing closure $tree with cvs $cs")
491-
super.recheckClosure(tree, pt).capturing(cs)
492-
.showing(i"rechecked $tree / $pt = $result", capt)
493-
494-
/** Additionally to normal processing, update types of closures if the expected type
495-
* is a function with only pure parameters. In that case, make the anonymous function
496-
* also have the same parameters as the prototype.
497-
* TODO: Develop a clearer rationale for this.
498-
* TODO: Can we generalize this to arbitrary parameters?
499-
* Currently some tests fail if we do this. (e.g. neg.../stackAlloc.scala, others)
500-
*/
501-
override def recheckBlock(block: Block, pt: Type)(using Context): Type =
502-
block match
503-
case closureDef(mdef) =>
504-
pt.dealias match
505-
case defn.FunctionOf(ptformals, _, _)
506-
if ptformals.nonEmpty && ptformals.forall(_.captureSet.isAlwaysEmpty) =>
507-
// Redo setup of the anonymous function so that formal parameters don't
508-
// get capture sets. This is important to avoid false widenings to `cap`
509-
// when taking the base type of the actual closures's dependent function
510-
// type so that it conforms to the expected non-dependent function type.
511-
// See withLogFile.scala for a test case.
512-
val meth = mdef.symbol
513-
// First, undo the previous setup which installed a completer for `meth`.
514-
atPhase(preRecheckPhase.prev)(meth.denot.copySymDenotation())
515-
.installAfter(preRecheckPhase)
516-
517-
// Next, update all parameter symbols to match expected formals
518-
meth.paramSymss.head.lazyZip(ptformals).foreach: (psym, pformal) =>
519-
psym.updateInfoBetween(preRecheckPhase, thisPhase, pformal.mapExprType)
520-
521-
// Next, update types of parameter ValDefs
522-
mdef.paramss.head.lazyZip(ptformals).foreach: (param, pformal) =>
523-
val ValDef(_, tpt, _) = param: @unchecked
524-
tpt.rememberTypeAlways(pformal)
525-
526-
// Next, install a new completer reflecting the new parameters for the anonymous method
527-
val mt = meth.info.asInstanceOf[MethodType]
528-
val completer = new LazyType:
529-
def complete(denot: SymDenotation)(using Context) =
530-
denot.info = mt.companion(ptformals, mdef.tpt.knownType)
531-
.showing(i"simplify info of $meth to $result", capt)
532-
recheckDef(mdef, meth)
533-
meth.updateInfoBetween(preRecheckPhase, thisPhase, completer)
534-
case _ =>
535-
mdef.rhs match
536-
case rhs @ closure(_, _, _) =>
537-
// In a curried closure `x => y => e` don't leak capabilities retained by
538-
// the second closure `y => e` into the first one. This is an approximation
539-
// of the CC rule which says that a closure contributes captures to its
540-
// environment only if a let-bound reference to the closure is used.
541-
mdef.rhs.putAttachment(ClosureBodyValue, ())
542-
case _ =>
497+
super.recheckClosure(tree, pt, forceDependent).capturing(cs)
498+
.showing(i"rechecked closure $tree / $pt = $result", capt)
499+
500+
override def recheckClosureBlock(mdef: DefDef, expr: Closure, pt: Type)(using Context): Type =
501+
mdef.rhs match
502+
case rhs @ closure(_, _, _) =>
503+
// In a curried closure `x => y => e` don't leak capabilities retained by
504+
// the second closure `y => e` into the first one. This is an approximation
505+
// of the CC rule which says that a closure contributes captures to its
506+
// environment only if a let-bound reference to the closure is used.
507+
mdef.rhs.putAttachment(ClosureBodyValue, ())
543508
case _ =>
544-
super.recheckBlock(block, pt)
509+
510+
// Constrain closure's parameters and result from the expected type before
511+
// rechecking the body.
512+
val res = recheckClosure(expr, pt, forceDependent = true)
513+
recheckDef(mdef, mdef.symbol)
514+
res
515+
end recheckClosureBlock
545516

546517
override def recheckValDef(tree: ValDef, sym: Symbol)(using Context): Unit =
547518
try

compiler/src/dotty/tools/dotc/cc/Setup.scala

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,11 +350,17 @@ extends tpd.TreeTraverser:
350350
val newInfo = integrateRT(sym.info, sym.paramSymss, Nil, Nil)
351351
.showing(i"update info $sym: ${sym.info} --> $result", capt)
352352
if newInfo ne sym.info then
353-
val completer = new LazyType:
354-
def complete(denot: SymDenotation)(using Context) =
355-
denot.info = newInfo
356-
recheckDef(tree, sym)
357-
updateInfo(sym, completer)
353+
updateInfo(sym,
354+
if sym.isAnonymousFunction then
355+
// closures are handled specially; the newInfo is constrained from
356+
// the expected type and only afterwards we recheck the definition
357+
newInfo
358+
else new LazyType:
359+
def complete(denot: SymDenotation)(using Context) =
360+
// infos other methods are determined from their definitions which
361+
// are checked on depand
362+
denot.info = newInfo
363+
recheckDef(tree, sym))
358364
case tree: Bind =>
359365
val sym = tree.symbol
360366
updateInfo(sym, transformInferredType(sym.info, boxed = false))

compiler/src/dotty/tools/dotc/transform/Recheck.scala

Lines changed: 60 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,18 @@ object Recheck:
5252
*/
5353
def updateInfoBetween(prevPhase: DenotTransformer, lastPhase: DenotTransformer, newInfo: Type)(using Context): Unit =
5454
if sym.info ne newInfo then
55+
val flags = sym.flags
5556
sym.copySymDenotation(
5657
initFlags =
57-
if sym.flags.isAllOf(ResetPrivateParamAccessor)
58-
then sym.flags &~ ResetPrivate | Private
59-
else sym.flags
58+
if flags.isAllOf(ResetPrivateParamAccessor)
59+
then flags &~ ResetPrivate | Private
60+
else flags
6061
).installAfter(lastPhase) // reset
6162
sym.copySymDenotation(
6263
info = newInfo,
6364
initFlags =
64-
if newInfo.isInstanceOf[LazyType] then sym.flags &~ Touched
65-
else sym.flags
65+
if newInfo.isInstanceOf[LazyType] then flags &~ Touched
66+
else flags
6667
).installAfter(prevPhase)
6768

6869
/** Does symbol have a new denotation valid from phase.next that is different
@@ -96,17 +97,44 @@ object Recheck:
9697
case Some(tpe) => tree.withType(tpe).asInstanceOf[T]
9798
case None => tree
9899

99-
extension (tpe: Type)
100-
101-
/** Map ExprType => T to () ?=> T (and analogously for pure versions).
102-
* Even though this phase runs after ElimByName, ExprTypes can still occur
103-
* as by-name arguments of applied types. See note in doc comment for
104-
* ElimByName phase. Test case is bynamefun.scala.
105-
*/
106-
def mapExprType(using Context): Type = tpe match
107-
case ExprType(rt) => defn.ByNameFunction(rt)
108-
case _ => tpe
109100

101+
/** Map ExprType => T to () ?=> T (and analogously for pure versions).
102+
* Even though this phase runs after ElimByName, ExprTypes can still occur
103+
* as by-name arguments of applied types. See note in doc comment for
104+
* ElimByName phase. Test case is bynamefun.scala.
105+
*/
106+
private def mapExprType(tp: Type)(using Context): Type = tp match
107+
case ExprType(rt) => defn.ByNameFunction(rt)
108+
case _ => tp
109+
110+
/** Normalize `=> A` types to `() ?=> A` types
111+
* - at the top level
112+
* - in function and method parameter types
113+
* - under annotations
114+
*/
115+
def normalizeByName(tp: Type)(using Context): Type = tp match
116+
case tp: ExprType =>
117+
mapExprType(tp)
118+
case tp: PolyType =>
119+
tp.derivedLambdaType(resType = normalizeByName(tp.resType))
120+
case tp: MethodType =>
121+
tp.derivedLambdaType(
122+
paramInfos = tp.paramInfos.mapConserve(mapExprType),
123+
resType = normalizeByName(tp.resType))
124+
case tp @ RefinedType(parent, nme.apply, rinfo) if defn.isFunctionType(tp) =>
125+
tp.derivedRefinedType(parent, nme.apply, normalizeByName(rinfo))
126+
case tp @ defn.FunctionOf(pformals, restpe, isContextual) =>
127+
val pformals1 = pformals.mapConserve(mapExprType)
128+
val restpe1 = normalizeByName(restpe)
129+
if (pformals1 ne pformals) || (restpe1 ne restpe) then
130+
defn.FunctionOf(pformals1, restpe1, isContextual)
131+
else
132+
tp
133+
case tp @ AnnotatedType(parent, ann) =>
134+
tp.derivedAnnotatedType(normalizeByName(parent), ann)
135+
case _ =>
136+
tp
137+
end Recheck
110138

111139
/** A base class that runs a simplified typer pass over an already re-typed program. The pass
112140
* does not transform trees but returns instead the re-typed type of each tree as it is
@@ -183,27 +211,16 @@ abstract class Recheck extends Phase, SymTransformer:
183211
else AnySelectionProto
184212
recheckSelection(tree, recheck(qual, proto).widenIfUnstable, name, pt)
185213

186-
/** When we select the `apply` of a function with type such as `(=> A) => B`,
187-
* we need to convert the parameter type `=> A` to `() ?=> A`. See doc comment
188-
* of `mapExprType`.
189-
*/
190-
def normalizeByName(mbr: SingleDenotation)(using Context): SingleDenotation = mbr.info match
191-
case mt: MethodType if mt.paramInfos.exists(_.isInstanceOf[ExprType]) =>
192-
mbr.derivedSingleDenotation(mbr.symbol,
193-
mt.derivedLambdaType(paramInfos = mt.paramInfos.map(_.mapExprType)))
194-
case _ =>
195-
mbr
196-
197214
def recheckSelection(tree: Select, qualType: Type, name: Name,
198215
sharpen: Denotation => Denotation)(using Context): Type =
199216
if name.is(OuterSelectName) then tree.tpe
200217
else
201218
//val pre = ta.maybeSkolemizePrefix(qualType, name)
202-
val mbr = normalizeByName(
219+
val mbr =
203220
sharpen(
204221
qualType.findMember(name, qualType,
205222
excluded = if tree.symbol.is(Private) then EmptyFlags else Private
206-
)).suchThat(tree.symbol == _))
223+
)).suchThat(tree.symbol == _)
207224
val newType = tree.tpe match
208225
case prevType: NamedType =>
209226
val prevDenot = prevType.denot
@@ -281,7 +298,7 @@ abstract class Recheck extends Phase, SymTransformer:
281298
else fntpe.paramInfos
282299
def recheckArgs(args: List[Tree], formals: List[Type], prefs: List[ParamRef]): List[Type] = args match
283300
case arg :: args1 =>
284-
val argType = recheck(arg, formals.head.mapExprType)
301+
val argType = recheck(arg, normalizeByName(formals.head))
285302
val formals1 =
286303
if fntpe.isParamDependent
287304
then formals.tail.map(_.substParam(prefs.head, argType))
@@ -313,27 +330,33 @@ abstract class Recheck extends Phase, SymTransformer:
313330
recheck(tree.rhs, lhsType.widen)
314331
defn.UnitType
315332

316-
def recheckBlock(stats: List[Tree], expr: Tree, pt: Type)(using Context): Type =
333+
private def recheckBlock(stats: List[Tree], expr: Tree)(using Context): Type =
317334
recheckStats(stats)
318335
val exprType = recheck(expr)
336+
TypeOps.avoid(exprType, localSyms(stats).filterConserve(_.isTerm))
337+
338+
def recheckBlock(tree: Block, pt: Type)(using Context): Type = tree match
339+
case Block(Nil, expr: Block) => recheckBlock(expr, pt)
340+
case Block((mdef : DefDef) :: Nil, closure: Closure) =>
341+
recheckClosureBlock(mdef, closure.withSpan(tree.span), pt)
342+
case Block(stats, expr) => recheckBlock(stats, expr)
319343
// The expected type `pt` is not propagated. Doing so would allow variables in the
320344
// expected type to contain references to local symbols of the block, so the
321345
// local symbols could escape that way.
322-
TypeOps.avoid(exprType, localSyms(stats).filterConserve(_.isTerm))
323346

324-
def recheckBlock(tree: Block, pt: Type)(using Context): Type =
325-
recheckBlock(tree.stats, tree.expr, pt)
347+
def recheckClosureBlock(mdef: DefDef, expr: Closure, pt: Type)(using Context): Type =
348+
recheckBlock(mdef :: Nil, expr)
326349

327350
def recheckInlined(tree: Inlined, pt: Type)(using Context): Type =
328-
recheckBlock(tree.bindings, tree.expansion, pt)(using inlineContext(tree))
351+
recheckBlock(tree.bindings, tree.expansion)(using inlineContext(tree))
329352

330353
def recheckIf(tree: If, pt: Type)(using Context): Type =
331354
recheck(tree.cond, defn.BooleanType)
332355
recheck(tree.thenp, pt) | recheck(tree.elsep, pt)
333356

334-
def recheckClosure(tree: Closure, pt: Type)(using Context): Type =
357+
def recheckClosure(tree: Closure, pt: Type, forceDependent: Boolean = false)(using Context): Type =
335358
if tree.tpt.isEmpty then
336-
tree.meth.tpe.widen.toFunctionType(tree.meth.symbol.is(JavaDefined))
359+
tree.meth.tpe.widen.toFunctionType(tree.meth.symbol.is(JavaDefined), alwaysDependent = forceDependent)
337360
else
338361
recheck(tree.tpt)
339362

@@ -534,9 +557,7 @@ abstract class Recheck extends Phase, SymTransformer:
534557

535558
/** Check that widened types of `tpe` and `pt` are compatible. */
536559
def checkConforms(tpe: Type, pt: Type, tree: Tree)(using Context): Unit = tree match
537-
case _: DefTree | EmptyTree | _: TypeTree | _: Closure =>
538-
// Don't report closure nodes, since their span is a point; wait instead
539-
// for enclosing block to preduce an error
560+
case _: DefTree | EmptyTree | _: TypeTree =>
540561
case _ =>
541562
checkConformsExpr(tpe.widenExpr, pt.widenExpr, tree)
542563

tests/neg-custom-args/captures/capt1.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/capt1.scala:14:2 -----------------------------------------
1616
14 | def f(y: Int) = if x == null then y else y // error
1717
| ^
18-
| Found: Int ->{x} Int
18+
| Found: (y: Int) ->{x} Int
1919
| Required: Matchable
2020
15 | f
2121
|

tests/neg-custom-args/captures/try.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
| This is often caused by a local capability in an argument of method handle
77
| leaking as part of its result.
88
-- [E007] Type Mismatch Error: tests/neg-custom-args/captures/try.scala:29:43 ------------------------------------------
9-
29 | val b = handle[Exception, () -> Nothing] { // error
9+
29 | val b = handle[Exception, () -> Nothing] { // error
1010
| ^
1111
| Found: (x: CT[Exception]^) ->? () ->{x} Nothing
1212
| Required: (x$0: CanThrow[Exception]) => () -> Nothing

tests/neg-custom-args/captures/try.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def test =
2626
(ex: Exception) => ???
2727
}
2828

29-
val b = handle[Exception, () -> Nothing] { // error
29+
val b = handle[Exception, () -> Nothing] { // error
3030
(x: CanThrow[Exception]) => () => raise(new Exception)(using x)
3131
} {
3232
(ex: Exception) => ???
Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
object test:
22
class Plan(elem: Plan)
33
object SomePlan extends Plan(???)
4+
type PP = (-> Plan) -> Plan
45
def f1(expr: (-> Plan) -> Plan): Plan = expr(SomePlan)
56
f1 { onf => Plan(onf) }
67
def f2(expr: (=> Plan) -> Plan): Plan = ???
78
f2 { onf => Plan(onf) }
89
def f3(expr: (-> Plan) => Plan): Plan = ???
9-
f1 { onf => Plan(onf) }
10+
f3 { onf => Plan(onf) }
1011
def f4(expr: (=> Plan) => Plan): Plan = ???
11-
f2 { onf => Plan(onf) }
12+
f4 { onf => Plan(onf) }
13+
def f5(expr: PP): Plan = expr(SomePlan)
14+
f5 { onf => Plan(onf) }

0 commit comments

Comments
 (0)