diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h index a232230c4ff..09a96a56eee 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h @@ -18,6 +18,7 @@ #include "tensorrt_llm/executor/executor.h" +#include #include #include #include @@ -36,7 +37,8 @@ using BlockPtr = std::shared_ptr; class KVCacheEventManager { public: - explicit KVCacheEventManager(size_t maxKVEventEntries); + explicit KVCacheEventManager(size_t maxKVEventEntries, std::optional attentionDpRank = std::nullopt, + std::optional attentionDpSize = std::nullopt, SizeType32 attentionDpEventsGatherPeriodMs = 5); ~KVCacheEventManager(); KVCacheEventManager(KVCacheEventManager& other) = delete; @@ -61,14 +63,19 @@ class KVCacheEventManager // Worker thread which adds events to mEvents. void worker(); + // Thread which exchanges events if attentionDP is enabled + void exchangeAttentionDpThread(); + private: // Add an event to mEventQueue void enqueueEvent(executor::KVCacheEvent&& event); /// @brief Flag to terminate the worker - bool mRun; + std::atomic mRun; /// @brief Worker thread std::thread mWorkerThread; + /// @brief Exchange thread for attention DP events + std::thread mExchangeAttentionDpThread; /// @brief The deque of events std::deque mEvents; @@ -91,6 +98,17 @@ class KVCacheEventManager size_t mMaxSize; /// @brief An auto-incrementing event id counter size_t mEventId; + + /// @brief Attention DP ranks and size + /// If set, we will exchange KV cache events and accumulate on rank 0 + std::optional mAttentionDpRank; + std::optional mAttentionDpSize; + + /// @brief The period in milliseconds to gather attention DP events across rank + SizeType32 mAttentionDpEventsGatherPeriodMs; + + /// @brief MPI communicator for attention DP + std::unique_ptr mMpiComm; }; } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 6d592654ffd..0a58298c279 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -1001,6 +1001,7 @@ class KvCacheConfig std::optional const& crossKvCacheFraction = std::nullopt, std::optional secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0, bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool useUvm = false, + SizeType32 attentionDpEventsGatherPeriodMs = 5, std::optional const& runtimeDefaults = std::nullopt); [[nodiscard]] bool getEnableBlockReuse() const; @@ -1016,6 +1017,7 @@ class KvCacheConfig [[nodiscard]] std::optional getSecondaryOffloadMinPriority() const; [[nodiscard]] size_t getEventBufferMaxSize() const; [[nodiscard]] bool getUseUvm() const; + [[nodiscard]] SizeType32 getAttentionDpEventsGatherPeriodMs() const; void setEnableBlockReuse(bool enableBlockReuse); void setEnablePartialReuse(bool enablePartialReuse); @@ -1030,6 +1032,7 @@ class KvCacheConfig void setSecondaryOffloadMinPriority(std::optional secondaryOffloadMinPriority); void setEventBufferMaxSize(size_t eventBufferMaxSize); void setUseUvm(bool useUvm); + void setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventsGatherPeriodMs); void fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults); @@ -1085,6 +1088,9 @@ class KvCacheConfig /// @brief Whether to use UVM for the KV cache. bool mUseUvm; + + /// @brief The period in milliseconds to gather attention DP events across ranks + SizeType32 mAttentionDpEventsGatherPeriodMs; }; /// @brief Configuration class for the runtime perf knobs @@ -1702,6 +1708,12 @@ struct KVCacheUpdatedData explicit KVCacheUpdatedData(IdType blockHash) : blockHash{blockHash} {}; + explicit KVCacheUpdatedData(IdType blockHash, std::optional> cacheLevel, + std::optional> priority) + : blockHash{blockHash} + , cacheLevel{cacheLevel} + , priority{priority} {}; + KVCacheUpdatedData& cacheLevelUpdated(SizeType32 oldValue, SizeType32 newValue) { cacheLevel = KVCacheEventDiff{oldValue, newValue}; @@ -1726,8 +1738,8 @@ using KVCacheEventData = std::variant attentionDpRank = std::nullopt); /// @brief The unique id of this event IdType eventId; @@ -1735,6 +1747,8 @@ struct KVCacheEvent KVCacheEventData data; /// @brief The sliding window size SizeType32 windowSize; + /// @brief The attention DP rank of the event, if applicable + std::optional attentionDpRank; }; /// @brief Exposes a limited set of KV cache manager functionalities diff --git a/cpp/include/tensorrt_llm/executor/serialization.h b/cpp/include/tensorrt_llm/executor/serialization.h index b2ecfc66c84..c370a652350 100644 --- a/cpp/include/tensorrt_llm/executor/serialization.h +++ b/cpp/include/tensorrt_llm/executor/serialization.h @@ -302,6 +302,53 @@ class Serialization [[nodiscard]] static std::vector deserializeRequestStatsPerIterationVec( std::vector& buffer); + // KVCacheEvent deque + [[nodiscard]] static std::vector serialize(std::deque const& kvCacheEvents); + [[nodiscard]] static std::deque deserializeKVCacheEvents(std::vector& buffer); + + // KVCacheEvent + [[nodiscard]] static size_t serializedSize(KVCacheEvent const& event); + static void serialize(KVCacheEvent const& event, std::ostream& os); + [[nodiscard]] static KVCacheEvent deserializeKVCacheEvent(std::istream& is); + + // KVCacheCreatedData + [[nodiscard]] static size_t serializedSize(KVCacheCreatedData const& data); + static void serialize(KVCacheCreatedData const& data, std::ostream& os); + [[nodiscard]] static KVCacheCreatedData deserializeKVCacheCreatedData(std::istream& is); + + // KVCacheStoredData + [[nodiscard]] static size_t serializedSize(KVCacheStoredData const& data); + static void serialize(KVCacheStoredData const& data, std::ostream& os); + [[nodiscard]] static KVCacheStoredData deserializeKVCacheStoredData(std::istream& is); + + // KVCacheStoredBlockData + [[nodiscard]] static size_t serializedSize(KVCacheStoredBlockData const& data); + static void serialize(KVCacheStoredBlockData const& data, std::ostream& os); + [[nodiscard]] static KVCacheStoredBlockData deserializeKVCacheStoredBlockData(std::istream& is); + + // KVCacheRemovedData + [[nodiscard]] static size_t serializedSize(KVCacheRemovedData const& data); + static void serialize(KVCacheRemovedData const& data, std::ostream& os); + [[nodiscard]] static KVCacheRemovedData deserializeKVCacheRemovedData(std::istream& is); + + // KVCacheEventDiff + template + [[nodiscard]] static size_t serializedSize(KVCacheEventDiff const& data); + template + static void serialize(KVCacheEventDiff const& data, std::ostream& os); + template + [[nodiscard]] static KVCacheEventDiff deserializeKVCacheEventDiff(std::istream& is); + + // KVCacheUpdateData + [[nodiscard]] static size_t serializedSize(KVCacheUpdatedData const& data); + static void serialize(KVCacheUpdatedData const& data, std::ostream& os); + [[nodiscard]] static KVCacheUpdatedData deserializeKVCacheUpdatedData(std::istream& is); + + // UniqueToken + [[nodiscard]] static size_t serializedSize(tensorrt_llm::runtime::UniqueToken const& token); + static void serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os); + [[nodiscard]] static tensorrt_llm::runtime::UniqueToken deserializeUniqueToken(std::istream& is); + // String static std::string deserializeString(std::istream& is); diff --git a/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h b/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h index 4443d422ab8..32c086c84ee 100644 --- a/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h +++ b/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h @@ -68,6 +68,10 @@ enum class MpiTag : int // LogitsThread kSpecDecLogitsId = 129, kSpecDecLogitsData = 1025, + + // KvCacheEventManager + kKvCacheEventSize = 1026, + kKvCacheEvent = 1027 }; } // namespace tensorrt_llm::mpi diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp index ff2a2f6b787..ac37278d45f 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp @@ -18,20 +18,51 @@ #include "tensorrt_llm/batch_manager/kvCacheEventManager.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/serialization.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" namespace tle = tensorrt_llm::executor; namespace tensorrt_llm::batch_manager::kv_cache_manager { -KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries) +KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries, std::optional attentionDpRank, + std::optional attentionDpSize, SizeType32 attentionDpEventsGatherPeriodMs) : mRun{true} , mMaxSize{maxKVEventEntries} , mEventId{0} + , mAttentionDpRank{attentionDpRank} + , mAttentionDpSize{attentionDpSize} + , mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs) { TLLM_CHECK(mMaxSize > 0); - // mWorkerThread = std::thread(std::bind(&KVCacheEventManager::worker, this)); + if (mAttentionDpRank) + { + TLLM_CHECK_WITH_INFO( + mAttentionDpSize.has_value(), "If attention DP rank is set, the attention DP size must also be set"); + TLLM_CHECK_WITH_INFO(mAttentionDpRank.value() < mAttentionDpSize.value(), + "Attention DP rank must be less than attention DP size"); + if (mAttentionDpRank.value() == 0) + { + // Rank 0 will gather events from all other ranks + // Need to increase size + mMaxSize *= mAttentionDpSize.value(); + } + // Create a communicator to be used for event exchange + mMpiComm = std::make_unique(COMM_SESSION.split(0, mAttentionDpRank.value())); + } + else + { + TLLM_CHECK_WITH_INFO( + !mAttentionDpSize.has_value(), "If attention DP rank is not set, the attention DP size must not be set"); + } mWorkerThread = std::thread([this]() { this->worker(); }); +#if ENABLE_MULTI_DEVICE + if (mAttentionDpRank) + { + mExchangeAttentionDpThread = std::thread([this]() { this->exchangeAttentionDpThread(); }); + } +#endif }; KVCacheEventManager::~KVCacheEventManager() @@ -40,12 +71,18 @@ KVCacheEventManager::~KVCacheEventManager() mPendingEmptyCV.notify_all(); mEmptyCV.notify_all(); mWorkerThread.join(); +#if ENABLE_MULTI_DEVICE + if (mAttentionDpRank) + { + mExchangeAttentionDpThread.join(); + } +#endif } void KVCacheEventManager::enqueueCreatedEvent( std::vector const& numBlocksPerCacheLevel, SizeType32 windowSize) { - enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}, windowSize}); + enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}, windowSize, mAttentionDpRank}); } void KVCacheEventManager::enqueueStoredEvent(std::vector const& blocks, SizeType32 windowSize) @@ -68,7 +105,7 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector const& blocks block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority()); } - enqueueEvent({mEventId++, data, windowSize}); + enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank}); } void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32 windowSize) @@ -81,13 +118,13 @@ void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32 } else { - enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize}); + enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize, mAttentionDpRank}); } } void KVCacheEventManager::enqueueUpdatedEvent(tle::KVCacheUpdatedData const& data, SizeType32 windowSize) { - enqueueEvent({mEventId++, data, windowSize}); + enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank}); } void KVCacheEventManager::enqueueEvent(tle::KVCacheEvent&& event) @@ -120,8 +157,76 @@ void KVCacheEventManager::flush() mPendingEmptyCV.notify_one(); } +void KVCacheEventManager::exchangeAttentionDpThread() +{ +#if ENABLE_MULTI_DEVICE + while (true) + { + TLLM_CHECK(mAttentionDpRank); + + // Check if any of the ranks have been shutdown + int32_t numFinished = 0; + int32_t finished = mRun ? 0 : 1; + mMpiComm->allreduce(&finished, &numFinished, 1, mpi::MpiType::kINT32, mpi::MpiOp::SUM); + if (numFinished > 0) + { + TLLM_LOG_INFO("One of the rank has been shut down, exiting"); + break; + } + + // If we are not rank 0, send events to rank 0 + if (mAttentionDpRank.value() != 0) + { + std::vector serializedEvents; + uint64_t numEvents = 0; + { + std::lock_guard lck(mEventsMutex); + serializedEvents = executor::Serialization::serialize(mEvents); + numEvents = mEvents.size(); + mEvents.clear(); + } + uint64_t vecSize = numEvents > 0 ? serializedEvents.size() : 0; + mMpiComm->send(&vecSize, 1, mpi::MpiType::kUINT64, 0, mpi::MpiTag::kKvCacheEventSize); + if (vecSize > 0) + { + mMpiComm->send(serializedEvents.data(), serializedEvents.size(), mpi::MpiType::kCHAR, 0, + mpi::MpiTag::kKvCacheEvent); + } + } + else + { + TLLM_CHECK(mAttentionDpSize.has_value()); + // Loop until have received events from all ranks + for (int rank = 1; rank < mAttentionDpSize.value(); ++rank) + { + uint64_t vecSize{0}; + mMpiComm->recv(&vecSize, 1, mpi::MpiType::kUINT64, rank, mpi::MpiTag::kKvCacheEventSize); + if (vecSize > 0) + { + std::vector serializedEvents(vecSize); + mMpiComm->recv( + serializedEvents.data(), vecSize, mpi::MpiType::kCHAR, rank, mpi::MpiTag::kKvCacheEvent); + + // Deserialize the events and add them to the local queue + auto rankEvents = executor::Serialization::deserializeKVCacheEvents(serializedEvents); + { + std::lock_guard lck(mEventsMutex); + mEvents.insert(mEvents.end(), rankEvents.begin(), rankEvents.end()); + mEmptyCV.notify_one(); + } + } + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(mAttentionDpEventsGatherPeriodMs)); + } +#else + TLLM_THROW("Multi device support is disabled."); +#endif +} + void KVCacheEventManager::worker() { + while (true) { std::deque events; @@ -151,6 +256,8 @@ void KVCacheEventManager::worker() // If there's still too many events, take from the front of the events queue. mEvents.insert(mEvents.end(), events.begin() + std::max(0, elementsToRemove), events.end()); + + // Notify the empty condition variable to wake up any waiting threads mEmptyCV.notify_one(); } } diff --git a/cpp/tensorrt_llm/executor/executor.cpp b/cpp/tensorrt_llm/executor/executor.cpp index 70ca2be41ab..091bb512823 100644 --- a/cpp/tensorrt_llm/executor/executor.cpp +++ b/cpp/tensorrt_llm/executor/executor.cpp @@ -132,10 +132,12 @@ std::optional> Executor::getKVCacheEventMan return mImpl->getKVCacheEventManager(); } -KVCacheEvent::KVCacheEvent(size_t eventId, KVCacheEventData data, SizeType32 windowSize) +KVCacheEvent::KVCacheEvent( + size_t eventId, KVCacheEventData data, SizeType32 windowSize, std::optional attentionDpRank) : eventId{eventId} , data{std::move(data)} , windowSize{windowSize} + , attentionDpRank{attentionDpRank} { } diff --git a/cpp/tensorrt_llm/executor/kvCacheConfig.cpp b/cpp/tensorrt_llm/executor/kvCacheConfig.cpp index 51b047ebd27..21cf314c875 100644 --- a/cpp/tensorrt_llm/executor/kvCacheConfig.cpp +++ b/cpp/tensorrt_llm/executor/kvCacheConfig.cpp @@ -27,6 +27,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional co std::optional const& hostCacheSize, bool onboardBlocks, std::optional const& crossKvCacheFraction, std::optional secondaryOffloadMinPriority, size_t eventBufferMaxSize, bool enablePartialReuse, bool copyOnPartialReuse, bool useUvm, + SizeType32 attentionDpEventsGatherPeriodMs, std::optional const& runtimeDefaults) : mEnableBlockReuse(enableBlockReuse) , mHostCacheSize(hostCacheSize) @@ -36,6 +37,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional co , mEnablePartialReuse{enablePartialReuse} , mCopyOnPartialReuse{copyOnPartialReuse} , mUseUvm{useUvm} + , mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs) { if (maxTokens) { @@ -61,6 +63,8 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional co { fillEmptyFieldsFromRuntimeDefaults(runtimeDefaults.value()); } + TLLM_CHECK_WITH_INFO( + mAttentionDpEventsGatherPeriodMs > 0, "Attention DP events gather period must be greater than 0"); } bool KvCacheConfig::getEnableBlockReuse() const @@ -128,6 +132,11 @@ bool KvCacheConfig::getUseUvm() const return mUseUvm; } +SizeType32 KvCacheConfig::getAttentionDpEventsGatherPeriodMs() const +{ + return mAttentionDpEventsGatherPeriodMs; +} + void KvCacheConfig::setEnableBlockReuse(bool enableBlockReuse) { mEnableBlockReuse = enableBlockReuse; @@ -204,6 +213,12 @@ void KvCacheConfig::setUseUvm(bool useUvm) mUseUvm = useUvm; } +void KvCacheConfig::setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventsGatherPeriodMs) +{ + TLLM_CHECK(attentionDpEventsGatherPeriodMs > 0); + mAttentionDpEventsGatherPeriodMs = attentionDpEventsGatherPeriodMs; +} + void KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults) { if (!mMaxAttentionWindowVec && runtimeDefaults.maxAttentionWindowVec) diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index 65718f0405d..38256edbc75 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -23,6 +23,7 @@ #include "tensorrt_llm/executor/serializeUtils.h" #include "tensorrt_llm/executor/types.h" #include "tensorrt_llm/runtime/cudaStream.h" +#include #include #include #include @@ -1162,10 +1163,11 @@ KvCacheConfig Serialization::deserializeKvCacheConfig(std::istream& is) auto secondaryOffloadMinPriority = su::deserialize>(is); auto eventBufferMaxSize = su::deserialize(is); auto useUvm = su::deserialize(is); + auto attentionDpEventsGatherPeriodMs = su::deserialize(is); return KvCacheConfig{enableBlockReuse, maxTokens, maxAttentionWindowVec, sinkTokenLength, freeGpuMemoryFraction, hostCacheSize, onboardBlocks, crossKvCacheFraction, secondaryOffloadMinPriority, eventBufferMaxSize, - enablePartialReuse, copyOnPartialReuse, useUvm}; + enablePartialReuse, copyOnPartialReuse, useUvm, attentionDpEventsGatherPeriodMs}; } void Serialization::serialize(KvCacheConfig const& kvCacheConfig, std::ostream& os) @@ -1183,6 +1185,7 @@ void Serialization::serialize(KvCacheConfig const& kvCacheConfig, std::ostream& su::serialize(kvCacheConfig.getSecondaryOffloadMinPriority(), os); su::serialize(kvCacheConfig.getEventBufferMaxSize(), os); su::serialize(kvCacheConfig.getUseUvm(), os); + su::serialize(kvCacheConfig.getAttentionDpEventsGatherPeriodMs(), os); } size_t Serialization::serializedSize(KvCacheConfig const& kvCacheConfig) @@ -1202,6 +1205,7 @@ size_t Serialization::serializedSize(KvCacheConfig const& kvCacheConfig) totalSize += su::serializedSize(kvCacheConfig.getSecondaryOffloadMinPriority()); totalSize += su::serializedSize(kvCacheConfig.getEventBufferMaxSize()); totalSize += su::serializedSize(kvCacheConfig.getUseUvm()); + totalSize += su::serializedSize(kvCacheConfig.getAttentionDpEventsGatherPeriodMs()); return totalSize; } @@ -2181,6 +2185,237 @@ std::vector Serialization::deserializeRequestStatsPerI return iterRequestStatsVec; } +// KVCacheEvents deque +std::vector Serialization::serialize(std::deque const& eventQueue) +{ + // Compute the size of serialized buffer + size_t totalSize = 0; + totalSize += sizeof(size_t); + for (auto const& event : eventQueue) + { + totalSize += su::serializedSize(event); + } + + std::vector buffer(totalSize); + std::stringbuf strbuf(std::ios_base::out | std::ios_base::in); + strbuf.pubsetbuf(buffer.data(), buffer.size()); + std::ostream os(&strbuf); + + su::serialize(eventQueue.size(), os); + for (auto const& event : eventQueue) + { + su::serialize(event, os); + } + return buffer; +} + +std::deque Serialization::deserializeKVCacheEvents(std::vector& buffer) +{ + std::deque kvCacheEvents; + su::VectorWrapBuf strbuf(buffer); + std::istream is(&strbuf); + auto numEvents = su::deserialize(is); + for (std::size_t event = 0; event < numEvents; ++event) + { + kvCacheEvents.emplace_back(Serialization::deserializeKVCacheEvent(is)); + } + return kvCacheEvents; +} + +// KVCacheEvent +size_t Serialization::serializedSize(KVCacheEvent const& event) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(event.eventId); + totalSize += su::serializedSize(event.data); + totalSize += su::serializedSize(event.windowSize); + totalSize += su::serializedSize(event.attentionDpRank); + return totalSize; +} + +void Serialization::serialize(KVCacheEvent const& event, std::ostream& os) +{ + su::serialize(event.eventId, os); + su::serialize(event.data, os); + su::serialize(event.windowSize, os); + su::serialize(event.attentionDpRank, os); +} + +KVCacheEvent Serialization::deserializeKVCacheEvent(std::istream& is) +{ + auto eventId = su::deserialize(is); + auto data = su::deserialize(is); + auto windowSize = su::deserialize(is); + auto attentionDpRank = su::deserialize>(is); + + return KVCacheEvent{eventId, data, windowSize, attentionDpRank}; +} + +// KVCacheCreatedData +size_t Serialization::serializedSize(KVCacheCreatedData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.numBlocksPerCacheLevel); + return totalSize; +} + +void Serialization::serialize(KVCacheCreatedData const& data, std::ostream& os) +{ + su::serialize(data.numBlocksPerCacheLevel, os); +} + +KVCacheCreatedData Serialization::deserializeKVCacheCreatedData(std::istream& is) +{ + auto numBlocksPerCacheLevel = su::deserialize>(is); + return KVCacheCreatedData{numBlocksPerCacheLevel}; +} + +// KVCacheStoredData +size_t Serialization::serializedSize(KVCacheStoredData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.parentHash); + totalSize += su::serializedSize(data.blocks); + return totalSize; +} + +void Serialization::serialize(KVCacheStoredData const& data, std::ostream& os) +{ + su::serialize(data.parentHash, os); + su::serialize(data.blocks, os); +} + +KVCacheStoredData Serialization::deserializeKVCacheStoredData(std::istream& is) +{ + auto parentHash = su::deserialize>(is); + auto blocks = su::deserialize>(is); + return KVCacheStoredData{parentHash, blocks}; +} + +// KVCacheStoredBlockData +size_t Serialization::serializedSize(KVCacheStoredBlockData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.blockHash); + totalSize += su::serializedSize(data.tokens); + totalSize += su::serializedSize(data.loraId); + totalSize += su::serializedSize(data.cacheLevel); + totalSize += su::serializedSize(data.priority); + return totalSize; +} + +void Serialization::serialize(KVCacheStoredBlockData const& data, std::ostream& os) +{ + su::serialize(data.blockHash, os); + su::serialize(data.tokens, os); + su::serialize(data.loraId, os); + su::serialize(data.cacheLevel, os); + su::serialize(data.priority, os); +} + +KVCacheStoredBlockData Serialization::deserializeKVCacheStoredBlockData(std::istream& is) +{ + auto blockHash = su::deserialize(is); + auto tokens = su::deserialize(is); + auto loraId = su::deserialize>(is); + auto cacheLevel = su::deserialize(is); + auto priority = su::deserialize(is); + + return KVCacheStoredBlockData{blockHash, tokens, loraId, cacheLevel, priority}; +} + +// KVcacheRemovedData + +size_t Serialization::serializedSize(KVCacheRemovedData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.blockHashes); + return totalSize; +} + +void Serialization::serialize(KVCacheRemovedData const& data, std::ostream& os) +{ + su::serialize(data.blockHashes, os); +} + +KVCacheRemovedData Serialization::deserializeKVCacheRemovedData(std::istream& is) +{ + auto blockHashes = su::deserialize>(is); + return KVCacheRemovedData{blockHashes}; +} + +// KVCacheEventDiff +template +size_t Serialization::serializedSize(KVCacheEventDiff const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.oldValue); + totalSize += su::serializedSize(data.newValue); + return totalSize; +} + +template +void Serialization::serialize(KVCacheEventDiff const& data, std::ostream& os) +{ + su::serialize(data.oldValue, os); + su::serialize(data.newValue, os); +} + +template +KVCacheEventDiff Serialization::deserializeKVCacheEventDiff(std::istream& is) +{ + auto oldValue = su::deserialize(is); + auto newValue = su::deserialize(is); + return KVCacheEventDiff{oldValue, newValue}; +} + +// KVCacheUpdatedData +size_t Serialization::serializedSize(KVCacheUpdatedData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.blockHash); + totalSize += su::serializedSize(data.cacheLevel); + totalSize += su::serializedSize(data.priority); + return totalSize; +} + +void Serialization::serialize(KVCacheUpdatedData const& data, std::ostream& os) +{ + su::serialize(data.blockHash, os); + su::serialize(data.cacheLevel, os); + su::serialize(data.priority, os); +} + +KVCacheUpdatedData Serialization::deserializeKVCacheUpdatedData(std::istream& is) +{ + auto blockHash = su::deserialize(is); + auto cacheLevel = su::deserialize>>(is); + auto priority = su::deserialize>>(is); + return KVCacheUpdatedData{blockHash, cacheLevel, priority}; +} + +// UniqueToken +size_t Serialization::serializedSize(tensorrt_llm::runtime::UniqueToken const& token) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(token.tokenId); + totalSize += su::serializedSize(token.tokenExtraId); + return totalSize; +} + +void Serialization::serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os) +{ + su::serialize(token.tokenId, os); + su::serialize(token.tokenExtraId, os); +} + +tensorrt_llm::runtime::UniqueToken Serialization::deserializeUniqueToken(std::istream& is) +{ + auto tokenId = su::deserialize(is); + auto tokenExtraId = su::deserialize(is); + return tensorrt_llm::runtime::UniqueToken{tokenId, tokenExtraId}; +} + // String std::string Serialization::deserializeString(std::istream& is) { diff --git a/cpp/tensorrt_llm/executor/serializeUtils.h b/cpp/tensorrt_llm/executor/serializeUtils.h index 8f26c58d622..40b50f92309 100644 --- a/cpp/tensorrt_llm/executor/serializeUtils.h +++ b/cpp/tensorrt_llm/executor/serializeUtils.h @@ -122,6 +122,14 @@ static_assert(hasSerializedSize(size_t())); static_assert(!hasSerializedSize(size_t())); static_assert(!hasSerializedSize>(size_t())); static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize>(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); template size_t serializedSize(T const& data) @@ -219,6 +227,14 @@ static_assert(hasSerialize(nullptr)); static_assert(!hasSerialize(nullptr)); static_assert(!hasSerialize>(nullptr)); static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize>(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); template void serialize(T const& data, std::ostream& os) @@ -291,6 +307,22 @@ struct get_variant_alternative_type } }; +template +T deserialize(std::istream& is); + +// Helper function to deserialize variant by index using template recursion +template +T deserializeVariantByIndex(std::istream& is, std::size_t index, std::index_sequence /*indices*/) +{ + T result; + bool found = ((Is == index ? (result = deserialize>(is), true) : false) || ...); + if (!found) + { + TLLM_THROW("Invalid variant index during deserialization: " + std::to_string(index)); + } + return result; +} + // Deserialize template T deserialize(std::istream& is) @@ -511,6 +543,38 @@ T deserialize(std::istream& is) { return Serialization::deserializeCacheTransceiverConfig(is); } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheEvent(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheCreatedData(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheStoredData(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheStoredBlockData(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheRemovedData(is); + } + else if constexpr (std::is_same_v>) + { + return Serialization::deserializeKVCacheEventDiff(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheUpdatedData(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeUniqueToken(is); + } // Optional else if constexpr (std::is_same_v::type>>) { @@ -547,23 +611,7 @@ T deserialize(std::istream& is) std::size_t index = 0; is.read(reinterpret_cast(&index), sizeof(index)); - // TODO: Is there a better way to implement this? - T data; - if (index == 0) - { - using U = std::variant_alternative_t<0, T>; - data = deserialize(is); - } - else if (index == 1) - { - using U = std::variant_alternative_t<1, T>; - data = deserialize(is); - } - else - { - TLLM_THROW("Serialization of variant of size > 2 is not supported."); - } - return data; + return deserializeVariantByIndex(is, index, std::make_index_sequence>{}); } else { diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp index 74049eaf96b..412698215aa 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp @@ -325,7 +325,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(nb::module_& m) .def_static("hash", &tbk::BlockKeyHasher::hash, nb::arg("block_key"), nb::arg("parent_hash") = 0); nb::class_(m, "KVCacheEventManager") - .def(nb::init(), nb::arg("max_kv_event_entries")); + .def(nb::init, std::optional, SizeType32>(), + nb::arg("max_kv_event_entries"), nb::arg("attention_dp_rank") = std::nullopt, + nb::arg("attention_dp_size") = std::nullopt, nb::arg("attention_dp_events_gather_period_ms") = 5); nb::class_(m, "BaseKVCacheManager") .def_static("calculate_max_num_blocks", &tbk::BaseKVCacheManager::calculateMaxNumBlocks, nb::arg("config"), diff --git a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp index 5760d77fb47..505ecfca595 100644 --- a/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp @@ -110,11 +110,12 @@ void initConfigBindings(nb::module_& m) return nb::make_tuple(self.getEnableBlockReuse(), self.getMaxTokens(), self.getMaxAttentionWindowVec(), self.getSinkTokenLength(), self.getFreeGpuMemoryFraction(), self.getHostCacheSize(), self.getOnboardBlocks(), self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(), - self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm()); + self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm(), + self.getAttentionDpEventsGatherPeriodMs()); }; auto kvCacheConfigSetstate = [](tle::KvCacheConfig& self, nb::tuple const& state) { - if (state.size() != 13) + if (state.size() != 14) { throw std::runtime_error("Invalid state!"); } @@ -123,20 +124,21 @@ void initConfigBindings(nb::module_& m) nb::cast>(state[4]), nb::cast>(state[5]), nb::cast(state[6]), nb::cast>(state[7]), nb::cast>(state[8]), nb::cast(state[9]), - nb::cast(state[10]), nb::cast(state[11]), nb::cast(state[12])); + nb::cast(state[10]), nb::cast(state[11]), nb::cast(state[12]), + nb::cast(state[13])); }; nb::class_(m, "KvCacheConfig") .def(nb::init const&, std::optional> const&, std::optional const&, std::optional const&, std::optional const&, bool, std::optional const&, std::optional, size_t const&, bool, bool, bool, - std::optional const&>(), + SizeType32, std::optional const&>(), nb::arg("enable_block_reuse") = true, nb::arg("max_tokens") = nb::none(), nb::arg("max_attention_window") = nb::none(), nb::arg("sink_token_length") = nb::none(), nb::arg("free_gpu_memory_fraction") = nb::none(), nb::arg("host_cache_size") = nb::none(), nb::arg("onboard_blocks") = true, nb::arg("cross_kv_cache_fraction") = nb::none(), nb::arg("secondary_offload_min_priority") = nb::none(), nb::arg("event_buffer_max_size") = 0, nb::kw_only(), nb::arg("enable_partial_reuse") = true, nb::arg("copy_on_partial_reuse") = true, nb::arg("use_uvm") = false, - nb::arg("runtime_defaults") = nb::none()) + nb::arg("attention_dp_events_gather_period_ms") = 5, nb::arg("runtime_defaults") = nb::none()) .def_prop_rw( "enable_block_reuse", &tle::KvCacheConfig::getEnableBlockReuse, &tle::KvCacheConfig::setEnableBlockReuse) .def_prop_rw("max_tokens", &tle::KvCacheConfig::getMaxTokens, &tle::KvCacheConfig::setMaxTokens) @@ -159,6 +161,8 @@ void initConfigBindings(nb::module_& m) .def_prop_rw("copy_on_partial_reuse", &tle::KvCacheConfig::getCopyOnPartialReuse, &tle::KvCacheConfig::setCopyOnPartialReuse) .def_prop_rw("use_uvm", &tle::KvCacheConfig::getUseUvm, &tle::KvCacheConfig::setUseUvm) + .def_prop_rw("attention_dp_events_gather_period_ms", &tle::KvCacheConfig::getAttentionDpEventsGatherPeriodMs, + &tle::KvCacheConfig::setAttentionDpEventsGatherPeriodMs) .def("fill_empty_fields_from_runtime_defaults", &tle::KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults) .def("__getstate__", kvCacheConfigGetstate) .def("__setstate__", kvCacheConfigSetstate); diff --git a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp index 255b0f8efa3..54835e81d7f 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp @@ -321,7 +321,9 @@ void tb::kv_cache_manager::KVCacheManagerBindings::initBindings(py::module_& m) .def_static("hash", &tbk::BlockKeyHasher::hash, py::arg("block_key"), py::arg("parent_hash") = 0); py::class_>(m, "KVCacheEventManager") - .def(py::init(), py::arg("max_kv_event_entries")); + .def(py::init, std::optional, SizeType32>(), + py::arg("max_kv_event_entries"), py::arg("attention_dp_rank") = std::nullopt, + py::arg("attention_dp_size") = std::nullopt, py::arg("attention_dp_events_gather_period_ms") = 5); py::classh(m, "BaseKVCacheManager") .def_static("calculate_max_num_blocks", &tbk::BaseKVCacheManager::calculateMaxNumBlocks, py::arg("config"), diff --git a/cpp/tensorrt_llm/pybind/executor/bindings.cpp b/cpp/tensorrt_llm/pybind/executor/bindings.cpp index a8f6aaef73d..bbb843bedba 100644 --- a/cpp/tensorrt_llm/pybind/executor/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/executor/bindings.cpp @@ -240,7 +240,8 @@ void initBindings(pybind11::module_& m) py::class_(executor_kv_cache, "KVCacheEvent") .def_readonly("event_id", &tle::KVCacheEvent::eventId) .def_readonly("data", &tle::KVCacheEvent::data) - .def_readonly("window_size", &tle::KVCacheEvent::windowSize); + .def_readonly("window_size", &tle::KVCacheEvent::windowSize) + .def_readonly("attention_dp_rank", &tle::KVCacheEvent::attentionDpRank); py::class_>( executor_kv_cache, "KVCacheEventManager") diff --git a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp index ccbb21aab21..0e279a3e47b 100644 --- a/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp +++ b/cpp/tensorrt_llm/pybind/executor/executorConfig.cpp @@ -103,11 +103,12 @@ void initConfigBindings(pybind11::module_& m) return py::make_tuple(self.getEnableBlockReuse(), self.getMaxTokens(), self.getMaxAttentionWindowVec(), self.getSinkTokenLength(), self.getFreeGpuMemoryFraction(), self.getHostCacheSize(), self.getOnboardBlocks(), self.getCrossKvCacheFraction(), self.getSecondaryOffloadMinPriority(), - self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm()); + self.getEventBufferMaxSize(), self.getEnablePartialReuse(), self.getCopyOnPartialReuse(), self.getUseUvm(), + self.getAttentionDpEventsGatherPeriodMs()); }; auto kvCacheConfigSetstate = [](py::tuple const& state) { - if (state.size() != 13) + if (state.size() != 14) { throw std::runtime_error("Invalid state!"); } @@ -115,20 +116,21 @@ void initConfigBindings(pybind11::module_& m) state[2].cast>>(), state[3].cast>(), state[4].cast>(), state[5].cast>(), state[6].cast(), state[7].cast>(), state[8].cast>(), - state[9].cast(), state[10].cast(), state[11].cast(), state[12].cast()); + state[9].cast(), state[10].cast(), state[11].cast(), state[12].cast(), + state[13].cast()); }; py::class_(m, "KvCacheConfig") .def(py::init const&, std::optional> const&, std::optional const&, std::optional const&, std::optional const&, bool, std::optional const&, std::optional, size_t const&, bool, bool, bool, - std::optional const&>(), + SizeType32, std::optional const&>(), py::arg("enable_block_reuse") = true, py::arg("max_tokens") = py::none(), py::arg("max_attention_window") = py::none(), py::arg("sink_token_length") = py::none(), py::arg("free_gpu_memory_fraction") = py::none(), py::arg("host_cache_size") = py::none(), py::arg("onboard_blocks") = true, py::arg("cross_kv_cache_fraction") = py::none(), py::arg("secondary_offload_min_priority") = py::none(), py::arg("event_buffer_max_size") = 0, py::kw_only(), py::arg("enable_partial_reuse") = true, py::arg("copy_on_partial_reuse") = true, py::arg("use_uvm") = false, - py::arg("runtime_defaults") = py::none()) + py::arg("attention_dp_events_gather_period_ms") = 5, py::arg("runtime_defaults") = py::none()) .def_property( "enable_block_reuse", &tle::KvCacheConfig::getEnableBlockReuse, &tle::KvCacheConfig::setEnableBlockReuse) .def_property("max_tokens", &tle::KvCacheConfig::getMaxTokens, &tle::KvCacheConfig::setMaxTokens) @@ -151,6 +153,8 @@ void initConfigBindings(pybind11::module_& m) .def_property("copy_on_partial_reuse", &tle::KvCacheConfig::getCopyOnPartialReuse, &tle::KvCacheConfig::setCopyOnPartialReuse) .def_property("use_uvm", &tle::KvCacheConfig::getUseUvm, &tle::KvCacheConfig::setUseUvm) + .def_property("attention_dp_events_gather_period_ms", &tle::KvCacheConfig::getAttentionDpEventsGatherPeriodMs, + &tle::KvCacheConfig::setAttentionDpEventsGatherPeriodMs) .def("fill_empty_fields_from_runtime_defaults", &tle::KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults) .def(py::pickle(kvCacheConfigGetstate, kvCacheConfigSetstate)); diff --git a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp index 18f7e6f5379..27fff8df7d0 100644 --- a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp +++ b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp @@ -474,7 +474,7 @@ TEST(SerializeUtilsTest, VectorResponses) TEST(SerializeUtilsTest, KvCacheConfig) { texec::KvCacheConfig kvCacheConfig( - true, 10, std::vector(1, 100), 2, 0.1, 10000, false, 0.5, 50, 1024, false, false, true); + true, 10, std::vector(1, 100), 2, 0.1, 10000, false, 0.5, 50, 1024, false, false, true, 77); auto kvCacheConfig2 = serializeDeserialize(kvCacheConfig); EXPECT_EQ(kvCacheConfig.getEnableBlockReuse(), kvCacheConfig2.getEnableBlockReuse()); @@ -490,6 +490,7 @@ TEST(SerializeUtilsTest, KvCacheConfig) EXPECT_EQ(kvCacheConfig.getSecondaryOffloadMinPriority(), kvCacheConfig2.getSecondaryOffloadMinPriority()); EXPECT_EQ(kvCacheConfig.getEventBufferMaxSize(), kvCacheConfig2.getEventBufferMaxSize()); EXPECT_EQ(kvCacheConfig.getUseUvm(), kvCacheConfig2.getUseUvm()); + EXPECT_EQ(kvCacheConfig.getAttentionDpEventsGatherPeriodMs(), kvCacheConfig2.getAttentionDpEventsGatherPeriodMs()); } TEST(SerializeUtilsTest, SchedulerConfig) @@ -846,6 +847,168 @@ TEST(SerializeUtilsTest, RequestStatsPerIteration) compareRequestStatsPerIteration(requestStatsPerIteration, requestStatsPerIteration2); } +void compareKvCacheEvents(texec::KVCacheEvent const& kvCacheEvent, texec::KVCacheEvent const& kvCacheEvent2) +{ + EXPECT_EQ(kvCacheEvent.eventId, kvCacheEvent2.eventId); + EXPECT_EQ(kvCacheEvent.windowSize, kvCacheEvent2.windowSize); + EXPECT_EQ(kvCacheEvent.attentionDpRank, kvCacheEvent2.attentionDpRank); + + if (std::holds_alternative(kvCacheEvent.data)) + { + EXPECT_TRUE(std::holds_alternative(kvCacheEvent2.data)); + auto data = std::get(kvCacheEvent.data); + auto data2 = std::get(kvCacheEvent2.data); + EXPECT_EQ(data.numBlocksPerCacheLevel, data2.numBlocksPerCacheLevel); + } + else if (std::holds_alternative(kvCacheEvent.data)) + { + EXPECT_TRUE(std::holds_alternative(kvCacheEvent2.data)); + auto data = std::get(kvCacheEvent.data); + auto data2 = std::get(kvCacheEvent2.data); + EXPECT_EQ(data.blockHashes, data2.blockHashes); + } + else if (std::holds_alternative(kvCacheEvent.data)) + { + EXPECT_TRUE(std::holds_alternative(kvCacheEvent2.data)); + auto data = std::get(kvCacheEvent.data); + auto data2 = std::get(kvCacheEvent2.data); + EXPECT_EQ(data.parentHash, data2.parentHash); + EXPECT_EQ(data.blocks.size(), data2.blocks.size()); + for (size_t i = 0; i < data.blocks.size(); ++i) + { + auto blockData = data.blocks[i]; + auto blockData2 = data2.blocks[i]; + EXPECT_EQ(blockData.blockHash, blockData2.blockHash); + EXPECT_EQ(blockData.loraId, blockData2.loraId); + EXPECT_EQ(blockData.cacheLevel, blockData2.cacheLevel); + EXPECT_EQ(blockData.priority, blockData2.priority); + EXPECT_EQ(blockData.tokens.size(), blockData2.tokens.size()); + for (size_t j = 0; j < blockData.tokens.size(); ++j) + { + EXPECT_EQ(blockData.tokens[j].tokenId, blockData2.tokens[j].tokenId); + EXPECT_EQ(blockData.tokens[j].tokenExtraId, blockData2.tokens[j].tokenExtraId); + } + } + } + else if (std::holds_alternative(kvCacheEvent.data)) + { + EXPECT_TRUE(std::holds_alternative(kvCacheEvent2.data)); + auto data = std::get(kvCacheEvent.data); + auto data2 = std::get(kvCacheEvent2.data); + EXPECT_EQ(data.blockHash, data2.blockHash); + if (data.cacheLevel) + { + EXPECT_TRUE(data2.cacheLevel); + EXPECT_EQ(data.cacheLevel.value().oldValue, data2.cacheLevel.value().oldValue); + EXPECT_EQ(data.cacheLevel.value().newValue, data2.cacheLevel.value().newValue); + } + if (data.priority) + { + EXPECT_TRUE(data2.priority); + EXPECT_EQ(data.priority.value().oldValue, data2.priority.value().oldValue); + EXPECT_EQ(data.priority.value().newValue, data2.priority.value().newValue); + } + } + else + { + FAIL() << "Unknown KVCacheEvent data type"; + } +} + +TEST(SerializeUtilsTest, KvCacheEventsDeque) +{ + // Created event + texec::KVCacheCreatedData kvCacheCreatedData{{1, 2}}; + texec::KVCacheEvent kvCacheCreatedEvent(1, kvCacheCreatedData, 32); + + // Removed event + texec::KVCacheEvent kvCacheRemovedEvent(1, texec::KVCacheRemovedData{{3, 4}}, 32); + + // Stored event + auto storedBlockData1 = texec::KVCacheStoredBlockData(77, {{1, 2}, {3, 4}, {5, 6}}, 88, 0, 99); + auto storedBlockData2 = texec::KVCacheStoredBlockData(99, {{11, 12}, {3, 4}, {15, 6}}, 77, 1, 101); + texec::KVCacheStoredData kvCacheStoredData{177, {storedBlockData1, storedBlockData2}}; + texec::KVCacheEvent kvCacheStoredEvent(1, kvCacheStoredData, 32); + + // Updated event + texec::KVCacheEventDiff diff{0, 1}; + texec::KVCacheEventDiff diff2{90, 99}; + texec::KVCacheUpdatedData kvCacheUpdatedData(999, diff, diff2); + texec::KVCacheEvent kvCacheEvent(1, kvCacheUpdatedData, 32); + + std::deque kvCacheEvents{ + kvCacheCreatedEvent, kvCacheRemovedEvent, kvCacheStoredEvent, kvCacheEvent}; + + auto serializedEvents = texec::Serialization::serialize(kvCacheEvents); + auto kvCacheEvents2 = texec::Serialization::deserializeKVCacheEvents(serializedEvents); + + EXPECT_EQ(kvCacheEvents.size(), kvCacheEvents2.size()); + for (size_t i = 0; i < kvCacheEvents.size(); ++i) + { + compareKvCacheEvents(kvCacheEvents[i], kvCacheEvents2[i]); + } +} + +// Test for KVCacheEvent with KVCacheCreatedData +TEST(SerializeUtilsTest, KVCacheCreatedEvent) +{ + texec::KVCacheCreatedData kvCacheCreatedData{{1, 2}}; + texec::KVCacheEvent kvCacheEvent(1, kvCacheCreatedData, 32); + auto kvCacheEvent2 = serializeDeserialize(kvCacheEvent); + compareKvCacheEvents(kvCacheEvent, kvCacheEvent2); +} + +// Test for KVCacheEvent with KVCacheRemovedData +TEST(SerializeUtilsTest, KVCacheRemovedEvents) +{ + texec::KVCacheEvent kvCacheEvent(1, texec::KVCacheRemovedData{{3, 4}}, 32); + auto kvCacheEvent2 = serializeDeserialize(kvCacheEvent); + compareKvCacheEvents(kvCacheEvent, kvCacheEvent2); +} + +// Test for KVCacheEvent with KVCacheStoredData +TEST(SerializeUtilsTest, KVCacheStoredEvent) +{ + auto storedBlockData1 = texec::KVCacheStoredBlockData(77, {{1, 2}, {3, 4}, {5, 6}}, 88, 0, 99); + auto storedBlockData2 = texec::KVCacheStoredBlockData(99, {{11, 12}, {3, 4}, {15, 6}}, 77, 1, 101); + + texec::KVCacheStoredData kvCacheStoredData{177, {storedBlockData1, storedBlockData2}}; + texec::KVCacheEvent kvCacheEvent(1, kvCacheStoredData, 32); + auto kvCacheEvent2 = serializeDeserialize(kvCacheEvent); + compareKvCacheEvents(kvCacheEvent, kvCacheEvent2); +} + +// Test for KVCacheEvent with KVCacheUpdatedData +TEST(SerializeUtilsTest, KVCacheUpdatedEvent) +{ + texec::KVCacheEventDiff diff{0, 1}; + texec::KVCacheEventDiff diff2{90, 99}; + texec::KVCacheUpdatedData kvCacheUpdatedData(999, diff, diff2); + texec::KVCacheEvent kvCacheEvent(1, kvCacheUpdatedData, 32); + auto kvCacheEvent2 = serializeDeserialize(kvCacheEvent); + compareKvCacheEvents(kvCacheEvent, kvCacheEvent2); +} + +TEST(SerializeUtilsTest, UniqueToken) +{ + tensorrt_llm::runtime::UniqueToken token{1, 2}; + auto token2 = serializeDeserialize(token); + EXPECT_EQ(token.tokenId, token2.tokenId); + EXPECT_EQ(token.tokenExtraId, token2.tokenExtraId); +} + +TEST(SerializeUtilsTest, UniqueTokenVector) +{ + std::vector tokens{{1, 2}, {3, 4}, {5, 6}}; + auto tokens2 = serializeDeserialize(tokens); + EXPECT_EQ(tokens.size(), tokens2.size()); + for (size_t i = 0; i < tokens.size(); ++i) + { + EXPECT_EQ(tokens[i].tokenId, tokens2[i].tokenId); + EXPECT_EQ(tokens[i].tokenExtraId, tokens2[i].tokenExtraId); + } +} + TEST(SerializeUtilsTest, MethodReturnType) { struct S diff --git a/tensorrt_llm/_torch/pyexecutor/resource_manager.py b/tensorrt_llm/_torch/pyexecutor/resource_manager.py index 9f44649b494..f125303973e 100644 --- a/tensorrt_llm/_torch/pyexecutor/resource_manager.py +++ b/tensorrt_llm/_torch/pyexecutor/resource_manager.py @@ -196,6 +196,7 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], from ..speculative import get_num_extra_kv_tokens self.num_extra_kv_tokens = get_num_extra_kv_tokens(spec_config) self.event_buffer_max_size = kv_cache_config.event_buffer_max_size + self.attention_dp_events_gather_period_ms = kv_cache_config.attention_dp_events_gather_period_ms self.max_num_tokens = max_num_tokens # Determine max_attention_window_vec @@ -299,8 +300,17 @@ def append_to_kv_heads_per_layer(num_kv_heads_per_layer: List[int], 'copy_on_partial_reuse': kv_cache_config.copy_on_partial_reuse, } if self.event_buffer_max_size > 0: - kwargs['event_manager'] = KVCacheEventManagerCpp( - max_kv_event_entries=self.event_buffer_max_size) + if mapping.enable_attention_dp: + kwargs['event_manager'] = KVCacheEventManagerCpp( + max_kv_event_entries=self.event_buffer_max_size, + attention_dp_rank=mapping.rank, + attention_dp_size=mapping.world_size, + attention_dp_events_gather_period_ms=self. + attention_dp_events_gather_period_ms, + ) + else: + kwargs['event_manager'] = KVCacheEventManagerCpp( + max_kv_event_entries=self.event_buffer_max_size) self.impl = KVCacheManagerCpp(**kwargs) diff --git a/tensorrt_llm/_utils.py b/tensorrt_llm/_utils.py index 75be2727918..2b9b39f2e58 100644 --- a/tensorrt_llm/_utils.py +++ b/tensorrt_llm/_utils.py @@ -1014,11 +1014,15 @@ def to_json_str(cls, event): if event_serialize_func is None: raise ValueError(f"Unknown KVCache event data type: {event_type}") - return { + json_str = { "event_id": event.event_id, "data": event_serialize_func(event.data), - "window_size": event.window_size + "window_size": event.window_size, } + if event.attention_dp_rank is not None: + json_str["attention_dp_rank"] = event.attention_dp_rank + + return json_str @staticmethod def _created_to_json(data): diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index b7d46ed6fa2..0563bf23add 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -969,6 +969,11 @@ class KvCacheConfig(StrictBaseModel, PybindMirror): description= "Maximum size of the event buffer. If set to 0, the event buffer will not be used." ) + attention_dp_events_gather_period_ms: int = Field( + default=5, + description= + "The period in milliseconds to gather attention DP events across ranks." + ) enable_partial_reuse: bool = Field( default=True, description= @@ -999,7 +1004,10 @@ def _to_pybind(self): event_buffer_max_size=self.event_buffer_max_size, enable_partial_reuse=self.enable_partial_reuse, copy_on_partial_reuse=self.copy_on_partial_reuse, - use_uvm=self.use_uvm) + use_uvm=self.use_uvm, + attention_dp_events_gather_period_ms=self. + attention_dp_events_gather_period_ms, + ) @PybindMirror.mirror_pybind_fields(_ExtendedRuntimePerfKnobConfig) diff --git a/tests/integration/test_lists/test-db/l0_dgx_h200.yml b/tests/integration/test_lists/test-db/l0_dgx_h200.yml index 33542dd8d75..42667225456 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h200.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h200.yml @@ -136,6 +136,7 @@ l0_dgx_h200: - unittest/llmapi/test_llm_multi_gpu.py -m "gpu2 and part3" - unittest/llmapi/test_llm_multi_gpu.py -m "gpu4 and part0" - unittest/llmapi/test_llm_multi_gpu.py -m "not (gpu2 or gpu4)" + - unittest/llmapi/test_llm_kv_cache_events.py::test_llm_api_attention_dp_kv_events - examples/test_enc_dec.py::test_llm_enc_dec_general[compare_hf-t5-small-float32-enable_gemm_plugin-enable_attention_plugin-enable_paged_kv_cache-tp:2-pp:2-nb:1-enable_fp8] - llmapi/test_llm_e2e.py::test_llmapi_exit_multi_gpu - test_e2e.py::test_trtllm_bench_llmapi_launch[trt_backend-llama-v3-llama3-8b] diff --git a/tests/unittest/bindings/test_executor_bindings.py b/tests/unittest/bindings/test_executor_bindings.py index 6dcaa0d9535..8556cf54d69 100644 --- a/tests/unittest/bindings/test_executor_bindings.py +++ b/tests/unittest/bindings/test_executor_bindings.py @@ -1314,6 +1314,7 @@ def test_kv_cache_config(): assert config.enable_partial_reuse == True assert config.copy_on_partial_reuse == True assert config.use_uvm == False + assert config.attention_dp_events_gather_period_ms == 5 config.enable_block_reuse = False config.max_tokens = 1 @@ -1328,6 +1329,7 @@ def test_kv_cache_config(): config.enable_partial_reuse = False config.copy_on_partial_reuse = False config.use_uvm = True + config.attention_dp_events_gather_period_ms = 10 assert config.enable_block_reuse == False assert config.max_tokens == 1 assert config.max_attention_window == [2] @@ -1341,6 +1343,7 @@ def test_kv_cache_config(): assert config.enable_partial_reuse == False assert config.copy_on_partial_reuse == False assert config.use_uvm == True + assert config.attention_dp_events_gather_period_ms == 10 kwargs = { "enable_block_reuse": True, @@ -1354,7 +1357,8 @@ def test_kv_cache_config(): "event_buffer_max_size": 2048, "enable_partial_reuse": True, "copy_on_partial_reuse": False, - "use_uvm": True + "use_uvm": True, + "attention_dp_events_gather_period_ms": 10 } config = trtllm.KvCacheConfig(**kwargs) for k, v in kwargs.items(): diff --git a/tests/unittest/llmapi/test_llm_args.py b/tests/unittest/llmapi/test_llm_args.py index acb831837cd..a0ca2a6fabf 100644 --- a/tests/unittest/llmapi/test_llm_args.py +++ b/tests/unittest/llmapi/test_llm_args.py @@ -162,7 +162,8 @@ def test_KvCacheConfig_declaration(): secondary_offload_min_priority=1, event_buffer_max_size=0, enable_partial_reuse=True, - copy_on_partial_reuse=True) + copy_on_partial_reuse=True, + attention_dp_events_gather_period_ms=10) pybind_config = config._to_pybind() assert pybind_config.enable_block_reuse == True @@ -177,6 +178,7 @@ def test_KvCacheConfig_declaration(): assert pybind_config.event_buffer_max_size == 0 assert pybind_config.enable_partial_reuse == True assert pybind_config.copy_on_partial_reuse == True + assert pybind_config.attention_dp_events_gather_period_ms == 10 def test_KvCacheConfig_default_values(): diff --git a/tests/unittest/llmapi/test_llm_kv_cache_events.py b/tests/unittest/llmapi/test_llm_kv_cache_events.py index 8f7fb75c7f4..48665c3e25a 100644 --- a/tests/unittest/llmapi/test_llm_kv_cache_events.py +++ b/tests/unittest/llmapi/test_llm_kv_cache_events.py @@ -1,6 +1,9 @@ import asyncio import time +import pytest +from utils.util import skip_single_gpu + import tensorrt_llm from tensorrt_llm import LLM from tensorrt_llm._torch.pyexecutor.llm_request import LlmRequest @@ -9,6 +12,7 @@ from tensorrt_llm.llmapi import KvCacheConfig from tensorrt_llm.mapping import Mapping from tensorrt_llm.sampling_params import SamplingParams +from tensorrt_llm.scheduling_params import SchedulingParams from .test_llm import get_model_path @@ -145,33 +149,51 @@ async def main(): asyncio.run(main()) -def test_llm_kv_events_api(): - llm = create_llm() - sampling_params = SamplingParams(max_tokens=6, temperature=0.01) - - requests = [] - for i in range(3): - input_tokens = list(range(127 + i))[i:] - requests.append(input_tokens) +def check_events(llm, + requests, + sampling_params, + scheduling_params=None, + attention_dp_rank=None): - _ = llm.generate(requests[0], sampling_params=sampling_params) - events1 = llm.get_kv_cache_events(5) - - # Should have 1 stored event and 1 created event - event = events1.pop(0) # created event - while events1: - event = events1.pop(0) - if event: - assert event["event_id"] == 1 - assert event["data"]["type"] == "stored" - assert len(event["data"]["blocks"]) == 5 + _ = llm.generate(requests[0], + sampling_params=sampling_params, + scheduling_params=scheduling_params) + time.sleep(1) + events = llm.get_kv_cache_events(5) - _ = llm.generate(requests[1], sampling_params=sampling_params) + # Created or stored event + if attention_dp_rank is None: + event = events.pop(0) # created event + assert event["event_id"] == 0 + assert event["data"]["type"] == "created" + while events: + event = events.pop(0) + if event: + assert event["event_id"] == 1 + assert event["data"]["type"] == "stored" + assert len(event["data"]["blocks"]) == 5 + else: + while events: + event = events.pop(0) + if event and event["attention_dp_rank"] == attention_dp_rank: + assert event["event_id"] in [0, 1] + assert event["data"]["type"] in ["created", "stored"] + if event["data"]["type"] == "created": + assert event["event_id"] == 0 + if event["data"]["type"] == "stored": + assert event["event_id"] == 1 + assert len(event["data"]["blocks"]) == 5 + + _ = llm.generate(requests[1], + sampling_params=sampling_params, + scheduling_params=scheduling_params) + time.sleep(1) events2 = llm.get_kv_cache_events(5) while events2: event = events2.pop(0) - if event: + if event and (event["attention_dp_rank"] == attention_dp_rank + or attention_dp_rank is None): if event["event_id"] == 2: # 2 removed events needed # should be a removed event to make space for context block @@ -185,12 +207,16 @@ def test_llm_kv_events_api(): assert event["data"]["type"] == "stored" assert len(event["data"]["blocks"]) == 5 - _ = llm.generate(requests[2], sampling_params=sampling_params) + _ = llm.generate(requests[2], + sampling_params=sampling_params, + scheduling_params=scheduling_params) + time.sleep(1) events3 = llm.get_kv_cache_events(5) while events3: event = events3.pop(0) - if event: + if event and (event["attention_dp_rank"] == attention_dp_rank + or attention_dp_rank is None): if event["event_id"] == 5: assert event["data"]["type"] == "removed" assert event["data"]["block_hashes"] @@ -203,3 +229,44 @@ def test_llm_kv_events_api(): # no more events after request is finished assert not llm.get_kv_cache_events(5) + + +def test_llm_kv_events_api(): + llm = create_llm() + sampling_params = SamplingParams(max_tokens=6, + temperature=0.01, + ignore_eos=True) + + requests = [] + for i in range(3): + input_tokens = list(range(127 + i))[i:] + requests.append(input_tokens) + + check_events(llm, requests, sampling_params) + + +@skip_single_gpu +@pytest.mark.threadleak(enabled=False) +def test_llm_api_attention_dp_kv_events(): + + llm = LLM(model=llama_model_path, + tensor_parallel_size=2, + enable_attention_dp=True, + kv_cache_config=global_kvcache_config, + enable_autotuner=False) + + sampling_params = SamplingParams(max_tokens=6, + temperature=0.01, + ignore_eos=True) + + for attention_dp_rank in range(2): + requests = [] + for i in range(3): + input_tokens = list(range(127 + i))[i:] + requests.append(input_tokens) + + scheduling_params = SchedulingParams( + attention_dp_rank=attention_dp_rank, attention_dp_relax=False) + + check_events(llm, requests, sampling_params, scheduling_params, + attention_dp_rank)