@@ -10,19 +10,20 @@ private[quoted] object Matcher {
10
10
class QuoteMatcher [QCtx <: QuoteContext & Singleton ](given val qctx : QCtx ) {
11
11
// TODO improve performance
12
12
13
+ // TODO use flag from qctx.tasty.rootContext. Maybe -debug or add -debug-macros
13
14
private final val debug = false
14
15
15
16
import qctx .tasty .{_ , given }
16
17
import Matching ._
17
18
18
- private type Env = Set [( Symbol , Symbol ) ]
19
+ private type Env = Map [ Symbol , Symbol ]
19
20
20
21
inline private def withEnv [T ](env : Env )(body : => (given Env ) => T ): T = body(given env )
21
22
22
23
class SymBinding (val sym : Symbol , val fromAbove : Boolean )
23
24
24
25
def termMatch (scrutineeTerm : Term , patternTerm : Term , hasTypeSplices : Boolean ): Option [Tuple ] = {
25
- implicit val env : Env = Set .empty
26
+ implicit val env : Env = Map .empty
26
27
if (hasTypeSplices) {
27
28
implicit val ctx : Context = internal.Context_GADT_setFreshGADTBounds (rootContext)
28
29
val matchings = scrutineeTerm.underlyingArgument =?= patternTerm.underlyingArgument
@@ -42,7 +43,7 @@ private[quoted] object Matcher {
42
43
43
44
// TODO factor out common logic with `termMatch`
44
45
def typeTreeMatch (scrutineeTypeTree : TypeTree , patternTypeTree : TypeTree , hasTypeSplices : Boolean ): Option [Tuple ] = {
45
- implicit val env : Env = Set .empty
46
+ implicit val env : Env = Map .empty
46
47
if (hasTypeSplices) {
47
48
implicit val ctx : Context = internal.Context_GADT_setFreshGADTBounds (rootContext)
48
49
val matchings = scrutineeTypeTree =?= patternTypeTree
@@ -138,11 +139,28 @@ private[quoted] object Matcher {
138
139
matched(scrutinee.seal)
139
140
140
141
// Match a scala.internal.Quoted.patternHole and return the scrutinee tree
141
- case (scrutinee : Term , TypeApply (patternHole, tpt :: Nil ))
142
+ case (ClosedTerm ( scrutinee) , TypeApply (patternHole, tpt :: Nil ))
142
143
if patternHole.symbol == internal.Definitions_InternalQuoted_patternHole &&
143
144
scrutinee.tpe <:< tpt.tpe =>
144
145
matched(scrutinee.seal)
145
146
147
+ // Matches an open term and wraps it into a lambda that provides the free variables
148
+ case (scrutinee, pattern @ Apply (Select (TypeApply (Ident (" patternHole" ), List (Inferred ())), " apply" ), args0 @ IdentArgs (args))) =>
149
+ def bodyFn (lambdaArgs : List [Tree ]): Tree = {
150
+ val argsMap = args.map(_.symbol).zip(lambdaArgs.asInstanceOf [List [Term ]]).toMap
151
+ new TreeMap {
152
+ override def transformTerm (tree : Term )(given ctx : Context ): Term =
153
+ tree match
154
+ case tree : Ident => summon[Env ].get(tree.symbol).flatMap(argsMap.get).getOrElse(tree)
155
+ case tree => super .transformTerm(tree)
156
+ }.transformTree(scrutinee)
157
+ }
158
+ val names = args.map(_.name)
159
+ val argTypes = args0.map(x => x.tpe.widenTermRefExpr)
160
+ val resType = pattern.tpe
161
+ val res = Lambda (MethodType (names)(_ => argTypes, _ => resType), bodyFn)
162
+ matched(res.seal)
163
+
146
164
//
147
165
// Match two equivalent trees
148
166
//
@@ -156,7 +174,7 @@ private[quoted] object Matcher {
156
174
case (scrutinee, Typed (expr2, _)) =>
157
175
scrutinee =?= expr2
158
176
159
- case (Ident (_), Ident (_)) if scrutinee.symbol == pattern.symbol || summon[Env ].apply(( scrutinee.symbol, pattern.symbol) ) =>
177
+ case (Ident (_), Ident (_)) if scrutinee.symbol == pattern.symbol || summon[Env ].get( scrutinee.symbol).contains( pattern.symbol) =>
160
178
matched
161
179
162
180
case (Select (qual1, _), Select (qual2, _)) if scrutinee.symbol == pattern.symbol =>
@@ -165,18 +183,24 @@ private[quoted] object Matcher {
165
183
case (_ : Ref , _ : Ref ) if scrutinee.symbol == pattern.symbol =>
166
184
matched
167
185
168
- case (Apply (fn1, args1), Apply (fn2, args2)) if fn1.symbol == fn2.symbol =>
186
+ case (Apply (fn1, args1), Apply (fn2, args2)) if fn1.symbol == fn2.symbol || summon[ Env ].get(fn1.symbol).contains(fn2.symbol) =>
169
187
fn1 =?= fn2 && args1 =?= args2
170
188
171
- case (TypeApply (fn1, args1), TypeApply (fn2, args2)) if fn1.symbol == fn2.symbol =>
189
+ case (TypeApply (fn1, args1), TypeApply (fn2, args2)) if fn1.symbol == fn2.symbol || summon[ Env ].get(fn1.symbol).contains(fn2.symbol) =>
172
190
fn1 =?= fn2 && args1 =?= args2
173
191
174
192
case (Block (stats1, expr1), Block (binding :: stats2, expr2)) if isTypeBinding(binding) =>
175
193
qctx.tasty.internal.Context_GADT_addToConstraint (summon[Context ])(binding.symbol :: Nil )
176
194
matched(new SymBinding (binding.symbol, hasFromAboveAnnotation(binding.symbol))) && Block (stats1, expr1) =?= Block (stats2, expr2)
177
195
178
196
case (Block (stat1 :: stats1, expr1), Block (stat2 :: stats2, expr2)) =>
179
- withEnv(summon[Env ] + (stat1.symbol -> stat2.symbol)) {
197
+ val newEnv = (stat1, stat2) match {
198
+ case (stat1 : Definition , stat2 : Definition ) =>
199
+ summon[Env ] + (stat1.symbol -> stat2.symbol)
200
+ case _ =>
201
+ summon[Env ]
202
+ }
203
+ withEnv(newEnv) {
180
204
stat1 =?= stat2 && Block (stats1, expr1) =?= Block (stats2, expr2)
181
205
}
182
206
@@ -268,7 +292,7 @@ private[quoted] object Matcher {
268
292
|
269
293
| ${pattern.showExtractors}
270
294
|
271
- |
295
+ |with environment: ${summon[ Env ]}
272
296
|
273
297
|
274
298
| """ .stripMargin)
@@ -277,6 +301,31 @@ private[quoted] object Matcher {
277
301
}
278
302
}
279
303
304
+ private object ClosedTerm {
305
+ def unapply (term : Term )(given Context , Env ): Option [term.type ] =
306
+ if freeVars(term).isEmpty then Some (term) else None
307
+
308
+ def freeVars (tree : Tree )(given qctx : Context , env : Env ): Set [Symbol ] =
309
+ val accumulator = new TreeAccumulator [Set [Symbol ]] {
310
+ def foldTree (x : Set [Symbol ], tree : Tree )(given ctx : Context ): Set [Symbol ] =
311
+ tree match
312
+ case tree : Ident if env.contains(tree.symbol) => foldOverTree(x + tree.symbol, tree)
313
+ case _ => foldOverTree(x, tree)
314
+ }
315
+ accumulator.foldTree(Set .empty, tree)
316
+ }
317
+
318
+ private object IdentArgs {
319
+ def unapply (args : List [Term ])(given Context ): Option [List [Ident ]] =
320
+ args.foldRight(Option (List .empty[Ident ])) {
321
+ case (id : Ident , Some (acc)) => Some (id :: acc)
322
+ case (Block (List (DefDef (" $anonfun" , Nil , List (params), Inferred (), Some (Apply (id : Ident , args)))), Closure (Ident (" $anonfun" ), None )), Some (acc))
323
+ if params.zip(args).forall(_.symbol == _.symbol) =>
324
+ Some (id :: acc)
325
+ case _ => None
326
+ }
327
+ }
328
+
280
329
private def treeOptMatches (scrutinee : Option [Tree ], pattern : Option [Tree ])(given Context , Env ): Matching = {
281
330
(scrutinee, pattern) match {
282
331
case (Some (x), Some (y)) => x =?= y
@@ -344,7 +393,7 @@ private[quoted] object Matcher {
344
393
|
345
394
| ${pattern.showExtractors}
346
395
|
347
- |
396
+ |with environment: ${summon[ Env ]}
348
397
|
349
398
|
350
399
| """ .stripMargin)
0 commit comments