Skip to content

Performance improvements to the macro #198

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 6, 2018
25 changes: 14 additions & 11 deletions src/main/scala/scala/async/internal/AnfTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ private[async] trait AnfTransform {
stats :+ expr :+ api.typecheck(atPos(expr.pos)(Throw(Apply(Select(New(gen.mkAttributedRef(defn.IllegalStateExceptionClass)), nme.CONSTRUCTOR), Nil))))
expr match {
case Apply(fun, args) if isAwait(fun) =>
val valDef = defineVal(name.await, expr, tree.pos)
val valDef = defineVal(name.await(), expr, tree.pos)
val ref = gen.mkAttributedStableRef(valDef.symbol).setType(tree.tpe)
val ref1 = if (ref.tpe =:= definitions.UnitTpe)
// https://github.com/scala/async/issues/74
Expand Down Expand Up @@ -109,7 +109,7 @@ private[async] trait AnfTransform {
} else if (expr.tpe =:= definitions.NothingTpe) {
statsExprThrow
} else {
val varDef = defineVar(name.ifRes, expr.tpe, tree.pos)
val varDef = defineVar(name.ifRes(), expr.tpe, tree.pos)
def typedAssign(lhs: Tree) =
api.typecheck(atPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, tpe(varDef.symbol)))))

Expand Down Expand Up @@ -140,7 +140,7 @@ private[async] trait AnfTransform {
} else if (expr.tpe =:= definitions.NothingTpe) {
statsExprThrow
} else {
val varDef = defineVar(name.matchRes, expr.tpe, tree.pos)
val varDef = defineVar(name.matchRes(), expr.tpe, tree.pos)
def typedAssign(lhs: Tree) =
api.typecheck(atPos(lhs.pos)(Assign(Ident(varDef.symbol), mkAttributedCastPreservingAnnotations(lhs, tpe(varDef.symbol)))))
val casesWithAssign = cases map {
Expand All @@ -163,14 +163,14 @@ private[async] trait AnfTransform {
}
}

def defineVar(prefix: String, tp: Type, pos: Position): ValDef = {
val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp))
def defineVar(name: TermName, tp: Type, pos: Position): ValDef = {
val sym = api.currentOwner.newTermSymbol(name, pos, MUTABLE | SYNTHETIC).setInfo(uncheckedBounds(tp))
valDef(sym, mkZero(uncheckedBounds(tp))).setType(NoType).setPos(pos)
}
}

def defineVal(prefix: String, lhs: Tree, pos: Position): ValDef = {
val sym = api.currentOwner.newTermSymbol(name.fresh(prefix), pos, SYNTHETIC).setInfo(uncheckedBounds(lhs.tpe))
def defineVal(name: TermName, lhs: Tree, pos: Position): ValDef = {
val sym = api.currentOwner.newTermSymbol(name, pos, SYNTHETIC).setInfo(uncheckedBounds(lhs.tpe))
internal.valDef(sym, internal.changeOwner(lhs, api.currentOwner, sym)).setType(NoType).setPos(pos)
}

Expand Down Expand Up @@ -212,7 +212,7 @@ private[async] trait AnfTransform {
case Arg(expr, _, argName) =>
linearize.transformToList(expr) match {
case stats :+ expr1 =>
val valDef = defineVal(argName, expr1, expr1.pos)
val valDef = defineVal(name.freshen(argName), expr1, expr1.pos)
require(valDef.tpe != null, valDef)
val stats1 = stats :+ valDef
(stats1, atPos(tree.pos.makeTransparent)(gen.stabilize(gen.mkAttributedIdent(valDef.symbol))))
Expand Down Expand Up @@ -279,8 +279,9 @@ private[async] trait AnfTransform {
// TODO we can move this into ExprBuilder once we get rid of `AsyncDefinitionUseAnalyzer`.
val block = linearize.transformToBlock(body)
val (valDefs, mappings) = (pat collect {
case b@Bind(name, _) =>
val vd = defineVal(name.toTermName + AnfTransform.this.name.bindSuffix, gen.mkAttributedStableRef(b.symbol).setPos(b.pos), b.pos)
case b@Bind(bindName, _) =>
val vd = defineVal(name.freshen(bindName.toTermName), gen.mkAttributedStableRef(b.symbol).setPos(b.pos), b.pos)
vd.symbol.updateAttachment(SyntheticBindVal)
(vd, (b.symbol, vd.symbol))
}).unzip
val (from, to) = mappings.unzip
Expand Down Expand Up @@ -333,7 +334,7 @@ private[async] trait AnfTransform {
// Otherwise, create the matchres var. We'll callers of the label def below.
// Remember: we're iterating through the statement sequence in reverse, so we'll get
// to the LabelDef and mutate `matchResults` before we'll get to its callers.
val matchResult = linearize.defineVar(name.matchRes, param.tpe, ld.pos)
val matchResult = linearize.defineVar(name.matchRes(), param.tpe, ld.pos)
matchResults += matchResult
caseDefToMatchResult(ld.symbol) = matchResult.symbol
val rhs2 = ld.rhs.substituteSymbols(param.symbol :: Nil, matchResult.symbol :: Nil)
Expand Down Expand Up @@ -408,3 +409,5 @@ private[async] trait AnfTransform {
}).asInstanceOf[Block]
}
}

object SyntheticBindVal
11 changes: 11 additions & 0 deletions src/main/scala/scala/async/internal/AsyncMacro.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,18 @@ package scala.async.internal
object AsyncMacro {
def apply(c0: reflect.macros.Context, base: AsyncBase)(body0: c0.Tree): AsyncMacro { val c: c0.type } = {
import language.reflectiveCalls

// Use an attachment on RootClass as a sneaky place for a per-Global cache
val att = c0.internal.attachments(c0.universe.rootMirror.RootClass)
val names = att.get[AsyncNames[_]].getOrElse {
val names = new AsyncNames[c0.universe.type](c0.universe)
att.update(names)
names
}

new AsyncMacro { self =>
val c: c0.type = c0
val asyncNames: AsyncNames[c.universe.type] = names.asInstanceOf[AsyncNames[c.universe.type]]
val body: c.Tree = body0
// This member is required by `AsyncTransform`:
val asyncBase: AsyncBase = base
Expand All @@ -23,6 +33,7 @@ private[async] trait AsyncMacro
val c: scala.reflect.macros.Context
val body: c.Tree
var containsAwait: c.Tree => Boolean
val asyncNames: AsyncNames[c.universe.type]

lazy val macroPos: c.universe.Position = c.macroApplication.pos.makeTransparent
def atMacroPos(t: c.Tree): c.Tree = c.universe.atPos(macroPos)(t)
Expand Down
109 changes: 109 additions & 0 deletions src/main/scala/scala/async/internal/AsyncNames.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package scala.async.internal

import java.util.concurrent.atomic.AtomicInteger

import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import scala.reflect.api.Names

/**
* A per-global cache of names needed by the Async macro.
*/
final class AsyncNames[U <: Names with Singleton](val u: U) {
self =>
import u._

abstract class NameCache[N <: U#Name](base: String) {
val cached = new ArrayBuffer[N]()
protected def newName(s: String): N
def apply(i: Int): N = {
if (cached.isDefinedAt(i)) cached(i)
else {
assert(cached.length == i)
val name = newName(freshenString(base, i))
cached += name
name
}
}
}

final class TermNameCache(base: String) extends NameCache[U#TermName](base) {
override protected def newName(s: String): U#TermName = newTermName(s)
}
final class TypeNameCache(base: String) extends NameCache[U#TypeName](base) {
override protected def newName(s: String): U#TypeName = newTypeName(s)
}
private val matchRes: TermNameCache = new TermNameCache("match")
private val ifRes: TermNameCache = new TermNameCache("if")
private val await: TermNameCache = new TermNameCache("await")

private val result = newTermName("result$async")
private val completed: TermName = newTermName("completed$async")
private val apply = newTermName("apply")
private val stateMachine = newTermName("stateMachine$async")
private val stateMachineT = stateMachine.toTypeName
private val state: u.TermName = newTermName("state$async")
private val execContext = newTermName("execContext$async")
private val tr: u.TermName = newTermName("tr$async")
private val t: u.TermName = newTermName("throwable$async")

final class NameSource[N <: U#Name](cache: NameCache[N]) {
private val count = new AtomicInteger(0)
def apply(): N = cache(count.getAndIncrement())
}

class AsyncName {
final val matchRes = new NameSource[U#TermName](self.matchRes)
final val ifRes = new NameSource[U#TermName](self.matchRes)
final val await = new NameSource[U#TermName](self.await)
final val completed = self.completed
final val result = self.result
final val apply = self.apply
final val stateMachine = self.stateMachine
final val stateMachineT = self.stateMachineT
final val state: u.TermName = self.state
final val execContext = self.execContext
final val tr: u.TermName = self.tr
final val t: u.TermName = self.t

private val seenPrefixes = mutable.AnyRefMap[Name, AtomicInteger]()
private val freshened = mutable.HashSet[Name]()

final def freshenIfNeeded(name: TermName): TermName = {
seenPrefixes.getOrNull(name) match {
case null =>
seenPrefixes.put(name, new AtomicInteger())
name
case counter =>
freshen(name, counter)
}
}
final def freshenIfNeeded(name: TypeName): TypeName = {
seenPrefixes.getOrNull(name) match {
case null =>
seenPrefixes.put(name, new AtomicInteger())
name
case counter =>
freshen(name, counter)
}
}
final def freshen(name: TermName): TermName = {
val counter = seenPrefixes.getOrElseUpdate(name, new AtomicInteger())
freshen(name, counter)
}
final def freshen(name: TypeName): TypeName = {
val counter = seenPrefixes.getOrElseUpdate(name, new AtomicInteger())
freshen(name, counter)
}
private def freshen(name: TermName, counter: AtomicInteger): TermName = {
if (freshened.contains(name)) name
else TermName(freshenString(name.toString, counter.incrementAndGet()))
}
private def freshen(name: TypeName, counter: AtomicInteger): TypeName = {
if (freshened.contains(name)) name
else TypeName(freshenString(name.toString, counter.incrementAndGet()))
}
}

private def freshenString(name: String, counter: Int): String = name.toString + "$async$" + counter
}
12 changes: 8 additions & 4 deletions src/main/scala/scala/async/internal/AsyncTransform.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,6 @@ trait AsyncTransform {
buildAsyncBlock(anfTree, symLookup)
}

if(AsyncUtils.verbose)
logDiagnostics(anfTree, asyncBlock.asyncStates.map(_.toString))

val liftedFields: List[Tree] = liftables(asyncBlock.asyncStates)

// live variables analysis
Expand Down Expand Up @@ -114,10 +111,15 @@ trait AsyncTransform {
futureSystemOps.spawn(body, execContext) // generate lean code for the simple case of `async { 1 + 1 }`
else
startStateMachine

if(AsyncUtils.verbose) {
logDiagnostics(anfTree, asyncBlock, asyncBlock.asyncStates.map(_.toString))
}
futureSystemOps.dot(enclosingOwner, body).foreach(f => f(asyncBlock.toDot))
cleanupContainsAwaitAttachments(result)
}

def logDiagnostics(anfTree: Tree, states: Seq[String]): Unit = {
def logDiagnostics(anfTree: Tree, block: AsyncBlock, states: Seq[String]): Unit = {
def location = try {
macroPos.source.path
} catch {
Expand All @@ -129,6 +131,8 @@ trait AsyncTransform {
AsyncUtils.vprintln(s"${c.macroApplication}")
AsyncUtils.vprintln(s"ANF transform expands to:\n $anfTree")
states foreach (s => AsyncUtils.vprintln(s))
AsyncUtils.vprintln("===== DOT =====")
AsyncUtils.vprintln(block.toDot)
}

/**
Expand Down
Loading