From 4e8b59629cc6cb749c1186513402678c538e74a1 Mon Sep 17 00:00:00 2001 From: Ang9876 Date: Tue, 15 Feb 2022 09:20:10 +0100 Subject: [PATCH 1/4] [WIP] Macro Annotation --- .../dotty/tools/dotc/CompilationUnit.scala | 7 + compiler/src/dotty/tools/dotc/Compiler.scala | 2 +- .../dotty/tools/dotc/core/Definitions.scala | 2 + .../dotty/tools/dotc/transform/Inlining.scala | 70 +++- .../tools/dotc/transform/Interpreter.scala | 392 ++++++++++++++++++ .../dotc/transform/MacroAnnotations.scala | 189 +++++++++ .../tools/dotc/transform/PostTyper.scala | 9 + .../dotty/tools/dotc/transform/Splicer.scala | 328 +-------------- .../src/dotty/tools/dotc/typer/Typer.scala | 2 + .../quoted/runtime/impl/QuotesImpl.scala | 4 +- .../scala/annotation/MacroAnnotation.scala | 16 + project/MiMaFilters.scala | 5 +- .../annot-basic-accessIndirect/Macro_1.scala | 13 + .../annot-basic-accessIndirect/Macro_2.scala | 20 + .../annot-basic-accessIndirect/Test.scala | 3 + .../annot-basic-overload/Macro.scala | 17 + .../annot-basic-overload/Test.scala | 4 + .../annot-basic-changeClass/Macro.scala | 21 + .../annot-basic-changeClass/Test.scala | 7 + .../annot-basic-changeVal/Macro.scala | 12 + .../annot-basic-changeVal/Test.scala | 9 + .../annot-basic-classAddDef/Macro.scala | 23 + .../annot-basic-classAddDef/Test.scala | 7 + tests/run-macros/annot-basic-classDef.check | 1 + .../annot-basic-classDef/Macro.scala | 30 ++ .../annot-basic-classDef/Test.scala | 6 + tests/run-macros/annot-basic-companion.check | 0 .../annot-basic-complicated/Macro_1.scala | 32 ++ .../annot-basic-complicated/Macro_2.scala | 14 + .../annot-basic-complicated/Test_3.scala | 21 + tests/run-macros/annot-basic-fib.check | 5 + tests/run-macros/annot-basic-fib/Macro.scala | 27 ++ tests/run-macros/annot-basic-fib/Test.scala | 13 + .../run-macros/annot-basic-gen2/Macro_1.scala | 17 + .../run-macros/annot-basic-gen2/Macro_2.scala | 22 + .../run-macros/annot-basic-gen2/Test_3.scala | 5 + .../annot-basic-generate/Macro_1.scala | 13 + .../annot-basic-generate/Macro_2.scala | 18 + .../annot-basic-generate/Test_3.scala | 5 + tests/run-macros/annot-basic-multiAnnot.check | 1 + .../annot-basic-multiAnnot/Macro.scala | 28 ++ .../annot-basic-multiAnnot/Test.scala | 7 + .../run-macros/annot-basic-nested/Macro.scala | 13 + .../run-macros/annot-basic-nested/Test.scala | 10 + .../run-macros/annot-basic-object/Macro.scala | 27 ++ .../run-macros/annot-basic-object/Test.scala | 8 + .../run-macros/annot-basic-param/Macro.scala | 19 + tests/run-macros/annot-basic-param/Test.scala | 7 + 48 files changed, 1178 insertions(+), 333 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/transform/Interpreter.scala create mode 100644 compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala create mode 100644 library/src/scala/annotation/MacroAnnotation.scala create mode 100644 tests/neg-macros/annot-basic-accessIndirect/Macro_1.scala create mode 100644 tests/neg-macros/annot-basic-accessIndirect/Macro_2.scala create mode 100644 tests/neg-macros/annot-basic-accessIndirect/Test.scala create mode 100644 tests/pos-macros/annot-basic-overload/Macro.scala create mode 100644 tests/pos-macros/annot-basic-overload/Test.scala create mode 100644 tests/run-macros/annot-basic-changeClass/Macro.scala create mode 100644 tests/run-macros/annot-basic-changeClass/Test.scala create mode 100644 tests/run-macros/annot-basic-changeVal/Macro.scala create mode 100644 tests/run-macros/annot-basic-changeVal/Test.scala create mode 100644 tests/run-macros/annot-basic-classAddDef/Macro.scala create mode 100644 tests/run-macros/annot-basic-classAddDef/Test.scala create mode 100644 tests/run-macros/annot-basic-classDef.check create mode 100644 tests/run-macros/annot-basic-classDef/Macro.scala create mode 100644 tests/run-macros/annot-basic-classDef/Test.scala create mode 100644 tests/run-macros/annot-basic-companion.check create mode 100644 tests/run-macros/annot-basic-complicated/Macro_1.scala create mode 100644 tests/run-macros/annot-basic-complicated/Macro_2.scala create mode 100644 tests/run-macros/annot-basic-complicated/Test_3.scala create mode 100644 tests/run-macros/annot-basic-fib.check create mode 100644 tests/run-macros/annot-basic-fib/Macro.scala create mode 100644 tests/run-macros/annot-basic-fib/Test.scala create mode 100644 tests/run-macros/annot-basic-gen2/Macro_1.scala create mode 100644 tests/run-macros/annot-basic-gen2/Macro_2.scala create mode 100644 tests/run-macros/annot-basic-gen2/Test_3.scala create mode 100644 tests/run-macros/annot-basic-generate/Macro_1.scala create mode 100644 tests/run-macros/annot-basic-generate/Macro_2.scala create mode 100644 tests/run-macros/annot-basic-generate/Test_3.scala create mode 100644 tests/run-macros/annot-basic-multiAnnot.check create mode 100644 tests/run-macros/annot-basic-multiAnnot/Macro.scala create mode 100644 tests/run-macros/annot-basic-multiAnnot/Test.scala create mode 100644 tests/run-macros/annot-basic-nested/Macro.scala create mode 100644 tests/run-macros/annot-basic-nested/Test.scala create mode 100644 tests/run-macros/annot-basic-object/Macro.scala create mode 100644 tests/run-macros/annot-basic-object/Test.scala create mode 100644 tests/run-macros/annot-basic-param/Macro.scala create mode 100644 tests/run-macros/annot-basic-param/Test.scala diff --git a/compiler/src/dotty/tools/dotc/CompilationUnit.scala b/compiler/src/dotty/tools/dotc/CompilationUnit.scala index a6069e2749a9..76439d7f5129 100644 --- a/compiler/src/dotty/tools/dotc/CompilationUnit.scala +++ b/compiler/src/dotty/tools/dotc/CompilationUnit.scala @@ -43,6 +43,8 @@ class CompilationUnit protected (val source: SourceFile) { */ var needsInlining: Boolean = false + var hasMacroAnnotations: Boolean = false + /** Set to `true` if inliner added anonymous mirrors that need to be completed */ var needsMirrorSupport: Boolean = false @@ -111,6 +113,7 @@ object CompilationUnit { force.traverse(unit1.tpdTree) unit1.needsStaging = force.containsQuote unit1.needsInlining = force.containsInline + unit1.hasMacroAnnotations = force.containsMacroAnnotation } unit1 } @@ -138,11 +141,15 @@ object CompilationUnit { private class Force extends TreeTraverser { var containsQuote = false var containsInline = false + var containsMacroAnnotation = false def traverse(tree: Tree)(using Context): Unit = { if (tree.symbol.isQuote) containsQuote = true if tree.symbol.is(Flags.Inline) then containsInline = true + for annot <- tree.symbol.annotations do + if annot.tree.symbol.owner.derivesFrom(defn.QuotedMacroAnnotationClass) then + ctx.compilationUnit.hasMacroAnnotations = true traverseChildren(tree) } } diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index ce4ed2d4e4e8..afc77159f509 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -34,7 +34,7 @@ class Compiler { protected def frontendPhases: List[List[Phase]] = List(new Parser) :: // Compiler frontend: scanner, parser List(new TyperPhase) :: // Compiler frontend: namer, typer - List(new YCheckPositions) :: // YCheck positions + // List(new YCheckPositions) :: // YCheck positions List(new sbt.ExtractDependencies) :: // Sends information on classes' dependencies to sbt via callbacks List(new semanticdb.ExtractSemanticDB) :: // Extract info into .semanticdb files List(new PostTyper) :: // Additional checks and cleanups after type checking diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 3e2373d3bd4b..dc55dd82900e 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -865,6 +865,8 @@ class Definitions { @tu lazy val QuotedTypeModule: Symbol = QuotedTypeClass.companionModule @tu lazy val QuotedTypeModule_of: Symbol = QuotedTypeModule.requiredMethod("of") + @tu lazy val QuotedMacroAnnotationClass: ClassSymbol = requiredClass("scala.annotation.MacroAnnotation") + @tu lazy val CanEqualClass: ClassSymbol = getClassIfDefined("scala.Eql").orElse(requiredClass("scala.CanEqual")).asClass def CanEqual_canEqualAny(using Context): TermSymbol = val methodName = if CanEqualClass.name == tpnme.Eql then nme.eqlAny else nme.canEqualAny diff --git a/compiler/src/dotty/tools/dotc/transform/Inlining.scala b/compiler/src/dotty/tools/dotc/transform/Inlining.scala index 5ddcf600c63a..95e583b5095e 100644 --- a/compiler/src/dotty/tools/dotc/transform/Inlining.scala +++ b/compiler/src/dotty/tools/dotc/transform/Inlining.scala @@ -7,11 +7,15 @@ import Contexts._ import Symbols._ import SymUtils._ import dotty.tools.dotc.ast.tpd - +import dotty.tools.dotc.ast.Trees._ +import dotty.tools.dotc.quoted._ import dotty.tools.dotc.core.StagingContext._ import dotty.tools.dotc.inlines.Inlines import dotty.tools.dotc.ast.TreeMapWithImplicits +import scala.annotation.tailrec +import scala.collection.mutable + /** Inlines all calls to inline methods that are not in an inline method or a quote */ class Inlining extends MacroTransform { @@ -23,9 +27,11 @@ class Inlining extends MacroTransform { override def allowsImplicitSearch: Boolean = true - override def run(using Context): Unit = - if ctx.compilationUnit.needsInlining then - try super.run + override def run(using ctx0: Context): Unit = + if ctx0.compilationUnit.needsInlining || ctx0.compilationUnit.hasMacroAnnotations then + try + val ctx = QuotesCache.init(ctx0.fresh) + super.run(using ctx) catch case _: CompilationUnit.SuspendException => () override def runOn(units: List[CompilationUnit])(using Context): List[CompilationUnit] = @@ -61,7 +67,21 @@ class Inlining extends MacroTransform { tree match case tree: DefTree => if tree.symbol.is(Inline) then tree - else super.transform(tree) + else + tree match + case _: Bind => super.transform(tree) + case tree if tree.symbol.is(Param) => super.transform(tree) + case tree if !tree.symbol.isPrimaryConstructor => + val trees = MacroAnnotationTransformer.transform(List(tree), Set(tree.symbol)) + flatTree(trees.map(super.transform(_))) + case tree => super.transform(tree) + case ObjectTrees(valT, clsT) => + val trees = MacroAnnotationTransformer.transform(List(Thicket(valT, clsT)), Set(valT.symbol, clsT.symbol)) + assert(trees.size >= 2) + flatTree(trees.map{ tree => + if tree.symbol.is(Inline) then tree + else super.transform(tree) + }) case _: Typed | _: Block => super.transform(tree) case _ if Inlines.needsInlining(tree) => @@ -75,6 +95,46 @@ class Inlining extends MacroTransform { case _ => super.transform(tree) } + + override def transformStats[T](trees: List[Tree], exprOwner: Symbol, wrapResult: List[Tree] => Context ?=> T)(using Context): T = + @tailrec + def loop(mapped: mutable.ListBuffer[Tree] | Null, unchanged: List[Tree], pending: List[Tree])(using Context): T = + inline def recur(unchange: Boolean, stat1: Tree, rest: List[Tree])(using Context): T = + if unchange then + loop(mapped, unchanged, rest) + else + val buf = if mapped == null then new mutable.ListBuffer[Tree] else mapped + var xc = unchanged + while xc ne pending do + buf += xc.head + xc = xc.tail + stat1 match + case Thicket(stats1) => buf ++= stats1 + case _ => buf += stat1 + loop(buf, rest, rest) + + pending match + case valT :: clsT :: rest if valT.symbol.is(ModuleVal) && clsT.symbol.is(ModuleClass) && + valT.symbol.moduleClass == clsT.symbol => + val stat1 = transform(Thicket(List(valT, clsT)))(using ctx) + val unchange = stat1 match + case Thicket(List(valT1, clsT1)) => (valT eq valT1) && (clsT eq clsT1) + case _ => false + recur(unchange, stat1, rest)(using ctx) + case stat :: rest => + val statCtx = stat match + case _: DefTree | _: ImportOrExport => ctx + case _ => ctx.exprContext(stat, exprOwner) + val stat1 = transform(stat)(using statCtx) + val restCtx = stat match + case stat: Import => ctx.importContext(stat, stat.symbol) + case _ => ctx + recur(stat1 eq stat, stat1, rest)(using restCtx) + case nil => + wrapResult( + if mapped == null then unchanged + else mapped.prependToList(unchanged)) + loop(null, trees, trees) } } diff --git a/compiler/src/dotty/tools/dotc/transform/Interpreter.scala b/compiler/src/dotty/tools/dotc/transform/Interpreter.scala new file mode 100644 index 000000000000..ccbd6c13942d --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/Interpreter.scala @@ -0,0 +1,392 @@ +package dotty.tools.dotc +package transform + +import scala.language.unsafeNulls + +import java.io.{PrintWriter, StringWriter} +import java.lang.reflect.{InvocationTargetException, Method => JLRMethod} + +import core._ +import Decorators._ +import Flags._ +import Types._ +import Contexts._ +import Symbols._ +import Constants._ +import ast.Trees._ +import ast.{TreeTypeMap, untpd} +import util.Spans._ +import SymUtils._ +import NameKinds._ +import dotty.tools.dotc.ast.tpd +import typer.Implicits.SearchFailureType +import typer.PrepareInlineable +import SymDenotations.NoDenotation + +import scala.collection.mutable +import dotty.tools.dotc.core.Annotations._ +import dotty.tools.dotc.core.Names._ +import dotty.tools.dotc.core.StdNames._ +import dotty.tools.dotc.core.StagingContext._ +import dotty.tools.dotc.quoted._ +import dotty.tools.dotc.transform.TreeMapWithStages._ +import dotty.tools.dotc.typer.Inliner +import dotty.tools.dotc.typer.ImportInfo.withRootImports +import dotty.tools.dotc.ast.TreeMapWithImplicits + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.Trees._ +import dotty.tools.dotc.core.Contexts._ +import dotty.tools.dotc.core.Decorators._ +import dotty.tools.dotc.core.Flags._ +import dotty.tools.dotc.core.NameKinds.FlatName +import dotty.tools.dotc.core.Names.{Name, TermName} +import dotty.tools.dotc.core.StdNames._ +import dotty.tools.dotc.core.Types._ +import dotty.tools.dotc.core.Symbols._ +import dotty.tools.dotc.core.Denotations.staticRef +import dotty.tools.dotc.core.{NameKinds, TypeErasure} +import dotty.tools.dotc.core.Constants.Constant + +import scala.util.control.NonFatal +import dotty.tools.dotc.util.SrcPos +import dotty.tools.repl.AbstractFileClassLoader + +import scala.annotation.constructorOnly +import scala.annotation.tailrec +import scala.collection.mutable.ListBuffer + +import scala.quoted.runtime.impl.QuotesImpl + +import scala.reflect.ClassTag +/** List of classes of the parameters of the signature of `sym` */ +abstract class Interpreter(pos: SrcPos, classLoader: ClassLoader)(using Context) { + import tpd._ + type Env = Map[Symbol, Object] + + /** Returns the interpreted result of interpreting the code a call to the symbol with default arguments. + * Return Some of the result or None if some error happen during the interpretation. + */ + def interpret[T](tree: Tree)(implicit ct: ClassTag[T]): Option[T] = + interpretTree(tree)(Map.empty) match { + case obj: T => Some(obj) + case obj => + // TODO upgrade to a full type tag check or something similar + report.error(s"Interpreted tree returned a result of an unexpected type. Expected ${ct.runtimeClass} but was ${obj.getClass}", pos) + None + } + + def interpretTree(tree: Tree)(implicit env: Env): Object = tree match { + case Literal(Constant(value)) => + interpretLiteral(value) + + case tree: Ident if tree.symbol.is(Inline, butNot = Method) => + tree.tpe.widenTermRefExpr match + case ConstantType(c) => c.value.asInstanceOf[Object] + case _ => throw new StopInterpretation(em"${tree.symbol} could not be inlined", tree.srcPos) + + // TODO disallow interpreted method calls as arguments + case Call(fn, args) => + if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) + interpretNew(fn.symbol, args.flatten.map(interpretTree)) + else if (fn.symbol.is(Module)) + interpretModuleAccess(fn.symbol) + else if (fn.symbol.is(Method) && fn.symbol.isStatic) { + val staticMethodCall = interpretedStaticMethodCall(fn.symbol.owner, fn.symbol) + staticMethodCall(interpretArgs(args, fn.symbol.info)) + } + else if fn.symbol.isStatic then + assert(args.isEmpty) + interpretedStaticFieldAccess(fn.symbol) + else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) + if (fn.name == nme.asInstanceOfPM) + interpretModuleAccess(fn.qualifier.symbol) + else { + val staticMethodCall = interpretedStaticMethodCall(fn.qualifier.symbol.moduleClass, fn.symbol) + staticMethodCall(interpretArgs(args, fn.symbol.info)) + } + else if (env.contains(fn.symbol)) + env(fn.symbol) + else if (tree.symbol.is(InlineProxy)) + interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs) + else + unexpectedTree(tree) + + case closureDef((ddef @ DefDef(_, ValDefs(arg :: Nil) :: Nil, _, _))) => + (obj: AnyRef) => interpretTree(ddef.rhs)(using env.updated(arg.symbol, obj)) + + // Interpret `foo(j = x, i = y)` which it is expanded to + // `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)` + case Block(stats, expr) => interpretBlock(stats, expr) + case NamedArg(_, arg) => interpretTree(arg) + + case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion) + + case Typed(expr, _) => + interpretTree(expr) + + case SeqLiteral(elems, _) => + interpretVarargs(elems.map(e => interpretTree(e))) + + case _ => + unexpectedTree(tree) + } + + private def interpretArgs(argss: List[List[Tree]], fnType: Type)(using Env): List[Object] = { + def interpretArgsGroup(args: List[Tree], argTypes: List[Type]): List[Object] = + assert(args.size == argTypes.size) + val view = + for (arg, info) <- args.lazyZip(argTypes) yield + info match + case _: ExprType => () => interpretTree(arg) // by-name argument + case _ => interpretTree(arg) // by-value argument + view.toList + + fnType.dealias match + case fnType: MethodType if fnType.isErasedMethod => interpretArgs(argss, fnType.resType) + case fnType: MethodType => + val argTypes = fnType.paramInfos + assert(argss.head.size == argTypes.size) + interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, fnType.resType) + case fnType: AppliedType if defn.isContextFunctionType(fnType) => + val argTypes :+ resType = fnType.args: @unchecked + interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, resType) + case fnType: PolyType => interpretArgs(argss, fnType.resType) + case fnType: ExprType => interpretArgs(argss, fnType.resType) + case _ => + assert(argss.isEmpty) + Nil + } + + private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = { + var unexpected: Option[Object] = None + val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match { + case stat: ValDef => + accEnv.updated(stat.symbol, interpretTree(stat.rhs)(accEnv)) + case stat => + if (unexpected.isEmpty) + unexpected = Some(unexpectedTree(stat)) + accEnv + }) + unexpected.getOrElse(interpretTree(expr)(newEnv)) + } + + private def interpretLiteral(value: Any)(implicit env: Env): Object = + value.asInstanceOf[Object] + + private def interpretVarargs(args: List[Object])(implicit env: Env): Object = + args.toSeq + + private def interpretedStaticMethodCall(moduleClass: Symbol, fn: Symbol)(implicit env: Env): List[Object] => Object = { + val (inst, clazz) = + try + if (moduleClass.name.startsWith(str.REPL_SESSION_LINE)) + (null, loadReplLineClass(moduleClass)) + else { + val inst = loadModule(moduleClass) + (inst, inst.getClass) + } + catch + case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable => + if (ctx.settings.XprintSuspension.value) + report.echo(i"suspension triggered by a dependency on $sym", pos) + ctx.compilationUnit.suspend() // this throws a SuspendException + + val name = fn.name.asTermName + val method = getMethod(clazz, name, paramsSig(fn)) + (args: List[Object]) => stopIfRuntimeException(method.invoke(inst, args: _*), method) + } + + private def interpretedStaticFieldAccess(sym: Symbol)(implicit env: Env): Object = { + val clazz = loadClass(sym.owner.fullName.toString) + val field = clazz.getField(sym.name.toString) + field.get(null) + } + + private def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object = + loadModule(fn.moduleClass) + + def interpretNew(fn: Symbol, args: => List[Object])(implicit env: Env): Object = { + val clazz = loadClass(fn.owner.fullName.toString) + val constr = clazz.getConstructor(paramsSig(fn): _*) + constr.newInstance(args: _*).asInstanceOf[Object] + } + + private def unexpectedTree(tree: Tree)(implicit env: Env): Object = + throw new StopInterpretation("Unexpected tree could not be interpreted: " + tree, tree.srcPos) + + private def loadModule(sym: Symbol): Object = + if (sym.owner.is(Package)) { + // is top level object + val moduleClass = loadClass(sym.fullName.toString) + moduleClass.getField(str.MODULE_INSTANCE_FIELD).get(null) + } + else { + // nested object in an object + val className = { + val pack = sym.topLevelClass.owner + if (pack == defn.RootPackage || pack == defn.EmptyPackageClass) sym.flatName.toString + else pack.showFullName + "." + sym.flatName + } + val clazz = loadClass(className) + clazz.getConstructor().newInstance().asInstanceOf[Object] + } + + private def loadReplLineClass(moduleClass: Symbol)(implicit env: Env): Class[?] = { + val lineClassloader = new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader) + lineClassloader.loadClass(moduleClass.name.firstPart.toString) + } + + protected def loadClass(name: String): Class[?] = + try classLoader.loadClass(name) + catch { + case _: ClassNotFoundException if ctx.compilationUnit.isSuspendable => + if (ctx.settings.XprintSuspension.value) + report.echo(i"suspension triggered by a dependency on $name", pos) + ctx.compilationUnit.suspend() + case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable => + if (ctx.settings.XprintSuspension.value) + report.echo(i"suspension triggered by a dependency on $sym", pos) + ctx.compilationUnit.suspend() // this throws a SuspendException + } + + protected def getMethod(clazz: Class[?], name: Name, paramClasses: List[Class[?]]): JLRMethod = + try clazz.getMethod(name.toString, paramClasses: _*) + catch { + case _: NoSuchMethodException => + val msg = em"Could not find method ${clazz.getCanonicalName}.$name with parameters ($paramClasses%, %)" + throw new StopInterpretation(msg, pos) + case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable => + if (ctx.settings.XprintSuspension.value) + report.echo(i"suspension triggered by a dependency on $sym", pos) + ctx.compilationUnit.suspend() // this throws a SuspendException + } + + private def stopIfRuntimeException[T](thunk: => T, method: JLRMethod): T = + try thunk + catch { + case ex: RuntimeException => + val sw = new StringWriter() + sw.write("A runtime exception occurred while executing macro expansion\n") + sw.write(ex.getMessage) + sw.write("\n") + ex.printStackTrace(new PrintWriter(sw)) + sw.write("\n") + throw new StopInterpretation(sw.toString, pos) + case ex: InvocationTargetException => + ex.getTargetException match { + case ex: scala.quoted.runtime.StopMacroExpansion => + throw ex + case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable => + if (ctx.settings.XprintSuspension.value) + report.echo(i"suspension triggered by a dependency on $sym", pos) + ctx.compilationUnit.suspend() // this throws a SuspendException + case targetException => + val sw = new StringWriter() + sw.write("Exception occurred while executing macro expansion.\n") + if (!ctx.settings.Ydebug.value) { + val end = targetException.getStackTrace.lastIndexWhere { x => + x.getClassName == method.getDeclaringClass.getCanonicalName && x.getMethodName == method.getName + } + val shortStackTrace = targetException.getStackTrace.take(end + 1) + targetException.setStackTrace(shortStackTrace) + } + targetException.printStackTrace(new PrintWriter(sw)) + sw.write("\n") + throw new StopInterpretation(sw.toString, pos) + } + } + + private object MissingClassDefinedInCurrentRun { + def unapply(targetException: NoClassDefFoundError)(using Context): Option[Symbol] = { + val className = targetException.getMessage + if (className eq null) None + else { + val sym = staticRef(className.toTypeName).symbol + if (sym.isDefinedInCurrentRun) Some(sym) else None + } + } + } + + /** List of classes of the parameters of the signature of `sym` */ + protected def paramsSig(sym: Symbol): List[Class[?]] = { + def paramClass(param: Type): Class[?] = { + def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match { + case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1) + case _ => (tpe, depth) + } + def javaArraySig(tpe: Type): String = { + val (elemType, depth) = arrayDepth(tpe, 0) + val sym = elemType.classSymbol + val suffix = + if (sym == defn.BooleanClass) "Z" + else if (sym == defn.ByteClass) "B" + else if (sym == defn.ShortClass) "S" + else if (sym == defn.IntClass) "I" + else if (sym == defn.LongClass) "J" + else if (sym == defn.FloatClass) "F" + else if (sym == defn.DoubleClass) "D" + else if (sym == defn.CharClass) "C" + else "L" + javaSig(elemType) + ";" + ("[" * depth) + suffix + } + def javaSig(tpe: Type): String = tpe match { + case tpe: JavaArrayType => javaArraySig(tpe) + case _ => + // Take the flatten name of the class and the full package name + val pack = tpe.classSymbol.topLevelClass.owner + val packageName = if (pack == defn.EmptyPackageClass) "" else s"${pack.fullName}." + packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString + } + + val sym = param.classSymbol + if (sym == defn.BooleanClass) classOf[Boolean] + else if (sym == defn.ByteClass) classOf[Byte] + else if (sym == defn.CharClass) classOf[Char] + else if (sym == defn.ShortClass) classOf[Short] + else if (sym == defn.IntClass) classOf[Int] + else if (sym == defn.LongClass) classOf[Long] + else if (sym == defn.FloatClass) classOf[Float] + else if (sym == defn.DoubleClass) classOf[Double] + else java.lang.Class.forName(javaSig(param), false, classLoader) + } + def getExtraParams(tp: Type): List[Type] = tp.widenDealias match { + case tp: AppliedType if defn.isContextFunctionType(tp) => + // Call context function type direct method + tp.args.init.map(arg => TypeErasure.erasure(arg)) ::: getExtraParams(tp.args.last) + case _ => Nil + } + val extraParams = getExtraParams(sym.info.finalResultType) + val allParams = TypeErasure.erasure(sym.info) match { + case meth: MethodType => meth.paramInfos ::: extraParams + case _ => extraParams + } + allParams.map(paramClass) + } +} + +/** Exception that stops interpretation if some issue is found */ +class StopInterpretation(val msg: String, val pos: SrcPos) extends Exception + +object Call { + import tpd._ + /** Matches an expression that is either a field access or an application + * It retruns a TermRef containing field accessed or a method reference and the arguments passed to it. + */ + def unapply(arg: Tree)(using Context): Option[(RefTree, List[List[Tree]])] = + Call0.unapply(arg).map((fn, args) => (fn, args.reverse)) + + private object Call0 { + def unapply(arg: Tree)(using Context): Option[(RefTree, List[List[Tree]])] = arg match { + case Select(Call0(fn, args), nme.apply) if defn.isContextFunctionType(fn.tpe.widenDealias.finalResultType) => + Some((fn, args)) + case fn: Ident => Some((tpd.desugarIdent(fn).withSpan(fn.span), Nil)) + case fn: Select => Some((fn, Nil)) + case Apply(f @ Call0(fn, args1), args2) => + if (f.tpe.widenDealias.isErasedMethod) Some((fn, args1)) + else Some((fn, args2 :: args1)) + case TypeApply(Call0(fn, args), _) => Some((fn, args)) + case _ => None + } + } +} \ No newline at end of file diff --git a/compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala b/compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala new file mode 100644 index 000000000000..d2e3a12ea52f --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala @@ -0,0 +1,189 @@ +package dotty.tools.dotc +package transform + +import scala.language.unsafeNulls + +import java.io.{PrintWriter, StringWriter} +import java.lang.reflect.{InvocationTargetException, Method => JLRMethod} + +import core._ +import Decorators._ +import Flags._ +import Types._ +import Contexts._ +import Symbols._ +import Constants._ +import ast.Trees._ +import ast.{TreeTypeMap, untpd} +import util.Spans._ +import SymUtils._ +import NameKinds._ +import dotty.tools.dotc.ast.tpd +import typer.Implicits.SearchFailureType +import typer.PrepareInlineable +import SymDenotations.NoDenotation + +import scala.collection.mutable +import dotty.tools.dotc.core.Annotations._ +import dotty.tools.dotc.core.Names._ +import dotty.tools.dotc.core.StdNames._ +import dotty.tools.dotc.core.StagingContext._ +import dotty.tools.dotc.quoted._ +import dotty.tools.dotc.transform.TreeMapWithStages._ +import dotty.tools.dotc.typer.Inliner +import dotty.tools.dotc.typer.ImportInfo.withRootImports +import dotty.tools.dotc.ast.TreeMapWithImplicits + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.ast.Trees._ +import dotty.tools.dotc.core.Contexts._ +import dotty.tools.dotc.core.Decorators._ +import dotty.tools.dotc.core.Flags._ +import dotty.tools.dotc.core.NameKinds.FlatName +import dotty.tools.dotc.core.Names.{Name, TermName} +import dotty.tools.dotc.core.StdNames._ +import dotty.tools.dotc.core.Types._ +import dotty.tools.dotc.core.Symbols._ +import dotty.tools.dotc.core.Denotations.staticRef +import dotty.tools.dotc.core.{NameKinds, TypeErasure} +import dotty.tools.dotc.core.Constants.Constant + +import dotty.tools.dotc.util.SrcPos + +import scala.annotation.constructorOnly +import scala.annotation.tailrec +import scala.annotation.MacroAnnotation +import scala.collection.mutable.ListBuffer +import scala.quoted._ +import scala.quoted.runtime.impl.{QuotesImpl, SpliceScope} +import java.lang.reflect.Method + + +object ObjectTrees: + import tpd._ + def unapply(tree: Tree)(using Context): Option[(Tree, Tree)] = tree match + case Thicket(List(valT, clsT)) if valT.symbol.is(ModuleVal) && clsT.symbol.is(ModuleClass) && + valT.symbol.moduleClass == clsT.symbol => + Some((valT, clsT)) + case _ => None + +object MacroAnnotationTransformer { + import tpd._ + + def instantiateAnnot(tree: Tree, annot: Annotation)(using Context): MacroAnnotation = + val interpreter = new InterpreterMacroAnnot(tree.srcPos, MacroClassLoader.fromContext) + interpreter.interpret[Object](annot.tree) match { + case Some(obj) => obj.asInstanceOf[MacroAnnotation] + case None => + throw new Exception(s"Macro annotation ${annot.tree.symbol} cannot be instantiated") + } + + @tailrec + def transform(trees: List[Tree], pending: Set[Symbol])(using Context): List[Tree] = pending match + case _ if pending.size == 0 => + assert(!trees.isEmpty) + trees + case _ => + def getSymbol(tree: Tree)(using Context): List[Symbol] = tree match + case ObjectTrees(valT, clsT) => List(valT.symbol, clsT.symbol) + case _ => List(tree.symbol) + + val expandedTrees = trees.flatMap { + case ObjectTrees(valT, clsT) => transformDef(valT, Some(clsT), pending) + case tree => transformDef(tree, None, pending) + } + transform(expandedTrees, pending union expandedTrees.flatMap(getSymbol(_)).toSet diff trees.flatMap(getSymbol(_)).toSet) + + def transformDef(tree: Tree, modCls: Option[Tree], pending: Set[Symbol])(using Context): List[Tree] = + + def isMacroAnnotation(annot: Annotation)(using Context): Boolean = + val sym = annot.tree.symbol + sym.denot != NoDenotation && sym.owner.derivesFrom(defn.QuotedMacroAnnotationClass) + + def enterNewDefInClass(tree: Tree)(using Context): Tree = + tree match + case tree @ TypeDef(_, tmpl: Template) => + tmpl.body.foreach{case t: DefTree => + if !t.symbol.owner.info.decls.contains(t.symbol.name, t.symbol) then + t.symbol.entered + } + case _ => () + tree + + if !pending.contains(tree.symbol) then + return List(tree) + if level != 0 then + return List(tree) + + val annotTree = tree.symbol.annotations.filter(isMacroAnnotation(_)) + val annotParam = tree match + case tree @ DefDef(_, paramss, _, _) => + paramss.flatMap{ params => params.flatMap{ param => + param.symbol.annotations.filter(isMacroAnnotation(_)).map((_, param))}} + case _ => Nil + + val annots: List[Annotation | (Annotation, DefTree)] = annotTree ++ annotParam + if annots.isEmpty then + modCls match + case Some(clsT) => List(tree, clsT) + case None => List(tree) + + var modClsTmp: Option[Tree] = modCls + var newTreesBefore = new ListBuffer[Tree]() // ListBuilder + var newTreesAfter = new ListBuffer[Tree]() + + val transformedTree: Tree = annots.foldLeft(tree){ (tree, annot) => + + given quotes: Quotes = QuotesImpl()(using SpliceScope.contextWithNewSpliceScope(tree.symbol.sourcePos)(using MacroExpansion.context(tree)).withOwner(tree.symbol)) + given Conversion[Tree, quotes.reflect.Definition] = _.asInstanceOf[quotes.reflect.Definition] + val transformedTrees = annot match + case (annot, param) => + instantiateAnnot(param, annot).asInstanceOf[MacroAnnotation]. + transformParam(param, tree).asInstanceOf[List[Tree]] + case annot: Annotation => modClsTmp match + case Some(modCls) => + val convertedModVal = tree.asInstanceOf[quotes.reflect.ValDef] + val convertedModCls = modCls.asInstanceOf[quotes.reflect.TypeDef] + instantiateAnnot(tree, annot).asInstanceOf[MacroAnnotation]. + transformObject(convertedModVal, convertedModCls).asInstanceOf[List[Tree]] + case None => + instantiateAnnot(tree, annot).asInstanceOf[MacroAnnotation]. + transform(tree).asInstanceOf[List[Tree]] + + val (before, selfAndAfter) = transformedTrees.splitAt(transformedTrees.map(_.symbol).indexOf(tree.symbol)) + val after = modClsTmp match + case Some(_) => + modClsTmp = Some(selfAndAfter.tail.head) + enterNewDefInClass(modClsTmp.get) + selfAndAfter.tail.tail + case None => selfAndAfter.tail + (before ++ after).foreach{case t: DefTree => + // should not have this check, will be removed afterwards + if !t.symbol.owner.info.decls.contains(t.symbol.name, t.symbol) then + t.symbol.entered + } + newTreesBefore.appendAll(before) + newTreesAfter.prependAll(after) + enterNewDefInClass(selfAndAfter.head) + } + newTreesBefore.toList ++ List(transformedTree) ++ modCls.toList ++ newTreesAfter.toList +} + +class InterpreterMacroAnnot(pos: SrcPos, classLoader: ClassLoader)(using Context) extends Interpreter(pos, classLoader): + import tpd._ + override def interpretTree(tree: Tree)(implicit env: Env): Object = tree match { + case Apply(Select(New(annot), _), args) => + val interpretedArgs = args.map(interpret[Object](_).get) + interpretNew(tree.symbol, interpretedArgs) + + case _ => super.interpretTree(tree) + } + + override def interpretNew(fn: Symbol, args: => List[Object])(implicit env: Env): Object = { + val clazz = loadClass(fn.owner.fullName.toString.replaceAll("\\$\\.", "\\$")) + val constr = clazz.getConstructor(paramsSig(fn): _*) + constr.newInstance(args: _*).asInstanceOf[Object] + } + + def getMethod(clazz: Class[_], methodSym: Symbol): Method = + super.getMethod(clazz, methodSym.name.asTermName, paramsSig(methodSym)) diff --git a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala index 2b72b6c8bdbd..16f5de9e78f9 100644 --- a/compiler/src/dotty/tools/dotc/transform/PostTyper.scala +++ b/compiler/src/dotty/tools/dotc/transform/PostTyper.scala @@ -376,17 +376,20 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase ) } case tree: ValDef => + checkAnnotationMacros(tree) checkErasedDef(tree) val tree1 = cpy.ValDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol)) if tree1.removeAttachment(desugar.UntupledParam).isDefined then checkStableSelection(tree.rhs) processValOrDefDef(super.transform(tree1)) case tree: DefDef => + checkAnnotationMacros(tree) checkErasedDef(tree) annotateContextResults(tree) val tree1 = cpy.DefDef(tree)(rhs = normalizeErasedRhs(tree.rhs, tree.symbol)) processValOrDefDef(superAcc.wrapDefDef(tree1)(super.transform(tree1).asInstanceOf[DefDef])) case tree: TypeDef => + checkAnnotationMacros(tree) val sym = tree.symbol if (sym.isClass) VarianceChecker.check(tree) @@ -477,6 +480,12 @@ class PostTyper extends MacroTransform with IdentityDenotTransformer { thisPhase private def normalizeErasedRhs(rhs: Tree, sym: Symbol)(using Context) = if (sym.isEffectivelyErased) dropInlines.transform(rhs) else rhs + private def checkAnnotationMacros(tree: Tree)(using Context) = + if !ctx.compilationUnit.hasMacroAnnotations then + for annot <- tree.symbol.annotations do + if annot.tree.symbol.denot != NoDenotation && annot.tree.symbol.owner.derivesFrom(defn.QuotedMacroAnnotationClass) then + ctx.compilationUnit.hasMacroAnnotations = true + private def checkErasedDef(tree: ValOrDefDef)(using Context): Unit = if tree.symbol.is(Erased, butNot = Macro) then val tpe = tree.rhs.tpe diff --git a/compiler/src/dotty/tools/dotc/transform/Splicer.scala b/compiler/src/dotty/tools/dotc/transform/Splicer.scala index 31c28d7b1854..903534cbe73e 100644 --- a/compiler/src/dotty/tools/dotc/transform/Splicer.scala +++ b/compiler/src/dotty/tools/dotc/transform/Splicer.scala @@ -50,7 +50,7 @@ object Splicer { val oldContextClassLoader = Thread.currentThread().getContextClassLoader Thread.currentThread().setContextClassLoader(classLoader) try { - val interpreter = new Interpreter(splicePos, classLoader) + val interpreter = new InterpreterSplicer(splicePos, classLoader) // Some parts of the macro are evaluated during the unpickling performed in quotedExprToTree val interpretedExpr = interpreter.interpret[Quotes => scala.quoted.Expr[Any]](tree) @@ -220,23 +220,9 @@ object Splicer { } /** Tree interpreter that evaluates the tree */ - private class Interpreter(pos: SrcPos, classLoader: ClassLoader)(using Context) { - - type Env = Map[Symbol, Object] - - /** Returns the interpreted result of interpreting the code a call to the symbol with default arguments. - * Return Some of the result or None if some error happen during the interpretation. - */ - def interpret[T](tree: Tree)(implicit ct: ClassTag[T]): Option[T] = - interpretTree(tree)(Map.empty) match { - case obj: T => Some(obj) - case obj => - // TODO upgrade to a full type tag check or something similar - report.error(s"Interpreted tree returned a result of an unexpected type. Expected ${ct.runtimeClass} but was ${obj.getClass}", pos) - None - } + private class InterpreterSplicer(pos: SrcPos, classLoader: ClassLoader)(using Context) extends Interpreter(pos, classLoader) { - def interpretTree(tree: Tree)(implicit env: Env): Object = tree match { + override def interpretTree(tree: Tree)(implicit env: Env): Object = tree match { case Apply(Select(Apply(TypeApply(fn, _), quoted :: Nil), nme.apply), _) if fn.symbol == defn.QuotedRuntime_exprQuote => val quoted1 = quoted match { case quoted: Ident if quoted.symbol.isAllOf(InlineByNameProxy) => @@ -250,98 +236,7 @@ object Splicer { case Apply(TypeApply(fn, quoted :: Nil), _) if fn.symbol == defn.QuotedTypeModule_of => interpretTypeQuote(quoted) - case Literal(Constant(value)) => - interpretLiteral(value) - - case tree: Ident if tree.symbol.is(Inline, butNot = Method) => - tree.tpe.widenTermRefExpr match - case ConstantType(c) => c.value.asInstanceOf[Object] - case _ => throw new StopInterpretation(em"${tree.symbol} could not be inlined", tree.srcPos) - - // TODO disallow interpreted method calls as arguments - case Call(fn, args) => - if (fn.symbol.isConstructor && fn.symbol.owner.owner.is(Package)) - interpretNew(fn.symbol, args.flatten.map(interpretTree)) - else if (fn.symbol.is(Module)) - interpretModuleAccess(fn.symbol) - else if (fn.symbol.is(Method) && fn.symbol.isStatic) { - val staticMethodCall = interpretedStaticMethodCall(fn.symbol.owner, fn.symbol) - staticMethodCall(interpretArgs(args, fn.symbol.info)) - } - else if fn.symbol.isStatic then - assert(args.isEmpty) - interpretedStaticFieldAccess(fn.symbol) - else if (fn.qualifier.symbol.is(Module) && fn.qualifier.symbol.isStatic) - if (fn.name == nme.asInstanceOfPM) - interpretModuleAccess(fn.qualifier.symbol) - else { - val staticMethodCall = interpretedStaticMethodCall(fn.qualifier.symbol.moduleClass, fn.symbol) - staticMethodCall(interpretArgs(args, fn.symbol.info)) - } - else if (env.contains(fn.symbol)) - env(fn.symbol) - else if (tree.symbol.is(InlineProxy)) - interpretTree(tree.symbol.defTree.asInstanceOf[ValOrDefDef].rhs) - else - unexpectedTree(tree) - - case closureDef((ddef @ DefDef(_, ValDefs(arg :: Nil) :: Nil, _, _))) => - (obj: AnyRef) => interpretTree(ddef.rhs)(using env.updated(arg.symbol, obj)) - - // Interpret `foo(j = x, i = y)` which it is expanded to - // `val j$1 = x; val i$1 = y; foo(i = i$1, j = j$1)` - case Block(stats, expr) => interpretBlock(stats, expr) - case NamedArg(_, arg) => interpretTree(arg) - - case Inlined(_, bindings, expansion) => interpretBlock(bindings, expansion) - - case Typed(expr, _) => - interpretTree(expr) - - case SeqLiteral(elems, _) => - interpretVarargs(elems.map(e => interpretTree(e))) - - case _ => - unexpectedTree(tree) - } - - private def interpretArgs(argss: List[List[Tree]], fnType: Type)(using Env): List[Object] = { - def interpretArgsGroup(args: List[Tree], argTypes: List[Type]): List[Object] = - assert(args.size == argTypes.size) - val view = - for (arg, info) <- args.lazyZip(argTypes) yield - info match - case _: ExprType => () => interpretTree(arg) // by-name argument - case _ => interpretTree(arg) // by-value argument - view.toList - - fnType.dealias match - case fnType: MethodType if fnType.isErasedMethod => interpretArgs(argss, fnType.resType) - case fnType: MethodType => - val argTypes = fnType.paramInfos - assert(argss.head.size == argTypes.size) - interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, fnType.resType) - case fnType: AppliedType if defn.isContextFunctionType(fnType) => - val argTypes :+ resType = fnType.args: @unchecked - interpretArgsGroup(argss.head, argTypes) ::: interpretArgs(argss.tail, resType) - case fnType: PolyType => interpretArgs(argss, fnType.resType) - case fnType: ExprType => interpretArgs(argss, fnType.resType) - case _ => - assert(argss.isEmpty) - Nil - } - - private def interpretBlock(stats: List[Tree], expr: Tree)(implicit env: Env) = { - var unexpected: Option[Object] = None - val newEnv = stats.foldLeft(env)((accEnv, stat) => stat match { - case stat: ValDef => - accEnv.updated(stat.symbol, interpretTree(stat.rhs)(accEnv)) - case stat => - if (unexpected.isEmpty) - unexpected = Some(unexpectedTree(stat)) - accEnv - }) - unexpected.getOrElse(interpretTree(expr)(newEnv)) + case _ => super.interpretTree(tree) } private def interpretQuote(tree: Tree)(implicit env: Env): Object = @@ -349,220 +244,5 @@ object Splicer { private def interpretTypeQuote(tree: Tree)(implicit env: Env): Object = new TypeImpl(QuoteUtils.changeOwnerOfTree(tree, ctx.owner), SpliceScope.getCurrent) - - private def interpretLiteral(value: Any)(implicit env: Env): Object = - value.asInstanceOf[Object] - - private def interpretVarargs(args: List[Object])(implicit env: Env): Object = - args.toSeq - - private def interpretedStaticMethodCall(moduleClass: Symbol, fn: Symbol)(implicit env: Env): List[Object] => Object = { - val (inst, clazz) = - try - if (moduleClass.name.startsWith(str.REPL_SESSION_LINE)) - (null, loadReplLineClass(moduleClass)) - else { - val inst = loadModule(moduleClass) - (inst, inst.getClass) - } - catch - case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable => - if (ctx.settings.XprintSuspension.value) - report.echo(i"suspension triggered by a dependency on $sym", pos) - ctx.compilationUnit.suspend() // this throws a SuspendException - - val name = fn.name.asTermName - val method = getMethod(clazz, name, paramsSig(fn)) - (args: List[Object]) => stopIfRuntimeException(method.invoke(inst, args: _*), method) - } - - private def interpretedStaticFieldAccess(sym: Symbol)(implicit env: Env): Object = { - val clazz = loadClass(sym.owner.fullName.toString) - val field = clazz.getField(sym.name.toString) - field.get(null) - } - - private def interpretModuleAccess(fn: Symbol)(implicit env: Env): Object = - loadModule(fn.moduleClass) - - private def interpretNew(fn: Symbol, args: => List[Object])(implicit env: Env): Object = { - val clazz = loadClass(fn.owner.fullName.toString) - val constr = clazz.getConstructor(paramsSig(fn): _*) - constr.newInstance(args: _*).asInstanceOf[Object] - } - - private def unexpectedTree(tree: Tree)(implicit env: Env): Object = - throw new StopInterpretation("Unexpected tree could not be interpreted: " + tree, tree.srcPos) - - private def loadModule(sym: Symbol): Object = - if (sym.owner.is(Package)) { - // is top level object - val moduleClass = loadClass(sym.fullName.toString) - moduleClass.getField(str.MODULE_INSTANCE_FIELD).get(null) - } - else { - // nested object in an object - val className = { - val pack = sym.topLevelClass.owner - if (pack == defn.RootPackage || pack == defn.EmptyPackageClass) sym.flatName.toString - else pack.showFullName + "." + sym.flatName - } - val clazz = loadClass(className) - clazz.getConstructor().newInstance().asInstanceOf[Object] - } - - private def loadReplLineClass(moduleClass: Symbol)(implicit env: Env): Class[?] = { - val lineClassloader = new AbstractFileClassLoader(ctx.settings.outputDir.value, classLoader) - lineClassloader.loadClass(moduleClass.name.firstPart.toString) - } - - private def loadClass(name: String): Class[?] = - try classLoader.loadClass(name) - catch { - case _: ClassNotFoundException => - val msg = s"Could not find class $name in classpath" - throw new StopInterpretation(msg, pos) - } - - private def getMethod(clazz: Class[?], name: Name, paramClasses: List[Class[?]]): JLRMethod = - try clazz.getMethod(name.toString, paramClasses: _*) - catch { - case _: NoSuchMethodException => - val msg = em"Could not find method ${clazz.getCanonicalName}.$name with parameters ($paramClasses%, %)" - throw new StopInterpretation(msg, pos) - case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable => - if (ctx.settings.XprintSuspension.value) - report.echo(i"suspension triggered by a dependency on $sym", pos) - ctx.compilationUnit.suspend() // this throws a SuspendException - } - - private def stopIfRuntimeException[T](thunk: => T, method: JLRMethod): T = - try thunk - catch { - case ex: RuntimeException => - val sw = new StringWriter() - sw.write("A runtime exception occurred while executing macro expansion\n") - sw.write(ex.getMessage) - sw.write("\n") - ex.printStackTrace(new PrintWriter(sw)) - sw.write("\n") - throw new StopInterpretation(sw.toString, pos) - case ex: InvocationTargetException => - ex.getTargetException match { - case ex: scala.quoted.runtime.StopMacroExpansion => - throw ex - case MissingClassDefinedInCurrentRun(sym) if ctx.compilationUnit.isSuspendable => - if (ctx.settings.XprintSuspension.value) - report.echo(i"suspension triggered by a dependency on $sym", pos) - ctx.compilationUnit.suspend() // this throws a SuspendException - case targetException => - val sw = new StringWriter() - sw.write("Exception occurred while executing macro expansion.\n") - if (!ctx.settings.Ydebug.value) { - val end = targetException.getStackTrace.lastIndexWhere { x => - x.getClassName == method.getDeclaringClass.getCanonicalName && x.getMethodName == method.getName - } - val shortStackTrace = targetException.getStackTrace.take(end + 1) - targetException.setStackTrace(shortStackTrace) - } - targetException.printStackTrace(new PrintWriter(sw)) - sw.write("\n") - throw new StopInterpretation(sw.toString, pos) - } - } - - private object MissingClassDefinedInCurrentRun { - def unapply(targetException: NoClassDefFoundError)(using Context): Option[Symbol] = { - val className = targetException.getMessage - if (className == null) None - else { - val sym = staticRef(className.toTypeName).symbol - if (sym.isDefinedInCurrentRun) Some(sym) else None - } - } - } - - /** List of classes of the parameters of the signature of `sym` */ - private def paramsSig(sym: Symbol): List[Class[?]] = { - def paramClass(param: Type): Class[?] = { - def arrayDepth(tpe: Type, depth: Int): (Type, Int) = tpe match { - case JavaArrayType(elemType) => arrayDepth(elemType, depth + 1) - case _ => (tpe, depth) - } - def javaArraySig(tpe: Type): String = { - val (elemType, depth) = arrayDepth(tpe, 0) - val sym = elemType.classSymbol - val suffix = - if (sym == defn.BooleanClass) "Z" - else if (sym == defn.ByteClass) "B" - else if (sym == defn.ShortClass) "S" - else if (sym == defn.IntClass) "I" - else if (sym == defn.LongClass) "J" - else if (sym == defn.FloatClass) "F" - else if (sym == defn.DoubleClass) "D" - else if (sym == defn.CharClass) "C" - else "L" + javaSig(elemType) + ";" - ("[" * depth) + suffix - } - def javaSig(tpe: Type): String = tpe match { - case tpe: JavaArrayType => javaArraySig(tpe) - case _ => - // Take the flatten name of the class and the full package name - val pack = tpe.classSymbol.topLevelClass.owner - val packageName = if (pack == defn.EmptyPackageClass) "" else s"${pack.fullName}." - packageName + tpe.classSymbol.fullNameSeparated(FlatName).toString - } - - val sym = param.classSymbol - if (sym == defn.BooleanClass) classOf[Boolean] - else if (sym == defn.ByteClass) classOf[Byte] - else if (sym == defn.CharClass) classOf[Char] - else if (sym == defn.ShortClass) classOf[Short] - else if (sym == defn.IntClass) classOf[Int] - else if (sym == defn.LongClass) classOf[Long] - else if (sym == defn.FloatClass) classOf[Float] - else if (sym == defn.DoubleClass) classOf[Double] - else java.lang.Class.forName(javaSig(param), false, classLoader) - } - def getExtraParams(tp: Type): List[Type] = tp.widenDealias match { - case tp: AppliedType if defn.isContextFunctionType(tp) => - // Call context function type direct method - tp.args.init.map(arg => TypeErasure.erasure(arg)) ::: getExtraParams(tp.args.last) - case _ => Nil - } - val extraParams = getExtraParams(sym.info.finalResultType) - val allParams = TypeErasure.erasure(sym.info) match { - case meth: MethodType => meth.paramInfos ::: extraParams - case _ => extraParams - } - allParams.map(paramClass) - } - } - - - - /** Exception that stops interpretation if some issue is found */ - private class StopInterpretation(val msg: String, val pos: SrcPos) extends Exception - - object Call { - /** Matches an expression that is either a field access or an application - * It retruns a TermRef containing field accessed or a method reference and the arguments passed to it. - */ - def unapply(arg: Tree)(using Context): Option[(RefTree, List[List[Tree]])] = - Call0.unapply(arg).map((fn, args) => (fn, args.reverse)) - - private object Call0 { - def unapply(arg: Tree)(using Context): Option[(RefTree, List[List[Tree]])] = arg match { - case Select(Call0(fn, args), nme.apply) if defn.isContextFunctionType(fn.tpe.widenDealias.finalResultType) => - Some((fn, args)) - case fn: Ident => Some((tpd.desugarIdent(fn).withSpan(fn.span), Nil)) - case fn: Select => Some((fn, Nil)) - case Apply(f @ Call0(fn, args1), args2) => - if (f.tpe.widenDealias.isErasedMethod) Some((fn, args1)) - else Some((fn, args2 :: args1)) - case TypeApply(Call0(fn, args), _) => Some((fn, args)) - case _ => None - } - } } } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index fa8cbaf7ff5a..68061e97426d 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -67,6 +67,8 @@ object Typer { def assertPositioned(tree: untpd.Tree)(using Context): Unit = if (!tree.isEmpty && !tree.isInstanceOf[untpd.TypedSplice] && ctx.typerState.isGlobalCommittable) assert(tree.span.exists, i"position not set for $tree # ${tree.uniqueId} of ${tree.getClass} in ${tree.source}") + assert(1 == 1) + /** An attachment for GADT constraints that were inferred for a pattern. */ val InferredGadtConstraints = new Property.StickyKey[core.GadtConstraint] diff --git a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala index 30f4dc29aeea..bdc3413dc096 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala @@ -430,10 +430,10 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler object Ref extends RefModule: def term(tp: TermRef): Ref = - withDefaultPos(tpd.ref(tp).asInstanceOf[tpd.RefTree]) + withDefaultPos(withDefaultPos(tpd.ref(tp).asInstanceOf[tpd.RefTree])) def apply(sym: Symbol): Ref = assert(sym.isTerm) - withDefaultPos(tpd.ref(sym).asInstanceOf[tpd.RefTree]) + withDefaultPos(withDefaultPos(tpd.ref(sym).asInstanceOf[tpd.RefTree])) end Ref type Ident = tpd.Ident diff --git a/library/src/scala/annotation/MacroAnnotation.scala b/library/src/scala/annotation/MacroAnnotation.scala new file mode 100644 index 000000000000..0fd4e01457e3 --- /dev/null +++ b/library/src/scala/annotation/MacroAnnotation.scala @@ -0,0 +1,16 @@ +// TODO in which package should this class be located? +package scala +package annotation + +import scala.quoted._ +import scala.annotation.{StaticAnnotation, experimental} + +// @experimental +trait MacroAnnotation extends StaticAnnotation { + def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + quotes.reflect.report.errorAndAbort(tree.show, tree.pos) + def transformObject(using Quotes)(valTree: quotes.reflect.ValDef, classTree: quotes.reflect.TypeDef): List[quotes.reflect.Definition] = + quotes.reflect.report.errorAndAbort(classTree.show, classTree.pos) + def transformParam(using Quotes)(paramTree: quotes.reflect.Definition, ownerTree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + quotes.reflect.report.errorAndAbort(paramTree.show, paramTree.pos) +} diff --git a/project/MiMaFilters.scala b/project/MiMaFilters.scala index 846856adc2c8..bebd2001edca 100644 --- a/project/MiMaFilters.scala +++ b/project/MiMaFilters.scala @@ -28,7 +28,10 @@ object MiMaFilters { ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#SymbolMethods.typeRef"), ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#SymbolMethods.termRef"), ProblemFilters.exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule#TypeTreeModule.ref"), - + // Private to the compiler - needed for forward binary compatibility ProblemFilters.exclude[MissingClassProblem]("scala.annotation.since"), + + // Macro annotations + ProblemFilters.exclude[MissingClassProblem]("scala.quoted.annotation.MacroAnnotation"), ) } diff --git a/tests/neg-macros/annot-basic-accessIndirect/Macro_1.scala b/tests/neg-macros/annot-basic-accessIndirect/Macro_1.scala new file mode 100644 index 000000000000..a185b1702067 --- /dev/null +++ b/tests/neg-macros/annot-basic-accessIndirect/Macro_1.scala @@ -0,0 +1,13 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation +import scala.collection.mutable.Map + +@experimental +class hello extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + val helloSymbol = Symbol.newVal(tree.symbol.owner, "hello", TypeRepr.of[String], Flags.EmptyFlags, Symbol.noSymbol) + val helloVal = ValDef(helloSymbol, Some(Literal(StringConstant("Hello, World!")))) + List(helloVal, tree) +} diff --git a/tests/neg-macros/annot-basic-accessIndirect/Macro_2.scala b/tests/neg-macros/annot-basic-accessIndirect/Macro_2.scala new file mode 100644 index 000000000000..59982f72ab36 --- /dev/null +++ b/tests/neg-macros/annot-basic-accessIndirect/Macro_2.scala @@ -0,0 +1,20 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation +import scala.collection.mutable.Map + +@experimental +class foo extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + val s = '{@hello def foo1(x: Int): Int = x + 1;()}.asTerm + val fooDef = s.asInstanceOf[Inlined].body.asInstanceOf[Block].statements.head.asInstanceOf[DefDef] + val hello = Ref(tree.symbol.owner.declaredFields("hello").head).asExprOf[String] // error + tree match + case DefDef(name, params, tpt, Some(t)) => + val rhs = '{ + ${t.asExprOf[String]} + $hello + }.asTerm + val newDef = DefDef.copy(tree)(name, params, tpt, Some(rhs)) + List(fooDef, newDef) +} diff --git a/tests/neg-macros/annot-basic-accessIndirect/Test.scala b/tests/neg-macros/annot-basic-accessIndirect/Test.scala new file mode 100644 index 000000000000..6e2bbd3d3361 --- /dev/null +++ b/tests/neg-macros/annot-basic-accessIndirect/Test.scala @@ -0,0 +1,3 @@ +class Bar: + @foo def bar(x: String): String = x // error + bar("a") diff --git a/tests/pos-macros/annot-basic-overload/Macro.scala b/tests/pos-macros/annot-basic-overload/Macro.scala new file mode 100644 index 000000000000..2a51721c0e13 --- /dev/null +++ b/tests/pos-macros/annot-basic-overload/Macro.scala @@ -0,0 +1,17 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation + +@experimental +class add extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + tree match + case ClassDef(name, constr, parents, self, body) => + given Quotes = tree.symbol.asQuotes + val defSym = Symbol.newMethod(tree.symbol, "foo", + MethodType(List("n"))(_ => List(TypeTree.of[String].tpe), _ => TypeTree.of[String].tpe)) + val overloadedDef = DefDef(defSym, + { case List(List(n)) => Some('{${n.asExprOf[String]} + "1"}.asTerm)}) + List(ClassDef.copy(tree)(name, constr, parents, self, overloadedDef +: body)) +} diff --git a/tests/pos-macros/annot-basic-overload/Test.scala b/tests/pos-macros/annot-basic-overload/Test.scala new file mode 100644 index 000000000000..47fd00a474c8 --- /dev/null +++ b/tests/pos-macros/annot-basic-overload/Test.scala @@ -0,0 +1,4 @@ +@add +class A: + def foo(n: Int): Int = n + 1 + \ No newline at end of file diff --git a/tests/run-macros/annot-basic-changeClass/Macro.scala b/tests/run-macros/annot-basic-changeClass/Macro.scala new file mode 100644 index 000000000000..8f9088964ed1 --- /dev/null +++ b/tests/run-macros/annot-basic-changeClass/Macro.scala @@ -0,0 +1,21 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation + +@experimental +class change extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + tree match + case ClassDef(name, constr, parents, self, body) => + val newBody = body.map { + case stat @ DefDef("toString", paramss, tpt, Some(t)) => + given Quotes = tree.symbol.asQuotes + val rhs = '{ + ${t.asExprOf[String]} + " changed by macro annotation" + }.asTerm + DefDef.copy(stat)("toString", paramss, tpt, Some(rhs)) + case stat => stat + } + List(ClassDef.copy(tree)(name, constr, parents, self, newBody)) +} diff --git a/tests/run-macros/annot-basic-changeClass/Test.scala b/tests/run-macros/annot-basic-changeClass/Test.scala new file mode 100644 index 000000000000..735ecff59d5a --- /dev/null +++ b/tests/run-macros/annot-basic-changeClass/Test.scala @@ -0,0 +1,7 @@ +@change +class A: + override def toString = "It is A" + +@main def Test = + val a = new A + assert(a.toString == "It is A changed by macro annotation") diff --git a/tests/run-macros/annot-basic-changeVal/Macro.scala b/tests/run-macros/annot-basic-changeVal/Macro.scala new file mode 100644 index 000000000000..684af1b4e0e5 --- /dev/null +++ b/tests/run-macros/annot-basic-changeVal/Macro.scala @@ -0,0 +1,12 @@ +import scala.annotation.experimental +import scala.quoted.* +import scala.annotation.MacroAnnotation + +object ChangeVal: + @experimental + class change(i: Int) extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect.* + tree match + case ValDef(n, t, _) => List(ValDef.copy(tree)(n, t, Some(Literal(IntConstant(i))))) + } diff --git a/tests/run-macros/annot-basic-changeVal/Test.scala b/tests/run-macros/annot-basic-changeVal/Test.scala new file mode 100644 index 000000000000..6e0e44ad8885 --- /dev/null +++ b/tests/run-macros/annot-basic-changeVal/Test.scala @@ -0,0 +1,9 @@ +import ChangeVal._ + +class Bar: + @change(5) + val foo: Int = 3 + +@main def Test = + val t = new Bar + assert(t.foo == 5) diff --git a/tests/run-macros/annot-basic-classAddDef/Macro.scala b/tests/run-macros/annot-basic-classAddDef/Macro.scala new file mode 100644 index 000000000000..f6fa50d1f041 --- /dev/null +++ b/tests/run-macros/annot-basic-classAddDef/Macro.scala @@ -0,0 +1,23 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation + +@experimental +class change extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + tree match + case ClassDef(name, constr, parents, self, body) => + val changeSymbol = Symbol.newVal(tree.symbol, "changed", TypeRepr.of[String], Flags.EmptyFlags, Symbol.noSymbol) + val changeVal = ValDef(changeSymbol, Some(Literal(StringConstant(" changed by macro annotation")))) + val newBody = changeVal +: body.map { + case stat @ DefDef("toString", paramss, tpt, Some(t)) => + given Quotes = tree.symbol.asQuotes + val rhs = '{ + ${t.asExprOf[String]} + ${Ref(changeSymbol).asExprOf[String]} + }.asTerm + DefDef.copy(stat)("toString", paramss, tpt, Some(rhs)) + case stat => stat + } + List(ClassDef.copy(tree)(name, constr, parents, self, newBody)) +} diff --git a/tests/run-macros/annot-basic-classAddDef/Test.scala b/tests/run-macros/annot-basic-classAddDef/Test.scala new file mode 100644 index 000000000000..735ecff59d5a --- /dev/null +++ b/tests/run-macros/annot-basic-classAddDef/Test.scala @@ -0,0 +1,7 @@ +@change +class A: + override def toString = "It is A" + +@main def Test = + val a = new A + assert(a.toString == "It is A changed by macro annotation") diff --git a/tests/run-macros/annot-basic-classDef.check b/tests/run-macros/annot-basic-classDef.check new file mode 100644 index 000000000000..abb28452a53c --- /dev/null +++ b/tests/run-macros/annot-basic-classDef.check @@ -0,0 +1 @@ +Calling Check with 3 diff --git a/tests/run-macros/annot-basic-classDef/Macro.scala b/tests/run-macros/annot-basic-classDef/Macro.scala new file mode 100644 index 000000000000..30868a8acd59 --- /dev/null +++ b/tests/run-macros/annot-basic-classDef/Macro.scala @@ -0,0 +1,30 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation + +@experimental +class check extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + tree match + case DefDef(name, params, tpt, Some(fooTree)) => + val name: String = "myClass" + val parents = List(TypeTree.of[Object]) + def decls(cls: Symbol): List[Symbol] = + List(Symbol.newMethod(cls, "check", MethodType(List("x"))(_ => List(TypeRepr.of[Int]), _ => TypeRepr.of[Unit]))) + val cls = Symbol.newClass(Symbol.spliceOwner.owner, name, parents = parents.map(_.tpe), decls, selfType = None) + val checkSym = cls.declaredMethod("check").head + def checkRhs(args: List[List[Tree]]): Option[Term] = + val x = args.head.head.asExprOf[Int] + Some('{println(s"Calling Check with ${$x}")}.asTerm) + val checkDef = DefDef(checkSym, checkRhs) + val clsDef = ClassDef(cls, parents, body = List(checkDef)) + + val x = Ref(params.head.params.head.symbol) + val callCheck = Apply(Select.unique(Apply(Select(New(TypeIdent(cls)), cls.primaryConstructor), Nil), "check"), List(x)).asExprOf[Unit] + val rhs = '{ + $callCheck + ${fooTree.asExprOf[Int]} + }.asTerm + List(clsDef, DefDef.copy(tree)(name, params, tpt, Some(rhs))) +} diff --git a/tests/run-macros/annot-basic-classDef/Test.scala b/tests/run-macros/annot-basic-classDef/Test.scala new file mode 100644 index 000000000000..6f938b246928 --- /dev/null +++ b/tests/run-macros/annot-basic-classDef/Test.scala @@ -0,0 +1,6 @@ +class A: + @check + def foo(x: Int) = 1 + +@main def Test = + (new A).foo(3) diff --git a/tests/run-macros/annot-basic-companion.check b/tests/run-macros/annot-basic-companion.check new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/run-macros/annot-basic-complicated/Macro_1.scala b/tests/run-macros/annot-basic-complicated/Macro_1.scala new file mode 100644 index 000000000000..7a9b4bcb6902 --- /dev/null +++ b/tests/run-macros/annot-basic-complicated/Macro_1.scala @@ -0,0 +1,32 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation +import scala.collection.mutable.Map + +@experimental +class newS(s: String, up: Boolean) extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + val newSymbol = Symbol.newVal(tree.symbol.owner, s, TypeRepr.of[String], Flags.EmptyFlags, Symbol.noSymbol) + val newVal = ValDef(newSymbol, Some(Literal(StringConstant(s)))) + if up then List(newVal, tree) else List(tree, newVal) +} + +@experimental +class add(s: String) extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + val hello = Ref(tree.symbol.owner.declaredField(s)).asExprOf[String] + + tree match + case DefDef(name, params, tpt, Some(t)) => + val rhs = '{ + ${t.asExprOf[String]} + $hello + }.asTerm + List(DefDef.copy(tree)(name, params, tpt, Some(rhs))) + case ValDef(name, tpt, Some(t)) => + val rhs = '{ + ${t.asExprOf[String]} + $hello + }.asTerm + List(ValDef.copy(tree)(name, tpt, Some(rhs))) +} diff --git a/tests/run-macros/annot-basic-complicated/Macro_2.scala b/tests/run-macros/annot-basic-complicated/Macro_2.scala new file mode 100644 index 000000000000..19fad95a0628 --- /dev/null +++ b/tests/run-macros/annot-basic-complicated/Macro_2.scala @@ -0,0 +1,14 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation +import scala.collection.mutable.Map + +@experimental +class hello extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + given Quotes = tree.symbol.owner.asQuotes + val s = '{@add("world") @newS("world", true) val helloString = "hello";()}.asTerm + val newDef = s.asInstanceOf[Inlined].body.asInstanceOf[Block].statements.head.asInstanceOf[ValDef] + List(newDef, tree) +} diff --git a/tests/run-macros/annot-basic-complicated/Test_3.scala b/tests/run-macros/annot-basic-complicated/Test_3.scala new file mode 100644 index 000000000000..456d57fbc048 --- /dev/null +++ b/tests/run-macros/annot-basic-complicated/Test_3.scala @@ -0,0 +1,21 @@ + +// val world = "world" +// @new("world", true) @add("world") +// val hello = "Hello" + world +// @hello @new("good", false) @add("hello") val info = "A:" + hello +// val good = "good" + +// @add("good") +// def foo(x: String) = x + good + +class A: + @newS("good", false) @add("helloString") @hello + val info = "A:" + + @add("info") @add("good") + def foo(x: String) = x + +@main def Test: Unit = + val a = new A + assert(a.info == "A:helloworld") + assert(a.foo("good") == "goodgoodA:helloworld") diff --git a/tests/run-macros/annot-basic-fib.check b/tests/run-macros/annot-basic-fib.check new file mode 100644 index 000000000000..2f68352db392 --- /dev/null +++ b/tests/run-macros/annot-basic-fib.check @@ -0,0 +1,5 @@ +compute fib of 3 +compute fib of 2 +compute fib of 1 +compute fib of 0 +compute fib of 4 diff --git a/tests/run-macros/annot-basic-fib/Macro.scala b/tests/run-macros/annot-basic-fib/Macro.scala new file mode 100644 index 000000000000..0d9800573ce7 --- /dev/null +++ b/tests/run-macros/annot-basic-fib/Macro.scala @@ -0,0 +1,27 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation +import scala.collection.mutable.Map + +@experimental +class memoize extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + tree match + case DefDef(name, params, tpt, Some(fibTree)) => + val cacheRhs = '{Map.empty[Int, Int]}.asTerm + val cacheSymbol = Symbol.newVal(tree.symbol.owner, "fibCache", TypeRepr.of[Map[Int, Int]], Flags.EmptyFlags, Symbol.noSymbol) + val cacheVal = ValDef(cacheSymbol, Some(cacheRhs)) + val fibCache = Ref(cacheSymbol).asExprOf[Map[Int, Int]] + val n = Ref(params.head.params.head.symbol).asExprOf[Int] + val rhs = '{ + if $fibCache.contains($n) then + $fibCache($n) + else + val res = ${fibTree.asExprOf[Int]} + $fibCache($n) = res + res + }.asTerm + val newFib = DefDef.copy(tree)(name, params, tpt, Some(rhs)) + List(cacheVal, newFib) +} diff --git a/tests/run-macros/annot-basic-fib/Test.scala b/tests/run-macros/annot-basic-fib/Test.scala new file mode 100644 index 000000000000..b9ff8a3440de --- /dev/null +++ b/tests/run-macros/annot-basic-fib/Test.scala @@ -0,0 +1,13 @@ +import scala.collection.mutable.Map + +class Bar: + @memoize + def fib(n: Int): Int = + println(s"compute fib of $n") + if n <= 1 then n + else fib(n - 1) + fib(n - 2) + +@main def Test = + val t = new Bar + assert(t.fib(3) == 2) + assert(t.fib(4) == 3) diff --git a/tests/run-macros/annot-basic-gen2/Macro_1.scala b/tests/run-macros/annot-basic-gen2/Macro_1.scala new file mode 100644 index 000000000000..5b3d9d43fff8 --- /dev/null +++ b/tests/run-macros/annot-basic-gen2/Macro_1.scala @@ -0,0 +1,17 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation +import scala.collection.mutable.Map + +@experimental +class hello extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + tree match + case DefDef(name, params, tpt, Some(t)) => + val rhs = '{ + ${t.asExprOf[String]} + "hello" + }.asTerm + val newDef = DefDef.copy(tree)(name, params, tpt, Some(rhs)) + List(newDef) +} diff --git a/tests/run-macros/annot-basic-gen2/Macro_2.scala b/tests/run-macros/annot-basic-gen2/Macro_2.scala new file mode 100644 index 000000000000..d3e1b5ba43a2 --- /dev/null +++ b/tests/run-macros/annot-basic-gen2/Macro_2.scala @@ -0,0 +1,22 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation +import scala.collection.mutable.Map + +@experimental +class foo extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + tree match + case DefDef(name, params, tpt, Some(t)) => + val s = Ref(params.head.params.head.symbol).asExprOf[String] + val rhs = '{ + @hello def foo1(s: String): String = ${ + @hello def foo(s: String) = s + "a" + Expr(foo("a")) + } + foo1($s) + }.asTerm + val newDef = DefDef.copy(tree)(name, params, tpt, Some(rhs)) + List(newDef) +} diff --git a/tests/run-macros/annot-basic-gen2/Test_3.scala b/tests/run-macros/annot-basic-gen2/Test_3.scala new file mode 100644 index 000000000000..1a7ca80f6479 --- /dev/null +++ b/tests/run-macros/annot-basic-gen2/Test_3.scala @@ -0,0 +1,5 @@ +class Bar: + @foo def bar(s: String) = s + +@main def Test = + assert((new Bar).bar("bar") == "aahellohello") diff --git a/tests/run-macros/annot-basic-generate/Macro_1.scala b/tests/run-macros/annot-basic-generate/Macro_1.scala new file mode 100644 index 000000000000..a185b1702067 --- /dev/null +++ b/tests/run-macros/annot-basic-generate/Macro_1.scala @@ -0,0 +1,13 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation +import scala.collection.mutable.Map + +@experimental +class hello extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + val helloSymbol = Symbol.newVal(tree.symbol.owner, "hello", TypeRepr.of[String], Flags.EmptyFlags, Symbol.noSymbol) + val helloVal = ValDef(helloSymbol, Some(Literal(StringConstant("Hello, World!")))) + List(helloVal, tree) +} diff --git a/tests/run-macros/annot-basic-generate/Macro_2.scala b/tests/run-macros/annot-basic-generate/Macro_2.scala new file mode 100644 index 000000000000..0a348dcf4148 --- /dev/null +++ b/tests/run-macros/annot-basic-generate/Macro_2.scala @@ -0,0 +1,18 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation +import scala.collection.mutable.Map + +@experimental +class foo extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + tree match + case DefDef(name, params, tpt, Some(t)) => + val rhs = '{ + @hello def foo(x: Int): Int = x + 1 + ${t.asExprOf[Int]} + }.asTerm + val newDef = DefDef.copy(tree)(name, params, tpt, Some(rhs)) + List(newDef) +} diff --git a/tests/run-macros/annot-basic-generate/Test_3.scala b/tests/run-macros/annot-basic-generate/Test_3.scala new file mode 100644 index 000000000000..7077fd544111 --- /dev/null +++ b/tests/run-macros/annot-basic-generate/Test_3.scala @@ -0,0 +1,5 @@ +class Bar: + @foo def bar(x: Int) = x + 1 + +@main def Test = + assert((new Bar).bar(1) == 2) diff --git a/tests/run-macros/annot-basic-multiAnnot.check b/tests/run-macros/annot-basic-multiAnnot.check new file mode 100644 index 000000000000..8ab686eafeb1 --- /dev/null +++ b/tests/run-macros/annot-basic-multiAnnot.check @@ -0,0 +1 @@ +Hello, World! diff --git a/tests/run-macros/annot-basic-multiAnnot/Macro.scala b/tests/run-macros/annot-basic-multiAnnot/Macro.scala new file mode 100644 index 000000000000..bf968fa9a257 --- /dev/null +++ b/tests/run-macros/annot-basic-multiAnnot/Macro.scala @@ -0,0 +1,28 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation +import scala.collection.mutable.Map + +@experimental +class hello extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + val helloSymbol = Symbol.newVal(tree.symbol.owner, "hello", TypeRepr.of[String], Flags.EmptyFlags, Symbol.noSymbol) + val helloVal = ValDef(helloSymbol, Some(Literal(StringConstant("Hello, World!")))) + List(helloVal, tree) +} + +@experimental +class callHello extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + tree match + case DefDef(name, params, tpt, Some(t)) => + val hello = Ref(tree.symbol.owner.declaredField("hello")).asExprOf[String] + val rhs = '{ + println($hello) + ${t.asExprOf[Int]} + }.asTerm + val newDef = DefDef.copy(tree)(name, params, tpt, Some(rhs)) + List(newDef) +} diff --git a/tests/run-macros/annot-basic-multiAnnot/Test.scala b/tests/run-macros/annot-basic-multiAnnot/Test.scala new file mode 100644 index 000000000000..1613205ca5e8 --- /dev/null +++ b/tests/run-macros/annot-basic-multiAnnot/Test.scala @@ -0,0 +1,7 @@ +class Foo: + @callHello @hello + def bar = 3 + +@main def Test = + val foo = new Foo + assert(foo.bar == 3) diff --git a/tests/run-macros/annot-basic-nested/Macro.scala b/tests/run-macros/annot-basic-nested/Macro.scala new file mode 100644 index 000000000000..a185b1702067 --- /dev/null +++ b/tests/run-macros/annot-basic-nested/Macro.scala @@ -0,0 +1,13 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation +import scala.collection.mutable.Map + +@experimental +class hello extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + val helloSymbol = Symbol.newVal(tree.symbol.owner, "hello", TypeRepr.of[String], Flags.EmptyFlags, Symbol.noSymbol) + val helloVal = ValDef(helloSymbol, Some(Literal(StringConstant("Hello, World!")))) + List(helloVal, tree) +} diff --git a/tests/run-macros/annot-basic-nested/Test.scala b/tests/run-macros/annot-basic-nested/Test.scala new file mode 100644 index 000000000000..3e4df2397a5c --- /dev/null +++ b/tests/run-macros/annot-basic-nested/Test.scala @@ -0,0 +1,10 @@ +import scala.collection.mutable.Map + +class D: + @hello + class A: + @hello + class B + +@main def Test = + assert(2 == 2) diff --git a/tests/run-macros/annot-basic-object/Macro.scala b/tests/run-macros/annot-basic-object/Macro.scala new file mode 100644 index 000000000000..1fcab1f92f12 --- /dev/null +++ b/tests/run-macros/annot-basic-object/Macro.scala @@ -0,0 +1,27 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation +import scala.collection.mutable.Map + +@experimental +class hello extends MacroAnnotation { + override def transformObject(using Quotes)(valTree: quotes.reflect.ValDef, classTree: quotes.reflect.TypeDef): List[quotes.reflect.Definition] = + import quotes.reflect._ + val helloSymbol = Symbol.newVal(classTree.symbol.owner, "hello", TypeRepr.of[String], Flags.EmptyFlags, Symbol.noSymbol) + val helloVal = ValDef(helloSymbol, Some(Literal(StringConstant("Hello, World!")))) + List(helloVal, valTree, classTree) +} + +@experimental +class double extends MacroAnnotation { + override def transform(using Quotes)(tree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + tree match + case DefDef(name, params, tpt, Some(t)) => + val n = Ref(params.head.params.head.symbol).asExprOf[Int] + val rhs = '{ + $n * 2 + }.asTerm + val newDef = DefDef.copy(tree)(name, params, tpt, Some(rhs)) + List(newDef) +} diff --git a/tests/run-macros/annot-basic-object/Test.scala b/tests/run-macros/annot-basic-object/Test.scala new file mode 100644 index 000000000000..65d2dc9d2dd2 --- /dev/null +++ b/tests/run-macros/annot-basic-object/Test.scala @@ -0,0 +1,8 @@ +class Bar: + @hello + object Foo: + @double + def foo(x: Int) = x + 1 + +@main def Test = + assert((new Bar).Foo.foo(3) == 6) diff --git a/tests/run-macros/annot-basic-param/Macro.scala b/tests/run-macros/annot-basic-param/Macro.scala new file mode 100644 index 000000000000..9e6074d0daed --- /dev/null +++ b/tests/run-macros/annot-basic-param/Macro.scala @@ -0,0 +1,19 @@ +import scala.annotation.experimental +import scala.quoted._ +import scala.annotation.MacroAnnotation +import scala.collection.mutable.Map + +@experimental +class pos extends MacroAnnotation { + override def transformParam(using Quotes)(paramTree: quotes.reflect.Definition, ownerTree: quotes.reflect.Definition): List[quotes.reflect.Definition] = + import quotes.reflect._ + val n = Ref(paramTree.symbol).asExprOf[Int] + val rhs = '{ + // assert($n >= 0) + ${ownerTree.asInstanceOf[DefDef].rhs.get.asExprOf[List[Int]]} + }.asTerm + val newRepeat = ownerTree match + case DefDef(name, params, tpt, _) => + DefDef.copy(ownerTree)(name, params, tpt, Some(rhs)) + List(newRepeat) +} diff --git a/tests/run-macros/annot-basic-param/Test.scala b/tests/run-macros/annot-basic-param/Test.scala new file mode 100644 index 000000000000..e5b29b730710 --- /dev/null +++ b/tests/run-macros/annot-basic-param/Test.scala @@ -0,0 +1,7 @@ +class Foo: + def repeat(@pos n: Int, a: Int): List[Int] = + if (n == 0) Nil + else a :: repeat(n - 1, a) + +@main def Test = + assert((new Foo).repeat(3, 1) == List(1, 1, 1)) From 0a9ef97fdf1d5f676c26cdc5988b205376187f98 Mon Sep 17 00:00:00 2001 From: Ang9876 Date: Fri, 8 Jul 2022 22:39:07 +0200 Subject: [PATCH 2/4] fix unused import --- compiler/src/dotty/tools/dotc/transform/Interpreter.scala | 2 -- compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala | 2 -- 2 files changed, 4 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/Interpreter.scala b/compiler/src/dotty/tools/dotc/transform/Interpreter.scala index ccbd6c13942d..0557a93c84c7 100644 --- a/compiler/src/dotty/tools/dotc/transform/Interpreter.scala +++ b/compiler/src/dotty/tools/dotc/transform/Interpreter.scala @@ -20,7 +20,6 @@ import SymUtils._ import NameKinds._ import dotty.tools.dotc.ast.tpd import typer.Implicits.SearchFailureType -import typer.PrepareInlineable import SymDenotations.NoDenotation import scala.collection.mutable @@ -30,7 +29,6 @@ import dotty.tools.dotc.core.StdNames._ import dotty.tools.dotc.core.StagingContext._ import dotty.tools.dotc.quoted._ import dotty.tools.dotc.transform.TreeMapWithStages._ -import dotty.tools.dotc.typer.Inliner import dotty.tools.dotc.typer.ImportInfo.withRootImports import dotty.tools.dotc.ast.TreeMapWithImplicits diff --git a/compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala b/compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala index d2e3a12ea52f..2dad70446128 100644 --- a/compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala +++ b/compiler/src/dotty/tools/dotc/transform/MacroAnnotations.scala @@ -20,7 +20,6 @@ import SymUtils._ import NameKinds._ import dotty.tools.dotc.ast.tpd import typer.Implicits.SearchFailureType -import typer.PrepareInlineable import SymDenotations.NoDenotation import scala.collection.mutable @@ -30,7 +29,6 @@ import dotty.tools.dotc.core.StdNames._ import dotty.tools.dotc.core.StagingContext._ import dotty.tools.dotc.quoted._ import dotty.tools.dotc.transform.TreeMapWithStages._ -import dotty.tools.dotc.typer.Inliner import dotty.tools.dotc.typer.ImportInfo.withRootImports import dotty.tools.dotc.ast.TreeMapWithImplicits From f085bd0339facf89f0a899bcee29eab64a3a0b39 Mon Sep 17 00:00:00 2001 From: Ang9876 Date: Fri, 8 Jul 2022 23:05:10 +0200 Subject: [PATCH 3/4] add no denotation condition --- compiler/src/dotty/tools/dotc/CompilationUnit.scala | 2 +- compiler/src/dotty/tools/dotc/typer/Typer.scala | 2 -- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/CompilationUnit.scala b/compiler/src/dotty/tools/dotc/CompilationUnit.scala index 76439d7f5129..e50400686ac2 100644 --- a/compiler/src/dotty/tools/dotc/CompilationUnit.scala +++ b/compiler/src/dotty/tools/dotc/CompilationUnit.scala @@ -148,7 +148,7 @@ object CompilationUnit { if tree.symbol.is(Flags.Inline) then containsInline = true for annot <- tree.symbol.annotations do - if annot.tree.symbol.owner.derivesFrom(defn.QuotedMacroAnnotationClass) then + if annot.tree.symbol.denot != NoDenotation && annot.tree.symbol.owner.derivesFrom(defn.QuotedMacroAnnotationClass) then ctx.compilationUnit.hasMacroAnnotations = true traverseChildren(tree) } diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 68061e97426d..fa8cbaf7ff5a 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -67,8 +67,6 @@ object Typer { def assertPositioned(tree: untpd.Tree)(using Context): Unit = if (!tree.isEmpty && !tree.isInstanceOf[untpd.TypedSplice] && ctx.typerState.isGlobalCommittable) assert(tree.span.exists, i"position not set for $tree # ${tree.uniqueId} of ${tree.getClass} in ${tree.source}") - assert(1 == 1) - /** An attachment for GADT constraints that were inferred for a pattern. */ val InferredGadtConstraints = new Property.StickyKey[core.GadtConstraint] From 7567db5377fdc55b83e67ff6cf4f81dead4ead48 Mon Sep 17 00:00:00 2001 From: Ang9876 Date: Fri, 8 Jul 2022 23:11:25 +0200 Subject: [PATCH 4/4] add no denotation condition --- compiler/src/dotty/tools/dotc/CompilationUnit.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compiler/src/dotty/tools/dotc/CompilationUnit.scala b/compiler/src/dotty/tools/dotc/CompilationUnit.scala index e50400686ac2..23acd44774e0 100644 --- a/compiler/src/dotty/tools/dotc/CompilationUnit.scala +++ b/compiler/src/dotty/tools/dotc/CompilationUnit.scala @@ -3,7 +3,7 @@ package dotc import core._ import Contexts._ -import SymDenotations.ClassDenotation +import SymDenotations.{ClassDenotation, NoDenotation} import Symbols._ import util.{FreshNameCreator, SourceFile, NoSource} import util.Spans.Span