Skip to content

Place staged type captures in Quote AST #17424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/CompilationUnit.scala
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ object CompilationUnit {
if tree.symbol.is(Flags.Inline) then
containsInline = true
tree match
case tpd.Quote(_) =>
case _: tpd.Quote =>
containsQuote = true
case tree: tpd.Apply if tree.symbol == defn.QuotedTypeModule_of =>
containsQuote = true
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1976,7 +1976,7 @@ object desugar {
trees foreach collect
case Block(Nil, expr) =>
collect(expr)
case Quote(body) =>
case Quote(body, _) =>
new UntypedTreeTraverser {
def traverse(tree: untpd.Tree)(using Context): Unit = tree match {
case Splice(expr) => collect(expr)
Expand Down
21 changes: 13 additions & 8 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -690,9 +690,14 @@ object Trees {
* when type checking. TASTy files will not contain type quotes. Type quotes are used again
* in the `staging` phase to represent the reification of `Type.of[T]]`.
*
* Type tags `tags` are always empty before the `staging` phase. Tags for stage inconsistent
* types are added in the `staging` phase to level 0 quotes. Tags for types that refer to
* definitions in an outer quote are added in the `splicing` phase
*
* @param body The tree that was quoted
* @param tags Term references to instances of `Type[T]` for `T`s that are used in the quote
*/
case class Quote[+T <: Untyped] private[ast] (body: Tree[T])(implicit @constructorOnly src: SourceFile)
case class Quote[+T <: Untyped] private[ast] (body: Tree[T], tags: List[Tree[T]])(implicit @constructorOnly src: SourceFile)
extends TermTree[T] {
type ThisTree[+T <: Untyped] = Quote[T]

Expand Down Expand Up @@ -1313,9 +1318,9 @@ object Trees {
case tree: Inlined if (call eq tree.call) && (bindings eq tree.bindings) && (expansion eq tree.expansion) => tree
case _ => finalize(tree, untpd.Inlined(call, bindings, expansion)(sourceFile(tree)))
}
def Quote(tree: Tree)(body: Tree)(using Context): Quote = tree match {
case tree: Quote if (body eq tree.body) => tree
case _ => finalize(tree, untpd.Quote(body)(sourceFile(tree)))
def Quote(tree: Tree)(body: Tree, tags: List[Tree])(using Context): Quote = tree match {
case tree: Quote if (body eq tree.body) && (tags eq tree.tags) => tree
case _ => finalize(tree, untpd.Quote(body, tags)(sourceFile(tree)))
}
def Splice(tree: Tree)(expr: Tree)(using Context): Splice = tree match {
case tree: Splice if (expr eq tree.expr) => tree
Expand Down Expand Up @@ -1558,8 +1563,8 @@ object Trees {
case Thicket(trees) =>
val trees1 = transform(trees)
if (trees1 eq trees) tree else Thicket(trees1)
case tree @ Quote(body) =>
cpy.Quote(tree)(transform(body)(using quoteContext))
case Quote(body, tags) =>
cpy.Quote(tree)(transform(body)(using quoteContext), transform(tags))
case tree @ Splice(expr) =>
cpy.Splice(tree)(transform(expr)(using spliceContext))
case tree @ Hole(isTerm, idx, args, content, tpt) =>
Expand Down Expand Up @@ -1703,8 +1708,8 @@ object Trees {
this(this(x, arg), annot)
case Thicket(ts) =>
this(x, ts)
case Quote(body) =>
this(x, body)(using quoteContext)
case Quote(body, tags) =>
this(this(x, body)(using quoteContext), tags)
case Splice(expr) =>
this(x, expr)(using spliceContext)
case Hole(_, _, args, content, tpt) =>
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/ast/tpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
def Inlined(call: Tree, bindings: List[MemberDef], expansion: Tree)(using Context): Inlined =
ta.assignType(untpd.Inlined(call, bindings, expansion), bindings, expansion)

def Quote(body: Tree)(using Context): Quote =
untpd.Quote(body).withBodyType(body.tpe)
def Quote(body: Tree, tags: List[Tree])(using Context): Quote =
untpd.Quote(body, tags).withBodyType(body.tpe)

def Splice(expr: Tree, tpe: Type)(using Context): Splice =
untpd.Splice(expr).withType(tpe)
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/ast/untpd.scala
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,7 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
def SeqLiteral(elems: List[Tree], elemtpt: Tree)(implicit src: SourceFile): SeqLiteral = new SeqLiteral(elems, elemtpt)
def JavaSeqLiteral(elems: List[Tree], elemtpt: Tree)(implicit src: SourceFile): JavaSeqLiteral = new JavaSeqLiteral(elems, elemtpt)
def Inlined(call: tpd.Tree, bindings: List[MemberDef], expansion: Tree)(implicit src: SourceFile): Inlined = new Inlined(call, bindings, expansion)
def Quote(body: Tree)(implicit src: SourceFile): Quote = new Quote(body)
def Quote(body: Tree, tags: List[Tree])(implicit src: SourceFile): Quote = new Quote(body, tags)
def Splice(expr: Tree)(implicit src: SourceFile): Splice = new Splice(expr)
def TypeTree()(implicit src: SourceFile): TypeTree = new TypeTree()
def InferredTypeTree()(implicit src: SourceFile): TypeTree = new InferredTypeTree()
Expand Down
4 changes: 4 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Phases.scala
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ object Phases {
private var mySbtExtractDependenciesPhase: Phase = _
private var myPicklerPhase: Phase = _
private var myInliningPhase: Phase = _
private var myStagingPhase: Phase = _
private var mySplicingPhase: Phase = _
private var myFirstTransformPhase: Phase = _
private var myCollectNullableFieldsPhase: Phase = _
Expand All @@ -235,6 +236,7 @@ object Phases {
final def sbtExtractDependenciesPhase: Phase = mySbtExtractDependenciesPhase
final def picklerPhase: Phase = myPicklerPhase
final def inliningPhase: Phase = myInliningPhase
final def stagingPhase: Phase = myStagingPhase
final def splicingPhase: Phase = mySplicingPhase
final def firstTransformPhase: Phase = myFirstTransformPhase
final def collectNullableFieldsPhase: Phase = myCollectNullableFieldsPhase
Expand Down Expand Up @@ -262,6 +264,7 @@ object Phases {
mySbtExtractDependenciesPhase = phaseOfClass(classOf[sbt.ExtractDependencies])
myPicklerPhase = phaseOfClass(classOf[Pickler])
myInliningPhase = phaseOfClass(classOf[Inlining])
myStagingPhase = phaseOfClass(classOf[Staging])
mySplicingPhase = phaseOfClass(classOf[Splicing])
myFirstTransformPhase = phaseOfClass(classOf[FirstTransform])
myCollectNullableFieldsPhase = phaseOfClass(classOf[CollectNullableFields])
Expand Down Expand Up @@ -449,6 +452,7 @@ object Phases {
def sbtExtractDependenciesPhase(using Context): Phase = ctx.base.sbtExtractDependenciesPhase
def picklerPhase(using Context): Phase = ctx.base.picklerPhase
def inliningPhase(using Context): Phase = ctx.base.inliningPhase
def stagingPhase(using Context): Phase = ctx.base.stagingPhase
def splicingPhase(using Context): Phase = ctx.base.splicingPhase
def firstTransformPhase(using Context): Phase = ctx.base.firstTransformPhase
def refchecksPhase(using Context): Phase = ctx.base.refchecksPhase
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/tasty/TreePickler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -665,7 +665,7 @@ class TreePickler(pickler: TastyPickler) {
pickleTree(hi)
pickleTree(alias)
}
case tree @ Quote(body) =>
case tree @ Quote(body, Nil) =>
// TODO: Add QUOTE tag to TASTy
assert(body.isTerm,
"""Quote with type should not be pickled.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1269,7 +1269,7 @@ class TreeUnpickler(reader: TastyReader,

def quotedExpr(fn: Tree, args: List[Tree]): Tree =
val TypeApply(_, targs) = fn: @unchecked
untpd.Quote(args.head).withBodyType(targs.head.tpe)
untpd.Quote(args.head, Nil).withBodyType(targs.head.tpe)

def splicedExpr(fn: Tree, args: List[Tree]): Tree =
val TypeApply(_, targs) = fn: @unchecked
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/inlines/Inliner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,7 @@ class Inliner(val call: tpd.Tree)(using Context):

override def typedQuote(tree: untpd.Quote, pt: Type)(using Context): Tree =
super.typedQuote(tree, pt) match
case Quote(Splice(inner)) => inner
case Quote(Splice(inner), _) => inner
case tree1 =>
ctx.compilationUnit.needsStaging = true
tree1
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1243,7 +1243,7 @@ object Parsers {
}
}
in.nextToken()
Quote(t)
Quote(t, Nil)
}
else
if !in.featureEnabled(Feature.symbolLiterals) then
Expand Down Expand Up @@ -2480,7 +2480,7 @@ object Parsers {
val body =
if (in.token == LBRACKET) inBrackets(typ())
else stagedBlock()
Quote(body)
Quote(body, Nil)
}
}
case NEW =>
Expand Down
5 changes: 3 additions & 2 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -726,11 +726,12 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
"Thicket {" ~~ toTextGlobal(trees, "\n") ~~ "}"
case MacroTree(call) =>
keywordStr("macro ") ~ toTextGlobal(call)
case tree @ Quote(body) =>
case tree @ Quote(body, tags) =>
val tagsText = (keywordStr("<") ~ toTextGlobal(tags, ", ") ~ keywordStr(">")).provided(tree.tags.nonEmpty)
val exprTypeText = (keywordStr("[") ~ toTextGlobal(tree.bodyType) ~ keywordStr("]")).provided(printDebug && tree.typeOpt.exists)
val open = if (body.isTerm) keywordStr("{") else keywordStr("[")
val close = if (body.isTerm) keywordStr("}") else keywordStr("]")
keywordStr("'") ~ exprTypeText ~ open ~ toTextGlobal(body) ~ close
keywordStr("'") ~ tagsText ~ exprTypeText ~ open ~ toTextGlobal(body) ~ close
case Splice(expr) =>
val spliceTypeText = (keywordStr("[") ~ toTextGlobal(tree.typeOpt) ~ keywordStr("]")).provided(printDebug && tree.typeOpt.exists)
keywordStr("$") ~ spliceTypeText ~ keywordStr("{") ~ toTextGlobal(expr) ~ keywordStr("}")
Expand Down
80 changes: 40 additions & 40 deletions compiler/src/dotty/tools/dotc/staging/CrossStageSafety.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,36 @@ import dotty.tools.dotc.util.Property
import dotty.tools.dotc.util.Spans._
import dotty.tools.dotc.util.SrcPos

/** Checks that staging level consistency holds and heals staged types .
/** Checks that staging level consistency holds and heals staged types.
*
* Local term references are level consistent if and only if they are used at the same level as their definition.
*
* Local type references can be used at the level of their definition or lower. If used used at a higher level,
* it will be healed if possible, otherwise it is inconsistent.
*
* Type healing consists in transforming a level inconsistent type `T` into `summon[Type[T]].Underlying`.
* Healing a type consists in replacing locally defined types defined at staging level 0 and used in higher levels.
* For each type local `T` that is defined at level 0 and used in a quote, we summon a tag `t: Type[T]`. This `t`
* tag must be defined at level 0. The tags will be listed in the `tags` of the level 0 quote (`'<t>{ ... }`) and
* each reference to `T` will be replaced by `t.Underlying` in the body of the quote.
*
* We delay the healing of types in quotes at level 1 or higher until those quotes reach level 0. At this point
* more types will be statically known and fewer types will need to be healed. This also keeps the nested quotes
* in their original form, we do not want macro users to see any artifacts of this phase in quoted expressions
* they might inspect.
*
* Type heal example:
*
* As references to types do not necessarily have an associated tree it is not always possible to replace the types directly.
* Instead we always generate a type alias for it and place it at the start of the surrounding quote. This also avoids duplication.
* For example:
* '{
* val x: List[T] = List[T]()
* '{ .. T .. }
* ()
* }
*
* is transformed to
*
* '{
* type t$1 = summon[Type[T]].Underlying
* val x: List[t$1] = List[t$1]();
* '<t>{ // where `t` is a given term of type `Type[T]`
* val x: List[t.Underlying] = List[t.Underlying]();
* '{ .. t.Underlying .. }
* ()
* }
*
Expand All @@ -56,11 +64,18 @@ class CrossStageSafety extends TreeMapWithStages {
case tree: Quote =>
if (ctx.property(InAnnotation).isDefined)
report.error("Cannot have a quote in an annotation", tree.srcPos)
val body1 = transformQuoteBody(tree.body, tree.span)
val stripAnnotationsDeep: TypeMap = new TypeMap:
def apply(tp: Type): Type = mapOver(tp.stripAnnots)
val bodyType1 = healType(tree.srcPos)(stripAnnotationsDeep(tree.bodyType))
cpy.Quote(tree)(body1).withBodyType(bodyType1)

val tree1 =
val stripAnnotationsDeep: TypeMap = new TypeMap:
def apply(tp: Type): Type = mapOver(tp.stripAnnots)
val bodyType1 = healType(tree.srcPos)(stripAnnotationsDeep(tree.bodyType))
tree.withBodyType(bodyType1)

if level == 0 then
val (tags, body1) = inContextWithQuoteTypeTags { transform(tree1.body)(using quoteContext) }
cpy.Quote(tree1)(body1, tags)
else
super.transform(tree1)

case CancelledSplice(tree) =>
transform(tree) // Optimization: `${ 'x }` --> `x`
Expand All @@ -74,22 +89,18 @@ class CrossStageSafety extends TreeMapWithStages {
case tree @ QuotedTypeOf(body) =>
if (ctx.property(InAnnotation).isDefined)
report.error("Cannot have a quote in an annotation", tree.srcPos)
body.tpe match
case DirectTypeOf(termRef) =>
// Optimization: `quoted.Type.of[x.Underlying](quotes)` --> `x`
ref(termRef).withSpan(tree.span)
case _ =>
transformQuoteBody(body, tree.span) match
case DirectTypeOf.Healed(termRef) =>
// Optimization: `quoted.Type.of[@SplicedType type T = x.Underlying; T](quotes)` --> `x`
ref(termRef).withSpan(tree.span)
case transformedBody =>
val quotes = transform(tree.args.head)
// `quoted.Type.of[<body>](quotes)` --> `quoted.Type.of[<body2>](quotes)`
val TypeApply(fun, _) = tree.fun: @unchecked
if level != 0 then cpy.Apply(tree)(cpy.TypeApply(tree.fun)(fun, transformedBody :: Nil), quotes :: Nil)
else tpd.Quote(transformedBody).select(nme.apply).appliedTo(quotes).withSpan(tree.span)

if level == 0 then
val (tags, body1) = inContextWithQuoteTypeTags { transform(body)(using quoteContext) }
val quotes = transform(tree.args.head)
tags match
case tag :: Nil if body1.isType && body1.tpe =:= tag.tpe.select(tpnme.Underlying) =>
tag // Optimization: `quoted.Type.of[x.Underlying](quotes)` --> `x`
case _ =>
// `quoted.Type.of[<body>](<quotes>)` --> `'[<body1>].apply(<quotes>)`
tpd.Quote(body1, tags).select(nme.apply).appliedTo(quotes).withSpan(tree.span)
else
super.transform(tree)
case _: DefDef if tree.symbol.isInlineMethod =>
tree

Expand Down Expand Up @@ -137,17 +148,6 @@ class CrossStageSafety extends TreeMapWithStages {
super.transform(tree)
end transform

private def transformQuoteBody(body: Tree, span: Span)(using Context): Tree = {
val taggedTypes = new QuoteTypeTags(span)
val contextWithQuote =
if level == 0 then contextWithQuoteTypeTags(taggedTypes)(using quoteContext)
else quoteContext
val transformedBody = transform(body)(using contextWithQuote)
taggedTypes.getTypeTags match
case Nil => transformedBody
case tags => tpd.Block(tags, transformedBody).withSpan(body.span)
}

def transformTypeAnnotationSplices(tp: Type)(using Context) = new TypeMap {
def apply(tp: Type): Type = tp match
case tp: AnnotatedType =>
Expand Down Expand Up @@ -234,7 +234,7 @@ class CrossStageSafety extends TreeMapWithStages {
def unapply(tree: Splice): Option[Tree] =
def rec(tree: Tree): Option[Tree] = tree match
case Block(Nil, expr) => rec(expr)
case Quote(inner) => Some(inner)
case Quote(inner, _) => Some(inner)
case _ => None
rec(tree.expr)
}
25 changes: 0 additions & 25 deletions compiler/src/dotty/tools/dotc/staging/DirectTypeOf.scala

This file was deleted.

12 changes: 5 additions & 7 deletions compiler/src/dotty/tools/dotc/staging/HealType.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {
*
* If `T` is a reference to a type at the wrong level, try to heal it by replacing it with
* a type tag of type `quoted.Type[T]`.
* The tag is generated by an instance of `QuoteTypeTags` directly if the splice is explicit
* The tag is recorded by an instance of `QuoteTypeTags` directly if the splice is explicit
* or indirectly by `tryHeal`.
*/
def apply(tp: Type): Type =
Expand All @@ -43,11 +43,9 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {

private def healTypeRef(tp: TypeRef): Type =
tp.prefix match
case NoPrefix if tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) =>
tp
case prefix: TermRef if tp.symbol.isTypeSplice =>
checkNotWildcardSplice(tp)
if level == 0 then tp else getQuoteTypeTags.getTagRef(prefix)
if level == 0 then tp else getTagRef(prefix)
case _: NamedType | _: ThisType | NoPrefix =>
if levelInconsistentRootOfPath(tp).exists then
tryHeal(tp)
Expand All @@ -58,7 +56,7 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {

private object NonSpliceAlias:
def unapply(tp: TypeRef)(using Context): Option[Type] = tp.underlying match
case TypeAlias(alias) if !tp.symbol.isTypeSplice && !tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) => Some(alias)
case TypeAlias(alias) if !tp.symbol.isTypeSplice => Some(alias)
case _ => None

private def checkNotWildcardSplice(splice: TypeRef): Unit =
Expand All @@ -78,7 +76,7 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {

/** Try to heal reference to type `T` used in a higher level than its definition.
* Returns a reference to a type tag generated by `QuoteTypeTags` that contains a
* reference to a type alias containing the equivalent of `${summon[quoted.Type[T]]}`.
* reference to a type alias containing the equivalent of `${summon[quoted.Type[T]]}.Underlying`.
* Emits an error if `T` cannot be healed and returns `T`.
*/
protected def tryHeal(tp: TypeRef): Type = {
Expand All @@ -88,7 +86,7 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {
case tp: TermRef =>
ctx.typer.checkStable(tp, pos, "type witness")
if levelOf(tp.symbol) > 0 then tp.select(tpnme.Underlying)
else getQuoteTypeTags.getTagRef(tp)
else getTagRef(tp)
case _: SearchFailureType =>
report.error(
ctx.typer.missingArgMsg(tag, reqType, "")
Expand Down
Loading