Skip to content

Commit 5d6d6aa

Browse files
Merge pull request #8859 from dotty-staging/reimplement-compiletime-code
Implement compiletime.code with a macro
2 parents 9d6036b + b6995d5 commit 5d6d6aa

File tree

8 files changed

+86
-86
lines changed

8 files changed

+86
-86
lines changed

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 3 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -705,41 +705,6 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
705705
case _ =>
706706
}
707707

708-
def issueCode()(using Context): Literal = {
709-
def decompose(arg: Tree): String = arg match {
710-
case Typed(arg, _) => decompose(arg)
711-
case SeqLiteral(elems, _) => elems.map(decompose).mkString(", ")
712-
case Block(Nil, expr) => decompose(expr)
713-
case Inlined(_, Nil, expr) => decompose(expr)
714-
case arg =>
715-
arg.tpe.widenTermRefExpr match {
716-
case ConstantType(Constant(c)) => c.toString
717-
case _ => arg.show
718-
}
719-
}
720-
721-
def malformedString(): String = {
722-
ctx.error("Malformed part `code` string interpolator", call.sourcePos)
723-
""
724-
}
725-
726-
callValueArgss match {
727-
case List(List(Apply(_,List(Typed(SeqLiteral(Literal(headConst) :: parts,_),_)))), List(Typed(SeqLiteral(interpolatedParts,_),_)))
728-
if parts.size == interpolatedParts.size =>
729-
val constantParts = parts.map {
730-
case Literal(const) => const.stringValue
731-
case _ => malformedString()
732-
}
733-
val decomposedInterpolations = interpolatedParts.map(decompose)
734-
val constantString = decomposedInterpolations.zip(constantParts)
735-
.foldLeft(headConst.stringValue) { case (acc, (p1, p2)) => acc + p1 + p2 }
736-
737-
Literal(Constant(constantString)).withSpan(call.span)
738-
case _ =>
739-
Literal(Constant(malformedString()))
740-
}
741-
}
742-
743708
trace(i"inlining $call", inlining, show = true) {
744709

745710
// The normalized bindings collected in `bindingsBuf`
@@ -763,12 +728,9 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
763728

764729
if (inlinedMethod == defn.Compiletime_error) issueError()
765730

766-
if (inlinedMethod == defn.Compiletime_code)
767-
issueCode()(using ctx.fresh.setSetting(ctx.settings.color, "never"))
768-
else
769-
// Take care that only argument bindings go into `bindings`, since positions are
770-
// different for bindings from arguments and bindings from body.
771-
tpd.Inlined(call, finalBindings, finalExpansion)
731+
// Take care that only argument bindings go into `bindings`, since positions are
732+
// different for bindings from arguments and bindings from body.
733+
tpd.Inlined(call, finalBindings, finalExpansion)
772734
}
773735
}
774736

library/src-bootstrapped/dotty/internal/StringContextMacro.scala

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -61,19 +61,9 @@ object StringContextMacro {
6161
import qctx.tasty._
6262
val sourceFile = strCtxExpr.unseal.pos.sourceFile
6363

64-
def notStatic =
65-
qctx.throwError("Expected statically known String Context", strCtxExpr)
66-
def splitParts(seq: Expr[Seq[String]]) = seq match {
67-
case Varargs(p1) =>
68-
p1 match
69-
case Consts(p2) => (p1.toList, p2.toList)
70-
case _ => notStatic
71-
case _ => notStatic
72-
}
7364
val (partsExpr, parts) = strCtxExpr match {
74-
case '{ StringContext($parts: _*) } => splitParts(parts)
75-
case '{ new StringContext($parts: _*) } => splitParts(parts)
76-
case _ => notStatic
65+
case Expr.StringContext(p1 @ Consts(p2)) => (p1.toList, p2.toList)
66+
case _ => qctx.throwError("Expected statically known String Context", strCtxExpr)
7767
}
7868

7969
val args = argsExpr match {

library/src/scala/compiletime/package.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package scala
22

3+
import scala.quoted._
4+
35
package object compiletime {
46

57
erased def erasedValue[T]: T = ???
@@ -17,7 +19,7 @@ package object compiletime {
1719
*/
1820
inline def error(inline msg: String): Nothing = ???
1921

20-
/** Returns the string representation of interpolated values:
22+
/** Returns the string representation of interpolated elaborated code:
2123
*
2224
* ```scala
2325
* inline def logged(p1: => Any) = {
@@ -27,13 +29,19 @@ package object compiletime {
2729
* }
2830
* logged(identity("foo"))
2931
* // above is equivalent to:
30-
* // ("code: identity("foo")", identity("foo"))
32+
* // ("code: scala.Predef.identity("foo")", identity("foo"))
3133
* ```
3234
*
3335
* @note only by-name arguments will be displayed as "code".
3436
* Other values may display unintutively.
3537
*/
36-
inline def (self: => StringContext) code (args: => Any*): String = ???
38+
transparent inline def (inline self: StringContext) code (inline args: Any*): String = ${ codeExpr('self, 'args) }
39+
private def codeExpr(using qctx: QuoteContext)(sc: Expr[StringContext], args: Expr[Seq[Any]]): Expr[String] =
40+
(sc, args) match
41+
case (Expr.StringContext(Consts(parts)), Varargs(args2)) =>
42+
Expr(StringContext(parts: _*).s(args2.map(_.show): _*))
43+
case _ =>
44+
qctx.throwError("compiletime.code must be used as a string interpolator `code\"...\"`")
3745

3846
inline def constValueOpt[T]: Option[T] = ???
3947

library/src/scala/quoted/Expr.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,4 +219,15 @@ object Expr {
219219
}
220220
}
221221

222+
object StringContext {
223+
/** Matches a `StringContext(part0, part1, ...)` and extracts the parts of a call to if the
224+
* parts are passed explicitly. Returns the equvalent to `Seq('{part0}, '{part1}, ...)`.
225+
*/
226+
def unapply(sc: Expr[StringContext])(using QuoteContext): Option[Seq[Expr[String]]] =
227+
sc match
228+
case '{ scala.StringContext(${Varargs(parts)}: _*) } => Some(parts)
229+
case '{ new scala.StringContext(${Varargs(parts)}: _*) } => Some(parts)
230+
case _ => None
231+
}
232+
222233
}

tests/neg/i6622f.check

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
-- Error: tests/neg/i6622f.scala:6:8 -----------------------------------------------------------------------------------
22
6 | fail(println("foo")) // error
33
| ^^^^^^^^^^^^^^^^^^^^
4-
| failed: println("foo") ...
4+
| failed: scala.Predef.println("foo") ...

tests/run-macros/beta-reduce-inline-result.check

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
compile-time: 4
1+
compile-time: ((3.+(1): scala.Int): scala.Int)
22
run-time: 4
3-
compile-time: 1
3+
compile-time: ((1: 1): scala.Int)
44
run-time: 1
55
run-time: 5
66
run-time: 7

tests/run/i6622.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@ import scala.compiletime._
33
object Test {
44

55
def main(args: Array[String]): Unit = {
6-
assert(code"abc ${println(34)} ..." == "abc println(34) ...")
7-
assert(code"abc ${println(34)}" == "abc println(34)")
8-
assert(code"${println(34)} ..." == "println(34) ...")
9-
assert(code"${println(34)}" == "println(34)")
6+
assert(code"abc ${println(34)} ..." == "abc scala.Predef.println(34) ...")
7+
assert(code"abc ${println(34)}" == "abc scala.Predef.println(34)")
8+
assert(code"${println(34)} ..." == "scala.Predef.println(34) ...")
9+
assert(code"${println(34)}" == "scala.Predef.println(34)")
1010
assert(code"..." == "...")
1111
assert(testConstant(code"") == "")
1212
}

tests/run/i8306.check

Lines changed: 52 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,54 @@
1-
compile-time: 3
2-
run-time: 3
3-
compile-time: 3
4-
run-time: 3
5-
compile-time: 3
6-
run-time: 3
7-
compile-time: 3
8-
run-time: 3
9-
compile-time: {
1+
compile-time: ({
2+
val i: 3 = 3
3+
4+
(i: 3)
5+
}: scala.Int)
6+
run-time: 3
7+
compile-time: ({
8+
val i: 3 = 3
9+
10+
(i: 3)
11+
}: scala.Int)
12+
run-time: 3
13+
compile-time: ({
14+
val i: 3 = 3
15+
16+
(i: 3)
17+
}: scala.Int)
18+
run-time: 3
19+
compile-time: (3: scala.Int)
20+
run-time: 3
21+
compile-time: ({
1022
val $elem9: A = Test.a
11-
val $elem10: Int = $elem9.i
12-
val i: Int = $elem10
13-
i:Int
14-
}
15-
run-time: 3
16-
compile-time: 3
17-
run-time: 3
18-
compile-time: 3
19-
run-time: 3
20-
compile-time: 3
21-
run-time: 3
22-
compile-time: 3
23-
run-time: 3
24-
compile-time: 3
23+
val $elem10: scala.Int = $elem9.i
24+
val i: scala.Int = $elem10
25+
26+
(i: scala.Int)
27+
}: scala.Int)
28+
run-time: 3
29+
compile-time: ({
30+
val i: 3 = 3
31+
32+
(i: 3)
33+
}: scala.Int)
34+
run-time: 3
35+
compile-time: ({
36+
val i: 3 = 3
37+
38+
(i: 3)
39+
}: scala.Int)
40+
run-time: 3
41+
compile-time: ({
42+
val i: 3 = 3
43+
44+
(i: 3)
45+
}: scala.Int)
46+
run-time: 3
47+
compile-time: (3: scala.Int)
48+
run-time: 3
49+
compile-time: ({
50+
val t: 3 = 3
51+
52+
(t: 3)
53+
}: scala.Int)
2554
run-time: 3

0 commit comments

Comments
 (0)