Skip to content

Improve quote matcher performance #12418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
318 changes: 182 additions & 136 deletions compiler/src/scala/quoted/runtime/impl/QuoteMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -181,26 +181,30 @@ object QuoteMatcher {
case _ => None
end TypeTreeTypeTest

(scrutinee, pattern) match
val res = pattern match

/* Term hole */
// Match a scala.internal.Quoted.patternHole typed as a repeated argument and return the scrutinee tree
case (scrutinee @ Typed(s, tpt1), Typed(TypeApply(patternHole, tpt :: Nil), tpt2))
case Typed(TypeApply(patternHole, tpt :: Nil), tpt2)
if patternHole.symbol.eq(defn.QuotedRuntimePatterns_patternHole) &&
s.tpe <:< tpt.tpe &&
tpt2.tpe.derivesFrom(defn.RepeatedParamClass) =>
matched(scrutinee)
tpt2.tpe.derivesFrom(defn.RepeatedParamClass) =>
scrutinee match
case Typed(s, tpt1) if s.tpe <:< tpt.tpe => matched(scrutinee)
case _ => notMatched

/* Term hole */
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
case (ClosedPatternTerm(scrutinee), TypeApply(patternHole, tpt :: Nil))
case TypeApply(patternHole, tpt :: Nil)
if patternHole.symbol.eq(defn.QuotedRuntimePatterns_patternHole) &&
scrutinee.tpe <:< tpt.tpe =>
matched(scrutinee)
scrutinee match
case ClosedPatternTerm(scrutinee) => matched(scrutinee)
case _ => notMatched


/* Higher order term hole */
// Matches an open term and wraps it into a lambda that provides the free variables
case (scrutinee, pattern @ Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil))
case Apply(TypeApply(Ident(_), List(TypeTree())), SeqLiteral(args, _) :: Nil)
if pattern.symbol.eq(defn.QuotedRuntimePatterns_higherOrderHole) =>
val names: List[TermName] = args.map {
case Block(List(DefDef(nme.ANON_FUN, _, _, Apply(Ident(name), _))), _) => name.asTermName
Expand All @@ -221,138 +225,180 @@ object QuoteMatcher {
}
matched(Closure(meth, bodyFn))

//
// Match two equivalent trees
//

/* Match literal */
case (Literal(constant1), Literal(constant2)) if constant1 == constant2 =>
matched

/* Match type ascription (a) */
case (Typed(expr1, _), pattern) =>
expr1 =?= pattern

/* Match type ascription (b) */
case (scrutinee, Typed(expr2, _)) =>
case Typed(expr2, _) =>
scrutinee =?= expr2

/* Match selection */
case (ref: RefTree, Select(qual2, _)) if symbolMatch(scrutinee, pattern) =>
ref match
case Select(qual1, _) => qual1 =?= qual2
case ref: Ident =>
ref.tpe match
case TermRef(qual: TermRef, _) => tpd.ref(qual) =?= qual2
case _ => matched

/* Match reference */
case (_: RefTree, _: Ident) if symbolMatch(scrutinee, pattern) =>
matched

/* Match application */
case (Apply(fn1, args1), Apply(fn2, args2)) =>
fn1 =?= fn2 &&& args1 =?= args2

/* Match type application */
case (TypeApply(fn1, args1), TypeApply(fn2, args2)) =>
fn1 =?= fn2 &&& args1 =?= args2

/* Match block */
case (Block(stat1 :: stats1, expr1), Block(stat2 :: stats2, expr2)) =>
val newEnv = (stat1, stat2) match {
case (stat1: MemberDef, stat2: MemberDef) =>
summon[Env] + (stat1.symbol -> stat2.symbol)
case _ =>
summon[Env]
}
withEnv(newEnv) {
stat1 =?= stat2 &&& Block(stats1, expr1) =?= Block(stats2, expr2)
}

/* Match if */
case (If(cond1, thenp1, elsep1), If(cond2, thenp2, elsep2)) =>
cond1 =?= cond2 &&& thenp1 =?= thenp2 &&& elsep1 =?= elsep2

/* Match while */
case (WhileDo(cond1, body1), WhileDo(cond2, body2)) =>
cond1 =?= cond2 &&& body1 =?= body2

/* Match assign */
case (Assign(lhs1, rhs1), Assign(lhs2, rhs2)) =>
lhs1 =?= lhs2 &&& rhs1 =?= rhs2

/* Match new */
case (New(tpt1), New(tpt2)) if tpt1.tpe.typeSymbol == tpt2.tpe.typeSymbol =>
matched

/* Match this */
case (This(_), This(_)) if scrutinee.symbol == pattern.symbol =>
matched

/* Match super */
case (Super(qual1, mix1), Super(qual2, mix2)) if mix1 == mix2 =>
qual1 =?= qual2

/* Match varargs */
case (SeqLiteral(elems1, _), SeqLiteral(elems2, _)) if elems1.size == elems2.size =>
elems1 =?= elems2

/* Match type */
// TODO remove this?
case (TypeTreeTypeTest(scrutinee), TypeTreeTypeTest(pattern)) if scrutinee.tpe <:< pattern.tpe =>
matched

/* Match val */
case (scrutinee @ ValDef(_, tpt1, _), pattern @ ValDef(_, tpt2, _)) if checkValFlags() =>
def rhsEnv = summon[Env] + (scrutinee.symbol -> pattern.symbol)
tpt1 =?= tpt2 &&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs)

/* Match def */
case (scrutinee @ DefDef(_, paramss1, tpt1, _), pattern @ DefDef(_, paramss2, tpt2, _)) =>
def rhsEnv: Env =
val paramSyms: List[(Symbol, Symbol)] =
for
(clause1, clause2) <- paramss1.zip(paramss2)
(param1, param2) <- clause1.zip(clause2)
yield
param1.symbol -> param2.symbol
val oldEnv: Env = summon[Env]
val newEnv: List[(Symbol, Symbol)] = (scrutinee.symbol -> pattern.symbol) :: paramSyms
oldEnv ++ newEnv

matchLists(paramss1, paramss2)(_ =?= _)
&&& tpt1 =?= tpt2
&&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs)

case (Closure(_, _, tpt1), Closure(_, _, tpt2)) =>
// TODO match tpt1 with tpt2?
matched

case (NamedArg(name1, arg1), NamedArg(name2, arg2)) if name1 == name2 =>
arg1 =?= arg2

case (EmptyTree, EmptyTree) =>
matched

// No Match
case _ =>
if (debug)
val quotes = QuotesImpl()
println(
s""">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|Scrutinee
| ${scrutinee.show}
|did not match pattern
| ${pattern.show}
|
|with environment: ${summon[Env]}
|
|Scrutinee: ${quotes.reflect.Printer.TreeStructure.show(scrutinee.asInstanceOf)}
|Pattern: ${quotes.reflect.Printer.TreeStructure.show(pattern.asInstanceOf)}
|
|""".stripMargin)
notMatched
scrutinee match
/* Match type ascription (a) */
case Typed(expr1, _) =>
expr1 =?= pattern

/* Match literal */
case Literal(constant1) =>
pattern match
case Literal(constant2) if constant1 == constant2 => matched
case _ => notMatched

case ref: RefTree =>
pattern match
/* Match selection */
case Select(qual2, _) if symbolMatch(scrutinee, pattern) =>
ref match
case Select(qual1, _) => qual1 =?= qual2
case ref: Ident =>
ref.tpe match
case TermRef(qual: TermRef, _) => tpd.ref(qual) =?= qual2
case _ => matched
/* Match reference */
case _: Ident if symbolMatch(scrutinee, pattern) => matched
/* Match type */
case TypeTreeTypeTest(pattern) if scrutinee.tpe <:< pattern.tpe => matched
case _ => notMatched

/* Match application */
case Apply(fn1, args1) =>
pattern match
case Apply(fn2, args2) =>
fn1 =?= fn2 &&& args1 =?= args2
case _ => notMatched

/* Match type application */
case TypeApply(fn1, args1) =>
pattern match
case TypeApply(fn2, args2) =>
fn1 =?= fn2 &&& args1 =?= args2
case _ => notMatched

/* Match block */
case Block(stat1 :: stats1, expr1) =>
pattern match
case Block(stat2 :: stats2, expr2) =>
val newEnv = (stat1, stat2) match {
case (stat1: MemberDef, stat2: MemberDef) =>
summon[Env] + (stat1.symbol -> stat2.symbol)
case _ =>
summon[Env]
}
withEnv(newEnv) {
stat1 =?= stat2 &&& Block(stats1, expr1) =?= Block(stats2, expr2)
}
case _ => notMatched

/* Match if */
case If(cond1, thenp1, elsep1) =>
pattern match
case If(cond2, thenp2, elsep2) =>
cond1 =?= cond2 &&& thenp1 =?= thenp2 &&& elsep1 =?= elsep2
case _ => notMatched

/* Match while */
case WhileDo(cond1, body1) =>
pattern match
case WhileDo(cond2, body2) => cond1 =?= cond2 &&& body1 =?= body2
case _ => notMatched

/* Match assign */
case Assign(lhs1, rhs1) =>
pattern match
case Assign(lhs2, rhs2) => lhs1 =?= lhs2 &&& rhs1 =?= rhs2
case _ => notMatched

/* Match new */
case New(tpt1) =>
pattern match
case New(tpt2) if tpt1.tpe.typeSymbol == tpt2.tpe.typeSymbol => matched
case _ => notMatched

/* Match this */
case This(_) =>
pattern match
case This(_) if scrutinee.symbol == pattern.symbol => matched
case _ => notMatched

/* Match super */
case Super(qual1, mix1) =>
pattern match
case Super(qual2, mix2) if mix1 == mix2 => qual1 =?= qual2
case _ => notMatched

/* Match varargs */
case SeqLiteral(elems1, _) =>
pattern match
case SeqLiteral(elems2, _) if elems1.size == elems2.size => elems1 =?= elems2
case _ => notMatched

/* Match type */
// TODO remove this?
case TypeTreeTypeTest(scrutinee) =>
pattern match
case TypeTreeTypeTest(pattern) if scrutinee.tpe <:< pattern.tpe => matched
case _ => notMatched

/* Match val */
case scrutinee @ ValDef(_, tpt1, _) =>
pattern match
case pattern @ ValDef(_, tpt2, _) if checkValFlags() =>
def rhsEnv = summon[Env] + (scrutinee.symbol -> pattern.symbol)
tpt1 =?= tpt2 &&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs)
case _ => notMatched

/* Match def */
case scrutinee @ DefDef(_, paramss1, tpt1, _) =>
pattern match
case pattern @ DefDef(_, paramss2, tpt2, _) =>
def rhsEnv: Env =
val paramSyms: List[(Symbol, Symbol)] =
for
(clause1, clause2) <- paramss1.zip(paramss2)
(param1, param2) <- clause1.zip(clause2)
yield
param1.symbol -> param2.symbol
val oldEnv: Env = summon[Env]
val newEnv: List[(Symbol, Symbol)] = (scrutinee.symbol -> pattern.symbol) :: paramSyms
oldEnv ++ newEnv
matchLists(paramss1, paramss2)(_ =?= _)
&&& tpt1 =?= tpt2
&&& withEnv(rhsEnv)(scrutinee.rhs =?= pattern.rhs)
case _ => notMatched

case Closure(_, _, tpt1) =>
pattern match
case Closure(_, _, tpt2) => matched // TODO match tpt1 with tpt2?
case _ => notMatched

case NamedArg(name1, arg1) =>
pattern match
case NamedArg(name2, arg2) if name1 == name2 => arg1 =?= arg2
case _ => notMatched

case EmptyTree =>
if pattern.isEmpty then matched
else notMatched

// No Match
case _ =>
notMatched

if (debug && res == notMatched)
val quotes = QuotesImpl()
println(
s""">>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>
|Scrutinee
| ${scrutinee.show}
|did not match pattern
| ${pattern.show}
|
|with environment: ${summon[Env]}
|
|Scrutinee: ${quotes.reflect.Printer.TreeStructure.show(scrutinee.asInstanceOf)}
|Pattern: ${quotes.reflect.Printer.TreeStructure.show(pattern.asInstanceOf)}
|
|""".stripMargin)

res
end =?=

end extension

Expand Down