Skip to content
Merged
22 changes: 20 additions & 2 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "tensorrt_llm/executor/executor.h"

#include <atomic>
#include <chrono>
#include <condition_variable>
#include <deque>
Expand All @@ -36,7 +37,8 @@ using BlockPtr = std::shared_ptr<KVCacheBlock>;
class KVCacheEventManager
{
public:
explicit KVCacheEventManager(size_t maxKVEventEntries);
explicit KVCacheEventManager(size_t maxKVEventEntries, std::optional<SizeType32> attentionDpRank = std::nullopt,
std::optional<SizeType32> attentionDpSize = std::nullopt, SizeType32 attentionDpEventsGatherPeriodMs = 5);

~KVCacheEventManager();
KVCacheEventManager(KVCacheEventManager& other) = delete;
Expand All @@ -61,14 +63,19 @@ class KVCacheEventManager
// Worker thread which adds events to mEvents.
void worker();

// Thread which exchanges events if attentionDP is enabled
void exchangeAttentionDpThread();

private:
// Add an event to mEventQueue
void enqueueEvent(executor::KVCacheEvent&& event);

/// @brief Flag to terminate the worker
bool mRun;
std::atomic<bool> mRun;
/// @brief Worker thread
std::thread mWorkerThread;
/// @brief Exchange thread for attention DP events
std::thread mExchangeAttentionDpThread;

/// @brief The deque of events
std::deque<executor::KVCacheEvent> mEvents;
Expand All @@ -91,6 +98,17 @@ class KVCacheEventManager
size_t mMaxSize;
/// @brief An auto-incrementing event id counter
size_t mEventId;

/// @brief Attention DP ranks and size
/// If set, we will exchange KV cache events and accumulate on rank 0
std::optional<SizeType32> mAttentionDpRank;
std::optional<SizeType32> mAttentionDpSize;

/// @brief The period in milliseconds to gather attention DP events across rank
SizeType32 mAttentionDpEventsGatherPeriodMs;

/// @brief MPI communicator for attention DP
std::unique_ptr<tensorrt_llm::mpi::MpiComm> mMpiComm;
};

} // namespace tensorrt_llm::batch_manager::kv_cache_manager
18 changes: 16 additions & 2 deletions cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,7 @@ class KvCacheConfig
std::optional<FloatType> const& crossKvCacheFraction = std::nullopt,
std::optional<RetentionPriority> secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0,
bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool useUvm = false,
SizeType32 attentionDpEventsGatherPeriodMs = 5,
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt);

[[nodiscard]] bool getEnableBlockReuse() const;
Expand All @@ -1016,6 +1017,7 @@ class KvCacheConfig
[[nodiscard]] std::optional<RetentionPriority> getSecondaryOffloadMinPriority() const;
[[nodiscard]] size_t getEventBufferMaxSize() const;
[[nodiscard]] bool getUseUvm() const;
[[nodiscard]] SizeType32 getAttentionDpEventsGatherPeriodMs() const;

void setEnableBlockReuse(bool enableBlockReuse);
void setEnablePartialReuse(bool enablePartialReuse);
Expand All @@ -1030,6 +1032,7 @@ class KvCacheConfig
void setSecondaryOffloadMinPriority(std::optional<RetentionPriority> secondaryOffloadMinPriority);
void setEventBufferMaxSize(size_t eventBufferMaxSize);
void setUseUvm(bool useUvm);
void setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventsGatherPeriodMs);

void fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults);

Expand Down Expand Up @@ -1085,6 +1088,9 @@ class KvCacheConfig

/// @brief Whether to use UVM for the KV cache.
bool mUseUvm;

/// @brief The period in milliseconds to gather attention DP events across ranks
SizeType32 mAttentionDpEventsGatherPeriodMs;
};

/// @brief Configuration class for the runtime perf knobs
Expand Down Expand Up @@ -1702,6 +1708,12 @@ struct KVCacheUpdatedData
explicit KVCacheUpdatedData(IdType blockHash)
: blockHash{blockHash} {};

explicit KVCacheUpdatedData(IdType blockHash, std::optional<KVCacheEventDiff<SizeType32>> cacheLevel,
std::optional<KVCacheEventDiff<SizeType32>> priority)
: blockHash{blockHash}
, cacheLevel{cacheLevel}
, priority{priority} {};

KVCacheUpdatedData& cacheLevelUpdated(SizeType32 oldValue, SizeType32 newValue)
{
cacheLevel = KVCacheEventDiff<SizeType32>{oldValue, newValue};
Expand All @@ -1726,15 +1738,17 @@ using KVCacheEventData = std::variant<KVCacheCreatedData, KVCacheStoredData, KVC

struct KVCacheEvent
{

KVCacheEvent(IdType eventId, KVCacheEventData data, SizeType32 windowSize);
KVCacheEvent(IdType eventId, KVCacheEventData data, SizeType32 windowSize,
std::optional<SizeType32> attentionDpRank = std::nullopt);

/// @brief The unique id of this event
IdType eventId;
/// @brief The data corresponding to this event
KVCacheEventData data;
/// @brief The sliding window size
SizeType32 windowSize;
/// @brief The attention DP rank of the event, if applicable
std::optional<SizeType32> attentionDpRank;
};

/// @brief Exposes a limited set of KV cache manager functionalities
Expand Down
47 changes: 47 additions & 0 deletions cpp/include/tensorrt_llm/executor/serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,53 @@ class Serialization
[[nodiscard]] static std::vector<RequestStatsPerIteration> deserializeRequestStatsPerIterationVec(
std::vector<char>& buffer);

// KVCacheEvent deque
[[nodiscard]] static std::vector<char> serialize(std::deque<KVCacheEvent> const& kvCacheEvents);
[[nodiscard]] static std::deque<KVCacheEvent> deserializeKVCacheEvents(std::vector<char>& buffer);

// KVCacheEvent
[[nodiscard]] static size_t serializedSize(KVCacheEvent const& event);
static void serialize(KVCacheEvent const& event, std::ostream& os);
[[nodiscard]] static KVCacheEvent deserializeKVCacheEvent(std::istream& is);

// KVCacheCreatedData
[[nodiscard]] static size_t serializedSize(KVCacheCreatedData const& data);
static void serialize(KVCacheCreatedData const& data, std::ostream& os);
[[nodiscard]] static KVCacheCreatedData deserializeKVCacheCreatedData(std::istream& is);

// KVCacheStoredData
[[nodiscard]] static size_t serializedSize(KVCacheStoredData const& data);
static void serialize(KVCacheStoredData const& data, std::ostream& os);
[[nodiscard]] static KVCacheStoredData deserializeKVCacheStoredData(std::istream& is);

// KVCacheStoredBlockData
[[nodiscard]] static size_t serializedSize(KVCacheStoredBlockData const& data);
static void serialize(KVCacheStoredBlockData const& data, std::ostream& os);
[[nodiscard]] static KVCacheStoredBlockData deserializeKVCacheStoredBlockData(std::istream& is);

// KVCacheRemovedData
[[nodiscard]] static size_t serializedSize(KVCacheRemovedData const& data);
static void serialize(KVCacheRemovedData const& data, std::ostream& os);
[[nodiscard]] static KVCacheRemovedData deserializeKVCacheRemovedData(std::istream& is);

// KVCacheEventDiff
template <typename T>
[[nodiscard]] static size_t serializedSize(KVCacheEventDiff<T> const& data);
template <typename T>
static void serialize(KVCacheEventDiff<T> const& data, std::ostream& os);
template <typename T>
[[nodiscard]] static KVCacheEventDiff<T> deserializeKVCacheEventDiff(std::istream& is);

// KVCacheUpdateData
[[nodiscard]] static size_t serializedSize(KVCacheUpdatedData const& data);
static void serialize(KVCacheUpdatedData const& data, std::ostream& os);
[[nodiscard]] static KVCacheUpdatedData deserializeKVCacheUpdatedData(std::istream& is);

// UniqueToken
[[nodiscard]] static size_t serializedSize(tensorrt_llm::runtime::UniqueToken const& token);
static void serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os);
[[nodiscard]] static tensorrt_llm::runtime::UniqueToken deserializeUniqueToken(std::istream& is);

// String
static std::string deserializeString(std::istream& is);

Expand Down
4 changes: 4 additions & 0 deletions cpp/include/tensorrt_llm/runtime/utils/mpiTags.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ enum class MpiTag : int
// LogitsThread
kSpecDecLogitsId = 129,
kSpecDecLogitsData = 1025,

// KvCacheEventManager
kKvCacheEventSize = 1026,
kKvCacheEvent = 1027
};

} // namespace tensorrt_llm::mpi
119 changes: 113 additions & 6 deletions cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,51 @@
#include "tensorrt_llm/batch_manager/kvCacheEventManager.h"
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
#include "tensorrt_llm/executor/executor.h"
#include "tensorrt_llm/executor/serialization.h"
#include "tensorrt_llm/runtime/utils/mpiUtils.h"

namespace tle = tensorrt_llm::executor;

namespace tensorrt_llm::batch_manager::kv_cache_manager
{

KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries)
KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries, std::optional<SizeType32> attentionDpRank,
std::optional<SizeType32> attentionDpSize, SizeType32 attentionDpEventsGatherPeriodMs)
: mRun{true}
, mMaxSize{maxKVEventEntries}
, mEventId{0}
, mAttentionDpRank{attentionDpRank}
, mAttentionDpSize{attentionDpSize}
, mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs)
{
TLLM_CHECK(mMaxSize > 0);
// mWorkerThread = std::thread(std::bind(&KVCacheEventManager::worker, this));
if (mAttentionDpRank)
{
TLLM_CHECK_WITH_INFO(
mAttentionDpSize.has_value(), "If attention DP rank is set, the attention DP size must also be set");
TLLM_CHECK_WITH_INFO(mAttentionDpRank.value() < mAttentionDpSize.value(),
"Attention DP rank must be less than attention DP size");
if (mAttentionDpRank.value() == 0)
{
// Rank 0 will gather events from all other ranks
// Need to increase size
mMaxSize *= mAttentionDpSize.value();
}
// Create a communicator to be used for event exchange
mMpiComm = std::make_unique<tensorrt_llm::mpi::MpiComm>(COMM_SESSION.split(0, mAttentionDpRank.value()));
}
else
{
TLLM_CHECK_WITH_INFO(
!mAttentionDpSize.has_value(), "If attention DP rank is not set, the attention DP size must not be set");
}
mWorkerThread = std::thread([this]() { this->worker(); });
#if ENABLE_MULTI_DEVICE
if (mAttentionDpRank)
{
mExchangeAttentionDpThread = std::thread([this]() { this->exchangeAttentionDpThread(); });
}
#endif
};

KVCacheEventManager::~KVCacheEventManager()
Expand All @@ -40,12 +71,18 @@ KVCacheEventManager::~KVCacheEventManager()
mPendingEmptyCV.notify_all();
mEmptyCV.notify_all();
mWorkerThread.join();
#if ENABLE_MULTI_DEVICE
if (mAttentionDpRank)
{
mExchangeAttentionDpThread.join();
}
#endif
}

void KVCacheEventManager::enqueueCreatedEvent(
std::vector<SizeType32> const& numBlocksPerCacheLevel, SizeType32 windowSize)
{
enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}, windowSize});
enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}, windowSize, mAttentionDpRank});
}

void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks, SizeType32 windowSize)
Expand All @@ -68,7 +105,7 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector<BlockPtr> const& blocks
block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority());
}

enqueueEvent({mEventId++, data, windowSize});
enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank});
}

void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32 windowSize)
Expand All @@ -81,13 +118,13 @@ void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32
}
else
{
enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize});
enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize, mAttentionDpRank});
}
}

void KVCacheEventManager::enqueueUpdatedEvent(tle::KVCacheUpdatedData const& data, SizeType32 windowSize)
{
enqueueEvent({mEventId++, data, windowSize});
enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank});
}

void KVCacheEventManager::enqueueEvent(tle::KVCacheEvent&& event)
Expand Down Expand Up @@ -120,8 +157,76 @@ void KVCacheEventManager::flush()
mPendingEmptyCV.notify_one();
}

void KVCacheEventManager::exchangeAttentionDpThread()
{
#if ENABLE_MULTI_DEVICE
while (true)
{
TLLM_CHECK(mAttentionDpRank);

// Check if any of the ranks have been shutdown
int32_t numFinished = 0;
int32_t finished = mRun ? 0 : 1;
mMpiComm->allreduce(&finished, &numFinished, 1, mpi::MpiType::kINT32, mpi::MpiOp::SUM);
if (numFinished > 0)
{
TLLM_LOG_INFO("One of the rank has been shut down, exiting");
break;
}

// If we are not rank 0, send events to rank 0
if (mAttentionDpRank.value() != 0)
{
std::vector<char> serializedEvents;
uint64_t numEvents = 0;
{
std::lock_guard<std::mutex> lck(mEventsMutex);
serializedEvents = executor::Serialization::serialize(mEvents);
numEvents = mEvents.size();
mEvents.clear();
}
uint64_t vecSize = numEvents > 0 ? serializedEvents.size() : 0;
mMpiComm->send(&vecSize, 1, mpi::MpiType::kUINT64, 0, mpi::MpiTag::kKvCacheEventSize);
if (vecSize > 0)
{
mMpiComm->send(serializedEvents.data(), serializedEvents.size(), mpi::MpiType::kCHAR, 0,
mpi::MpiTag::kKvCacheEvent);
}
}
else
{
TLLM_CHECK(mAttentionDpSize.has_value());
// Loop until have received events from all ranks
for (int rank = 1; rank < mAttentionDpSize.value(); ++rank)
{
uint64_t vecSize{0};
mMpiComm->recv(&vecSize, 1, mpi::MpiType::kUINT64, rank, mpi::MpiTag::kKvCacheEventSize);
if (vecSize > 0)
{
std::vector<char> serializedEvents(vecSize);
mMpiComm->recv(
serializedEvents.data(), vecSize, mpi::MpiType::kCHAR, rank, mpi::MpiTag::kKvCacheEvent);

// Deserialize the events and add them to the local queue
auto rankEvents = executor::Serialization::deserializeKVCacheEvents(serializedEvents);
{
std::lock_guard<std::mutex> lck(mEventsMutex);
mEvents.insert(mEvents.end(), rankEvents.begin(), rankEvents.end());
mEmptyCV.notify_one();
}
}
}
}
std::this_thread::sleep_for(std::chrono::milliseconds(mAttentionDpEventsGatherPeriodMs));
}
#else
TLLM_THROW("Multi device support is disabled.");
#endif
}

void KVCacheEventManager::worker()
{

while (true)
{
std::deque<tle::KVCacheEvent> events;
Expand Down Expand Up @@ -151,6 +256,8 @@ void KVCacheEventManager::worker()

// If there's still too many events, take from the front of the events queue.
mEvents.insert(mEvents.end(), events.begin() + std::max(0, elementsToRemove), events.end());

// Notify the empty condition variable to wake up any waiting threads
mEmptyCV.notify_one();
}
}
Expand Down
4 changes: 3 additions & 1 deletion cpp/tensorrt_llm/executor/executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,12 @@ std::optional<std::shared_ptr<KVCacheEventManager>> Executor::getKVCacheEventMan
return mImpl->getKVCacheEventManager();
}

KVCacheEvent::KVCacheEvent(size_t eventId, KVCacheEventData data, SizeType32 windowSize)
KVCacheEvent::KVCacheEvent(
size_t eventId, KVCacheEventData data, SizeType32 windowSize, std::optional<SizeType32> attentionDpRank)
: eventId{eventId}
, data{std::move(data)}
, windowSize{windowSize}
, attentionDpRank{attentionDpRank}
{
}

Expand Down
Loading