Skip to content

Commit 67d8995

Browse files
committed
Visit all trees
1 parent a349775 commit 67d8995

File tree

3 files changed

+70
-47
lines changed

3 files changed

+70
-47
lines changed

compiler/src/dotty/tools/dotc/transform/TailRec.scala

+24-47
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,12 @@ import scala.collection.mutable
2121
*
2222
* What it does:
2323
*
24-
* Finds method calls in tail-position and replaces them with jumps.
25-
* A call is in a tail-position if it is the last instruction to be
26-
* executed in the body of a method. This includes being in
27-
* tail-position of a `return` from a `Labeled` block which is itself
28-
* in tail-position (which is critical for tail-recursive calls in the
29-
* cases of a `match`). To identify tail positions, we recurse over
30-
* the trees that may contain calls in tail-position (trees that can't
31-
* contain such calls are not transformed).
24+
* Finds method calls in tail-position and replaces them with jumps. A call is
25+
* in a tail-position if it is the last instruction to be executed in the body
26+
* of a method. This includes being in tail-position inside a `return`
27+
* expression. If the `return` targets a `Labeled` block, then the target block
28+
* must itself be in tail-position (which is critical for tail-recursive calls
29+
* in the cases of a `match`).
3230
*
3331
* When a method contains at least one tail-recursive call, its rhs
3432
* is wrapped in the following structure:
@@ -49,7 +47,7 @@ import scala.collection.mutable
4947
* reassigning the local `var`s substituting formal parameters and
5048
* (b) a `return` from the `tailResult` labeled block, which has the
5149
* net effect of looping back to the beginning of the method.
52-
* If the receiver is modifed in a recursive call, an additional `var`
50+
* If the receiver is modified in a recursive call, an additional `var`
5351
* is used to replace `this`.
5452
*
5553
* As a complete example of the transformation, the classical `fact`
@@ -118,7 +116,7 @@ class TailRec extends MiniPhase {
118116
override def transformDefDef(tree: DefDef)(using Context): Tree = {
119117
val method = tree.symbol
120118
val mandatory = method.hasAnnotation(defn.TailrecAnnot)
121-
def noTailTransform(failureReported: Boolean) = {
119+
def transform(failureReported: Boolean) = {
122120
// FIXME: want to report this error on `tree.nameSpan`, but
123121
// because of extension method getting a weird position, it is
124122
// better to report on method symbol so there's no overlap.
@@ -212,9 +210,9 @@ class TailRec extends MiniPhase {
212210
)
213211
)
214212
}
215-
else noTailTransform(failureReported = transformer.failureReported)
213+
else transform(failureReported = transformer.failureReported)
216214
}
217-
else noTailTransform(failureReported = false)
215+
else transform(failureReported = false)
218216
}
219217

220218
class TailRecElimination(method: Symbol, enclosingClass: ClassSymbol, paramSyms: List[Symbol], isMandatory: Boolean) extends TreeMap {
@@ -274,34 +272,13 @@ class TailRec extends MiniPhase {
274272
finally inTailPosition = saved
275273
}
276274

277-
def yesTailTransform(tree: Tree)(using Context): Tree =
278-
transform(tree, tailPosition = true)
279-
280-
/** If not in tail position a tree traversal may not be needed.
281-
*
282-
* A recursive call may still be in tail position if within the return
283-
* expression of a labeled block.
284-
* A tree traversal may also be needed to report a failure to transform
285-
* a recursive call of a @tailrec annotated method (i.e. `isMandatory`).
286-
*/
287-
private def isTraversalNeeded =
288-
isMandatory || tailPositionLabeledSyms.size > 0
289-
290-
def noTailTransform(tree: Tree)(using Context): Tree =
291-
if (isTraversalNeeded) transform(tree, tailPosition = false)
292-
else tree
293-
294-
def noTailTransforms[Tr <: Tree](trees: List[Tr])(using Context): List[Tr] =
295-
if (isTraversalNeeded) trees.mapConserve(noTailTransform).asInstanceOf[List[Tr]]
296-
else trees
297-
298275
override def transform(tree: Tree)(using Context): Tree = {
299276
/* Rewrite an Apply to be considered for tail call transformation. */
300277
def rewriteApply(tree: Apply): Tree = {
301-
val arguments = noTailTransforms(tree.args)
278+
val arguments = transform(tree.args)
302279

303280
def continue =
304-
cpy.Apply(tree)(noTailTransform(tree.fun), arguments)
281+
cpy.Apply(tree)(transform(tree.fun), arguments)
305282

306283
def fail(reason: String) = {
307284
if (isMandatory) {
@@ -344,7 +321,7 @@ class TailRec extends MiniPhase {
344321
if (prefix eq EmptyTree) assignParamPairs
345322
else
346323
// TODO Opt: also avoid assigning `this` if the prefix is `this.`
347-
(getVarForRewrittenThis(), noTailTransform(prefix)) :: assignParamPairs
324+
(getVarForRewrittenThis(), transform(prefix)) :: assignParamPairs
348325

349326
val assignments = assignThisAndParamPairs match {
350327
case (lhs, rhs) :: Nil =>
@@ -377,22 +354,22 @@ class TailRec extends MiniPhase {
377354
case tree @ Apply(fun, args) =>
378355
val meth = fun.symbol
379356
if (meth == defn.Boolean_|| || meth == defn.Boolean_&&)
380-
cpy.Apply(tree)(noTailTransform(fun), transform(args))
357+
cpy.Apply(tree)(transform(fun), transform(args))
381358
else
382359
rewriteApply(tree)
383360

384361
case tree @ Select(qual, name) =>
385-
cpy.Select(tree)(noTailTransform(qual), name)
362+
cpy.Select(tree)(transform(qual), name)
386363

387364
case tree @ Block(stats, expr) =>
388365
cpy.Block(tree)(
389-
noTailTransforms(stats),
366+
transform(stats),
390367
transform(expr)
391368
)
392369

393370
case tree @ If(cond, thenp, elsep) =>
394371
cpy.If(tree)(
395-
noTailTransform(cond),
372+
transform(cond),
396373
transform(thenp),
397374
transform(elsep)
398375
)
@@ -402,33 +379,33 @@ class TailRec extends MiniPhase {
402379

403380
case tree @ Match(selector, cases) =>
404381
cpy.Match(tree)(
405-
noTailTransform(selector),
382+
transform(selector),
406383
transformSub(cases)
407384
)
408385

409386
case tree: Try =>
410-
val expr = noTailTransform(tree.expr)
387+
val expr = transform(tree.expr)
411388
if (tree.finalizer eq EmptyTree)
412389
// SI-1672 Catches are in tail position when there is no finalizer
413390
cpy.Try(tree)(expr, transformSub(tree.cases), EmptyTree)
414391
else cpy.Try(tree)(
415392
expr,
416-
noTailTransforms(tree.cases),
417-
noTailTransform(tree.finalizer)
393+
transformSub(tree.cases),
394+
transform(tree.finalizer)
418395
)
419396

420397
case tree @ WhileDo(cond, body) =>
421398
cpy.WhileDo(tree)(
422-
noTailTransform(cond),
423-
noTailTransform(body)
399+
transform(cond),
400+
transform(body)
424401
)
425402

426403
case _: Alternative | _: Bind =>
427404
assert(false, "We should never have gotten inside a pattern")
428405
tree
429406

430407
case tree: ValOrDefDef =>
431-
if (isMandatory) noTailTransform(tree.rhs)
408+
if (isMandatory) transform(tree.rhs)
432409
tree
433410

434411
case _: Super | _: This | _: Literal | _: TypeTree | _: TypeDef | EmptyTree =>

tests/run/tailrec-return.check

+4
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
6
22
false
3+
true
4+
false
5+
true
6+
Ada Lovelace, Alan Turing

tests/run/tailrec-return.scala

+42
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,48 @@ object Test:
1111
if n == 1 then return false
1212
true
1313

14+
@annotation.tailrec
15+
def isEvenApply(n: Int): Boolean =
16+
// Return inside an `Apply.fun`
17+
(
18+
if n != 0 && n != 1 then return isEvenApply(n - 2)
19+
else if n == 1 then return false
20+
else (x: Boolean) => x
21+
)(true)
22+
23+
@annotation.tailrec
24+
def isEvenWhile(n: Int): Boolean =
25+
// Return inside a `WhileDo.cond`
26+
while(
27+
if n != 0 && n != 1 then return isEvenWhile(n - 2)
28+
else if n == 1 then return false
29+
else true
30+
) {}
31+
true
32+
33+
@annotation.tailrec
34+
def isEvenReturn(n: Int): Boolean =
35+
// Return inside a `Return`
36+
return
37+
if n != 0 && n != 1 then return isEvenReturn(n - 2)
38+
else if n == 1 then return false
39+
else true
40+
41+
@annotation.tailrec
42+
def names(l: List[(String, String) | Null], acc: List[String] = Nil): List[String] =
43+
l match
44+
case Nil => acc.reverse
45+
case x :: xs =>
46+
if x == null then return names(xs, acc)
47+
48+
val displayName = x._1 + " " + x._2
49+
names(xs, displayName :: acc)
50+
51+
1452
def main(args: Array[String]): Unit =
1553
println(sum(3))
1654
println(isEven(5))
55+
println(isEvenApply(6))
56+
println(isEvenWhile(7))
57+
println(isEvenReturn(8))
58+
println(names(List(("Ada", "Lovelace"), null, ("Alan", "Turing"))).mkString(", "))

0 commit comments

Comments
 (0)