Skip to content

Commit 5b4ca4f

Browse files
committed
Less boxing in field nulling analysis
1 parent c126865 commit 5b4ca4f

File tree

2 files changed

+52
-31
lines changed

2 files changed

+52
-31
lines changed

src/main/scala/scala/async/internal/LiveVariables.scala

+37-27
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package scala.async.internal
22

3+
import java.util
4+
35
import scala.collection.immutable.IntMap
46

57
trait LiveVariables {
@@ -19,19 +21,22 @@ trait LiveVariables {
1921
def fieldsToNullOut(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Int, List[Tree]] = {
2022
// live variables analysis:
2123
// the result map indicates in which states a given field should be nulled out
22-
val liveVarsMap: Map[Tree, Set[Int]] = liveVars(asyncStates, liftables)
24+
val liveVarsMap: Map[Tree, StateSet] = liveVars(asyncStates, liftables)
2325

2426
var assignsOf = Map[Int, List[Tree]]()
2527

26-
for ((fld, where) <- liveVarsMap; state <- where)
27-
assignsOf get state match {
28-
case None =>
29-
assignsOf += (state -> List(fld))
30-
case Some(trees) if !trees.exists(_.symbol == fld.symbol) =>
31-
assignsOf += (state -> (fld +: trees))
32-
case _ =>
33-
/* do nothing */
28+
for ((fld, where) <- liveVarsMap) {
29+
where.foreach { (state: Int) =>
30+
assignsOf get state match {
31+
case None =>
32+
assignsOf += (state -> List(fld))
33+
case Some(trees) if !trees.exists(_.symbol == fld.symbol) =>
34+
assignsOf += (state -> (fld +: trees))
35+
case _ =>
36+
// do nothing
37+
}
3438
}
39+
}
3540

3641
assignsOf
3742
}
@@ -48,7 +53,7 @@ trait LiveVariables {
4853
* @param liftables the lifted fields
4954
* @return a map which indicates for a given field (the key) the states in which it should be nulled out
5055
*/
51-
def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Tree, Set[Int]] = {
56+
def liveVars(asyncStates: List[AsyncState], liftables: List[Tree]): Map[Tree, StateSet] = {
5257
val liftedSyms: Set[Symbol] = // include only vars
5358
liftables.iterator.filter {
5459
case ValDef(mods, _, _, _) => mods.hasFlag(MUTABLE)
@@ -230,53 +235,58 @@ trait LiveVariables {
230235
}
231236
}
232237

233-
def lastUsagesOf(field: Tree, at: AsyncState): Set[Int] = {
238+
def lastUsagesOf(field: Tree, at: AsyncState): StateSet = {
234239
val avoid = scala.collection.mutable.HashSet[AsyncState]()
235240

236-
def lastUsagesOf0(field: Tree, at: AsyncState): Set[Int] = {
237-
if (avoid(at)) Set()
241+
val result = new StateSet
242+
def lastUsagesOf0(field: Tree, at: AsyncState): Unit = {
243+
if (avoid(at)) ()
238244
else if (captured(field.symbol)) {
239-
Set()
245+
()
240246
}
241247
else LVentry get at.state match {
242248
case Some(fields) if fields.contains(field.symbol) =>
243-
Set(at.state)
249+
result += at.state
244250
case _ =>
245251
avoid += at
246-
val preds = asyncStates.filter(state => contains(state.nextStates, at.state)).toSet
247-
preds.flatMap(p => lastUsagesOf0(field, p))
252+
for (state <- asyncStates) {
253+
if (contains(state.nextStates, at.state)) {
254+
lastUsagesOf0(field, state)
255+
}
256+
}
248257
}
249258
}
250259

251260
lastUsagesOf0(field, at)
261+
result
252262
}
253263

254-
val lastUsages: Map[Tree, Set[Int]] =
255-
liftables.map(fld => fld -> lastUsagesOf(fld, finalState)).toMap
264+
val lastUsages: Map[Tree, StateSet] =
265+
liftables.iterator.map(fld => fld -> lastUsagesOf(fld, finalState)).toMap
256266

257267
if(AsyncUtils.verbose) {
258268
for ((fld, lastStates) <- lastUsages)
259-
AsyncUtils.vprintln(s"field ${fld.symbol.name} is last used in states ${lastStates.mkString(", ")}")
269+
AsyncUtils.vprintln(s"field ${fld.symbol.name} is last used in states ${lastStates.iterator.mkString(", ")}")
260270
}
261271

262-
val nullOutAt: Map[Tree, Set[Int]] =
272+
val nullOutAt: Map[Tree, StateSet] =
263273
for ((fld, lastStates) <- lastUsages) yield {
264-
val killAt = lastStates.flatMap { s =>
265-
if (s == finalState.state) Set()
266-
else {
274+
var result = new StateSet
275+
lastStates.foreach { s =>
276+
if (s != finalState.state) {
267277
val lastAsyncState = asyncStates.find(_.state == s).get
268278
val succNums = lastAsyncState.nextStates
269279
// all successor states that are not indirect predecessors
270280
// filter out successor states where the field is live at the entry
271-
succNums.iterator.filter(num => !isPred(num, s)).filterNot(num => LVentry(num).contains(fld.symbol))
281+
util.Arrays.stream(succNums).filter(num => !isPred(num, s)).filter(num => !LVentry(num).contains(fld.symbol)).forEach(result += _)
272282
}
273283
}
274-
(fld, killAt)
284+
(fld, result)
275285
}
276286

277287
if(AsyncUtils.verbose) {
278288
for ((fld, killAt) <- nullOutAt)
279-
AsyncUtils.vprintln(s"field ${fld.symbol.name} should be nulled out in states ${killAt.mkString(", ")}")
289+
AsyncUtils.vprintln(s"field ${fld.symbol.name} should be nulled out in states ${killAt.iterator.mkString(", ")}")
280290
}
281291

282292
nullOutAt
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
1+
/*
2+
* Copyright (C) 2018 Lightbend Inc. <http://www.lightbend.com>
3+
*/
14
package scala.async.internal
25

36
import java.util
7+
import java.util.function.IntConsumer
48

9+
import scala.collection.JavaConverters.{asScalaIteratorConverter, iterableAsScalaIterableConverter}
510
import scala.collection.mutable
611

712
// Set for StateIds, which are either small positive integers or -symbolID.
813
final class StateSet {
9-
private var bitSet = mutable.BitSet()
14+
private var bitSet = new java.util.BitSet()
1015
private var caseSet = new util.HashSet[Integer]()
11-
def +=(stateId: Int): Unit = if (stateId > 0) bitSet += stateId else caseSet.add(stateId)
12-
def contains(stateId: Int): Boolean = if (stateId > 0) bitSet.contains(stateId) else caseSet.contains(stateId)
13-
16+
def +=(stateId: Int): Unit = if (stateId > 0) bitSet.set(stateId) else caseSet.add(stateId)
17+
def contains(stateId: Int): Boolean = if (stateId > 0) bitSet.get(stateId) else caseSet.contains(stateId)
18+
def iterator: Iterator[Integer] = {
19+
bitSet.stream().iterator().asScala ++ caseSet.asScala.iterator
20+
}
21+
def foreach(f: IntConsumer): Unit = {
22+
bitSet.stream().forEach(f)
23+
caseSet.stream().forEach(integer => f.accept(integer))
24+
}
1425
}

0 commit comments

Comments
 (0)