Skip to content

Commit 7293f3a

Browse files
KAFKA-19183 Replace Pool with ConcurrentHashMap (#19535)
1. Replace `Pool.scala` with `ConcurrentHashMap`. 2. Remove `PoolTest.scala`. Reviewers: Chia-Ping Tsai <[email protected]>
1 parent 51ef290 commit 7293f3a

File tree

14 files changed

+85
-233
lines changed

14 files changed

+85
-233
lines changed

checkstyle/import-control-jmh-benchmarks.xml

-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@
4242
<allow pkg="kafka.network"/>
4343
<allow pkg="kafka.utils"/>
4444
<allow pkg="kafka.zk"/>
45-
<allow class="kafka.utils.Pool"/>
4645
<allow class="kafka.utils.KafkaScheduler"/>
4746
<allow class="org.apache.kafka.clients.FetchSessionHandler"/>
4847
<allow pkg="kafka.common"/>

core/src/main/scala/kafka/cluster/Partition.scala

+6-7
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ package kafka.cluster
1919
import java.lang.{Long => JLong}
2020
import java.util.concurrent.locks.ReentrantReadWriteLock
2121
import java.util.Optional
22-
import java.util.concurrent.{CompletableFuture, CopyOnWriteArrayList}
22+
import java.util.concurrent.{CompletableFuture, ConcurrentHashMap, CopyOnWriteArrayList}
2323
import kafka.controller.StateChangeLogger
2424
import kafka.log._
2525
import kafka.server._
@@ -322,7 +322,7 @@ class Partition(val topicPartition: TopicPartition,
322322
def partitionId: Int = topicPartition.partition
323323

324324
private val stateChangeLogger = new StateChangeLogger(localBrokerId, inControllerContext = false, None)
325-
private val remoteReplicasMap = new Pool[Int, Replica]
325+
private val remoteReplicasMap = new ConcurrentHashMap[Int, Replica]
326326
// The read lock is only required when multiple reads are executed and needs to be in a consistent manner
327327
private val leaderIsrUpdateLock = new ReentrantReadWriteLock
328328

@@ -604,7 +604,7 @@ class Partition(val topicPartition: TopicPartition,
604604

605605
// remoteReplicas will be called in the hot path, and must be inexpensive
606606
def remoteReplicas: Iterable[Replica] =
607-
remoteReplicasMap.values
607+
remoteReplicasMap.values.asScala
608608

609609
def futureReplicaDirChanged(newDestinationDir: String): Boolean = {
610610
inReadLock(leaderIsrUpdateLock) {
@@ -983,12 +983,11 @@ class Partition(val topicPartition: TopicPartition,
983983
): Unit = {
984984
if (isLeader) {
985985
val followers = replicas.filter(_ != localBrokerId)
986-
val removedReplicas = remoteReplicasMap.keys.filterNot(followers.contains(_))
987986

988987
// Due to code paths accessing remoteReplicasMap without a lock,
989988
// first add the new replicas and then remove the old ones.
990-
followers.foreach(id => remoteReplicasMap.getAndMaybePut(id, new Replica(id, topicPartition, metadataCache)))
991-
remoteReplicasMap.removeAll(removedReplicas)
989+
followers.foreach(id => remoteReplicasMap.computeIfAbsent(id, _ => new Replica(id, topicPartition, metadataCache)))
990+
remoteReplicasMap.keySet.removeIf(replica => !followers.contains(replica))
992991
} else {
993992
remoteReplicasMap.clear()
994993
}
@@ -1158,7 +1157,7 @@ class Partition(val topicPartition: TopicPartition,
11581157
// avoid unnecessary collection generation
11591158
val leaderLogEndOffset = leaderLog.logEndOffsetMetadata
11601159
var newHighWatermark = leaderLogEndOffset
1161-
remoteReplicasMap.values.foreach { replica =>
1160+
remoteReplicasMap.forEach { (_, replica) =>
11621161
val replicaState = replica.stateSnapshot
11631162

11641163
def shouldWaitForReplicaToJoinIsr: Boolean = {

core/src/main/scala/kafka/coordinator/transaction/TransactionStateManager.scala

+24-24
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,12 @@ package kafka.coordinator.transaction
1818

1919
import java.nio.ByteBuffer
2020
import java.util.Properties
21+
import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
2122
import java.util.concurrent.atomic.AtomicBoolean
2223
import java.util.concurrent.locks.ReentrantReadWriteLock
2324
import kafka.server.ReplicaManager
2425
import kafka.utils.CoreUtils.{inReadLock, inWriteLock}
25-
import kafka.utils.{Logging, Pool}
26+
import kafka.utils.Logging
2627
import org.apache.kafka.common.config.TopicConfig
2728
import org.apache.kafka.common.internals.Topic
2829
import org.apache.kafka.common.message.ListTransactionsResponseData
@@ -126,7 +127,7 @@ class TransactionStateManager(brokerId: Int,
126127
val now = time.milliseconds()
127128
inReadLock(stateLock) {
128129
transactionMetadataCache.flatMap { case (_, entry) =>
129-
entry.metadataPerTransactionalId.filter { case (_, txnMetadata) =>
130+
entry.metadataPerTransactionalId.asScala.filter { case (_, txnMetadata) =>
130131
if (txnMetadata.pendingTransitionInProgress) {
131132
false
132133
} else {
@@ -156,7 +157,7 @@ class TransactionStateManager(brokerId: Int,
156157
val maxBatchSize = logConfig.maxMessageSize
157158
val expired = mutable.ListBuffer.empty[TransactionalIdCoordinatorEpochAndMetadata]
158159
var recordsBuilder: MemoryRecordsBuilder = null
159-
val stateEntries = txnMetadataCacheEntry.metadataPerTransactionalId.values.iterator.buffered
160+
val stateEntries = txnMetadataCacheEntry.metadataPerTransactionalId.values.asScala.iterator.buffered
160161

161162
def flushRecordsBuilder(): Unit = {
162163
writeTombstonesForExpiredTransactionalIds(
@@ -350,7 +351,7 @@ class TransactionStateManager(brokerId: Int,
350351

351352
val states = new java.util.ArrayList[ListTransactionsResponseData.TransactionState]
352353
transactionMetadataCache.foreachEntry { (_, cache) =>
353-
cache.metadataPerTransactionalId.values.foreach { txnMetadata =>
354+
cache.metadataPerTransactionalId.forEach { (_, txnMetadata) =>
354355
txnMetadata.inLock {
355356
if (shouldInclude(txnMetadata)) {
356357
states.add(new ListTransactionsResponseData.TransactionState()
@@ -386,7 +387,7 @@ class TransactionStateManager(brokerId: Int,
386387
case Some(cacheEntry) =>
387388
val txnMetadata = Option(cacheEntry.metadataPerTransactionalId.get(transactionalId)).orElse {
388389
createdTxnMetadataOpt.map { createdTxnMetadata =>
389-
Option(cacheEntry.metadataPerTransactionalId.putIfNotExists(transactionalId, createdTxnMetadata))
390+
Option(cacheEntry.metadataPerTransactionalId.putIfAbsent(transactionalId, createdTxnMetadata))
390391
.getOrElse(createdTxnMetadata)
391392
}
392393
}
@@ -428,10 +429,10 @@ class TransactionStateManager(brokerId: Int,
428429

429430
def partitionFor(transactionalId: String): Int = Utils.abs(transactionalId.hashCode) % transactionTopicPartitionCount
430431

431-
private def loadTransactionMetadata(topicPartition: TopicPartition, coordinatorEpoch: Int): Pool[String, TransactionMetadata] = {
432+
private def loadTransactionMetadata(topicPartition: TopicPartition, coordinatorEpoch: Int): ConcurrentMap[String, TransactionMetadata] = {
432433
def logEndOffset = replicaManager.getLogEndOffset(topicPartition).getOrElse(-1L)
433434

434-
val loadedTransactions = new Pool[String, TransactionMetadata]
435+
val loadedTransactions = new ConcurrentHashMap[String, TransactionMetadata]
435436

436437
replicaManager.getLog(topicPartition) match {
437438
case None =>
@@ -509,7 +510,7 @@ class TransactionStateManager(brokerId: Int,
509510
*/
510511
private[transaction] def addLoadedTransactionsToCache(txnTopicPartition: Int,
511512
coordinatorEpoch: Int,
512-
loadedTransactions: Pool[String, TransactionMetadata]): Unit = {
513+
loadedTransactions: ConcurrentMap[String, TransactionMetadata]): Unit = {
513514
val txnMetadataCacheEntry = TxnMetadataCacheEntry(coordinatorEpoch, loadedTransactions)
514515
val previousTxnMetadataCacheEntryOpt = transactionMetadataCache.put(txnTopicPartition, txnMetadataCacheEntry)
515516

@@ -549,22 +550,21 @@ class TransactionStateManager(brokerId: Int,
549550
addLoadedTransactionsToCache(topicPartition.partition, coordinatorEpoch, loadedTransactions)
550551

551552
val transactionsPendingForCompletion = new mutable.ListBuffer[TransactionalIdCoordinatorEpochAndTransitMetadata]
552-
loadedTransactions.foreach {
553-
case (transactionalId, txnMetadata) =>
554-
txnMetadata.inLock {
555-
// if state is PrepareCommit or PrepareAbort we need to complete the transaction
556-
txnMetadata.state match {
557-
case PrepareAbort =>
558-
transactionsPendingForCompletion +=
559-
TransactionalIdCoordinatorEpochAndTransitMetadata(transactionalId, coordinatorEpoch, TransactionResult.ABORT, txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
560-
case PrepareCommit =>
561-
transactionsPendingForCompletion +=
562-
TransactionalIdCoordinatorEpochAndTransitMetadata(transactionalId, coordinatorEpoch, TransactionResult.COMMIT, txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
563-
case _ =>
564-
// nothing needs to be done
565-
}
553+
loadedTransactions.forEach((transactionalId, txnMetadata) => {
554+
txnMetadata.inLock {
555+
// if state is PrepareCommit or PrepareAbort we need to complete the transaction
556+
txnMetadata.state match {
557+
case PrepareAbort =>
558+
transactionsPendingForCompletion +=
559+
TransactionalIdCoordinatorEpochAndTransitMetadata(transactionalId, coordinatorEpoch, TransactionResult.ABORT, txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
560+
case PrepareCommit =>
561+
transactionsPendingForCompletion +=
562+
TransactionalIdCoordinatorEpochAndTransitMetadata(transactionalId, coordinatorEpoch, TransactionResult.COMMIT, txnMetadata, txnMetadata.prepareComplete(time.milliseconds()))
563+
case _ =>
564+
// nothing needs to be done
566565
}
567-
}
566+
}
567+
})
568568

569569
// we first remove the partition from loading partition then send out the markers for those pending to be
570570
// completed transactions, so that when the markers get sent the attempt of appending the complete transaction
@@ -820,7 +820,7 @@ class TransactionStateManager(brokerId: Int,
820820

821821

822822
private[transaction] case class TxnMetadataCacheEntry(coordinatorEpoch: Int,
823-
metadataPerTransactionalId: Pool[String, TransactionMetadata]) {
823+
metadataPerTransactionalId: ConcurrentMap[String, TransactionMetadata]) {
824824
override def toString: String = {
825825
s"TxnMetadataCacheEntry(coordinatorEpoch=$coordinatorEpoch, numTransactionalEntries=${metadataPerTransactionalId.size})"
826826
}

core/src/main/scala/kafka/log/LogManager.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ import java.util.concurrent._
2424
import java.util.concurrent.atomic.AtomicInteger
2525
import kafka.server.{KafkaConfig, KafkaRaftServer}
2626
import kafka.utils.threadsafe
27-
import kafka.utils.{CoreUtils, Logging, Pool}
27+
import kafka.utils.{CoreUtils, Logging}
2828
import org.apache.kafka.common.{DirectoryId, KafkaException, TopicPartition, Uuid}
2929
import org.apache.kafka.common.utils.{Exit, KafkaThread, Time, Utils}
3030
import org.apache.kafka.common.errors.{InconsistentTopicIdException, KafkaStorageException, LogDirNotFoundException}
@@ -92,7 +92,7 @@ class LogManager(logDirs: Seq[File],
9292

9393
// Map of stray partition to stray log. This holds all stray logs detected on the broker.
9494
// Visible for testing
95-
private val strayLogs = new Pool[TopicPartition, UnifiedLog]()
95+
private val strayLogs = new ConcurrentHashMap[TopicPartition, UnifiedLog]()
9696

9797
private val _liveLogDirs: ConcurrentLinkedQueue[File] = createAndValidateLogDirs(logDirs, initialOfflineDirs)
9898
@volatile private var _currentDefaultConfig = initialDefaultConfig

core/src/main/scala/kafka/server/AbstractFetcherManager.scala

+1-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,7 @@ abstract class AbstractFetcherManager[T <: AbstractFetcherThread](val name: Stri
4343
metricsGroup.newGauge("MaxLag", () => {
4444
// current max lag across all fetchers/topics/partitions
4545
fetcherThreadMap.values.foldLeft(0L) { (curMaxLagAll, fetcherThread) =>
46-
val maxLagThread = fetcherThread.fetcherLagStats.stats.values.foldLeft(0L)((curMaxLagThread, lagMetrics) =>
47-
math.max(curMaxLagThread, lagMetrics.lag))
46+
val maxLagThread = fetcherThread.fetcherLagStats.stats.values.stream().mapToLong(v => v.lag).max().orElse(0L)
4847
math.max(curMaxLagAll, maxLagThread)
4948
}
5049
}, tags)

core/src/main/scala/kafka/server/AbstractFetcherThread.scala

+5-8
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package kafka.server
2020
import com.yammer.metrics.core.Meter
2121
import kafka.server.AbstractFetcherThread.{ReplicaFetch, ResultWithPartitions}
2222
import kafka.utils.CoreUtils.inLock
23-
import kafka.utils.{Logging, Pool}
23+
import kafka.utils.Logging
2424
import org.apache.kafka.common.errors._
2525
import org.apache.kafka.common.internals.PartitionStates
2626
import org.apache.kafka.common.message.OffsetForLeaderEpochResponseData.EpochEndOffset
@@ -41,7 +41,7 @@ import org.apache.kafka.storage.log.metrics.BrokerTopicStats
4141
import java.nio.ByteBuffer
4242
import java.util
4343
import java.util.Optional
44-
import java.util.concurrent.TimeUnit
44+
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
4545
import java.util.concurrent.atomic.AtomicLong
4646
import java.util.concurrent.locks.ReentrantLock
4747
import scala.collection.{Map, Set, mutable}
@@ -903,11 +903,10 @@ class FetcherLagMetrics(metricId: ClientIdTopicPartition) {
903903
}
904904

905905
class FetcherLagStats(metricId: ClientIdAndBroker) {
906-
private val valueFactory = (k: TopicPartition) => new FetcherLagMetrics(ClientIdTopicPartition(metricId.clientId, k))
907-
val stats = new Pool[TopicPartition, FetcherLagMetrics](Some(valueFactory))
906+
val stats = new ConcurrentHashMap[TopicPartition, FetcherLagMetrics]
908907

909908
def getAndMaybePut(topicPartition: TopicPartition): FetcherLagMetrics = {
910-
stats.getAndMaybePut(topicPartition)
909+
stats.computeIfAbsent(topicPartition, k => new FetcherLagMetrics(ClientIdTopicPartition(metricId.clientId, k)))
911910
}
912911

913912
def unregister(topicPartition: TopicPartition): Unit = {
@@ -916,9 +915,7 @@ class FetcherLagStats(metricId: ClientIdAndBroker) {
916915
}
917916

918917
def unregister(): Unit = {
919-
stats.keys.toBuffer.foreach { key: TopicPartition =>
920-
unregister(key)
921-
}
918+
stats.forEach((key, _) => unregister(key))
922919
}
923920
}
924921

core/src/main/scala/kafka/server/DelayedProduce.scala

+7-9
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717

1818
package kafka.server
1919

20-
import java.util.concurrent.TimeUnit
20+
import java.util.concurrent.{ConcurrentHashMap, TimeUnit}
2121
import java.util.concurrent.locks.Lock
2222
import com.typesafe.scalalogging.Logger
2323
import com.yammer.metrics.core.Meter
24-
import kafka.utils.{Logging, Pool}
24+
import kafka.utils.Logging
2525
import org.apache.kafka.common.TopicPartition
2626
import org.apache.kafka.common.protocol.Errors
2727
import org.apache.kafka.common.requests.ProduceResponse.PartitionResponse
@@ -142,15 +142,13 @@ object DelayedProduceMetrics {
142142

143143
private val aggregateExpirationMeter = metricsGroup.newMeter("ExpiresPerSec", "requests", TimeUnit.SECONDS)
144144

145-
private val partitionExpirationMeterFactory = (key: TopicPartition) =>
146-
metricsGroup.newMeter("ExpiresPerSec",
147-
"requests",
148-
TimeUnit.SECONDS,
149-
Map("topic" -> key.topic, "partition" -> key.partition.toString).asJava)
150-
private val partitionExpirationMeters = new Pool[TopicPartition, Meter](valueFactory = Some(partitionExpirationMeterFactory))
145+
private val partitionExpirationMeters = new ConcurrentHashMap[TopicPartition, Meter]
151146

152147
def recordExpiration(partition: TopicPartition): Unit = {
153148
aggregateExpirationMeter.mark()
154-
partitionExpirationMeters.getAndMaybePut(partition).mark()
149+
partitionExpirationMeters.computeIfAbsent(partition, key => metricsGroup.newMeter("ExpiresPerSec",
150+
"requests",
151+
TimeUnit.SECONDS,
152+
Map("topic" -> key.topic, "partition" -> key.partition.toString).asJava)).mark()
155153
}
156154
}

core/src/main/scala/kafka/server/ReplicaManager.scala

+8-10
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ import java.nio.file.{Files, Paths}
7171
import java.util
7272
import java.util.concurrent.atomic.AtomicBoolean
7373
import java.util.concurrent.locks.Lock
74-
import java.util.concurrent.{CompletableFuture, Future, RejectedExecutionException, TimeUnit}
74+
import java.util.concurrent.{CompletableFuture, ConcurrentHashMap, Future, RejectedExecutionException, TimeUnit}
7575
import java.util.{Collections, Optional, OptionalInt, OptionalLong}
7676
import java.util.function.Consumer
7777
import scala.collection.{Map, Seq, Set, immutable, mutable}
@@ -317,9 +317,7 @@ class ReplicaManager(val config: KafkaConfig,
317317
/* epoch of the controller that last changed the leader */
318318
@volatile private[server] var controllerEpoch: Int = 0
319319
protected val localBrokerId = config.brokerId
320-
protected val allPartitions = new Pool[TopicPartition, HostedPartition](
321-
valueFactory = Some(tp => HostedPartition.Online(Partition(tp, time, this)))
322-
)
320+
protected val allPartitions = new ConcurrentHashMap[TopicPartition, HostedPartition]
323321
private val replicaStateChangeLock = new Object
324322
val replicaFetcherManager = createReplicaFetcherManager(metrics, time, threadNamePrefix, quotaManagers.follower)
325323
private[server] val replicaAlterLogDirsManager = createReplicaAlterLogDirsManager(quotaManagers.alterLogDirs, brokerTopicStats)
@@ -402,7 +400,7 @@ class ReplicaManager(val config: KafkaConfig,
402400
}
403401

404402
private def maybeRemoveTopicMetrics(topic: String): Unit = {
405-
val topicHasNonOfflinePartition = allPartitions.values.exists {
403+
val topicHasNonOfflinePartition = allPartitions.values.asScala.exists {
406404
case online: HostedPartition.Online => topic == online.partition.topic
407405
case HostedPartition.None | HostedPartition.Offline(_) => false
408406
}
@@ -561,14 +559,14 @@ class ReplicaManager(val config: KafkaConfig,
561559
// An iterator over all non offline partitions. This is a weakly consistent iterator; a partition made offline after
562560
// the iterator has been constructed could still be returned by this iterator.
563561
private def onlinePartitionsIterator: Iterator[Partition] = {
564-
allPartitions.values.iterator.flatMap {
562+
allPartitions.values.asScala.iterator.flatMap {
565563
case HostedPartition.Online(partition) => Some(partition)
566564
case _ => None
567565
}
568566
}
569567

570568
private def offlinePartitionCount: Int = {
571-
allPartitions.values.iterator.count(_.getClass == HostedPartition.Offline.getClass)
569+
allPartitions.values.asScala.iterator.count(_.getClass == HostedPartition.Offline.getClass)
572570
}
573571

574572
def getPartitionOrException(topicPartition: TopicPartition): Partition = {
@@ -2071,7 +2069,7 @@ class ReplicaManager(val config: KafkaConfig,
20712069

20722070
case HostedPartition.None =>
20732071
val partition = Partition(topicPartition, time, this)
2074-
allPartitions.putIfNotExists(topicPartition, HostedPartition.Online(partition))
2072+
allPartitions.putIfAbsent(topicPartition, HostedPartition.Online(partition))
20752073
Some(partition)
20762074
}
20772075

@@ -2512,7 +2510,7 @@ class ReplicaManager(val config: KafkaConfig,
25122510
trace("Evaluating ISR list of partitions to see which replicas can be removed from the ISR")
25132511

25142512
// Shrink ISRs for non offline partitions
2515-
allPartitions.keys.foreach { topicPartition =>
2513+
allPartitions.forEach { (topicPartition, _) =>
25162514
onlinePartition(topicPartition).foreach(_.maybeShrinkIsr())
25172515
}
25182516
}
@@ -2643,7 +2641,7 @@ class ReplicaManager(val config: KafkaConfig,
26432641

26442642
private def removeAllTopicMetrics(): Unit = {
26452643
val allTopics = new util.HashSet[String]
2646-
allPartitions.keys.foreach(partition =>
2644+
allPartitions.forEach((partition, _) =>
26472645
if (allTopics.add(partition.topic())) {
26482646
brokerTopicStats.removeMetrics(partition.topic())
26492647
})

0 commit comments

Comments
 (0)