Skip to content

Commit 144803b

Browse files
committed
Allow to beta reduce curried function applications in quotes reflect
Previously, the curried functions with multiple applications were not able to be beta-reduced in any way, which was unexpected. Now we allow reducing any number of top-level function applications for a curried function. This was also made clearer in the documentation for the affected (Expr.betaReduce and Term.betaReduce) methods.
1 parent 97677cc commit 144803b

File tree

5 files changed

+183
-21
lines changed

5 files changed

+183
-21
lines changed

compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -373,17 +373,22 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
373373
end TermTypeTest
374374

375375
object Term extends TermModule:
376-
def betaReduce(tree: Term): Option[Term] =
377-
tree match
378-
case tpd.Block(Nil, expr) =>
379-
for e <- betaReduce(expr) yield tpd.cpy.Block(tree)(Nil, e)
380-
case tpd.Inlined(_, Nil, expr) =>
381-
betaReduce(expr)
382-
case _ =>
383-
val tree1 = dotc.transform.BetaReduce(tree)
384-
if tree1 eq tree then None
385-
else Some(tree1.withSpan(tree.span))
386-
376+
def betaReduce(tree: Term): Option[Term] =
377+
val tree1 = new dotty.tools.dotc.ast.tpd.TreeMap {
378+
override def transform(tree: Tree)(using Context): Tree = tree match {
379+
case tpd.Block(Nil, _) | tpd.Inlined(_, Nil, _) =>
380+
super.transform(tree)
381+
case tpd.Apply(sel @ tpd.Select(expr, nme), args) =>
382+
val tree1 = cpy.Apply(tree)(cpy.Select(sel)(transform(expr), nme), args)
383+
dotc.transform.BetaReduce(tree1).withSpan(tree.span)
384+
case tpd.Apply(ta @ tpd.TypeApply(sel @ tpd.Select(expr: Apply, nme), tpts), args) =>
385+
val tree1 = cpy.Apply(tree)(cpy.TypeApply(ta)(cpy.Select(sel)(transform(expr), nme), tpts), args)
386+
dotc.transform.BetaReduce(tree1).withSpan(tree.span)
387+
case _ =>
388+
dotc.transform.BetaReduce(tree).withSpan(tree.span)
389+
}
390+
}.transform(tree)
391+
if tree1 == tree then None else Some(tree1)
387392
end Term
388393

389394
given TermMethods: TermMethods with

library/src/scala/quoted/Expr.scala

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,45 @@ abstract class Expr[+T] private[scala] ()
1010
object Expr {
1111

1212
/** `e.betaReduce` returns an expression that is functionally equivalent to `e`,
13-
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
14-
* then it optimizes this the top most call by returning the result of beta-reducing the application.
15-
* Otherwise returns `expr`.
13+
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
14+
* then it optimizes the top most call by returning the result of beta-reducing the application.
15+
* Similarly, all outermost curried function applications will be beta-reduced, if possible.
16+
* Otherwise returns `expr`.
1617
*
17-
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
18-
* Some bindings may be elided as an early optimization.
18+
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
19+
* Some bindings may be elided as an early optimization.
20+
*
21+
* Example:
22+
* ```scala sc:nocompile
23+
* ((a: Int, b: Int) => a + b).apply(x, y)
24+
* ```
25+
* will be reduced to
26+
* ```scala sc:nocompile
27+
* val a = x
28+
* val b = y
29+
* a + b
30+
* ```
31+
*
32+
* Generally:
33+
* ```scala sc:nocompile
34+
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
35+
* ```
36+
* will be reduced to
37+
* ```scala sc:nocompile
38+
* type X1 = Tx1
39+
* type Y1 = Ty1
40+
* ...
41+
* val x1 = myX1
42+
* val y1 = myY1
43+
* ...
44+
* type Xn = Txn
45+
* type Yn = Tyn
46+
* ...
47+
* val xn = myXn
48+
* val yn = myYn
49+
* ...
50+
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
51+
* ```
1952
*/
2053
def betaReduce[T](expr: Expr[T])(using Quotes): Expr[T] =
2154
import quotes.reflect._

library/src/scala/quoted/Quotes.scala

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -751,14 +751,47 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
751751
/** Methods of the module object `val Term` */
752752
trait TermModule { this: Term.type =>
753753

754-
/** Returns a term that is functionally equivalent to `t`,
754+
/** Returns a term that is functionally equivalent to `t`,
755755
* however if `t` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
756-
* then it optimizes this the top most call by returning the `Some`
757-
* with the result of beta-reducing the application.
756+
* then it optimizes the top most call by returning `Some`
757+
* with the result of beta-reducing the function application.
758+
* Similarly, all outermost curried function applications will be beta-reduced, if possible.
758759
* Otherwise returns `None`.
759760
*
760-
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
761-
* Some bindings may be elided as an early optimization.
761+
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
762+
* Some bindings may be elided as an early optimization.
763+
*
764+
* Example:
765+
* ```scala sc:nocompile
766+
* ((a: Int, b: Int) => a + b).apply(x, y)
767+
* ```
768+
* will be reduced to
769+
* ```scala sc:nocompile
770+
* val a = x
771+
* val b = y
772+
* a + b
773+
* ```
774+
*
775+
* Generally:
776+
* ```scala sc:nocompile
777+
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
778+
* ```
779+
* will be reduced to
780+
* ```scala sc:nocompile
781+
* type X1 = Tx1
782+
* type Y1 = Ty1
783+
* ...
784+
* val x1 = myX1
785+
* val y1 = myY1
786+
* ...
787+
* type Xn = Txn
788+
* type Yn = Tyn
789+
* ...
790+
* val xn = myXn
791+
* val yn = myYn
792+
* ...
793+
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
794+
* ```
762795
*/
763796
def betaReduce(term: Term): Option[Term]
764797

tests/pos-macros/i17506/Macro_1.scala

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
class Foo
2+
class Bar
3+
class Baz
4+
5+
import scala.quoted._
6+
7+
def assertBetaReduction(using Quotes)(applied: Expr[Any], expected: String): quotes.reflect.Term =
8+
import quotes.reflect._
9+
val reducedMaybe = Term.betaReduce(applied.asTerm)
10+
assert(reducedMaybe.isDefined)
11+
val reduced = reducedMaybe.get
12+
assert(reduced.show == expected,s"obtained: ${reduced.show}, expected: ${expected}")
13+
reduced
14+
15+
inline def regularCurriedCtxFun2BetaReduceTest(inline f: Foo ?=> Bar ?=> Int): Unit =
16+
${regularCurriedCtxFun2BetaReduceTestImpl('f)}
17+
def regularCurriedCtxFun2BetaReduceTestImpl(f: Expr[Foo ?=> Bar ?=> Int])(using Quotes): Expr[Int] =
18+
val expected =
19+
"""|{
20+
| val contextual$3: Bar = new Bar()
21+
| val contextual$2: Foo = new Foo()
22+
| 123
23+
|}""".stripMargin
24+
val applied = '{$f(using new Foo())(using new Bar())}
25+
assertBetaReduction(applied, expected).asExprOf[Int]
26+
27+
inline def regularCurriedFun2BetaReduceTest(inline f: Foo => Bar => Int): Int =
28+
${regularCurriedFun2BetaReduceTestImpl('f)}
29+
def regularCurriedFun2BetaReduceTestImpl(f: Expr[Foo => Bar => Int])(using Quotes): Expr[Int] =
30+
val expected =
31+
"""|{
32+
| val b: Bar = new Bar()
33+
| val f: Foo = new Foo()
34+
| 123
35+
|}""".stripMargin
36+
val applied = '{$f(new Foo())(new Bar())}
37+
assertBetaReduction(applied, expected).asExprOf[Int]
38+
39+
inline def typeParamCurriedFun2BetaReduceTest(inline f: [A] => A => [B] => B => Unit): Unit =
40+
${typeParamCurriedFun2BetaReduceTestImpl('f)}
41+
def typeParamCurriedFun2BetaReduceTestImpl(f: Expr[[A] => (a: A) => [B] => (b: B) => Unit])(using Quotes): Expr[Unit] =
42+
val expected =
43+
"""|{
44+
| type Y = Bar
45+
| val y: Bar = new Bar()
46+
| type X = Foo
47+
| val x: Foo = new Foo()
48+
| typeParamFun2[Y, X](y, x)
49+
|}""".stripMargin
50+
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar())}
51+
assertBetaReduction(applied, expected).asExprOf[Unit]
52+
53+
inline def regularCurriedFun3BetaReduceTest(inline f: Foo => Bar => Baz => Int): Int =
54+
${regularCurriedFun3BetaReduceTestImpl('f)}
55+
def regularCurriedFun3BetaReduceTestImpl(f: Expr[Foo => Bar => Baz => Int])(using Quotes): Expr[Int] =
56+
val expected =
57+
"""|{
58+
| val i: Baz = new Baz()
59+
| val b: Bar = new Bar()
60+
| val f: Foo = new Foo()
61+
| 123
62+
|}""".stripMargin
63+
val applied = '{$f(new Foo())(new Bar())(new Baz())}
64+
assertBetaReduction(applied, expected).asExprOf[Int]
65+
66+
inline def typeParamCurriedFun3BetaReduceTest(inline f: [A] => A => [B] => B => [C] => C => Unit): Unit =
67+
${typeParamCurriedFun3BetaReduceTestImpl('f)}
68+
def typeParamCurriedFun3BetaReduceTestImpl(f: Expr[[A] => A => [B] => B => [C] => C => Unit])(using Quotes): Expr[Unit] =
69+
val expected =
70+
"""|{
71+
| type Z = Baz
72+
| val z: Baz = new Baz()
73+
| type Y = Bar
74+
| val y: Bar = new Bar()
75+
| type X = Foo
76+
| val x: Foo = new Foo()
77+
| typeParamFun3[Z, Y, X](z, y, x)
78+
|}""".stripMargin
79+
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar()).apply[Baz](new Baz())}
80+
assertBetaReduction(applied, expected).asExprOf[Unit]

tests/pos-macros/i17506/Test_2.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
@main def run() =
2+
def typeParamFun2[A, B](a: A, b: B): Unit = println(a.toString + " " + b.toString)
3+
def typeParamFun3[A, B, C](a: A, b: B, c: C): Unit = println(a.toString + " " + b.toString)
4+
5+
regularCurriedCtxFun2BetaReduceTest((f: Foo) ?=> (b: Bar) ?=> 123)
6+
regularCurriedCtxFun2BetaReduceTest(123)
7+
regularCurriedFun2BetaReduceTest(((f: Foo) => (b: Bar) => 123))
8+
typeParamCurriedFun2BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => typeParamFun2[Y, X](y, x))
9+
10+
regularCurriedFun3BetaReduceTest((f: Foo) => (b: Bar) => (i: Baz) => 123)
11+
typeParamCurriedFun3BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => [Z] => (z: Z) => typeParamFun3[Z, Y, X](z, y, x))

0 commit comments

Comments
 (0)