Skip to content

Commit 9f508df

Browse files
committed
Merge pull request #199 from dotty-staging/transform/lambdalift
Transform/lambdalift
2 parents e992cf9 + 1070499 commit 9f508df

29 files changed

+302
-125
lines changed

src/dotty/tools/dotc/Compiler.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ class Compiler {
5353
new Literalize,
5454
new GettersSetters),
5555
List(new Erasure),
56-
List(new CapturedVars, new Constructors)/*,
57-
List(new LambdaLift)*/
56+
List(new CapturedVars,
57+
new Constructors),
58+
List(new LambdaLift)
5859
)
5960

6061
var runId = 1

src/dotty/tools/dotc/TypeErasure.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,10 @@ class TypeErasure(isJava: Boolean, isSemi: Boolean, isConstructor: Boolean, wild
258258
else this(parent)
259259
case tp: TermRef =>
260260
this(tp.widen)
261-
case ThisType(_) | SuperType(_, _) =>
261+
case ThisType(_) =>
262262
tp
263+
case SuperType(thistpe, supertpe) =>
264+
SuperType(this(thistpe), this(supertpe))
263265
case ExprType(rt) =>
264266
MethodType(Nil, Nil, this(rt))
265267
case tp: TypeProxy =>

src/dotty/tools/dotc/ast/TreeTypeMap.scala

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,13 @@ final class TreeTypeMap(
7979

8080
override def transform(tree: tpd.Tree)(implicit ctx: Context): tpd.Tree = treeMap(tree) match {
8181
case impl @ Template(constr, parents, self, body) =>
82-
val tmap = withMappedSyms(impl.symbol :: impl.constr.symbol :: Nil)
83-
val parents1 = parents mapconserve transform
84-
val (_, constr1 :: self1 :: Nil) = transformDefs(constr :: self :: Nil)
85-
val body1 = tmap.transformStats(body)
86-
updateDecls(constr :: body, constr1 :: body1)
82+
val tmap = withMappedSyms(localSyms(impl :: self :: Nil))
8783
cpy.Template(impl)(
88-
constr1.asInstanceOf[DefDef], parents1, self1.asInstanceOf[ValDef], body1)
89-
.withType(tmap.mapType(impl.tpe))
84+
constr = tmap.transformSub(constr),
85+
parents = parents mapconserve transform,
86+
self = tmap.transformSub(self),
87+
body = body mapconserve tmap.transform
88+
).withType(tmap.mapType(impl.tpe))
9089
case tree1 =>
9190
tree1.withType(mapType(tree1.tpe)) match {
9291
case id: Ident if tpd.needsSelect(id.tpe) =>
@@ -160,8 +159,24 @@ final class TreeTypeMap(
160159
* and return a treemap that contains the substitution
161160
* between original and mapped symbols.
162161
*/
163-
def withMappedSyms(syms: List[Symbol]): TreeTypeMap = {
164-
val mapped = ctx.mapSymbols(syms, this)
165-
withSubstitution(syms, mapped)
162+
def withMappedSyms(syms: List[Symbol], mapAlways: Boolean = false): TreeTypeMap =
163+
withMappedSyms(syms, ctx.mapSymbols(syms, this, mapAlways))
164+
165+
/** The tree map with the substitution between originals `syms`
166+
* and mapped symbols `mapped`. Also goes into mapped classes
167+
* and substitutes their declarations.
168+
*/
169+
def withMappedSyms(syms: List[Symbol], mapped: List[Symbol]): TreeTypeMap = {
170+
val symsChanged = syms ne mapped
171+
val substMap = withSubstitution(syms, mapped)
172+
val fullMap = (substMap /: mapped.filter(_.isClass)) { (tmap, cls) =>
173+
val origDcls = cls.decls.toList
174+
val mappedDcls = ctx.mapSymbols(origDcls, tmap)
175+
val tmap1 = tmap.withMappedSyms(origDcls, mappedDcls)
176+
if (symsChanged) (origDcls, mappedDcls).zipped.foreach(cls.asClass.replace)
177+
tmap1
178+
}
179+
if (symsChanged || (fullMap eq substMap)) fullMap
180+
else withMappedSyms(syms, mapAlways = true)
166181
}
167182
}

src/dotty/tools/dotc/ast/tpd.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import SymDenotations._, Symbols._, StdNames._, Annotations._, Trees._, Symbols.
99
import Denotations._, Decorators._
1010
import config.Printers._
1111
import typer.Mode
12+
import collection.mutable
1213
import typer.ErrorReporting._
1314

1415
import scala.annotation.tailrec
@@ -620,6 +621,12 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo {
620621
}
621622
acc(false, tree)
622623
}
624+
625+
def filterSubTrees(f: Tree => Boolean): List[Tree] = {
626+
val buf = new mutable.ListBuffer[Tree]
627+
foreachSubTree { tree => if (f(tree)) buf += tree }
628+
buf.toList
629+
}
623630
}
624631

625632
implicit class ListOfTreeDecorator(val xs: List[tpd.Tree]) extends AnyVal {

src/dotty/tools/dotc/config/Printers.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@ object Printers {
44

55
class Printer {
66
def println(msg: => String): Unit = System.out.println(msg)
7+
def echo[T](msg: => String, value: T): T = { println(msg + value); value }
78
}
89

910
object noPrinter extends Printer {
1011
override def println(msg: => String): Unit = ()
12+
override def echo[T](msg: => String, value: T): T = value
1113
}
1214

1315
val default: Printer = new Printer

src/dotty/tools/dotc/core/Decorators.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,11 @@ object Decorators {
100100
else x1 :: xs1
101101
}
102102

103+
def foldRightBN[U](z: => U)(op: (T, => U) => U): U = xs match {
104+
case Nil => z
105+
case x :: xs1 => op(x, xs1.foldRightBN(z)(op))
106+
}
107+
103108
final def hasSameLengthAs[U](ys: List[U]): Boolean = {
104109
@tailrec def loop(xs: List[T], ys: List[U]): Boolean =
105110
if (xs.isEmpty) ys.isEmpty

src/dotty/tools/dotc/core/Flags.scala

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,9 @@ object Flags {
331331
final val JavaDefined = commonFlag(30, "<java>")
332332

333333
/** Symbol is implemented as a Java static */
334-
final val Static = commonFlag(31, "<static>")
334+
final val JavaStatic = commonFlag(31, "<static>")
335+
final val JavaStaticTerm = JavaStatic.toTermFlags
336+
final val JavaStaticType = JavaStatic.toTypeFlags
335337

336338
/** Variable is accessed from nested function. */
337339
final val Captured = termFlag(32, "<captured>")
@@ -421,7 +423,7 @@ object Flags {
421423
/** Flags representing source modifiers */
422424
final val SourceModifierFlags =
423425
commonFlags(Private, Protected, Abstract, Final,
424-
Sealed, Case, Implicit, Override, AbsOverride, Lazy, Static)
426+
Sealed, Case, Implicit, Override, AbsOverride, Lazy, JavaStatic)
425427

426428
/** Flags representing modifiers that can appear in trees */
427429
final val ModifierFlags =
@@ -436,7 +438,7 @@ object Flags {
436438
/** Flags guaranteed to be set upon symbol creation */
437439
final val FromStartFlags =
438440
AccessFlags | Module | Package | Deferred | MethodOrHKCommon | Param | ParamAccessor | Scala2ExistentialCommon |
439-
InSuperCall | Touched | Static | CovariantOrOuter | ContravariantOrLabel | ExpandedName | AccessorOrSealed |
441+
InSuperCall | Touched | JavaStatic | CovariantOrOuter | ContravariantOrLabel | ExpandedName | AccessorOrSealed |
440442
CaseAccessorOrTypeArgument | Fresh | Frozen | Erroneous | ImplicitCommon | Permanent |
441443
SelfNameOrImplClass
442444

@@ -473,7 +475,7 @@ object Flags {
473475
*/
474476
final val RetainedModuleValAndClassFlags: FlagSet =
475477
AccessFlags | Package | Case |
476-
Synthetic | ExpandedName | JavaDefined | Static | Artifact |
478+
Synthetic | ExpandedName | JavaDefined | JavaStatic | Artifact |
477479
Erroneous | Lifted | MixedIn | Specialized
478480

479481
/** Flags that can apply to a module val */
@@ -487,7 +489,7 @@ object Flags {
487489

488490
/** Packages and package classes always have these flags set */
489491
final val PackageCreationFlags =
490-
Module | Package | Final | JavaDefined | Static
492+
Module | Package | Final | JavaDefined
491493

492494
/** These flags are pickled */
493495
final val PickledFlags = flagRange(FirstFlag, FirstNotPickledFlag)
@@ -562,7 +564,7 @@ object Flags {
562564
final val ProtectedLocal = allOf(Protected, Local)
563565

564566
/** Java symbol which is `protected` and `static` */
565-
final val StaticProtected = allOf(JavaDefined, Protected, Static)
567+
final val StaticProtected = allOf(JavaDefined, Protected, JavaStatic)
566568

567569
final val AbstractFinal = allOf(Abstract, Final)
568570
final val AbstractSealed = allOf(Abstract, Sealed)

src/dotty/tools/dotc/core/Scopes.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ object Scopes {
278278
if (e.sym == prev) e.sym = replacement
279279
e = lookupNextEntry(e)
280280
}
281+
elemsCache = null
281282
}
282283

283284
/** Lookup a symbol entry matching given name.

src/dotty/tools/dotc/core/SymDenotations.scala

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,7 @@ object SymDenotations {
390390

391391
/** Is this denotation static (i.e. with no outer instance)? */
392392
final def isStatic(implicit ctx: Context) =
393-
(this is Static) || this.exists && owner.isStaticOwner
393+
(this is JavaStatic) || this.exists && owner.isStaticOwner
394394

395395
/** Is this a package class or module class that defines static symbols? */
396396
final def isStaticOwner(implicit ctx: Context): Boolean =
@@ -666,10 +666,16 @@ object SymDenotations {
666666
* for these definitions.
667667
*/
668668
final def enclosingClass(implicit ctx: Context): Symbol = {
669-
def enclClass(d: SymDenotation): Symbol =
670-
if (d.isClass || !d.exists) d.symbol else enclClass(d.owner)
671-
val cls = enclClass(this)
672-
if (this is InSuperCall) cls.owner.enclosingClass else cls
669+
def enclClass(sym: Symbol, skip: Boolean): Symbol = {
670+
def newSkip = sym.is(InSuperCall) || sym.is(JavaStaticTerm)
671+
if (!sym.exists)
672+
NoSymbol
673+
else if (sym.isClass)
674+
if (skip) enclClass(sym.owner, newSkip) else sym
675+
else
676+
enclClass(sym.owner, skip || newSkip)
677+
}
678+
enclClass(symbol, false)
673679
}
674680

675681
final def isEffectivelyFinal(implicit ctx: Context): Boolean = {
@@ -976,7 +982,7 @@ object SymDenotations {
976982
/** The type parameters of this class */
977983
override final def typeParams(implicit ctx: Context): List[TypeSymbol] = {
978984
def computeTypeParams = {
979-
if (ctx.erasedTypes && (symbol ne defn.ArrayClass)) Nil
985+
if (ctx.erasedTypes || is(Module)) Nil // fast return for modules to avoid scanning package decls
980986
else if (this ne initial) initial.asSymDenotation.typeParams
981987
else decls.filter(sym =>
982988
(sym is TypeParam) && sym.owner == symbol).asInstanceOf[List[TypeSymbol]]

src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,10 +267,10 @@ trait Symbols { this: Context =>
267267
* Cross symbol references are brought over from originals to copies.
268268
* Do not copy any symbols if all attributes of all symbols stay the same.
269269
*/
270-
def mapSymbols(originals: List[Symbol], ttmap: TreeTypeMap) =
271-
if (originals forall (sym =>
270+
def mapSymbols(originals: List[Symbol], ttmap: TreeTypeMap, mapAlways: Boolean = false): List[Symbol] =
271+
if (originals.forall(sym =>
272272
(ttmap.mapType(sym.info) eq sym.info) &&
273-
!(ttmap.oldOwners contains sym.owner)))
273+
!(ttmap.oldOwners contains sym.owner)) && !mapAlways)
274274
originals
275275
else {
276276
val copies: List[Symbol] = for (original <- originals) yield

0 commit comments

Comments
 (0)