diff --git a/README.md b/README.md index 9c2e4f2c02..32f4d59c11 100644 --- a/README.md +++ b/README.md @@ -71,6 +71,8 @@ See the [CONTRIBUTING](CONTRIBUTING.md) file. - [x] `foldLeft` - [x] `foldRight` - [x] `get` +- [x] `getOrElse` +- [x] `getOrElseUpdate` - [x] `head` - [x] `indexWhere` - [x] `isDefinedAt` @@ -88,7 +90,7 @@ See the [CONTRIBUTING](CONTRIBUTING.md) file. - [x] `drop` - [x] `empty` - [x] `filter` / `filterNot` -- [ ] `groupBy` +- [x] `groupBy` - [x] `intersect` - [x] `partition` - [x] `range` diff --git a/benchmarks/time/src/main/scala/strawman/collection/immutable/HashSetBenchmark.scala b/benchmarks/time/src/main/scala/strawman/collection/immutable/HashSetBenchmark.scala index a7f95006a8..0018327302 100644 --- a/benchmarks/time/src/main/scala/strawman/collection/immutable/HashSetBenchmark.scala +++ b/benchmarks/time/src/main/scala/strawman/collection/immutable/HashSetBenchmark.scala @@ -59,4 +59,10 @@ class HashSetBenchmark { @Benchmark def map(bh: Blackhole): Unit = bh.consume(xs.map(x => x + 1)) + @Benchmark + def groupBy(bh: Blackhole): Unit = { + val result = xs.groupBy(_ % 5) + bh.consume(result) + } + } diff --git a/benchmarks/time/src/main/scala/strawman/collection/immutable/ImmutableArrayBenchmark.scala b/benchmarks/time/src/main/scala/strawman/collection/immutable/ImmutableArrayBenchmark.scala index e39d7a69aa..ecd3b3ae2b 100644 --- a/benchmarks/time/src/main/scala/strawman/collection/immutable/ImmutableArrayBenchmark.scala +++ b/benchmarks/time/src/main/scala/strawman/collection/immutable/ImmutableArrayBenchmark.scala @@ -87,4 +87,10 @@ class ImmutableArrayBenchmark { @Benchmark def map(bh: Blackhole): Unit = bh.consume(xs.map(x => x + 1)) + @Benchmark + def groupBy(bh: Blackhole): Unit = { + val result = xs.groupBy(_ % 5) + bh.consume(result) + } + } diff --git a/benchmarks/time/src/main/scala/strawman/collection/immutable/LazyListBenchmark.scala b/benchmarks/time/src/main/scala/strawman/collection/immutable/LazyListBenchmark.scala index e5757c6e45..1722015a9a 100644 --- a/benchmarks/time/src/main/scala/strawman/collection/immutable/LazyListBenchmark.scala +++ b/benchmarks/time/src/main/scala/strawman/collection/immutable/LazyListBenchmark.scala @@ -79,4 +79,10 @@ class LazyListBenchmark { @Benchmark def map(bh: Blackhole): Unit = bh.consume(xs.map(x => x + 1)) + @Benchmark + def groupBy(bh: Blackhole): Unit = { + val result = xs.groupBy(_ % 5) + bh.consume(result) + } + } diff --git a/benchmarks/time/src/main/scala/strawman/collection/immutable/ListBenchmark.scala b/benchmarks/time/src/main/scala/strawman/collection/immutable/ListBenchmark.scala index 12692cd872..13f0d56472 100644 --- a/benchmarks/time/src/main/scala/strawman/collection/immutable/ListBenchmark.scala +++ b/benchmarks/time/src/main/scala/strawman/collection/immutable/ListBenchmark.scala @@ -87,4 +87,10 @@ class ListBenchmark { @Benchmark def map(bh: Blackhole): Unit = bh.consume(xs.map(x => x + 1)) + @Benchmark + def groupBy(bh: Blackhole): Unit = { + val result = xs.groupBy(_ % 5) + bh.consume(result) + } + } diff --git a/benchmarks/time/src/main/scala/strawman/collection/immutable/PrimitiveArrayBenchmark.scala b/benchmarks/time/src/main/scala/strawman/collection/immutable/PrimitiveArrayBenchmark.scala index 939bbc4a75..5f04d0281d 100644 --- a/benchmarks/time/src/main/scala/strawman/collection/immutable/PrimitiveArrayBenchmark.scala +++ b/benchmarks/time/src/main/scala/strawman/collection/immutable/PrimitiveArrayBenchmark.scala @@ -3,7 +3,10 @@ package strawman.collection.immutable import java.util.concurrent.TimeUnit import org.openjdk.jmh.annotations._ -import scala.{Any, AnyRef, Int, Unit} +import org.openjdk.jmh.infra.Blackhole + +import scala.{Any, AnyRef, Int, Long, Unit} +import scala.Predef.intWrapper @BenchmarkMode(scala.Array(Mode.AverageTime)) @OutputTimeUnit(TimeUnit.NANOSECONDS) @@ -13,46 +16,81 @@ import scala.{Any, AnyRef, Int, Unit} @State(Scope.Benchmark) class PrimitiveArrayBenchmark { - @Param(scala.Array("8", "64", "512", "4096", "32768", "262144"/*, "2097152"*/)) + @Param(scala.Array("0", "1", "2", "3", "4", "7", "8", "15", "16", "17", "39", "282", "73121", "7312102")) var size: Int = _ var xs: ImmutableArray[Int] = _ - var obj: Int = _ + var xss: scala.Array[ImmutableArray[Int]] = _ + var randomIndices: scala.Array[Int] = _ @Setup(Level.Trial) def initData(): Unit = { - xs = ImmutableArray.fill(size)(obj) - obj = 123 + def freshCollection() = ImmutableArray((1 to size): _*) + xs = freshCollection() + xss = scala.Array.fill(1000)(freshCollection()) + if (size > 0) { + randomIndices = scala.Array.fill(1000)(scala.util.Random.nextInt(size)) + } } @Benchmark - def cons(): Any = { + // @OperationsPerInvocation(size) + def cons(bh: Blackhole): Unit = { var ys = ImmutableArray.empty[Int] var i = 0 while (i < size) { - ys = ys :+ obj - i += 1 + ys = ys :+ i + i = i + 1 } - ys + bh.consume(ys) } @Benchmark - def uncons(): Any = xs.tail + def uncons(bh: Blackhole): Unit = bh.consume(xs.tail) + + @Benchmark + def concat(bh: Blackhole): Unit = bh.consume(xs ++ xs) + + @Benchmark + def foreach(bh: Blackhole): Unit = xs.foreach(x => bh.consume(x)) + + @Benchmark + // @OperationsPerInvocation(size) + def foreach_while(bh: Blackhole): Unit = { + var ys = xs + while (ys.nonEmpty) { + bh.consume(ys.head) + ys = ys.tail + } + } @Benchmark - def concat(): Any = xs ++ xs + @OperationsPerInvocation(1000) + def lookupLast(bh: Blackhole): Unit = { + var i = 0 + while (i < 1000) { + bh.consume(xss(i)(size - 1)) + i = i + 1 + } + } @Benchmark - def foreach(): Any = { - var n = 0 - xs.foreach(x => if (x == 0) n += 1) - n + @OperationsPerInvocation(1000) + def randomLookup(bh: Blackhole): Unit = { + var i = 0 + while (i < 1000) { + bh.consume(xs(randomIndices(i))) + i = i + 1 + } } @Benchmark - def lookup(): Any = xs(size - 1) + def map(bh: Blackhole): Unit = bh.consume(xs.map(x => x + 1)) @Benchmark - def map(): Any = xs.map(x => if (x == 0) "foo" else "bar") + def groupBy(bh: Blackhole): Unit = { + val result = xs.groupBy(_ % 5) + bh.consume(result) + } } diff --git a/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaHashSetBenchmark.scala b/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaHashSetBenchmark.scala index a8d55aa758..a4340d67f8 100644 --- a/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaHashSetBenchmark.scala +++ b/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaHashSetBenchmark.scala @@ -57,4 +57,10 @@ class ScalaHashSetBenchmark { @Benchmark def map(bh: Blackhole): Unit = bh.consume(xs.map(x => x + 1)) + @Benchmark + def groupBy(bh: Blackhole): Unit = { + val result = xs.groupBy(_ % 5) + bh.consume(result) + } + } diff --git a/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaListBenchmark.scala b/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaListBenchmark.scala index ec43e1ba58..accf420719 100644 --- a/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaListBenchmark.scala +++ b/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaListBenchmark.scala @@ -87,4 +87,10 @@ class ScalaListBenchmark { @Benchmark def map(bh: Blackhole): Unit = bh.consume(xs.map(x => x + 1)) + @Benchmark + def groupBy(bh: Blackhole): Unit = { + val result = xs.groupBy(_ % 5) + bh.consume(result) + } + } diff --git a/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaTreeSetBenchmark.scala b/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaTreeSetBenchmark.scala index 792f74402f..aa32ca1a66 100644 --- a/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaTreeSetBenchmark.scala +++ b/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaTreeSetBenchmark.scala @@ -59,4 +59,10 @@ class ScalaTreeSetBenchmark { @Benchmark def map(bh: Blackhole): Unit = bh.consume(xs.map(x => x + 1)) + @Benchmark + def groupBy(bh: Blackhole): Unit = { + val result = xs.groupBy(_ % 5) + bh.consume(result) + } + } diff --git a/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaVectorBenchmark.scala b/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaVectorBenchmark.scala index cad1a8648e..dafe7e93c8 100644 --- a/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaVectorBenchmark.scala +++ b/benchmarks/time/src/main/scala/strawman/collection/immutable/ScalaVectorBenchmark.scala @@ -3,7 +3,10 @@ package strawman.collection.immutable import java.util.concurrent.TimeUnit import org.openjdk.jmh.annotations._ -import scala.{Any, AnyRef, Int, Unit} +import org.openjdk.jmh.infra.Blackhole + +import scala.{Any, AnyRef, Int, Long, Unit} +import scala.Predef.intWrapper @BenchmarkMode(scala.Array(Mode.AverageTime)) @OutputTimeUnit(TimeUnit.NANOSECONDS) @@ -16,43 +19,67 @@ class ScalaVectorBenchmark { @Param(scala.Array("8", "64", "512", "4096", "32768", "262144"/*, "2097152"*/)) var size: Int = _ - var xs: scala.Vector[AnyRef] = _ - var obj: Any = _ + var xs: scala.Vector[Long] = _ + var xss: scala.Array[scala.Vector[Long]] = _ + var randomIndices: scala.Array[Int] = _ @Setup(Level.Trial) def initData(): Unit = { - xs = scala.Vector.fill(size)("") - obj = "" + def freshCollection() = scala.Vector((1 to size).map(_.toLong): _*) + xs = freshCollection() + xss = scala.Array.fill(1000)(freshCollection()) + if (size > 0) { + randomIndices = scala.Array.fill(1000)(scala.util.Random.nextInt(size)) + } } @Benchmark - def cons(): Any = { - var ys = scala.Vector.empty[Any] - var i = 0 + def cons(bh: Blackhole): Unit = { + var ys = scala.Vector.empty[Long] + var i = 0L while (i < size) { - ys = ys :+ obj + ys = ys :+ i i += 1 } - ys + bh.consume(ys) } @Benchmark - def uncons(): Any = xs.tail + def uncons(bh: Blackhole): Unit = bh.consume(xs.tail) @Benchmark - def concat(): Any = xs ++ xs + def concat(bh: Blackhole): Unit = bh.consume(xs ++ xs) @Benchmark - def foreach(): Any = { - var n = 0 - xs.foreach(x => if (x eq null) n += 1) - n + def foreach(bh: Blackhole): Unit = xs.foreach(x => bh.consume(x)) + + @Benchmark + @OperationsPerInvocation(1000) + def lookupLast(bh: Blackhole): Unit = { + var i = 0 + while (i < 1000) { + bh.consume(xss(i)(size - 1)) + i = i + 1 + } } @Benchmark - def lookup(): Any = xs(size - 1) + @OperationsPerInvocation(1000) + def randomLookup(bh: Blackhole): Unit = { + var i = 0 + while (i < 1000) { + bh.consume(xs(randomIndices(i))) + i = i + 1 + } + } @Benchmark - def map(): Any = xs.map(x => if (x eq null) "foo" else "bar") + def map(bh: Blackhole): Unit = bh.consume(xs.map(x => x + 1)) + + @Benchmark + def groupBy(bh: Blackhole): Unit = { + val result = xs.groupBy(_ % 5) + bh.consume(result) + } } diff --git a/benchmarks/time/src/main/scala/strawman/collection/immutable/TreeSetBenchmark.scala b/benchmarks/time/src/main/scala/strawman/collection/immutable/TreeSetBenchmark.scala index fb0db16087..c8f521fb93 100644 --- a/benchmarks/time/src/main/scala/strawman/collection/immutable/TreeSetBenchmark.scala +++ b/benchmarks/time/src/main/scala/strawman/collection/immutable/TreeSetBenchmark.scala @@ -59,4 +59,10 @@ class TreeSetBenchmark { @Benchmark def map(bh: Blackhole): Unit = bh.consume(xs.map(x => x + 1)) + @Benchmark + def groupBy(bh: Blackhole): Unit = { + val result = xs.groupBy(_ % 5) + bh.consume(result) + } + } diff --git a/benchmarks/time/src/main/scala/strawman/collection/mutable/ArrayBufferBenchmark.scala b/benchmarks/time/src/main/scala/strawman/collection/mutable/ArrayBufferBenchmark.scala index 8a19f0070b..c1c8b20491 100644 --- a/benchmarks/time/src/main/scala/strawman/collection/mutable/ArrayBufferBenchmark.scala +++ b/benchmarks/time/src/main/scala/strawman/collection/mutable/ArrayBufferBenchmark.scala @@ -78,4 +78,10 @@ class ArrayBufferBenchmark { @Benchmark def map(bh: Blackhole): Unit = bh.consume(xs.map(x => x + 1)) + @Benchmark + def groupBy(bh: Blackhole): Unit = { + val result = xs.groupBy(_ % 5) + bh.consume(result) + } + } diff --git a/benchmarks/time/src/main/scala/strawman/collection/mutable/ListBufferBenchmark.scala b/benchmarks/time/src/main/scala/strawman/collection/mutable/ListBufferBenchmark.scala index cec46694f4..00e4765002 100644 --- a/benchmarks/time/src/main/scala/strawman/collection/mutable/ListBufferBenchmark.scala +++ b/benchmarks/time/src/main/scala/strawman/collection/mutable/ListBufferBenchmark.scala @@ -77,4 +77,10 @@ class ListBufferBenchmark { @Benchmark def map(bh: Blackhole): Unit = bh.consume(xs.map(x => x + 1)) + @Benchmark + def groupBy(bh: Blackhole): Unit = { + val result = xs.groupBy(_ % 5) + bh.consume(result) + } + } diff --git a/src/main/scala/strawman/collection/Iterable.scala b/src/main/scala/strawman/collection/Iterable.scala index 545f86f529..71120a5342 100644 --- a/src/main/scala/strawman/collection/Iterable.scala +++ b/src/main/scala/strawman/collection/Iterable.scala @@ -237,6 +237,33 @@ trait IterableOps[+A, +CC[X], +C] extends Any { def slice(from: Int, until: Int): C = fromSpecificIterable(View.Take(View.Drop(coll, from), until - from)) + /** Partitions this $coll into a map of ${coll}s according to some discriminator function. + * + * Note: When applied to a view or a lazy collection it will always force the elements. + * + * @param f the discriminator function. + * @tparam K the type of keys returned by the discriminator function. + * @return A map from keys to ${coll}s such that the following invariant holds: + * {{{ + * (xs groupBy f)(k) = xs filter (x => f(x) == k) + * }}} + * That is, every key `k` is bound to a $coll of those elements `x` + * for which `f(x)` equals `k`. + * + */ + def groupBy[K](f: A => K): immutable.Map[K, C] = { + val m = mutable.Map.empty[K, Builder[A, C]] + for (elem <- coll) { + val key = f(elem) + val bldr = m.getOrElseUpdate(key, newSpecificBuilder()) + bldr += elem + } + var result = immutable.Map.empty[K, C] + m.foreach { case (k, v) => + result = result + ((k, v.result())) + } + result + } /** Map */ def map[B](f: A => B): CC[B] = fromIterable(View.Map(coll, f)) diff --git a/src/main/scala/strawman/collection/Map.scala b/src/main/scala/strawman/collection/Map.scala index 3b2a61107a..1a4e3e51e5 100644 --- a/src/main/scala/strawman/collection/Map.scala +++ b/src/main/scala/strawman/collection/Map.scala @@ -27,6 +27,22 @@ trait MapOps[K, +V, +CC[X, Y] <: Map[X, Y], +C <: Map[K, V]] */ def get(key: K): Option[V] + /** Returns the value associated with a key, or a default value if the key is not contained in the map. + * @param key the key. + * @param default a computation that yields a default value in case no binding for `key` is + * found in the map. + * @tparam V1 the result type of the default computation. + * @return the value associated with `key` if it exists, + * otherwise the result of the `default` computation. + * + * @usecase def getOrElse(key: K, default: => V): V + * @inheritdoc + */ + def getOrElse[V1 >: V](key: K, default: => V1): V1 = get(key) match { + case Some(v) => v + case None => default + } + /** Retrieves the value which is associated with the given key. This * method invokes the `default` method of the map if there is no mapping * from the given key to a value. Unless overridden, the `default` method throws a diff --git a/src/main/scala/strawman/collection/immutable/BitSet.scala b/src/main/scala/strawman/collection/immutable/BitSet.scala index b9052554c5..2b031a30b6 100644 --- a/src/main/scala/strawman/collection/immutable/BitSet.scala +++ b/src/main/scala/strawman/collection/immutable/BitSet.scala @@ -67,6 +67,11 @@ object BitSet extends SpecificIterableFactoryWithBuilder[Int, BitSet] { case _ => empty.++(it) } + def empty: BitSet = new BitSet1(0L) + + def newBuilder(): Builder[Int, BitSet] = + new GrowableBuilder(mutable.BitSet.empty).mapResult(bs => fromBitMaskNoCopy(bs.elems)) + private def createSmall(a: Long, b: Long): BitSet = if (b == 0L) new BitSet1(a) else new BitSet2(a, b) /** A bitset containing all the bits in an array */ @@ -92,11 +97,6 @@ object BitSet extends SpecificIterableFactoryWithBuilder[Int, BitSet] { else new BitSetN(elems) } - def empty: BitSet = new BitSet1(0L) - - def newBuilder(): Builder[Int, BitSet] = - new GrowableBuilder(mutable.BitSet.empty).mapResult(bs => fromBitMaskNoCopy(bs.elems)) - @SerialVersionUID(2260107458435649300L) class BitSet1(val elems: Long) extends BitSet { protected[collection] def nwords = 1 diff --git a/src/main/scala/strawman/collection/immutable/HashMap.scala b/src/main/scala/strawman/collection/immutable/HashMap.scala index 3ee86bce53..1dbfb5b24e 100644 --- a/src/main/scala/strawman/collection/immutable/HashMap.scala +++ b/src/main/scala/strawman/collection/immutable/HashMap.scala @@ -1,7 +1,7 @@ package strawman package collection.immutable -import collection.{Iterator, MapFactory, MapFactoryWithBuilder, StrictOptimizedIterableOps} +import collection.{Iterator, MapFactoryWithBuilder, StrictOptimizedIterableOps} import collection.Hashing.{computeHash, keepBits} import scala.annotation.unchecked.{uncheckedVariance => uV} diff --git a/src/main/scala/strawman/collection/immutable/ImmutableArray.scala b/src/main/scala/strawman/collection/immutable/ImmutableArray.scala index 2227d4d1a4..f9c3cc5f3a 100644 --- a/src/main/scala/strawman/collection/immutable/ImmutableArray.scala +++ b/src/main/scala/strawman/collection/immutable/ImmutableArray.scala @@ -14,7 +14,7 @@ import scala.Predef.{???, intWrapper} */ class ImmutableArray[+A] private (private val elements: scala.Array[Any]) extends IndexedSeq[A] - with SeqOps[A, ImmutableArray, ImmutableArray[A]] + with IndexedSeqOps[A, ImmutableArray, ImmutableArray[A]] with StrictOptimizedIterableOps[A, ImmutableArray[A]] { def iterableFactory: IterableFactory[ImmutableArray] = ImmutableArray diff --git a/src/main/scala/strawman/collection/immutable/LazyList.scala b/src/main/scala/strawman/collection/immutable/LazyList.scala index 5c9c918628..ae097f6682 100644 --- a/src/main/scala/strawman/collection/immutable/LazyList.scala +++ b/src/main/scala/strawman/collection/immutable/LazyList.scala @@ -68,5 +68,6 @@ object LazyList extends IterableFactory[LazyList] { def fromIterator[A](it: Iterator[A]): LazyList[A] = new LazyList(if (it.hasNext) Some(it.next(), fromIterator(it)) else None) - def empty[A]: LazyList[A] = new LazyList[A](None) + def empty[A]: LazyList[A] = Empty + } diff --git a/src/main/scala/strawman/collection/immutable/TreeSet.scala b/src/main/scala/strawman/collection/immutable/TreeSet.scala index ea910755e9..f7a1ea4186 100644 --- a/src/main/scala/strawman/collection/immutable/TreeSet.scala +++ b/src/main/scala/strawman/collection/immutable/TreeSet.scala @@ -34,6 +34,16 @@ final class TreeSet[A] private (tree: RB.Tree[A, Unit])(implicit val ordering: O def this()(implicit ordering: Ordering[A]) = this(null)(ordering) + def iterableFactory = Set + + protected[this] def fromSpecificIterable(coll: strawman.collection.Iterable[A]): TreeSet[A] = + TreeSet.sortedFromIterable(coll) + + protected[this] def sortedFromIterable[B : Ordering](coll: strawman.collection.Iterable[B]): TreeSet[B] = + TreeSet.sortedFromIterable(coll) + + protected[this] def newSpecificBuilder(): Builder[A, TreeSet[A]] = TreeSet.newBuilder() + private def newSet(t: RB.Tree[A, Unit]) = new TreeSet[A](t) override def size: Int = RB.count(tree) @@ -62,16 +72,6 @@ final class TreeSet[A] private (tree: RB.Tree[A, Unit])(implicit val ordering: O def keysIteratorFrom(start: A): Iterator[A] = RB.keysIterator(tree, Some(start)) - def iterableFactory = Set - - protected[this] def fromSpecificIterable(coll: strawman.collection.Iterable[A]): TreeSet[A] = - TreeSet.sortedFromIterable(coll) - - protected[this] def sortedFromIterable[B : Ordering](coll: strawman.collection.Iterable[B]): TreeSet[B] = - TreeSet.sortedFromIterable(coll) - - protected[this] def newSpecificBuilder(): Builder[A, TreeSet[A]] = TreeSet.newBuilder() - def unordered: Set[A] = this /** Checks if this set contains element `elem`. diff --git a/src/main/scala/strawman/collection/mutable/HashMap.scala b/src/main/scala/strawman/collection/mutable/HashMap.scala index 192c398d57..ad6f8243a0 100644 --- a/src/main/scala/strawman/collection/mutable/HashMap.scala +++ b/src/main/scala/strawman/collection/mutable/HashMap.scala @@ -80,6 +80,22 @@ final class HashMap[K, V] private[collection] (contents: HashTable.Contents[K, D else { val v = e.value; e.value = value; Some(v) } } + override def getOrElseUpdate(key: K, defaultValue: => V): V = { + val hash = table.elemHashCode(key) + val i = table.index(hash) + val entry = table.findEntry0(key, i) + if (entry != null) entry.value + else { + val table0 = table + val default = defaultValue + // Avoid recomputing index if the `defaultValue()` hasn't triggered + // a table resize. + val newEntryIndex = if (table0 eq table) i else table.index(hash) + table.addEntry0(table.createNewEntry(key, default), newEntryIndex) + default + } + } + private def writeObject(out: java.io.ObjectOutputStream): Unit = { table.serializeTo(out, { entry => out.writeObject(entry.key) diff --git a/src/main/scala/strawman/collection/mutable/HashTable.scala b/src/main/scala/strawman/collection/mutable/HashTable.scala index f32dc36bb0..28671c7fa8 100644 --- a/src/main/scala/strawman/collection/mutable/HashTable.scala +++ b/src/main/scala/strawman/collection/mutable/HashTable.scala @@ -139,7 +139,7 @@ private[mutable] abstract class HashTable[A, B, Entry >: Null <: HashEntry[A, En final def findEntry(key: A): Entry = findEntry0(key, index(elemHashCode(key))) - protected[this] final def findEntry0(key: A, h: Int): Entry = { + protected[collection] final def findEntry0(key: A, h: Int): Entry = { var e = table(h).asInstanceOf[Entry] while (e != null && !elemEquals(e.key, key)) e = e.next e @@ -152,7 +152,7 @@ private[mutable] abstract class HashTable[A, B, Entry >: Null <: HashEntry[A, En addEntry0(e, index(elemHashCode(e.key))) } - protected[this] final def addEntry0(e: Entry, h: Int): Unit = { + protected[collection] final def addEntry0(e: Entry, h: Int): Unit = { e.next = table(h).asInstanceOf[Entry] table(h) = e tableSize = tableSize + 1 @@ -363,7 +363,7 @@ private[mutable] abstract class HashTable[A, B, Entry >: Null <: HashEntry[A, En * Note: we take the most significant bits of the hashcode, not the lower ones * this is of crucial importance when populating the table in parallel */ - protected final def index(hcode: Int): Int = { + protected[collection] final def index(hcode: Int): Int = { val ones = table.length - 1 val exponent = Integer.numberOfLeadingZeros(ones) (improve(hcode, seedvalue) >>> exponent) & ones @@ -408,7 +408,7 @@ private[collection] object HashTable { // so that: protected final def sizeMapBucketSize = 1 << sizeMapBucketBitSize - protected def elemHashCode(key: KeyType) = key.## + protected[collection] def elemHashCode(key: KeyType) = key.## /** * Defer to a high-quality hash in [[scala.util.hashing]]. diff --git a/src/main/scala/strawman/collection/mutable/Map.scala b/src/main/scala/strawman/collection/mutable/Map.scala index c1ddd27be8..de69f583c2 100644 --- a/src/main/scala/strawman/collection/mutable/Map.scala +++ b/src/main/scala/strawman/collection/mutable/Map.scala @@ -4,7 +4,7 @@ package mutable import strawman.collection.{IterableOnce, MapFactory} -import scala.{Boolean, Option, Unit, `inline`} +import scala.{Boolean, None, Option, Some, Unit, `inline`} /** Base type of mutable Maps */ trait Map[K, V] @@ -47,6 +47,26 @@ trait MapOps[K, V, +CC[X, Y] <: Map[X, Y], +C <: Map[K, V]] */ def update(key: K, value: V): Unit = { coll += ((key, value)) } + /** If given key is already in this map, returns associated value. + * + * Otherwise, computes value from given expression `op`, stores with key + * in map and returns that value. + * + * Concurrent map implementations may evaluate the expression `op` + * multiple times, or may evaluate `op` without inserting the result. + * + * @param key the key to test + * @param op the computation yielding the value to associate with `key`, if + * `key` is previously unbound. + * @return the value associated with key (either previously or as a result + * of executing the method). + */ + def getOrElseUpdate(key: K, op: => V): V = + get(key) match { + case Some(v) => v + case None => val d = op; this(key) = d; d + } + override def clone(): C = empty ++= coll def mapInPlace(f: ((K, V)) => (K, V)): this.type = {