Skip to content

Backport "Allow to beta reduce curried function applications in quotes reflect" to LTS #21040

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -377,17 +377,22 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
end TermTypeTest

object Term extends TermModule:
def betaReduce(tree: Term): Option[Term] =
tree match
case tpd.Block(Nil, expr) =>
for e <- betaReduce(expr) yield tpd.cpy.Block(tree)(Nil, e)
case tpd.Inlined(_, Nil, expr) =>
betaReduce(expr)
case _ =>
val tree1 = dotc.transform.BetaReduce(tree)
if tree1 eq tree then None
else Some(tree1.withSpan(tree.span))

def betaReduce(tree: Term): Option[Term] =
val tree1 = new dotty.tools.dotc.ast.tpd.TreeMap {
override def transform(tree: Tree)(using Context): Tree = tree match {
case tpd.Block(Nil, _) | tpd.Inlined(_, Nil, _) =>
super.transform(tree)
case tpd.Apply(sel @ tpd.Select(expr, nme), args) =>
val tree1 = cpy.Apply(tree)(cpy.Select(sel)(transform(expr), nme), args)
dotc.transform.BetaReduce(tree1).withSpan(tree.span)
case tpd.Apply(ta @ tpd.TypeApply(sel @ tpd.Select(expr: Apply, nme), tpts), args) =>
val tree1 = cpy.Apply(tree)(cpy.TypeApply(ta)(cpy.Select(sel)(transform(expr), nme), tpts), args)
dotc.transform.BetaReduce(tree1).withSpan(tree.span)
case _ =>
dotc.transform.BetaReduce(tree).withSpan(tree.span)
}
}.transform(tree)
if tree1 == tree then None else Some(tree1)
end Term

given TermMethods: TermMethods with
Expand Down
43 changes: 38 additions & 5 deletions library/src/scala/quoted/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,45 @@ abstract class Expr[+T] private[scala] ()
object Expr {

/** `e.betaReduce` returns an expression that is functionally equivalent to `e`,
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
* then it optimizes this the top most call by returning the result of beta-reducing the application.
* Otherwise returns `expr`.
* however if `e` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
* then it optimizes the top most call by returning the result of beta-reducing the application.
* Similarly, all outermost curried function applications will be beta-reduced, if possible.
* Otherwise returns `expr`.
*
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
*
* Example:
* ```scala sc:nocompile
* ((a: Int, b: Int) => a + b).apply(x, y)
* ```
* will be reduced to
* ```scala sc:nocompile
* val a = x
* val b = y
* a + b
* ```
*
* Generally:
* ```scala sc:nocompile
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
* ```
* will be reduced to
* ```scala sc:nocompile
* type X1 = Tx1
* type Y1 = Ty1
* ...
* val x1 = myX1
* val y1 = myY1
* ...
* type Xn = Txn
* type Yn = Tyn
* ...
* val xn = myXn
* val yn = myYn
* ...
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
* ```
*/
def betaReduce[T](expr: Expr[T])(using Quotes): Expr[T] =
import quotes.reflect.*
Expand Down
43 changes: 38 additions & 5 deletions library/src/scala/quoted/Quotes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -751,14 +751,47 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching =>
/** Methods of the module object `val Term` */
trait TermModule { this: Term.type =>

/** Returns a term that is functionally equivalent to `t`,
/** Returns a term that is functionally equivalent to `t`,
* however if `t` is of the form `((y1, ..., yn) => e2)(e1, ..., en)`
* then it optimizes this the top most call by returning the `Some`
* with the result of beta-reducing the application.
* then it optimizes the top most call by returning `Some`
* with the result of beta-reducing the function application.
* Similarly, all outermost curried function applications will be beta-reduced, if possible.
* Otherwise returns `None`.
*
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
* To retain semantics the argument `ei` is bound as `val yi = ei` and by-name arguments to `def yi = ei`.
* Some bindings may be elided as an early optimization.
*
* Example:
* ```scala sc:nocompile
* ((a: Int, b: Int) => a + b).apply(x, y)
* ```
* will be reduced to
* ```scala sc:nocompile
* val a = x
* val b = y
* a + b
* ```
*
* Generally:
* ```scala sc:nocompile
* ([X1, Y1, ...] => (x1, y1, ...) => ... => [Xn, Yn, ...] => (xn, yn, ...) => f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...))).apply[Tx1, Ty1, ...](myX1, myY1, ...)....apply[Txn, Tyn, ...](myXn, myYn, ...)
* ```
* will be reduced to
* ```scala sc:nocompile
* type X1 = Tx1
* type Y1 = Ty1
* ...
* val x1 = myX1
* val y1 = myY1
* ...
* type Xn = Txn
* type Yn = Tyn
* ...
* val xn = myXn
* val yn = myYn
* ...
* f[X1, Y1, ..., Xn, Yn, ...](x1, y1, ..., xn, yn, ...)
* ```
*/
def betaReduce(term: Term): Option[Term]

Expand Down
80 changes: 80 additions & 0 deletions tests/pos-macros/i17506/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
class Foo
class Bar
class Baz

import scala.quoted._

def assertBetaReduction(using Quotes)(applied: Expr[Any], expected: String): quotes.reflect.Term =
import quotes.reflect._
val reducedMaybe = Term.betaReduce(applied.asTerm)
assert(reducedMaybe.isDefined)
val reduced = reducedMaybe.get
assert(reduced.show == expected,s"obtained: ${reduced.show}, expected: ${expected}")
reduced

inline def regularCurriedCtxFun2BetaReduceTest(inline f: Foo ?=> Bar ?=> Int): Unit =
${regularCurriedCtxFun2BetaReduceTestImpl('f)}
def regularCurriedCtxFun2BetaReduceTestImpl(f: Expr[Foo ?=> Bar ?=> Int])(using Quotes): Expr[Int] =
val expected =
"""|{
| val evidence$3: Bar = new Bar()
| val evidence$2: Foo = new Foo()
| 123
|}""".stripMargin
val applied = '{$f(using new Foo())(using new Bar())}
assertBetaReduction(applied, expected).asExprOf[Int]

inline def regularCurriedFun2BetaReduceTest(inline f: Foo => Bar => Int): Int =
${regularCurriedFun2BetaReduceTestImpl('f)}
def regularCurriedFun2BetaReduceTestImpl(f: Expr[Foo => Bar => Int])(using Quotes): Expr[Int] =
val expected =
"""|{
| val b: Bar = new Bar()
| val f: Foo = new Foo()
| 123
|}""".stripMargin
val applied = '{$f(new Foo())(new Bar())}
assertBetaReduction(applied, expected).asExprOf[Int]

inline def typeParamCurriedFun2BetaReduceTest(inline f: [A] => A => [B] => B => Unit): Unit =
${typeParamCurriedFun2BetaReduceTestImpl('f)}
def typeParamCurriedFun2BetaReduceTestImpl(f: Expr[[A] => (a: A) => [B] => (b: B) => Unit])(using Quotes): Expr[Unit] =
val expected =
"""|{
| type Y = Bar
| val y: Bar = new Bar()
| type X = Foo
| val x: Foo = new Foo()
| typeParamFun2[Y, X](y, x)
|}""".stripMargin
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar())}
assertBetaReduction(applied, expected).asExprOf[Unit]

inline def regularCurriedFun3BetaReduceTest(inline f: Foo => Bar => Baz => Int): Int =
${regularCurriedFun3BetaReduceTestImpl('f)}
def regularCurriedFun3BetaReduceTestImpl(f: Expr[Foo => Bar => Baz => Int])(using Quotes): Expr[Int] =
val expected =
"""|{
| val i: Baz = new Baz()
| val b: Bar = new Bar()
| val f: Foo = new Foo()
| 123
|}""".stripMargin
val applied = '{$f(new Foo())(new Bar())(new Baz())}
assertBetaReduction(applied, expected).asExprOf[Int]

inline def typeParamCurriedFun3BetaReduceTest(inline f: [A] => A => [B] => B => [C] => C => Unit): Unit =
${typeParamCurriedFun3BetaReduceTestImpl('f)}
def typeParamCurriedFun3BetaReduceTestImpl(f: Expr[[A] => A => [B] => B => [C] => C => Unit])(using Quotes): Expr[Unit] =
val expected =
"""|{
| type Z = Baz
| val z: Baz = new Baz()
| type Y = Bar
| val y: Bar = new Bar()
| type X = Foo
| val x: Foo = new Foo()
| typeParamFun3[Z, Y, X](z, y, x)
|}""".stripMargin
val applied = '{$f.apply[Foo](new Foo()).apply[Bar](new Bar()).apply[Baz](new Baz())}
assertBetaReduction(applied, expected).asExprOf[Unit]
11 changes: 11 additions & 0 deletions tests/pos-macros/i17506/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
@main def run() =
def typeParamFun2[A, B](a: A, b: B): Unit = println(a.toString + " " + b.toString)
def typeParamFun3[A, B, C](a: A, b: B, c: C): Unit = println(a.toString + " " + b.toString)

regularCurriedCtxFun2BetaReduceTest((f: Foo) ?=> (b: Bar) ?=> 123)
regularCurriedCtxFun2BetaReduceTest(123)
regularCurriedFun2BetaReduceTest(((f: Foo) => (b: Bar) => 123))
typeParamCurriedFun2BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => typeParamFun2[Y, X](y, x))

regularCurriedFun3BetaReduceTest((f: Foo) => (b: Bar) => (i: Baz) => 123)
typeParamCurriedFun3BetaReduceTest([X] => (x: X) => [Y] => (y: Y) => [Z] => (z: Z) => typeParamFun3[Z, Y, X](z, y, x))