Skip to content

Detect and deal with non-RefTree captures #206

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions src/main/scala/scala/async/internal/AsyncTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,26 @@ trait AsyncTransform {
sym.asModule.moduleClass.setOwner(stateMachineClass)
}
}

def adjustType(tree: Tree): Tree = {
val resultType = if (tree.tpe eq null) null else tree.tpe.map {
case TypeRef(pre, sym, args) if liftedSyms.contains(sym) =>
val tp1 = internal.typeRef(thisType(sym.owner.asClass), sym, args)
tp1
case SingleType(pre, sym) if liftedSyms.contains(sym) =>
val tp1 = internal.singleType(thisType(sym.owner.asClass), sym)
tp1
case tp => tp
}
setType(tree, resultType)
}

// Replace the ValDefs in the splicee with Assigns to the corresponding lifted
// fields. Similarly, replace references to them with references to the field.
//
// This transform will only be run on the RHS of `def foo`.
val useFields: (Tree, TypingTransformApi) => Tree = (tree, api) => tree match {
case _ if api.currentOwner == stateMachineClass =>
val useFields: (Tree, TypingTransformApi) => Tree = (tree, api) => tree match {
case _ if api.currentOwner == stateMachineClass =>
api.default(tree)
case ValDef(_, _, _, rhs) if liftedSyms(tree.symbol) =>
api.atOwner(api.currentOwner) {
Expand All @@ -172,14 +186,19 @@ trait AsyncTransform {
treeCopy.Assign(tree, lhs, api.recur(rhs)).setType(definitions.UnitTpe).changeOwner(fieldSym, api.currentOwner)
}
}
case _: DefTree if liftedSyms(tree.symbol) =>
case _: DefTree if liftedSyms(tree.symbol) =>
EmptyTree
case Ident(name) if liftedSyms(tree.symbol) =>
case Ident(name) if liftedSyms(tree.symbol) =>
val fieldSym = tree.symbol
atPos(tree.pos) {
gen.mkAttributedStableRef(thisType(fieldSym.owner.asClass), fieldSym).setType(tree.tpe)
}
case _ =>
case sel @ Select(n@New(tt: TypeTree), nme.CONSTRUCTOR) =>
adjustType(sel)
adjustType(n)
adjustType(tt)
sel
case _ =>
api.default(tree)
}

Expand Down
21 changes: 17 additions & 4 deletions src/main/scala/scala/async/internal/Lifter.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package scala.async.internal

import scala.collection.mutable
import scala.collection.mutable.ListBuffer

trait Lifter {
self: AsyncMacro =>
Expand Down Expand Up @@ -77,13 +78,25 @@ trait Lifter {
// The direct references of each block, excluding references of `DefTree`-s which
// are already accounted for.
val stateIdToDirectlyReferenced: mutable.LinkedHashMap[Int, List[Symbol]] = {
val refs: List[(Int, Symbol)] = asyncStates.flatMap(
asyncState => asyncState.stats.filterNot(t => t.isDef && !isLabel(t.symbol)).flatMap(_.collect {
val result = new mutable.LinkedHashMap[Int, ListBuffer[Symbol]]()
asyncStates.foreach(
asyncState => asyncState.stats.filterNot(t => t.isDef && !isLabel(t.symbol)).foreach(_.foreach {
case rt: RefTree
if symToDefiningState.contains(rt.symbol) => (asyncState.state, rt.symbol)
if symToDefiningState.contains(rt.symbol) =>
result.getOrElseUpdate(asyncState.state, new ListBuffer) += rt.symbol
case tt: TypeTree =>
tt.tpe.foreach { tp =>
val termSym = tp.termSymbol
if (symToDefiningState.contains(termSym))
result.getOrElseUpdate(asyncState.state, new ListBuffer) += termSym
val typeSym = tp.typeSymbol
if (symToDefiningState.contains(typeSym))
result.getOrElseUpdate(asyncState.state, new ListBuffer) += typeSym
}
case _ =>
})
)
toMultiMap(refs)
result.map { case (a, b) => (a, b.result())}
}

def liftableSyms: mutable.LinkedHashSet[Symbol] = {
Expand Down
30 changes: 21 additions & 9 deletions src/test/scala/scala/async/TreeInterrogation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,29 @@ object TreeInterrogationApp extends App {
val tree = tb.parse(
"""
| import scala.async.internal.AsyncId._
| async {
| var b = true
| while(await(b)) {
| b = false
| }
| (1, 1) match {
| case (x, y) => await(2); println(x)
| }
| await(b)
| trait QBound { type D; trait ResultType { case class Inner() }; def toResult: ResultType = ??? }
| trait QD[Q <: QBound] {
| val operation: Q
| type D = operation.D
| }
|
| async {
| if (!"".isEmpty) {
| val treeResult = null.asInstanceOf[QD[QBound]]
| await(0)
| val y = treeResult.operation
| type RD = treeResult.operation.D
| (null: Object) match {
| case (_, _: RD) => ???
| case _ => val x = y.toResult; x.Inner()
| }
| await(1)
| (y, null.asInstanceOf[RD])
| ""
| }
|
| }
|
| """.stripMargin)
println(tree)
val tree1 = tb.typeCheck(tree.duplicate)
Expand Down
71 changes: 64 additions & 7 deletions src/test/scala/scala/async/run/late/LateExpansion.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@ package scala.async.run.late
import java.io.File

import junit.framework.Assert.assertEquals
import org.junit.{Assert, Test}
import org.junit.{Assert, Ignore, Test}

import scala.annotation.StaticAnnotation
import scala.annotation.meta.{field, getter}
import scala.async.TreeInterrogation
import scala.async.internal.AsyncId
import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader
import scala.tools.nsc._
Expand All @@ -19,6 +18,57 @@ import scala.tools.nsc.transform.TypingTransformers
// calls it from a new phase that runs after patmat.
class LateExpansion {

@Test def testRewrittenApply(): Unit = {
val result = wrapAndRun(
"""
| object O {
| case class Foo(a: Any)
| }
| @autoawait def id(a: String) = a
| O.Foo
| id("foo") + id("bar")
| O.Foo(1)
| """.stripMargin)
assertEquals("Foo(1)", result.toString)
}

@Ignore("Need to use adjustType more pervasively in AsyncTransform, but that exposes bugs in {Type, ... }Symbol's cache invalidation")
@Test def testIsInstanceOfType(): Unit = {
val result = wrapAndRun(
"""
| class Outer
| @autoawait def id(a: String) = a
| val o = new Outer
| id("foo") + id("bar")
| ("": Object).isInstanceOf[o.type]
| """.stripMargin)
assertEquals(false, result)
}

@Test def testIsInstanceOfTerm(): Unit = {
val result = wrapAndRun(
"""
| class Outer
| @autoawait def id(a: String) = a
| val o = new Outer
| id("foo") + id("bar")
| o.isInstanceOf[Outer]
| """.stripMargin)
assertEquals(true, result)
}

@Test def testArrayLocalModule(): Unit = {
val result = wrapAndRun(
"""
| class Outer
| @autoawait def id(a: String) = a
| val O = ""
| id("foo") + id("bar")
| new Array[O.type](0)
| """.stripMargin)
assertEquals(classOf[Array[String]], result.getClass)
}

@Test def test0(): Unit = {
val result = wrapAndRun(
"""
Expand All @@ -27,6 +77,7 @@ class LateExpansion {
| """.stripMargin)
assertEquals("foobar", result)
}

@Test def testGuard(): Unit = {
val result = wrapAndRun(
"""
Expand Down Expand Up @@ -143,6 +194,7 @@ class LateExpansion {
|}
| """.stripMargin)
}

@Test def shadowing2(): Unit = {
val result = run(
"""
Expand Down Expand Up @@ -369,6 +421,7 @@ class LateExpansion {
}
""")
}

@Test def testNegativeArraySizeExceptionFine1(): Unit = {
val result = run(
"""
Expand All @@ -389,18 +442,20 @@ class LateExpansion {
}
""")
}

private def createTempDir(): File = {
val f = File.createTempFile("output", "")
f.delete()
f.mkdirs()
f
}

def run(code: String): Any = {
// settings.processArgumentString("-Xprint:patmat,postpatmat,jvm -Ybackend:GenASM -nowarn")
val out = createTempDir()
try {
val reporter = new StoreReporter
val settings = new Settings(println(_))
//settings.processArgumentString("-Xprint:refchecks,patmat,postpatmat,jvm -nowarn")
settings.outdir.value = out.getAbsolutePath
settings.embeddedDefaults(getClass.getClassLoader)
val isInSBT = !settings.classpath.isSetByUser
Expand Down Expand Up @@ -432,6 +487,7 @@ class LateExpansion {
}

abstract class LatePlugin extends Plugin {

import global._

override val components: List[PluginComponent] = List(new PluginComponent with TypingTransformers {
Expand All @@ -448,16 +504,16 @@ abstract class LatePlugin extends Plugin {
super.transform(tree) match {
case ap@Apply(fun, args) if fun.symbol.hasAnnotation(autoAwaitSym) =>
localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(ap.tpe) :: Nil), ap :: Nil))
case sel@Select(fun, _) if sel.symbol.hasAnnotation(autoAwaitSym) && !(tree.tpe.isInstanceOf[MethodTypeApi] || tree.tpe.isInstanceOf[PolyTypeApi] ) =>
case sel@Select(fun, _) if sel.symbol.hasAnnotation(autoAwaitSym) && !(tree.tpe.isInstanceOf[MethodTypeApi] || tree.tpe.isInstanceOf[PolyTypeApi]) =>
localTyper.typed(Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, awaitSym), TypeTree(sel.tpe) :: Nil), sel :: Nil))
case dd: DefDef if dd.symbol.hasAnnotation(lateAsyncSym) => atOwner(dd.symbol) {
deriveDefDef(dd){ rhs: Tree =>
deriveDefDef(dd) { rhs: Tree =>
val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs))
localTyper.typed(atPos(dd.pos)(invoke))
}
}
case vd: ValDef if vd.symbol.hasAnnotation(lateAsyncSym) => atOwner(vd.symbol) {
deriveValDef(vd){ rhs: Tree =>
deriveValDef(vd) { rhs: Tree =>
val invoke = Apply(TypeApply(gen.mkAttributedRef(asyncIdSym.typeOfThis, asyncSym), TypeTree(rhs.tpe) :: Nil), List(rhs))
localTyper.typed(atPos(vd.pos)(invoke))
}
Expand All @@ -468,6 +524,7 @@ abstract class LatePlugin extends Plugin {
}
}
}

override def newPhase(prev: Phase): Phase = new StdPhase(prev) {
override def apply(unit: CompilationUnit): Unit = {
val translated = newTransformer(unit).transformUnit(unit)
Expand All @@ -476,7 +533,7 @@ abstract class LatePlugin extends Plugin {
}
}

override val runsAfter: List[String] = "patmat" :: Nil
override val runsAfter: List[String] = "refchecks" :: Nil
override val phaseName: String = "postpatmat"

})
Expand Down