Skip to content

Implement compiletime.code with a macro #8859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 3 additions & 41 deletions compiler/src/dotty/tools/dotc/typer/Inliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -705,41 +705,6 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
case _ =>
}

def issueCode()(using Context): Literal = {
def decompose(arg: Tree): String = arg match {
case Typed(arg, _) => decompose(arg)
case SeqLiteral(elems, _) => elems.map(decompose).mkString(", ")
case Block(Nil, expr) => decompose(expr)
case Inlined(_, Nil, expr) => decompose(expr)
case arg =>
arg.tpe.widenTermRefExpr match {
case ConstantType(Constant(c)) => c.toString
case _ => arg.show
}
}

def malformedString(): String = {
ctx.error("Malformed part `code` string interpolator", call.sourcePos)
""
}

callValueArgss match {
case List(List(Apply(_,List(Typed(SeqLiteral(Literal(headConst) :: parts,_),_)))), List(Typed(SeqLiteral(interpolatedParts,_),_)))
if parts.size == interpolatedParts.size =>
val constantParts = parts.map {
case Literal(const) => const.stringValue
case _ => malformedString()
}
val decomposedInterpolations = interpolatedParts.map(decompose)
val constantString = decomposedInterpolations.zip(constantParts)
.foldLeft(headConst.stringValue) { case (acc, (p1, p2)) => acc + p1 + p2 }

Literal(Constant(constantString)).withSpan(call.span)
case _ =>
Literal(Constant(malformedString()))
}
}

trace(i"inlining $call", inlining, show = true) {

// The normalized bindings collected in `bindingsBuf`
Expand All @@ -763,12 +728,9 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {

if (inlinedMethod == defn.Compiletime_error) issueError()

if (inlinedMethod == defn.Compiletime_code)
issueCode()(using ctx.fresh.setSetting(ctx.settings.color, "never"))
else
// Take care that only argument bindings go into `bindings`, since positions are
// different for bindings from arguments and bindings from body.
tpd.Inlined(call, finalBindings, finalExpansion)
// Take care that only argument bindings go into `bindings`, since positions are
// different for bindings from arguments and bindings from body.
tpd.Inlined(call, finalBindings, finalExpansion)
}
}

Expand Down
14 changes: 2 additions & 12 deletions library/src-bootstrapped/dotty/internal/StringContextMacro.scala
Original file line number Diff line number Diff line change
Expand Up @@ -61,19 +61,9 @@ object StringContextMacro {
import qctx.tasty._
val sourceFile = strCtxExpr.unseal.pos.sourceFile

def notStatic =
qctx.throwError("Expected statically known String Context", strCtxExpr)
def splitParts(seq: Expr[Seq[String]]) = seq match {
case Varargs(p1) =>
p1 match
case Consts(p2) => (p1.toList, p2.toList)
case _ => notStatic
case _ => notStatic
}
val (partsExpr, parts) = strCtxExpr match {
case '{ StringContext($parts: _*) } => splitParts(parts)
case '{ new StringContext($parts: _*) } => splitParts(parts)
case _ => notStatic
case Expr.StringContext(p1 @ Consts(p2)) => (p1.toList, p2.toList)
case _ => qctx.throwError("Expected statically known String Context", strCtxExpr)
}

val args = argsExpr match {
Expand Down
14 changes: 11 additions & 3 deletions library/src/scala/compiletime/package.scala
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package scala

import scala.quoted._

package object compiletime {

erased def erasedValue[T]: T = ???
Expand All @@ -17,7 +19,7 @@ package object compiletime {
*/
inline def error(inline msg: String): Nothing = ???

/** Returns the string representation of interpolated values:
/** Returns the string representation of interpolated elaborated code:
*
* ```scala
* inline def logged(p1: => Any) = {
Expand All @@ -27,13 +29,19 @@ package object compiletime {
* }
* logged(identity("foo"))
* // above is equivalent to:
* // ("code: identity("foo")", identity("foo"))
* // ("code: scala.Predef.identity("foo")", identity("foo"))
* ```
*
* @note only by-name arguments will be displayed as "code".
* Other values may display unintutively.
*/
inline def (self: => StringContext) code (args: => Any*): String = ???
transparent inline def (inline self: StringContext) code (inline args: Any*): String = ${ codeExpr('self, 'args) }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not clear to me why we need transparent, given that the result type is always String.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is because we want this to return a string literal that can be constant folded with other string literals. Also, the error method used the type of the string know the value.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The intrinsic implementation was also transparent, but we forgot to mark the definition as such.

private def codeExpr(using qctx: QuoteContext)(sc: Expr[StringContext], args: Expr[Seq[Any]]): Expr[String] =
(sc, args) match
case (Expr.StringContext(Consts(parts)), Varargs(args2)) =>
Expr(StringContext(parts: _*).s(args2.map(_.show): _*))
case _ =>
qctx.throwError("compiletime.code must be used as a string interpolator `code\"...\"`")

inline def constValueOpt[T]: Option[T] = ???

Expand Down
11 changes: 11 additions & 0 deletions library/src/scala/quoted/Expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -219,4 +219,15 @@ object Expr {
}
}

object StringContext {
/** Matches a `StringContext(part0, part1, ...)` and extracts the parts of a call to if the
* parts are passed explicitly. Returns the equvalent to `Seq('{part0}, '{part1}, ...)`.
*/
def unapply(sc: Expr[StringContext])(using QuoteContext): Option[Seq[Expr[String]]] =
sc match
case '{ scala.StringContext(${Varargs(parts)}: _*) } => Some(parts)
case '{ new scala.StringContext(${Varargs(parts)}: _*) } => Some(parts)
case _ => None
}

}
2 changes: 1 addition & 1 deletion tests/neg/i6622f.check
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
-- Error: tests/neg/i6622f.scala:6:8 -----------------------------------------------------------------------------------
6 | fail(println("foo")) // error
| ^^^^^^^^^^^^^^^^^^^^
| failed: println("foo") ...
| failed: scala.Predef.println("foo") ...
4 changes: 2 additions & 2 deletions tests/run-macros/beta-reduce-inline-result.check
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
compile-time: 4
compile-time: ((3.+(1): scala.Int): scala.Int)
run-time: 4
compile-time: 1
compile-time: ((1: 1): scala.Int)
run-time: 1
run-time: 5
run-time: 7
Expand Down
8 changes: 4 additions & 4 deletions tests/run/i6622.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@ import scala.compiletime._
object Test {

def main(args: Array[String]): Unit = {
assert(code"abc ${println(34)} ..." == "abc println(34) ...")
assert(code"abc ${println(34)}" == "abc println(34)")
assert(code"${println(34)} ..." == "println(34) ...")
assert(code"${println(34)}" == "println(34)")
assert(code"abc ${println(34)} ..." == "abc scala.Predef.println(34) ...")
assert(code"abc ${println(34)}" == "abc scala.Predef.println(34)")
assert(code"${println(34)} ..." == "scala.Predef.println(34) ...")
assert(code"${println(34)}" == "scala.Predef.println(34)")
assert(code"..." == "...")
assert(testConstant(code"") == "")
}
Expand Down
75 changes: 52 additions & 23 deletions tests/run/i8306.check
Original file line number Diff line number Diff line change
@@ -1,25 +1,54 @@
compile-time: 3
run-time: 3
compile-time: 3
run-time: 3
compile-time: 3
run-time: 3
compile-time: 3
run-time: 3
compile-time: {
compile-time: ({
val i: 3 = 3

(i: 3)
}: scala.Int)
run-time: 3
compile-time: ({
val i: 3 = 3

(i: 3)
}: scala.Int)
run-time: 3
compile-time: ({
val i: 3 = 3

(i: 3)
}: scala.Int)
run-time: 3
compile-time: (3: scala.Int)
run-time: 3
compile-time: ({
val $elem9: A = Test.a
val $elem10: Int = $elem9.i
val i: Int = $elem10
i:Int
}
run-time: 3
compile-time: 3
run-time: 3
compile-time: 3
run-time: 3
compile-time: 3
run-time: 3
compile-time: 3
run-time: 3
compile-time: 3
val $elem10: scala.Int = $elem9.i
val i: scala.Int = $elem10

(i: scala.Int)
}: scala.Int)
run-time: 3
compile-time: ({
val i: 3 = 3

(i: 3)
}: scala.Int)
run-time: 3
compile-time: ({
val i: 3 = 3

(i: 3)
}: scala.Int)
run-time: 3
compile-time: ({
val i: 3 = 3

(i: 3)
}: scala.Int)
run-time: 3
compile-time: (3: scala.Int)
run-time: 3
compile-time: ({
val t: 3 = 3

(t: 3)
}: scala.Int)
run-time: 3