Skip to content

Commit e1524d6

Browse files
Backport "Allow to beta reduce curried function applications in quotes reflect" to LTS (#21040)
Backports #18121 to the LTS branch. PR submitted by the release tooling. [skip ci]
2 parents 79495fd + 499bbf3 commit e1524d6

File tree

5 files changed

+183
-21
lines changed

5 files changed

+183
-21
lines changed

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -377,17 +377,22 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
377377
end TermTypeTest
378378

379379
object Term extends TermModule:
380-
def betaReduce(tree: Term): Option[Term] =
381-
tree match
382-
case tpd.Block(Nil, expr) =>
383-
for e <- betaReduce(expr) yield tpd.cpy.Block(tree)(Nil, e)
384-
case tpd.Inlined(_, Nil, expr) =>
385-
betaReduce(expr)
386-
case _ =>
387-
val tree1 = dotc.transform.BetaReduce(tree)
388-
if tree1 eq tree then None
389-
else Some(tree1.withSpan(tree.span))
390-
380+
def betaReduce(tree: Term): Option[Term] =
381+
val tree1 = new dotty.tools.dotc.ast.tpd.TreeMap {
382+
override def transform(tree: Tree)(using Context): Tree = tree match {
383+
case tpd.Block(Nil, _) | tpd.Inlined(_, Nil, _) =>
384+
super.transform(tree)
385+
case tpd.Apply(sel @ tpd.Select(expr, nme), args) =>
386+
val tree1 = cpy.Apply(tree)(cpy.Select(sel)(transform(expr), nme), args)
387+
dotc.transform.BetaReduce(tree1).withSpan(tree.span)
388+
case tpd.Apply(ta @ tpd.TypeApply(sel @ tpd.Select(expr: Apply, nme), tpts), args) =>
389+
val tree1 = cpy.Apply(tree)(cpy.TypeApply(ta)(cpy.Select(sel)(transform(expr), nme), tpts), args)
390+
dotc.transform.BetaReduce(tree1).withSpan(tree.span)
391+
case _ =>
392+
dotc.transform.BetaReduce(tree).withSpan(tree.span)
393+
}
394+
}.transform(tree)
395+
if tree1 == tree then None else Some(tree1)
391396
end Term
392397

393398
given TermMethods: TermMethods with

library/src/scala/quoted/Expr.scala

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,45 @@ abstract class Expr[+T] private[scala] ()
1010
object Expr {
1111

1212
/** `e.betaReduce` returns an expression that is functionally equivalent to `e`,
13-
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
14-
* then it optimizes this the top most call by returning the result of beta-reducing the application.
15-
* Otherwise returns `expr`.
13+
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
14+
* then it optimizes the top most call by returning the result of beta-reducing the application.
15+
* Similarly, all outermost curried function applications will be beta-reduced, if possible.
16+
* Otherwise returns `expr`.
1617
*
17-
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
18-
* Some bindings may be elided as an early optimization.
18+
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
19+
* Some bindings may be elided as an early optimization.
20+
*
21+
* Example:
22+
* ```scala sc:nocompile
23+
* ((a: Int, b: Int) => a + b).apply(x, y)
24+
* ```
25+
* will be reduced to
26+
* ```scala sc:nocompile
27+
* val a = x
28+
* val b = y
29+
* a + b
30+
* ```
31+
*
32+
* Generally:
33+
* ```scala sc:nocompile
34+
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
35+
* ```
36+
* will be reduced to
37+
* ```scala sc:nocompile
38+
* type X1 = Tx1
39+
* type Y1 = Ty1
40+
* ...
41+
* val x1 = myX1
42+
* val y1 = myY1
43+
* ...
44+
* type Xn = Txn
45+
* type Yn = Tyn
46+
* ...
47+
* val xn = myXn
48+
* val yn = myYn
49+
* ...
50+
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
51+
* ```
1952
*/
2053
def betaReduce[T](expr: Expr[T])(using Quotes): Expr[T] =
2154
import quotes.reflect.*

library/src/scala/quoted/Quotes.scala

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -751,14 +751,47 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
751751
/** Methods of the module object `val Term` */
752752
trait TermModule { this: Term.type =>
753753

754-
/** Returns a term that is functionally equivalent to `t`,
754+
/** Returns a term that is functionally equivalent to `t`,
755755
* however if `t` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
756-
* then it optimizes this the top most call by returning the `Some`
757-
* with the result of beta-reducing the application.
756+
* then it optimizes the top most call by returning `Some`
757+
* with the result of beta-reducing the function application.
758+
* Similarly, all outermost curried function applications will be beta-reduced, if possible.
758759
* Otherwise returns `None`.
759760
*
760-
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
761-
* Some bindings may be elided as an early optimization.
761+
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
762+
* Some bindings may be elided as an early optimization.
763+
*
764+
* Example:
765+
* ```scala sc:nocompile
766+
* ((a: Int, b: Int) => a + b).apply(x, y)
767+
* ```
768+
* will be reduced to
769+
* ```scala sc:nocompile
770+
* val a = x
771+
* val b = y
772+
* a + b
773+
* ```
774+
*
775+
* Generally:
776+
* ```scala sc:nocompile
777+
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
778+
* ```
779+
* will be reduced to
780+
* ```scala sc:nocompile
781+
* type X1 = Tx1
782+
* type Y1 = Ty1
783+
* ...
784+
* val x1 = myX1
785+
* val y1 = myY1
786+
* ...
787+
* type Xn = Txn
788+
* type Yn = Tyn
789+
* ...
790+
* val xn = myXn
791+
* val yn = myYn
792+
* ...
793+
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
794+
* ```
762795
*/
763796
def betaReduce(term: Term): Option[Term]
764797

tests/pos-macros/i17506/Macro_1.scala

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
class Foo
2+
class Bar
3+
class Baz
4+
5+
import scala.quoted._
6+
7+
def assertBetaReduction(using Quotes)(applied: Expr[Any], expected: String): quotes.reflect.Term =
8+
import quotes.reflect._
9+
val reducedMaybe = Term.betaReduce(applied.asTerm)
10+
assert(reducedMaybe.isDefined)
11+
val reduced = reducedMaybe.get
12+
assert(reduced.show == expected,s"obtained: ${reduced.show}, expected: ${expected}")
13+
reduced
14+
15+
inline def regularCurriedCtxFun2BetaReduceTest(inline f: Foo ?=> Bar ?=> Int): Unit =
16+
${regularCurriedCtxFun2BetaReduceTestImpl('f)}
17+
def regularCurriedCtxFun2BetaReduceTestImpl(f: Expr[Foo ?=> Bar ?=> Int])(using Quotes): Expr[Int] =
18+
val expected =
19+
"""|{
20+
| val evidence$3: Bar = new Bar()
21+
| val evidence$2: Foo = new Foo()
22+
| 123
23+
|}""".stripMargin
24+
val applied = '{$f(using new Foo())(using new Bar())}
25+
assertBetaReduction(applied, expected).asExprOf[Int]
26+
27+
inline def regularCurriedFun2BetaReduceTest(inline f: Foo => Bar => Int): Int =
28+
${regularCurriedFun2BetaReduceTestImpl('f)}
29+
def regularCurriedFun2BetaReduceTestImpl(f: Expr[Foo => Bar => Int])(using Quotes): Expr[Int] =
30+
val expected =
31+
"""|{
32+
| val b: Bar = new Bar()
33+
| val f: Foo = new Foo()
34+
| 123
35+
|}""".stripMargin
36+
val applied = '{$f(new Foo())(new Bar())}
37+
assertBetaReduction(applied, expected).asExprOf[Int]
38+
39+
inline def typeParamCurriedFun2BetaReduceTest(inline f: [A] => A => [B] => B => Unit): Unit =
40+
${typeParamCurriedFun2BetaReduceTestImpl('f)}
41+
def typeParamCurriedFun2BetaReduceTestImpl(f: Expr[[A] => (a: A) => [B] => (b: B) => Unit])(using Quotes): Expr[Unit] =
42+
val expected =
43+
"""|{
44+
| type Y = Bar
45+
| val y: Bar = new Bar()
46+
| type X = Foo
47+
| val x: Foo = new Foo()
48+
| typeParamFun2[Y, X](y, x)
49+
|}""".stripMargin
50+
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar())}
51+
assertBetaReduction(applied, expected).asExprOf[Unit]
52+
53+
inline def regularCurriedFun3BetaReduceTest(inline f: Foo => Bar => Baz => Int): Int =
54+
${regularCurriedFun3BetaReduceTestImpl('f)}
55+
def regularCurriedFun3BetaReduceTestImpl(f: Expr[Foo => Bar => Baz => Int])(using Quotes): Expr[Int] =
56+
val expected =
57+
"""|{
58+
| val i: Baz = new Baz()
59+
| val b: Bar = new Bar()
60+
| val f: Foo = new Foo()
61+
| 123
62+
|}""".stripMargin
63+
val applied = '{$f(new Foo())(new Bar())(new Baz())}
64+
assertBetaReduction(applied, expected).asExprOf[Int]
65+
66+
inline def typeParamCurriedFun3BetaReduceTest(inline f: [A] => A => [B] => B => [C] => C => Unit): Unit =
67+
${typeParamCurriedFun3BetaReduceTestImpl('f)}
68+
def typeParamCurriedFun3BetaReduceTestImpl(f: Expr[[A] => A => [B] => B => [C] => C => Unit])(using Quotes): Expr[Unit] =
69+
val expected =
70+
"""|{
71+
| type Z = Baz
72+
| val z: Baz = new Baz()
73+
| type Y = Bar
74+
| val y: Bar = new Bar()
75+
| type X = Foo
76+
| val x: Foo = new Foo()
77+
| typeParamFun3[Z, Y, X](z, y, x)
78+
|}""".stripMargin
79+
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar()).apply[Baz](new Baz())}
80+
assertBetaReduction(applied, expected).asExprOf[Unit]

tests/pos-macros/i17506/Test_2.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
@main def run() =
2+
def typeParamFun2[A, B](a: A, b: B): Unit = println(a.toString + " " + b.toString)
3+
def typeParamFun3[A, B, C](a: A, b: B, c: C): Unit = println(a.toString + " " + b.toString)
4+
5+
regularCurriedCtxFun2BetaReduceTest((f: Foo) ?=> (b: Bar) ?=> 123)
6+
regularCurriedCtxFun2BetaReduceTest(123)
7+
regularCurriedFun2BetaReduceTest(((f: Foo) => (b: Bar) => 123))
8+
typeParamCurriedFun2BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => typeParamFun2[Y, X](y, x))
9+
10+
regularCurriedFun3BetaReduceTest((f: Foo) => (b: Bar) => (i: Baz) => 123)
11+
typeParamCurriedFun3BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => [Z] => (z: Z) => typeParamFun3[Z, Y, X](z, y, x))

0 commit comments

Comments
 (0)