-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Replace quoted type variables in signature of HOAS pattern result #16951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
package dotty.tools.dotc.util | ||
|
||
import scala.util.boundary | ||
|
||
/** Return type that indicates that the method returns a T or aborts to the enclosing boundary with a `None` */ | ||
type optional[T] = boundary.Label[None.type] ?=> T | ||
|
||
/** A prompt for `Option`, which establishes a boundary which `_.?` on `Option` can return */ | ||
object optional: | ||
inline def apply[T](inline body: optional[T]): Option[T] = | ||
boundary(Some(body)) | ||
|
||
extension [T](r: Option[T]) | ||
inline def ? (using label: boundary.Label[None.type]): T = r match | ||
case Some(x) => x | ||
case None => boundary.break(None) | ||
|
||
inline def break()(using label: boundary.Label[None.type]): Nothing = | ||
boundary.break(None) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,14 @@ | ||
package scala.quoted | ||
package runtime.impl | ||
|
||
|
||
import dotty.tools.dotc.ast.tpd | ||
import dotty.tools.dotc.core.Contexts.* | ||
import dotty.tools.dotc.core.Flags.* | ||
import dotty.tools.dotc.core.Names.* | ||
import dotty.tools.dotc.core.Types.* | ||
import dotty.tools.dotc.core.StdNames.nme | ||
import dotty.tools.dotc.core.Symbols.* | ||
import dotty.tools.dotc.util.optional | ||
|
||
/** Matches a quoted tree against a quoted pattern tree. | ||
* A quoted pattern tree may have type and term holes in addition to normal terms. | ||
|
@@ -103,12 +103,13 @@ import dotty.tools.dotc.core.Symbols.* | |
object QuoteMatcher { | ||
import tpd.* | ||
|
||
// TODO improve performance | ||
|
||
// TODO use flag from Context. Maybe -debug or add -debug-macros | ||
private inline val debug = false | ||
|
||
import Matching._ | ||
/** Sequence of matched expressions. | ||
* These expressions are part of the scrutinee and will be bound to the quote pattern term splices. | ||
*/ | ||
type MatchingExprs = Seq[MatchResult] | ||
|
||
/** A map relating equivalent symbols from the scrutinee and the pattern | ||
* For example in | ||
|
@@ -121,32 +122,34 @@ object QuoteMatcher { | |
|
||
private def withEnv[T](env: Env)(body: Env ?=> T): T = body(using env) | ||
|
||
def treeMatch(scrutineeTree: Tree, patternTree: Tree)(using Context): Option[Tuple] = | ||
def treeMatch(scrutineeTree: Tree, patternTree: Tree)(using Context): Option[MatchingExprs] = | ||
given Env = Map.empty | ||
scrutineeTree =?= patternTree | ||
optional: | ||
scrutineeTree =?= patternTree | ||
|
||
/** Check that all trees match with `mtch` and concatenate the results with &&& */ | ||
private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => Matching): Matching = (l1, l2) match { | ||
private def matchLists[T](l1: List[T], l2: List[T])(mtch: (T, T) => MatchingExprs): optional[MatchingExprs] = (l1, l2) match { | ||
case (x :: xs, y :: ys) => mtch(x, y) &&& matchLists(xs, ys)(mtch) | ||
case (Nil, Nil) => matched | ||
case _ => notMatched | ||
} | ||
|
||
extension (scrutinees: List[Tree]) | ||
private def =?= (patterns: List[Tree])(using Env, Context): Matching = | ||
private def =?= (patterns: List[Tree])(using Env, Context): optional[MatchingExprs] = | ||
matchLists(scrutinees, patterns)(_ =?= _) | ||
|
||
extension (scrutinee0: Tree) | ||
|
||
/** Check that the trees match and return the contents from the pattern holes. | ||
* Return None if the trees do not match otherwise return Some of a tuple containing all the contents in the holes. | ||
* Return a sequence containing all the contents in the holes. | ||
* If it does not match, continues to the `optional` with `None`. | ||
* | ||
* @param scrutinee The tree being matched | ||
* @param pattern The pattern tree that the scrutinee should match. Contains `patternHole` holes. | ||
* @param `summon[Env]` Set of tuples containing pairs of symbols (s, p) where s defines a symbol in `scrutinee` which corresponds to symbol p in `pattern`. | ||
* @return `None` if it did not match or `Some(tup: Tuple)` if it matched where `tup` contains the contents of the holes. | ||
* @return The sequence with the contents of the holes of the matched expression. | ||
*/ | ||
private def =?= (pattern0: Tree)(using Env, Context): Matching = | ||
private def =?= (pattern0: Tree)(using Env, Context): optional[MatchingExprs] = | ||
|
||
/* Match block flattening */ // TODO move to cases | ||
/** Normalize the tree */ | ||
|
@@ -203,31 +206,12 @@ object QuoteMatcher { | |
// Matches an open term and wraps it into a lambda that provides the free variables | ||
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil) | ||
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) => | ||
def hoasClosure = { | ||
val names: List[TermName] = args.map { | ||
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName | ||
case arg => arg.symbol.name.asTermName | ||
} | ||
val argTypes = args.map(x => x.tpe.widenTermRefExpr) | ||
val methTpe = MethodType(names)(_ => argTypes, _ => pattern.tpe) | ||
val meth = newAnonFun(ctx.owner, methTpe) | ||
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { | ||
val argsMap = args.map(_.symbol).zip(lambdaArgss.head).toMap | ||
val body = new TreeMap { | ||
override def transform(tree: Tree)(using Context): Tree = | ||
tree match | ||
case tree: Ident => summon[Env].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) | ||
case tree => super.transform(tree) | ||
}.transform(scrutinee) | ||
TreeOps(body).changeNonLocalOwners(meth) | ||
} | ||
Closure(meth, bodyFn) | ||
} | ||
val env = summon[Env] | ||
val capturedArgs = args.map(_.symbol) | ||
val captureEnv = summon[Env].filter((k, v) => !capturedArgs.contains(v)) | ||
val captureEnv = env.filter((k, v) => !capturedArgs.contains(v)) | ||
withEnv(captureEnv) { | ||
scrutinee match | ||
case ClosedPatternTerm(scrutinee) => matched(hoasClosure) | ||
case ClosedPatternTerm(scrutinee) => matchedOpen(scrutinee, pattern.tpe, args, env) | ||
case _ => notMatched | ||
} | ||
|
||
|
@@ -431,7 +415,6 @@ object QuoteMatcher { | |
case _ => scrutinee | ||
val pattern = patternTree.symbol | ||
|
||
|
||
devirtualizedScrutinee == pattern | ||
|| summon[Env].get(devirtualizedScrutinee).contains(pattern) | ||
|| devirtualizedScrutinee.allOverriddenSymbols.contains(pattern) | ||
|
@@ -452,32 +435,67 @@ object QuoteMatcher { | |
accumulator.apply(Set.empty, term) | ||
} | ||
|
||
/** Result of matching a part of an expression */ | ||
private type Matching = Option[Tuple] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The documentation comment for |
||
|
||
private object Matching { | ||
|
||
def notMatched: Matching = None | ||
|
||
val matched: Matching = Some(Tuple()) | ||
|
||
def matched(tree: Tree)(using Context): Matching = | ||
Some(Tuple1(new ExprImpl(tree, SpliceScope.getCurrent))) | ||
|
||
extension (self: Matching) | ||
def asOptionOfTuple: Option[Tuple] = self | ||
|
||
/** Concatenates the contents of two successful matchings or return a `notMatched` */ | ||
def &&& (that: => Matching): Matching = self match { | ||
case Some(x) => | ||
that match { | ||
case Some(y) => Some(x ++ y) | ||
case _ => None | ||
} | ||
case _ => None | ||
} | ||
end extension | ||
|
||
} | ||
enum MatchResult: | ||
/** Closed pattern extracted value | ||
* @param tree Scrutinee sub-tree that matched | ||
*/ | ||
case ClosedTree(tree: Tree) | ||
/** HOAS pattern extracted value | ||
* | ||
* @param tree Scrutinee sub-tree that matched | ||
* @param patternTpe Type of the pattern hole (from the pattern) | ||
* @param args HOAS arguments (from the pattern) | ||
* @param env Mapping between scrutinee and pattern variables | ||
*/ | ||
case OpenTree(tree: Tree, patternTpe: Type, args: List[Tree], env: Env) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Woudl be nice to document the expectation on the valid values of the arguments here. |
||
|
||
/** Return the expression that was extracted from a hole. | ||
* | ||
* If it was a closed expression it returns that expression. Otherwise, | ||
* if it is a HOAS pattern, the surrounding lambda is generated using | ||
* `mapTypeHoles` to create the signature of the lambda. | ||
* | ||
* This expression is assumed to be a valid expression in the given splice scope. | ||
*/ | ||
def toExpr(mapTypeHoles: TypeMap, spliceScope: Scope)(using Context): Expr[Any] = this match | ||
case MatchResult.ClosedTree(tree) => | ||
new ExprImpl(tree, spliceScope) | ||
case MatchResult.OpenTree(tree, patternTpe, args, env) => | ||
val names: List[TermName] = args.map { | ||
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName | ||
case arg => arg.symbol.name.asTermName | ||
} | ||
val paramTypes = args.map(x => mapTypeHoles(x.tpe.widenTermRefExpr)) | ||
val methTpe = MethodType(names)(_ => paramTypes, _ => mapTypeHoles(patternTpe)) | ||
val meth = newAnonFun(ctx.owner, methTpe) | ||
def bodyFn(lambdaArgss: List[List[Tree]]): Tree = { | ||
val argsMap = args.view.map(_.symbol).zip(lambdaArgss.head).toMap | ||
val body = new TreeMap { | ||
override def transform(tree: Tree)(using Context): Tree = | ||
tree match | ||
case tree: Ident => env.get(tree.symbol).flatMap(argsMap.get).getOrElse(tree) | ||
case tree => super.transform(tree) | ||
}.transform(tree) | ||
TreeOps(body).changeNonLocalOwners(meth) | ||
} | ||
val hoasClosure = Closure(meth, bodyFn) | ||
new ExprImpl(hoasClosure, spliceScope) | ||
|
||
private inline def notMatched: optional[MatchingExprs] = | ||
optional.break() | ||
|
||
private inline def matched: MatchingExprs = | ||
Seq.empty | ||
|
||
private inline def matched(tree: Tree)(using Context): MatchingExprs = | ||
Seq(MatchResult.ClosedTree(tree)) | ||
|
||
private def matchedOpen(tree: Tree, patternTpe: Type, args: List[Tree], env: Env)(using Context): MatchingExprs = | ||
Seq(MatchResult.OpenTree(tree, patternTpe, args, env)) | ||
|
||
extension (self: MatchingExprs) | ||
/** Concatenates the contents of two successful matchings */ | ||
def &&& (that: MatchingExprs): MatchingExprs = self ++ that | ||
end extension | ||
|
||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3137,20 +3137,27 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler | |
ctx1.gadtState.addToConstraint(typeHoles) | ||
ctx1 | ||
|
||
val matchings = QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1) | ||
|
||
if typeHoles.isEmpty then matchings | ||
else { | ||
// After matching and doing all subtype checks, we have to approximate all the type bindings | ||
// that we have found, seal them in a quoted.Type and add them to the result | ||
def typeHoleApproximation(sym: Symbol) = | ||
val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot) | ||
val fullBounds = ctx1.gadt.fullBounds(sym) | ||
val tp = if fromAboveAnnot then fullBounds.hi else fullBounds.lo | ||
reflect.TypeReprMethods.asType(tp) | ||
matchings.map { tup => | ||
Tuple.fromIArray(typeHoles.map(typeHoleApproximation).toArray.asInstanceOf[IArray[Object]]) ++ tup | ||
// After matching and doing all subtype checks, we have to approximate all the type bindings | ||
// that we have found, seal them in a quoted.Type and add them to the result | ||
def typeHoleApproximation(sym: Symbol) = | ||
val fromAboveAnnot = sym.hasAnnotation(dotc.core.Symbols.defn.QuotedRuntimePatterns_fromAboveAnnot) | ||
val fullBounds = ctx1.gadt.fullBounds(sym) | ||
if fromAboveAnnot then fullBounds.hi else fullBounds.lo | ||
|
||
QuoteMatcher.treeMatch(scrutinee, pat1)(using ctx1).map { matchings => | ||
import QuoteMatcher.MatchResult.* | ||
lazy val spliceScope = SpliceScope.getCurrent | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I intended to use pass it to |
||
val typeHoleApproximations = typeHoles.map(typeHoleApproximation) | ||
val typeHoleMapping = Map(typeHoles.zip(typeHoleApproximations)*) | ||
val typeHoleMap = new Types.TypeMap { | ||
def apply(tp: Types.Type): Types.Type = tp match | ||
case Types.TypeRef(Types.NoPrefix, _) => typeHoleMapping.getOrElse(tp.typeSymbol, tp) | ||
case _ => mapOver(tp) | ||
} | ||
val matchedExprs = matchings.map(_.toExpr(typeHoleMap, spliceScope)) | ||
val matchedTypes = typeHoleApproximations.map(reflect.TypeReprMethods.asType) | ||
val results = matchedTypes ++ matchedExprs | ||
Tuple.fromIArray(IArray.unsafeFromArray(results.toArray)) | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import scala.quoted.* | ||
|
||
inline def valToFun[T](inline expr: T): T = | ||
${ impl('expr) } | ||
|
||
def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] = | ||
expr match | ||
case '{ { val ident = ($a: α); $rest(ident): T } } => | ||
'{ { (y: α) => $rest(y) }.apply(???) } |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
def test = valToFun { | ||
val a: Int = 1 | ||
a + 1 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
import scala.quoted.* | ||
|
||
inline def valToFun[T](inline expr: T): T = | ||
${ impl('expr) } | ||
|
||
def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] = | ||
expr match | ||
case '{ { val ident = ($a: α); $rest(ident): T } } => | ||
'{ | ||
{ (y: α) => | ||
${ | ||
val bound = '{ ${ rest }(y) } | ||
Expr.betaReduce(bound) | ||
} | ||
}.apply($a) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
def test = valToFun { | ||
val a: Int = 1 | ||
a + 1 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import scala.quoted.* | ||
|
||
inline def valToFun[T](inline expr: T): T = | ||
${ impl('expr) } | ||
|
||
def impl[T: Type](expr: Expr[T])(using quotes: Quotes): Expr[T] = | ||
expr match | ||
case '{ type α; { val ident = ($a: `α`); $rest(ident): `α` & T } } => | ||
'{ { (y: α) => $rest(y) }.apply(???) } |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
def test = valToFun { | ||
val a: Int = 1 | ||
a + 1 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like the
summon[Env]
two lines below could becomeenv
with this addition.