Skip to content

Replace quoted type variables in signature of HOAS pattern result #16951

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions compiler/src/dotty/tools/dotc/util/optional.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package dotty.tools.dotc.util

import scala.util.boundary

/** Return type that indicates that the method returns a T or aborts to the enclosing boundary with a `None` */
type optional[T] = boundary.Label[None.type] ?=> T

/** A prompt for `Option`, which establishes a boundary which `_.?` on `Option` can return */
object optional:
inline def apply[T](inline body: optional[T]): Option[T] =
boundary(Some(body))

extension [T](r: Option[T])
inline def ? (using label: boundary.Label[None.type]): T = r match
case Some(x) => x
case None => boundary.break(None)

inline def break()(using label: boundary.Label[None.type]): Nothing =
boundary.break(None)
140 changes: 79 additions & 61 deletions compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
package scala.quoted
package runtime.impl


import dotty.tools.dotc.ast.tpd
import dotty.tools.dotc.core.Contexts.*
import dotty.tools.dotc.core.Flags.*
import dotty.tools.dotc.core.Names.*
import dotty.tools.dotc.core.Types.*
import dotty.tools.dotc.core.StdNames.nme
import dotty.tools.dotc.core.Symbols.*
import dotty.tools.dotc.util.optional

/** Matches a quoted tree against a quoted pattern tree.
* A quoted pattern tree may have type and term holes in addition to normal terms.
Expand Down Expand Up @@ -103,12 +103,13 @@ import dotty.tools.dotc.core.Symbols.*
object QuoteMatcher {
import tpd.*

// TODO improve performance

// TODO use flag from Context. Maybe -debug or add -debug-macros
private inline val debug = false

import Matching._
/** Sequence of matched expressions.
* These expressions are part of the scrutinee and will be bound to the quote pattern term splices.
*/
type MatchingExprs = Seq[MatchResult]

/** A map relating equivalent symbols from the scrutinee and the pattern
* For example in
Expand All @@ -121,32 +122,34 @@ object QuoteMatcher {

private def withEnv[T](env: Env)(body: Env ?=> T): T = body(using env)

def treeMatch(scrutineeTree: Tree, patternTree: Tree)(using Context): Option[Tuple] =
def treeMatch(scrutineeTree: Tree, patternTree: Tree)(using Context): Option[MatchingExprs] =
given Env = Map.empty
scrutineeTree =?= patternTree
optional:
scrutineeTree =?= patternTree

/** Check that all trees match with `mtch` and concatenate the results with &&& */
private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => Matching): Matching = (l1, l2) match {
private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => MatchingExprs): optional[MatchingExprs] = (l1, l2) match {
case (x :: xs, y :: ys) => mtch(x, y) &&& matchLists(xs, ys)(mtch)
case (Nil, Nil) => matched
case _ => notMatched
}

extension (scrutinees: List[Tree])
private def =?= (patterns: List[Tree])(using Env, Context): Matching =
private def =?= (patterns: List[Tree])(using Env, Context): optional[MatchingExprs] =
matchLists(scrutinees, patterns)(_ =?= _)

extension (scrutinee0: Tree)

/** Check that the trees match and return the contents from the pattern holes.
* Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes.
* Return a sequence containing all the contents in the holes.
* If it does not match, continues to the `optional` with `None`.
*
* @param scrutinee The tree being matched
* @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes.
* @param `summon[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`.
* @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes.
* @return The sequence with the contents of the holes of the matched expression.
*/
private def =?= (pattern0: Tree)(using Env, Context): Matching =
private def =?= (pattern0: Tree)(using Env, Context): optional[MatchingExprs] =

/* Match block flattening */ // TODO move to cases
/** Normalize the tree */
Expand Down Expand Up @@ -203,31 +206,12 @@ object QuoteMatcher {
// Matches an open term and wraps it into a lambda that provides the free variables
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil)
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) =>
def hoasClosure = {
val names: List[TermName] = args.map {
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
case arg => arg.symbol.name.asTermName
}
val argTypes = args.map(x => x.tpe.widenTermRefExpr)
val methTpe = MethodType(names)(_ => argTypes, _ => pattern.tpe)
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}.transform(scrutinee)
TreeOps(body).changeNonLocalOwners(meth)
}
Closure(meth, bodyFn)
}
val env = summon[Env]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like the summon[Env] two lines below could become env with this addition.

val capturedArgs = args.map(_.symbol)
val captureEnv = summon[Env].filter((k, v) => !capturedArgs.contains(v))
val captureEnv = env.filter((k, v) => !capturedArgs.contains(v))
withEnv(captureEnv) {
scrutinee match
case ClosedPatternTerm(scrutinee) => matched(hoasClosure)
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env)
case _ => notMatched
}

Expand Down Expand Up @@ -431,7 +415,6 @@ object QuoteMatcher {
case _ => scrutinee
val pattern = patternTree.symbol


devirtualizedScrutinee == pattern
|| summon[Env].get(devirtualizedScrutinee).contains(pattern)
|| devirtualizedScrutinee.allOverriddenSymbols.contains(pattern)
Expand All @@ -452,32 +435,67 @@ object QuoteMatcher {
accumulator.apply(Set.empty, term)
}

/** Result of matching a part of an expression */
private type Matching = Option[Tuple]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The documentation comment for =?= still mention that we return a Some(tup: Tuple). That should be updated, but I also think that the meaning of None vs Some(Seq()) vs Some(...) should be explained in a documentation comment for type Matching itself.


private object Matching {

def notMatched: Matching = None

val matched: Matching = Some(Tuple())

def matched(tree: Tree)(using Context): Matching =
Some(Tuple1(new ExprImpl(tree, SpliceScope.getCurrent)))

extension (self: Matching)
def asOptionOfTuple: Option[Tuple] = self

/** Concatenates the contents of two successful matchings or return a `notMatched` */
def &&& (that: => Matching): Matching = self match {
case Some(x) =>
that match {
case Some(y) => Some(x ++ y)
case _ => None
}
case _ => None
}
end extension

}
enum MatchResult:
/** Closed pattern extracted value
* @param tree Scrutinee sub-tree that matched
*/
case ClosedTree(tree: Tree)
/** HOAS pattern extracted value
*
* @param tree Scrutinee sub-tree that matched
* @param patternTpe Type of the pattern hole (from the pattern)
* @param args HOAS arguments (from the pattern)
* @param env Mapping between scrutinee and pattern variables
*/
case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Woudl be nice to document the expectation on the valid values of the arguments here.


/** Return the expression that was extracted from a hole.
*
* If it was a closed expression it returns that expression. Otherwise,
* if it is a HOAS pattern, the surrounding lambda is generated using
* `mapTypeHoles` to create the signature of the lambda.
*
* This expression is assumed to be a valid expression in the given splice scope.
*/
def toExpr(mapTypeHoles: TypeMap, spliceScope: Scope)(using Context): Expr[Any] = this match
case MatchResult.ClosedTree(tree) =>
new ExprImpl(tree, spliceScope)
case MatchResult.OpenTree(tree, patternTpe, args, env) =>
val names: List[TermName] = args.map {
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
case arg => arg.symbol.name.asTermName
}
val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr))
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe))
val meth = newAnonFun(ctx.owner, methTpe)
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = {
val argsMap = args.view.map(_.symbol).zip(lambdaArgss.head).toMap
val body = new TreeMap {
override def transform(tree: Tree)(using Context): Tree =
tree match
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
case tree => super.transform(tree)
}.transform(tree)
TreeOps(body).changeNonLocalOwners(meth)
}
val hoasClosure = Closure(meth, bodyFn)
new ExprImpl(hoasClosure, spliceScope)

private inline def notMatched: optional[MatchingExprs] =
optional.break()

private inline def matched: MatchingExprs =
Seq.empty

private inline def matched(tree: Tree)(using Context): MatchingExprs =
Seq(MatchResult.ClosedTree(tree))

private def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): MatchingExprs =
Seq(MatchResult.OpenTree(tree, patternTpe, args, env))

extension (self: MatchingExprs)
/** Concatenates the contents of two successful matchings */
def &&& (that: MatchingExprs): MatchingExprs = self ++ that
end extension

}
33 changes: 20 additions & 13 deletions compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3137,20 +3137,27 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler
ctx1.gadtState.addToConstraint(typeHoles)
ctx1

val matchings = QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1)

if typeHoles.isEmpty then matchings
else {
// After matching and doing all subtype checks, we have to approximate all the type bindings
// that we have found, seal them in a quoted.Type and add them to the result
def typeHoleApproximation(sym: Symbol) =
val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot)
val fullBounds = ctx1.gadt.fullBounds(sym)
val tp = if fromAboveAnnot then fullBounds.hi else fullBounds.lo
reflect.TypeReprMethods.asType(tp)
matchings.map { tup =>
Tuple.fromIArray(typeHoles.map(typeHoleApproximation).toArray.asInstanceOf[IArray[Object]]) ++ tup
// After matching and doing all subtype checks, we have to approximate all the type bindings
// that we have found, seal them in a quoted.Type and add them to the result
def typeHoleApproximation(sym: Symbol) =
val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot)
val fullBounds = ctx1.gadt.fullBounds(sym)
if fromAboveAnnot then fullBounds.hi else fullBounds.lo

QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1).map { matchings =>
import QuoteMatcher.MatchResult.*
lazy val spliceScope = SpliceScope.getCurrent
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

spliceScope appears to be unused?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I intended to use pass it to toExpr. I will change that.

val typeHoleApproximations = typeHoles.map(typeHoleApproximation)
val typeHoleMapping = Map(typeHoles.zip(typeHoleApproximations)*)
val typeHoleMap = new Types.TypeMap {
def apply(tp: Types.Type): Types.Type = tp match
case Types.TypeRef(Types.NoPrefix, _) => typeHoleMapping.getOrElse(tp.typeSymbol, tp)
case _ => mapOver(tp)
}
val matchedExprs = matchings.map(_.toExpr(typeHoleMap, spliceScope))
val matchedTypes = typeHoleApproximations.map(reflect.TypeReprMethods.asType)
val results = matchedTypes ++ matchedExprs
Tuple.fromIArray(IArray.unsafeFromArray(results.toArray))
}
}

Expand Down
9 changes: 9 additions & 0 deletions tests/pos-macros/i15165a/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.quoted.*

inline def valToFun[T](inline expr: T): T =
${ impl('expr) }

def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
expr match
case '{ { val ident = ($a: α); $rest(ident): T } } =>
'{ { (y: α) => $rest(y) }.apply(???) }
4 changes: 4 additions & 0 deletions tests/pos-macros/i15165a/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = valToFun {
val a: Int = 1
a + 1
}
16 changes: 16 additions & 0 deletions tests/pos-macros/i15165b/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import scala.quoted.*

inline def valToFun[T](inline expr: T): T =
${ impl('expr) }

def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
expr match
case '{ { val ident = ($a: α); $rest(ident): T } } =>
'{
{ (y: α) =>
${
val bound = '{ ${ rest }(y) }
Expr.betaReduce(bound)
}
}.apply($a)
}
4 changes: 4 additions & 0 deletions tests/pos-macros/i15165b/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = valToFun {
val a: Int = 1
a + 1
}
9 changes: 9 additions & 0 deletions tests/pos-macros/i15165c/Macro_1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import scala.quoted.*

inline def valToFun[T](inline expr: T): T =
${ impl('expr) }

def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] =
expr match
case '{ type α; { val ident = ($a: `α`); $rest(ident): `α` & T } } =>
'{ { (y: α) => $rest(y) }.apply(???) }
4 changes: 4 additions & 0 deletions tests/pos-macros/i15165c/Test_2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
def test = valToFun {
val a: Int = 1
a + 1
}