diff --git a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h index c39fee6f940..934cb39972b 100644 --- a/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h +++ b/cpp/include/tensorrt_llm/batch_manager/cacheTransceiver.h @@ -72,20 +72,20 @@ class CacheTransceiver : public BaseCacheTransceiver public: CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, executor::kv_cache::CacheState::ModelConfig const& cacheStateModelCfg, runtime::WorldConfig const& worldConfig, - nvinfer1::DataType dataType, + std::vector const& attentionLayerNumPerPP, nvinfer1::DataType dataType, executor::kv_cache::CacheState::AttentionType attentionType = executor::kv_cache::CacheState::AttentionType::kDEFAULT, std::optional cacheTransceiverConfig = std::nullopt); CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, std::vector numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, runtime::WorldConfig const& worldConfig, - nvinfer1::DataType dataType, + std::vector const& attentionLayerNumPerPP, nvinfer1::DataType dataType, executor::kv_cache::CacheState::AttentionType attentionType = executor::kv_cache::CacheState::AttentionType::kDEFAULT, std::optional cacheTransceiverConfig = std::nullopt) : CacheTransceiver(cacheManager, executor::kv_cache::CacheState::ModelConfig{numKvHeadsPerLayer, sizePerHead, tokensPerBlock}, worldConfig, - dataType, attentionType, cacheTransceiverConfig) + attentionLayerNumPerPP, dataType, attentionType, cacheTransceiverConfig) { } diff --git a/cpp/include/tensorrt_llm/executor/dataTransceiverState.h b/cpp/include/tensorrt_llm/executor/dataTransceiverState.h index 98b26a276c6..d49447a09a0 100644 --- a/cpp/include/tensorrt_llm/executor/dataTransceiverState.h +++ b/cpp/include/tensorrt_llm/executor/dataTransceiverState.h @@ -48,12 +48,13 @@ class CacheState final kMLA = 1, }; - CacheState(ModelConfig modelConfig, runtime::WorldConfig const& worldConfig, nvinfer1::DataType dataType, + CacheState(ModelConfig modelConfig, runtime::WorldConfig const& worldConfig, + std::vector const& attentionLayerNumPerPP, nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2) : mModelConfig(std::move(modelConfig)) , mParallelConfig{worldConfig.getTensorParallelism(), worldConfig.getPipelineParallelism(), worldConfig.getContextParallelism(), worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(), - worldConfig.getTensorParallelism()} + worldConfig.getTensorParallelism(), attentionLayerNumPerPP} , mDataType{dataType} , mAttentionConfig(attentionType, kvFactor) { @@ -61,10 +62,12 @@ class CacheState final CacheState(std::vector nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism, - nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, - bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0) + std::vector const& attentionLayerNumPerPP, nvinfer1::DataType dataType, + AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false, + int DPrank = 0, int DPsize = 0) : mModelConfig{std::move(nbKvHeadPerLayer), sizePerHead, tokensPerBlock} - , mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize} + , mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize, + attentionLayerNumPerPP} , mDataType{dataType} , mAttentionConfig(attentionType, kvFactor) { @@ -72,10 +75,12 @@ class CacheState final CacheState(SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism, - nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, - bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0) + std::vector const& attentionLayerNumPerPP, nvinfer1::DataType dataType, + AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false, + int DPrank = 0, int DPsize = 0) : mModelConfig{std::vector(nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock} - , mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize} + , mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize, + attentionLayerNumPerPP} , mDataType{dataType} , mAttentionConfig(attentionType, kvFactor) { @@ -108,12 +113,16 @@ class CacheState final bool mEnableAttentionDP; SizeType32 mDPrank; SizeType32 mDPsize; + // number of attention layers per pipeline parallelism rank, the size of the vector is equal to the pipeline + // parallelism size. + std::vector mAttentionLayerNumPerPP; [[nodiscard]] bool operator==(ParallelConfig const& other) const noexcept { return mTensorParallelism == other.mTensorParallelism && mPipelineParallelism == other.mPipelineParallelism && mContextParallelism == other.mContextParallelism && mEnableAttentionDP == other.mEnableAttentionDP - && mDPrank == other.mDPrank && mDPsize == other.mDPsize; + && mDPrank == other.mDPrank && mDPsize == other.mDPsize + && mAttentionLayerNumPerPP == other.mAttentionLayerNumPerPP; } }; diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index e73e0f15411..b5a68b74b4d 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -75,7 +75,6 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques bool CacheFormatter::needSendCache( CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx) { - // int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism; auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx); if (targetInfo.mDupHeadFactor <= 1) { @@ -91,15 +90,27 @@ bool CacheFormatter::needSendCache( selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup; } + // only TP rank % dupHeadFactor == 0 need to send cache. return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0; } void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig, BaseCacheFormatter::CacheState const& destConfig) { + // TODO: VSWA do not support uneven layer per PP. + // if gen PP and context PP are different, cache formatter only support alternative window like gpt-oss. + // which is one layer is WSA, and another layer is Full attention. + auto numPools = cacheManager->getBlockManager().getNumPools(); auto layerNum = cacheManager->getBlockManager().getNumLayers(); + auto selfPPNum = selfConfig.getParallelConfig().mPipelineParallelism; + auto selfAllLayerNum = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(); + auto destPPNum = destConfig.getParallelConfig().mPipelineParallelism; + auto destAllLayerNum = destConfig.getModelConfig().mNbKvHeadsPerLayer.size(); + TLLM_CHECK_WITH_INFO(selfAllLayerNum % selfPPNum == 0, " For VWSA selfAllLayerNum must be divisible by selfPPNum"); + TLLM_CHECK_WITH_INFO(destAllLayerNum % destPPNum == 0, "For VWSA destAllLayerNum must be divisible by destPPNum"); + std::vector poolIdxs(numPools); TLLM_CHECK(layerNum >= numPools); for (int i = 0; i < numPools; i++) @@ -156,6 +167,7 @@ void CacheFormatter::format(TransferSession& session) auto const& destConfig = session.getOtherState().getCacheState().value(); auto const selfIdx = session.getSelfState().getCommState().value().getSelfIdx(); auto& bufferManager = session.getBufferManager(); + // Some TP rank don't need to send cache since duplicate header is not needed. if (!needSendCache(selfConfig, destConfig, selfIdx)) { return; @@ -207,21 +219,22 @@ void CacheFormatter::format(TransferSession& session) int blockNum = 0; size_t allCacheBlockSize = 0; - - std::map> inputKvCacheBlocks; + // gather cache blocks of the request. + std::map> inputKvCacheBlocksPerWindow; for (auto poolIdx = 0; poolIdx < numPools; poolIdx++) { blockRange.updatePoolIdx(poolIdx); SizeType32 window = mCacheManager->getBlockManager().getPoolWindowSize(poolIdx); - TLLM_CHECK_WITH_INFO(inputKvCacheBlocks.find(window) == inputKvCacheBlocks.end(), + TLLM_CHECK_WITH_INFO(inputKvCacheBlocksPerWindow.find(window) == inputKvCacheBlocksPerWindow.end(), "window size already exists, which is not supported"); - inputKvCacheBlocks.emplace(window, std::vector()); + inputKvCacheBlocksPerWindow.emplace(window, std::vector()); auto maxBlockThisWindow = window / selfConfig.getModelConfig().mTokensPerBlock; + // only block in window will be sent. SizeType32 blockNumThisWindow = 0; for (auto it = blockRange.begin(); it != blockRange.end(); ++it) { blockNum++; - inputKvCacheBlocks.at(window).push_back(it); + inputKvCacheBlocksPerWindow.at(window).push_back(it); allCacheBlockSize += it->getSize(); blockNumThisWindow++; if (blockNumThisWindow >= maxBlockThisWindow) @@ -231,7 +244,7 @@ void CacheFormatter::format(TransferSession& session) } } - if (inputKvCacheBlocks.size() > 1) + if (inputKvCacheBlocksPerWindow.size() > 1) { if (selfConfig.getParallelConfig().mPipelineParallelism != destConfig.getParallelConfig().mPipelineParallelism) @@ -239,15 +252,15 @@ void CacheFormatter::format(TransferSession& session) checkAlternateWindow(mCacheManager, selfConfig, destConfig); } } - TLLM_CHECK(!inputKvCacheBlocks.empty()); + TLLM_CHECK(!inputKvCacheBlocksPerWindow.empty()); TLLM_CHECK(blockNum > 0); int deviceId = mCacheManager->getBlockManager().getStreamDevice(); auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx); if (common::getEnvTryZCopyForKVCacheTransfer() && (destConfig.getParallelConfig().mPipelineParallelism - <= selfConfig.getParallelConfig().mPipelineParallelism) - && (destConfig.getParallelConfig().mTensorParallelism <= selfConfig.getParallelConfig().mTensorParallelism)) + == selfConfig.getParallelConfig().mPipelineParallelism) + && (destConfig.getParallelConfig().mTensorParallelism == selfConfig.getParallelConfig().mTensorParallelism)) { TLLM_LOG_DEBUG("Try using zero-copy for the KV cache."); NVTX3_SCOPED_RANGE(sendBufferFun); @@ -257,7 +270,7 @@ void CacheFormatter::format(TransferSession& session) TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); for (size_t i = 0; i < connections.size(); i++) { - for (auto const& [window, blocks] : inputKvCacheBlocks) + for (auto const& [window, blocks] : inputKvCacheBlocksPerWindow) { for (auto const& block : blocks) { @@ -271,80 +284,123 @@ void CacheFormatter::format(TransferSession& session) return; } + // formatter flow + // 1. collect cache blocks of the request. + // 2. compute the buffer size for each target. + // 3. prepare the pre-allocated buffer for each target according to the buffer size. + // 4. call splitKVCacheDispatch to split the cache blocks according to the different parallelis and gather the + // cache blocks to the corresponding buffer. + // 5. send the buffer to the corresponding target. Ideally, we send only once (one buffer) for each target. + auto cacheBufferId = mCacheTransBufferManager->assignBufferIndexForSend(); int peerDuplicateHeadFactor = targetInfo.mPeerDupHeadFactor; auto targetNum = connections.size(); - auto const targetBufferSize = allCacheBlockSize / targetNum * peerDuplicateHeadFactor; auto bufferTargetNum = targetNum / peerDuplicateHeadFactor; - TLLM_LOG_DEBUG(" formatOutput bufferTargetNum: %d, targetNum: %d, peerDuplicateHeadFactor: %d dupliacete:%d ", - bufferTargetNum, targetNum, peerDuplicateHeadFactor, targetInfo.mDupHeadFactor); + auto ppRank = selfIdx + / (selfConfig.getParallelConfig().mTensorParallelism * selfConfig.getParallelConfig().mContextParallelism); + int selfAttentionLayerNum = selfConfig.getParallelConfig().mAttentionLayerNumPerPP.at(ppRank); + + auto getBufferSizeForTarget = [&]() + { + std::vector bufferSizeForTarget(targetNum, 0); + // only first bufferTargetNum is used. + if (inputKvCacheBlocksPerWindow.size() > 1) + { + // for VWSA + for (size_t i = 0; i < targetNum; i++) + { + bufferSizeForTarget[i] = allCacheBlockSize * peerDuplicateHeadFactor / targetNum; + } + return bufferSizeForTarget; + } + for (size_t i = 0; i < targetNum; i++) + { + bufferSizeForTarget[i] = allCacheBlockSize * peerDuplicateHeadFactor / targetInfo.mDomainTPSize + / selfAttentionLayerNum * targetInfo.getPeerPPDomainLayerNum(i); + } + + return bufferSizeForTarget; + }; + auto bufferEleSizes = getBufferSizeForTarget(); auto result = mCacheTransBufferManager->getOrAllocateSendBuffers( - cacheBufferId, bufferTargetNum, targetBufferSize, bufferManager); + cacheBufferId, static_cast(bufferTargetNum), bufferEleSizes, bufferManager); auto& outputSplitCaches = std::get<0>(result); auto& bufferCoverTargetNum = std::get<1>(result); auto& onlyUseDynamicBuffer = std::get<2>(result); + + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), + " format bufferTargetNum: %d, targetNum: %d, peerDuplicateHeadFactor: %d duplicate:%d " + "bufferCoverTargetNum:%d connections.size():%ld", + bufferTargetNum, targetNum, peerDuplicateHeadFactor, targetInfo.mDupHeadFactor, bufferCoverTargetNum, + connections.size()); auto* agentConnnecion = dynamic_cast(connections[0]); if (agentConnnecion != nullptr) { TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == bufferTargetNum, "Agent need all buffer pre-allocated"); TLLM_CHECK(onlyUseDynamicBuffer == false); } - + // TODO: add parameters for layerNumForEachOutput tensorrt_llm::executor::kv_cache::splitKVCacheDispatch( - inputKvCacheBlocks, outputSplitCaches, destConfig, selfConfig, selfIdx, bufferManager); + inputKvCacheBlocksPerWindow, outputSplitCaches, destConfig, selfConfig, selfIdx, bufferManager); bufferManager.getStream().synchronize(); auto preAllocSendBuffer = mCacheTransBufferManager->getSendBuffer(cacheBufferId); if (preAllocSendBuffer != nullptr) { - TLLM_CHECK(preAllocSendBuffer->getDataType() == inputKvCacheBlocks.begin()->second.front()->getDataType()); + TLLM_CHECK(preAllocSendBuffer->getDataType() + == inputKvCacheBlocksPerWindow.begin()->second.front()->getDataType()); } auto sendBufferFun = [&](int deviceId, size_t processIdx) { + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " send processIdx: %ld", processIdx); NVTX3_SCOPED_RANGE(sendBufferFun); TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); TLLM_CHECK(connections.size() > (processIdx / peerDuplicateHeadFactor)); TLLM_CHECK(outputSplitCaches.size() > (processIdx / peerDuplicateHeadFactor)); auto startTime = std::chrono::steady_clock::now(); - size_t size; size_t ppDomainSize = targetInfo.mDomainPPSize; size_t bufferTpRank = (processIdx / ppDomainSize) / peerDuplicateHeadFactor; size_t bufferIdx = (bufferTpRank * ppDomainSize) + (processIdx % ppDomainSize); - if (bufferIdx < bufferCoverTargetNum) - { + size_t size = outputSplitCaches[bufferIdx]->getSizeInBytes(); - size = outputSplitCaches[bufferIdx]->getSizeInBytes(); - session.send(processIdx, outputSplitCaches[bufferIdx]->data(), size); - } - else if (bufferCoverTargetNum > 0) + if (bufferIdx < bufferCoverTargetNum) { - // copy buffer allocated by cudaMallocAsync to buffer allocated by cudaMalloc before sending - auto sendBufferIdx = bufferIdx % bufferCoverTargetNum; - bufferManager.copy(*outputSplitCaches[processIdx], *outputSplitCaches.at(sendBufferIdx)); - bufferManager.getStream().synchronize(); - size = outputSplitCaches.at(sendBufferIdx)->getSizeInBytes(); - session.send(processIdx, outputSplitCaches.at(sendBufferIdx)->data(), size); + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " send processIdx: %d bufferIdx: %d size:%ld", + processIdx, bufferIdx, outputSplitCaches[bufferIdx]->getSizeInBytes()); + session.send( + processIdx, outputSplitCaches[bufferIdx]->data(), outputSplitCaches[bufferIdx]->getSizeInBytes()); + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " end send processIdx: %d bufferIdx: %d size:%ld", + processIdx, bufferIdx, outputSplitCaches[bufferIdx]->getSizeInBytes()); } else { + + // If cacheIdx< bufferCoverTargetNum, the ouputSplitCaches.at(cacheIdx) is allocated by cudaMallocAsync, + // which is unable to be transferred by UCX GPU-direct RDMA. We need copy the data to pre-allocated + // cudaMalloc buffer,and then start send. // bufferCoverTargetNum == 0, mSendBuffer size < one outputSlice // send multiple times - size = targetBufferSize; - size_t remainSendSize = targetBufferSize; + + size_t remainSendSize = outputSplitCaches[processIdx]->getSize(); + size_t needSendSize = outputSplitCaches[processIdx]->getSize(); + auto sendBufferIdx = bufferCoverTargetNum == 0 ? 0 : bufferIdx % bufferCoverTargetNum; + + auto sendUseAllocBuffer + = bufferCoverTargetNum == 0 ? preAllocSendBuffer : outputSplitCaches[sendBufferIdx]; while (remainSendSize > 0) { - TLLM_CHECK(preAllocSendBuffer != nullptr); - auto sendBufferEleSize = preAllocSendBuffer->getSize(); + TLLM_CHECK(sendUseAllocBuffer != nullptr); + auto sendBufferEleSize = sendUseAllocBuffer->getSize(); auto sendSize = std::min(remainSendSize, sendBufferEleSize); auto copySlice = runtime::ITensor::slice( - outputSplitCaches[bufferIdx], targetBufferSize - remainSendSize, sendSize); + outputSplitCaches[bufferIdx], needSendSize - remainSendSize, sendSize); - auto copyTargetSlice = runtime::ITensor::slice(preAllocSendBuffer, 0, sendSize); + auto copyTargetSlice = runtime::ITensor::slice(sendUseAllocBuffer, 0, sendSize); bufferManager.copy(*copySlice, *copyTargetSlice); bufferManager.getStream().synchronize(); session.send(processIdx, copyTargetSlice->data(), copyTargetSlice->getSizeInBytes()); @@ -376,7 +432,7 @@ void CacheFormatter::format(TransferSession& session) } else { - // concurrency num + // concurrency num should <=bufferCoverTargetNum to avoid data-race. auto concurrencyNum = std::min(std::max(static_cast(1), bufferCoverTargetNum), connections.size()); @@ -462,6 +518,7 @@ void CacheFormatter::unformat(TransferSession& session) TLLM_CHECK(!outputBuffersPerWindow.empty()); if (outputBuffersPerWindow.size() > 1) { + // We only support limited case for VSWA. if (selfConfig.getParallelConfig().mPipelineParallelism != destConfig.getParallelConfig().mPipelineParallelism) { checkAlternateWindow(mCacheManager, selfConfig, destConfig); @@ -560,13 +617,13 @@ void CacheFormatter::unformat(TransferSession& session) ctxReqId); return; } - // legacyPath: context executor rank only send data to one gen executor rank. it sends multiple cache - // blocks. - auto legacyPath = common::getEnvTryZCopyForKVCacheTransfer() - && (destConfig.getParallelConfig().mPipelineParallelism - >= selfConfig.getParallelConfig().mPipelineParallelism) - && (destConfig.getParallelConfig().mTensorParallelism - >= selfConfig.getParallelConfig().mTensorParallelism); + // unformatted flow + // 1. collect cache blocks of the request. + // 2. compute the buffer size for each target. + // 3. prepare the pre-allocated buffer for each target according to the buffer size. + // 4. receive the buffer from the corresponding target. Ideally, we receive only once (one buffer) for each + // target. + // 5. call concatKvCacheV2Dispatch to concatenate the cache blocks according to the different parallelis runtime::ITensor::SharedPtr recvBufferTemp; std::vector recvSplitCaches; @@ -574,7 +631,44 @@ void CacheFormatter::unformat(TransferSession& session) auto dataType = outputBuffersPerWindow.begin()->second.front()->getDataType(); auto targetNum = pickUpConnections.size(); TLLM_CHECK(cacheBlockSizeSum % targetNum == 0); - auto targetBufferSize = cacheBlockSizeSum / targetNum; + auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx); + auto ppRank = selfIdx + / (selfConfig.getParallelConfig().mTensorParallelism + * selfConfig.getParallelConfig().mContextParallelism); + int selfAttentionLayerNum = selfConfig.getParallelConfig().mAttentionLayerNumPerPP.at(ppRank); + auto getTargetBufferEleSize = [&]() + { + if (outputBuffersPerWindow.size() > 1) + { + std::vector bufferSizeForTarget(targetNum, 0); + for (size_t i = 0; i < targetNum; i++) + { + bufferSizeForTarget[i] = cacheBlockSizeSum / targetNum; + } + return bufferSizeForTarget; + } + // for duplicate header, gen will not recv from TP which has duplicate header, and will not prepare + // buffer for it. + size_t validTpSize = pickUpConnections.size() / targetInfo.mDomainPPSize; + TLLM_CHECK_WITH_INFO(cacheBlockSizeSum % validTpSize == 0, + "cacheBlockSizeSum must be divisible by validTpSize %ld", validTpSize); + TLLM_CHECK_WITH_INFO((cacheBlockSizeSum % (selfAttentionLayerNum * validTpSize)) == 0, + "cacheBlockSizeSum must be divisible by validTpSize %ld * selfAttentionLayerNum %d", validTpSize, + selfAttentionLayerNum); + TLLM_CHECK(targetNum == pickUpConnections.size()); + // the sum of buffer size is cacheBlockSizeSum. + size_t cacheBlockSizePerLayer = cacheBlockSizeSum / (validTpSize * selfAttentionLayerNum); + + std::vector bufferEleSizes(targetNum, 0); + + for (size_t i = 0; i < targetNum; i++) + { + auto layerNum = targetInfo.getPeerPPDomainLayerNum(static_cast(pickUpConnections[i])); + bufferEleSizes[i] = cacheBlockSizePerLayer * layerNum; + } + return bufferEleSizes; + }; + auto bufferEleSizes = getTargetBufferEleSize(); size_t remainNoCoverTargetNum = 0; size_t bufferCoverTargetNum = 0; @@ -583,49 +677,31 @@ void CacheFormatter::unformat(TransferSession& session) NVTX3_SCOPED_RANGE(formatInputAllocBuffer); TLLM_CHECK(blockNum > 0); - if (legacyPath) - { - - TLLM_LOG_DEBUG("formatOutput using legacy path"); - auto cacheShape = executor::kv_cache::makeShapeFromCacheState(destConfig); - auto cacheVolume = runtime::ITensor::volume(cacheShape); - size_t bufferNum = blockNum * pickUpConnections.size(); - recvBufferTemp = bufferManager.gpu( - runtime::ITensor::makeShape({static_cast(cacheVolume * bufferNum)}), dataType); - recvSplitCaches.resize(bufferNum); - for (size_t i = 0; i < bufferNum; i++) - { - recvSplitCaches[i] = runtime::ITensor::slice(recvBufferTemp, i * cacheVolume, cacheVolume); - } + auto* agentConnnecion + = dynamic_cast(connections[pickUpConnections[0]]); + if (agentConnnecion != nullptr) + { + cacheBufferId = agentConnnecion->getCacheBufferId(); + TLLM_CHECK(cacheBufferId.has_value()); } else { - auto* agentConnnecion - = dynamic_cast(connections[pickUpConnections[0]]); - if (agentConnnecion != nullptr) - { - cacheBufferId = agentConnnecion->getCacheBufferId(); - TLLM_CHECK(cacheBufferId.has_value()); - } - else - { - cacheBufferId = mCacheTransBufferManager->assignBufferIndexForRecv(); - } - TLLM_CHECK(cacheBufferId.has_value()); - auto [recvSplitCachestmp, bufferCoverTargetNumtmp, onlyUseDynamicBuffer] - = mCacheTransBufferManager->getOrAllocateRecvBuffers( - cacheBufferId, targetNum, targetBufferSize, bufferManager); - bufferCoverTargetNum = bufferCoverTargetNumtmp; - remainNoCoverTargetNum = targetNum > bufferCoverTargetNum ? targetNum - bufferCoverTargetNum : 0; - - if (agentConnnecion != nullptr) - { - TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == targetNum, "Agent need buffer pre-allocated"); - TLLM_CHECK(onlyUseDynamicBuffer == false); - } - recvSplitCaches = std::move(recvSplitCachestmp); + cacheBufferId = mCacheTransBufferManager->assignBufferIndexForRecv(); + } + TLLM_CHECK(cacheBufferId.has_value()); + auto [recvSplitCachestmp, bufferCoverTargetNumtmp, onlyUseDynamicBuffer] + = mCacheTransBufferManager->getOrAllocateRecvBuffers( + cacheBufferId, static_cast(targetNum), bufferEleSizes, bufferManager); + bufferCoverTargetNum = bufferCoverTargetNumtmp; + remainNoCoverTargetNum = targetNum > bufferCoverTargetNum ? targetNum - bufferCoverTargetNum : 0; + + if (agentConnnecion != nullptr) + { + TLLM_CHECK_WITH_INFO(bufferCoverTargetNum == targetNum, "Agent need buffer pre-allocated"); + TLLM_CHECK(onlyUseDynamicBuffer == false); } + recvSplitCaches = std::move(recvSplitCachestmp); // sync to alloc buffer bufferManager.getStream().synchronize(); @@ -647,63 +723,45 @@ void CacheFormatter::unformat(TransferSession& session) TLLM_CHECK(recvSplitCaches.size() > processIdx); auto startTime = std::chrono::steady_clock::now(); size_t size = 0; - if (legacyPath) - { - size_t idx = processIdx * blockNum; - for (size_t i = 0; i < blockNum; i++) - { - size_t commIdx = idx / (blockNum); - size_t blockIdx = idx % (blockNum); - size_t recvBufferIdx = blockIdx * pickUpConnections.size() + commIdx; - llmRequest.updateKvCacheSize((*recvSplitCaches[recvBufferIdx]).getSizeInBytes()); - auto& buffer = recvSplitCaches.at(recvBufferIdx); - size += buffer->getSizeInBytes(); - session.recv(pickUpConnections[processIdx], buffer->data(), buffer->getSizeInBytes()); - idx++; - } + if (processIdx >= remainNoCoverTargetNum) + { + llmRequest.updateKvCacheSize((*recvSplitCaches.at(processIdx)).getSizeInBytes()); + auto& buffer = recvSplitCaches[processIdx]; + size = buffer->getSizeInBytes(); + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " start recv bufferIdx: %d size:%ld", processIdx, + buffer->getSizeInBytes()); + session.recv(pickUpConnections[processIdx], buffer->data(), buffer->getSizeInBytes()); + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), " recv bufferIdx: %d size:%ld", processIdx, + buffer->getSizeInBytes()); } else { - if (processIdx >= remainNoCoverTargetNum) + auto recvBufferIdx + = bufferCoverTargetNum == 0 ? 0 : processIdx % bufferCoverTargetNum + remainNoCoverTargetNum; + // bufferCoverTargetNum == 0 + auto recvBufferUsed + = bufferCoverTargetNum == 0 ? preAllocRecvBuffer : recvSplitCaches[recvBufferIdx]; + + size_t remainRecvSize = recvSplitCaches[processIdx]->getSize(); + size_t needRecvSize = recvSplitCaches[processIdx]->getSize(); + while (remainRecvSize > 0) { - llmRequest.updateKvCacheSize((*recvSplitCaches.at(processIdx)).getSizeInBytes()); - auto& buffer = recvSplitCaches[processIdx]; - size = buffer->getSizeInBytes(); - session.recv(pickUpConnections[processIdx], buffer->data(), buffer->getSizeInBytes()); - } - else if (bufferCoverTargetNum > 0) - { - auto recvBufferIdx = processIdx % bufferCoverTargetNum - + remainNoCoverTargetNum; // caches.at(recvBufferIdx) is allocated by cudaMalloc - llmRequest.updateKvCacheSize((*recvSplitCaches.at(recvBufferIdx)).getSizeInBytes()); - auto& buffer = recvSplitCaches.at(recvBufferIdx); - size = buffer->getSizeInBytes(); - session.recv(pickUpConnections[processIdx], buffer->data(), buffer->getSizeInBytes()); - bufferManager.copy(*recvSplitCaches.at(recvBufferIdx), *recvSplitCaches[processIdx]); + TLLM_CHECK(recvBufferUsed != nullptr); + auto recvBufferEleSize = recvBufferUsed->getSize(); + auto recvSize = std::min(remainRecvSize, recvBufferEleSize); + auto recvSlice = runtime::ITensor::slice(recvBufferUsed, 0, recvSize); + auto copySlice = runtime::ITensor::slice( + recvSplitCaches[processIdx], needRecvSize - remainRecvSize, recvSize); + size += recvSlice->getSizeInBytes(); + llmRequest.updateKvCacheSize((*recvSlice).getSizeInBytes()); + session.recv(pickUpConnections[processIdx], recvSlice->data(), recvSlice->getSizeInBytes()); + bufferManager.copy(*recvSlice, *copySlice); bufferManager.getStream().synchronize(); - } - else - { - // bufferCoverTargetNum == 0 - size_t remainRecvSize = targetBufferSize; - while (remainRecvSize > 0) - { - TLLM_CHECK(preAllocRecvBuffer != nullptr); - auto recvBufferEleSize = preAllocRecvBuffer->getSize(); - auto recvSize = std::min(remainRecvSize, recvBufferEleSize); - auto recvSlice = runtime::ITensor::slice(preAllocRecvBuffer, 0, recvSize); - auto copySlice = runtime::ITensor::slice( - recvSplitCaches[processIdx], targetBufferSize - remainRecvSize, recvSize); - size += recvSlice->getSizeInBytes(); - llmRequest.updateKvCacheSize((*recvSlice).getSizeInBytes()); - session.recv(pickUpConnections[processIdx], recvSlice->data(), recvSlice->getSizeInBytes()); - bufferManager.copy(*recvSlice, *copySlice); - bufferManager.getStream().synchronize(); - remainRecvSize -= recvSize; - } + remainRecvSize -= recvSize; } } + auto endTime = std::chrono::steady_clock::now(); double delay = 0.0; if (recordDelay) @@ -764,19 +822,9 @@ void CacheFormatter::unformat(TransferSession& session) { NVTX3_SCOPED_RANGE(formatInputConcatenate); - if (legacyPath) - { - TLLM_CHECK(outputBuffersPerWindow.size() == 1); - executor::kv_cache::concatKVCacheDispatch(recvSplitCaches.data(), recvSplitCaches.size(), - getCounterparts(selfConfig, selfIdx, destConfig), destConfig, - outputBuffersPerWindow.begin()->second.data(), outputBuffersPerWindow.begin()->second.size(), - selfIdx, selfConfig, bufferManager); - } - else - { - executor::kv_cache::concatKvCacheV2Dispatch( - recvSplitCaches, outputBuffersPerWindow, destConfig, selfConfig, selfIdx, bufferManager); - } + executor::kv_cache::concatKvCacheV2Dispatch( + recvSplitCaches, outputBuffersPerWindow, destConfig, selfConfig, selfIdx, bufferManager); + bufferManager.getStream().synchronize(); if (cacheBufferId.has_value()) { @@ -852,27 +900,6 @@ void CacheFormatter::unformat(TransferSession& session) destConfig.getModelConfig().mNbKvHeadsPerLayer.size()); return false; } - int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(); - int selfPPSize = selfConfig.getParallelConfig().mPipelineParallelism; - int destPPSize = destConfig.getParallelConfig().mPipelineParallelism; - int destNumLayers = destConfig.getModelConfig().mNbKvHeadsPerLayer.size(); - - if (selfPPSize == destPPSize) - { - return true; - } - if (selfNumLayers % selfPPSize != 0) - { - TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d", - selfNumLayers, selfPPSize); - return false; - } - if (destNumLayers % destPPSize != 0) - { - TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d ", - destNumLayers, destPPSize); - return false; - } return true; } diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp index 1a3aed54f41..33986426f54 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp @@ -295,17 +295,19 @@ void CacheTransBufferManager::freeBufferIndexForRecv(std::optional bufferId } std::tuple, size_t, bool> CacheTransBufferManager::getOrAllocateSendBuffers( - std::optional bufferId, int targetNum, size_t targetBufferSize, + std::optional bufferId, int targetNum, std::vector const& targetBufferEleSizes, runtime::BufferManager const& bufferManagerToUse) { - return getOrAllocateBuffers(bufferId, targetNum, targetBufferSize, bufferManagerToUse, mConcurrenceSendResource); + return getOrAllocateBuffers( + bufferId, targetNum, targetBufferEleSizes, bufferManagerToUse, mConcurrenceSendResource); } std::tuple, size_t, bool> CacheTransBufferManager::getOrAllocateRecvBuffers( - std::optional bufferId, int targetNum, size_t targetBufferSize, + std::optional bufferId, int targetNum, std::vector const& targetBufferEleSizes, runtime::BufferManager const& bufferManagerToUse) { - return getOrAllocateBuffers(bufferId, targetNum, targetBufferSize, bufferManagerToUse, mConcurrenceRecvResource); + return getOrAllocateBuffers( + bufferId, targetNum, targetBufferEleSizes, bufferManagerToUse, mConcurrenceRecvResource); } runtime::ITensor::SharedPtr CacheTransBufferManager::getSendBuffer(std::optional bufferId) @@ -332,54 +334,58 @@ runtime::ITensor::SharedPtr CacheTransBufferManager::getRecvBuffer(std::optional } std::tuple, size_t, bool> CacheTransBufferManager::getOrAllocateBuffers( - std::optional bufferId, int targetNum, size_t targetBufferEleSize, + std::optional bufferId, int targetNum, std::vector const& targetBufferEleSizes, runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource) { TLLM_CHECK(bufferId.has_value() || mOnlyUseDynamicBuffer); + TLLM_CHECK(targetBufferEleSizes.size() >= static_cast(targetNum)); std::vector retSplitCaches; - size_t bufferCoverTargetNum = std::min( - static_cast(targetNum), mTransferBufferSize / (targetBufferEleSize * common::getDTypeSize(mDataType))); - TLLM_LOG_DEBUG("getOrAllocateBuffers bufferCoverTargetNum:%d", bufferCoverTargetNum); - if (bufferCoverTargetNum < static_cast(targetNum)) - { - TLLM_LOG_WARNING( - "CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:%d < targetNum:%d, may use dynamic buffer, " - "it's better to increase MaxTokensInBuffer in cacheTransceiverConfig, otherwise, the performance may " - "be degraded", - bufferCoverTargetNum, targetNum); - } + + size_t bufferCoverTargetNum = 0; + if (bufferId.has_value()) { TLLM_CHECK(static_cast(bufferId.value()) < concurrenceResource.mBuffers.size()); TLLM_CHECK(concurrenceResource.mBufferIndexFlag[bufferId.value()] == 1); - + size_t preBufferEleSize = 0; for (int i = 0; i < targetNum; i++) { - if (static_cast(i) < bufferCoverTargetNum) + // Strict checking. + if (preBufferEleSize + targetBufferEleSizes[i] <= mBufferEleSize) { auto slice = runtime::ITensor::slice( - concurrenceResource.mBuffers[bufferId.value()], i * targetBufferEleSize, targetBufferEleSize); + concurrenceResource.mBuffers[bufferId.value()], preBufferEleSize, targetBufferEleSizes[i]); + preBufferEleSize += targetBufferEleSizes[i]; + bufferCoverTargetNum++; retSplitCaches.push_back(std::move(slice)); } else { retSplitCaches.push_back(bufferManagerToUse.gpu( - runtime::ITensor::makeShape({static_cast(targetBufferEleSize)}), mDataType)); + runtime::ITensor::makeShape({static_cast(targetBufferEleSizes[i])}), mDataType)); } } + TLLM_LOG_DEBUG("getOrAllocateBuffers bufferCoverTargetNum:%d", bufferCoverTargetNum); + if (bufferCoverTargetNum < static_cast(targetNum)) + { + TLLM_LOG_WARNING( + "CacheTransceiver getOrAllocateBuffers: bufferCoverTargetNum:%d < targetNum:%d, may use dynamic " + "buffer, " + "it's better to increase MaxTokensInBuffer in cacheTransceiverConfig, otherwise, the performance may " + "be degraded", + bufferCoverTargetNum, targetNum); + } } else { for (int i = 0; i < targetNum; i++) { retSplitCaches.push_back(bufferManagerToUse.gpu( - runtime::ITensor::makeShape({static_cast(targetBufferEleSize)}), mDataType)); + runtime::ITensor::makeShape({static_cast(targetBufferEleSizes[i])}), mDataType)); } - } - if (mOnlyUseDynamicBuffer) - { bufferCoverTargetNum = targetNum; } + return std::make_tuple(retSplitCaches, bufferCoverTargetNum, mOnlyUseDynamicBuffer); } diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h index e7b050388fe..780cfc4ab9e 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h +++ b/cpp/tensorrt_llm/batch_manager/cacheTransBuffer.h @@ -69,11 +69,11 @@ class CacheTransBufferManager void freeBufferIndexForRecv(std::optional bufferId); std::tuple, size_t, bool> getOrAllocateSendBuffers( - std::optional bufferId, int targetNum, size_t targetBufferSize, + std::optional bufferId, int targetNum, std::vector const& targetBufferEleSizes, runtime::BufferManager const& bufferManagerToUse); std::tuple, size_t, bool> getOrAllocateRecvBuffers( - std::optional bufferId, int targetNum, size_t targetBufferSize, + std::optional bufferId, int targetNum, std::vector const& targetBufferEleSizes, runtime::BufferManager const& bufferManagerToUse); runtime::ITensor::SharedPtr getSendBuffer(std::optional bufferId); @@ -92,8 +92,8 @@ class CacheTransBufferManager }; std::tuple, size_t, bool> getOrAllocateBuffers(std::optional bufferId, - int targetNum, size_t targetBufferEleSize, runtime::BufferManager const& bufferManagerToUse, - ConcurrenceResource& concurrenceResource); + int targetNum, std::vector const& targetBufferEleSizes, + runtime::BufferManager const& bufferManagerToUse, ConcurrenceResource& concurrenceResource); void allocateBuffer(); std::optional assignBufferIndex(ConcurrenceResource& resource, size_t bufferCount, bool onlyUseDynamicBuffer); diff --git a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp index 48ac605a3fd..d42b792c14c 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheTransceiver.cpp @@ -96,13 +96,22 @@ std::unique_ptr CacheTransceiverFactory::createCacheTransc executor::kv_cache::CacheState::ModelConfig cacheStateCfg{ modelConfig.getNumKvHeadsPerLayer(), modelConfig.getSizePerHead(), modelConfig.getTokensPerBlock()}; - return std::make_unique( - cacheManager, cacheStateCfg, worldConfig, modelConfig.getKvDataType(), attentionType, cacheTransceiverConfig); + auto ppSize = worldConfig.getPipelineParallelism(); + + std::vector attentionLayerNumPerPP(ppSize, 0); + for (int ppRank = 0; ppRank < ppSize; ppRank++) + { + attentionLayerNumPerPP[ppRank] = modelConfig.getNbAttentionLayers(ppSize, ppRank); + } + + return std::make_unique(cacheManager, cacheStateCfg, worldConfig, attentionLayerNumPerPP, + modelConfig.getKvDataType(), attentionType, cacheTransceiverConfig); } CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheManager, executor::kv_cache::CacheState::ModelConfig const& cacheStateModelCfg, runtime::WorldConfig const& worldConfig, - nvinfer1::DataType dataType, executor::kv_cache::CacheState::AttentionType attentionType, + std::vector const& attentionLayerNumPerPP, nvinfer1::DataType dataType, + executor::kv_cache::CacheState::AttentionType attentionType, std::optional cacheTransceiverConfig) : mMpiGroupComm(std::addressof(tensorrt_llm::mpi::MpiComm::session())) , mCacheTransceiverConfig{cacheTransceiverConfig} @@ -124,7 +133,7 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa kvFactor = 1; } mCacheState = std::make_unique( - cacheStateModelCfg, worldConfig, dataType, attentionType, kvFactor); + cacheStateModelCfg, worldConfig, attentionLayerNumPerPP, dataType, attentionType, kvFactor); if (mCacheState->getParallelConfig().mEnableAttentionDP) { @@ -177,7 +186,7 @@ CacheTransceiver::CacheTransceiver(kv_cache_manager::BaseKVCacheManager* cacheMa else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::NIXL) { mManager = std::make_unique( - mCacheTransBufferManager.get()); + mCacheTransBufferManager.get(), *mCacheState); TLLM_LOG_INFO("NIXL Connection Manager created"); } else if (backendType.value() == executor::CacheTransceiverConfig::BackendType::MPI) diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index eaa2e957e87..474a0614d73 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -153,12 +153,24 @@ void MLACacheFormatter::format(TransferSession& session) // diff start auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx); - - size_t const pPDomainSize = targetInfo.mDomainPPSize; - TLLM_CHECK((cacheBlockSize * blockNum) % pPDomainSize == 0); - auto const targetBufferSize = (cacheBlockSize * blockNum) / pPDomainSize; + auto ppRank = selfIdx + / (selfConfig.getParallelConfig().mTensorParallelism * selfConfig.getParallelConfig().mContextParallelism); + int selfAttentionLayerNum = selfConfig.getParallelConfig().mAttentionLayerNumPerPP.at(ppRank); + size_t pPDomainSize = targetInfo.mDomainPPSize; + auto getBufferSizeForTarget = [&]() + { + std::vector bufferSizeForTarget(pPDomainSize, 0); + size_t cacheSizePerLayer = cacheBlockSize * blockNum / selfAttentionLayerNum; + for (size_t i = 0; i < pPDomainSize; i++) + { + auto layerNum = targetInfo.getPeerPPDomainLayerNum(i); + bufferSizeForTarget[i] = cacheSizePerLayer * layerNum; + } + return bufferSizeForTarget; + }; + auto bufferEleSizes = getBufferSizeForTarget(); auto result = mCacheTransBufferManager->getOrAllocateSendBuffers( - cacheBufferId, pPDomainSize, targetBufferSize, bufferManager); + cacheBufferId, static_cast(pPDomainSize), bufferEleSizes, bufferManager); auto& outputSplitCaches = std::get<0>(result); auto& bufferCoverTargetNum = std::get<1>(result); auto& onlyUseDynamicBuffer = std::get<2>(result); @@ -192,35 +204,30 @@ void MLACacheFormatter::format(TransferSession& session) TLLM_CUDA_CHECK(cudaSetDevice(deviceId)); auto startTime = std::chrono::steady_clock::now(); auto cacheIdx = processIdx % pPDomainSize; - size_t size; if (cacheIdx < bufferCoverTargetNum) { - size = outputSplitCaches.at(cacheIdx)->getSizeInBytes(); + size_t size = outputSplitCaches.at(cacheIdx)->getSizeInBytes(); session.send(processIdx, outputSplitCaches.at(cacheIdx)->data(), size); } - else if (bufferCoverTargetNum > 0) - { - // copy buffer allocated by cudaMallocAsync to buffer allocated by cudaMalloc before sending - auto sendBufferIdx = cacheIdx % bufferCoverTargetNum; - size = outputSplitCaches.at(sendBufferIdx)->getSizeInBytes(); - bufferManager.copy(*outputSplitCaches.at(cacheIdx), *outputSplitCaches.at(sendBufferIdx)); - bufferManager.getStream().synchronize(); - session.send(processIdx, outputSplitCaches.at(sendBufferIdx)->data(), size); - } else { - // bufferCoverTargetNum=0, mSendBuffer size < one outputSlice - // send multiple times - size = targetBufferSize; - size_t remainSendSize = targetBufferSize; + + // If cacheIdx< bufferCoverTargetNum, the ouputSplitCaches.at(cacheIdx) is allocated by cudaMallocAsync, + // which is unable to be transferred by UCX GPU-direct RDMA. We need copy the data to pre-allocated + // cudaMalloc buffer,and then start send. + // bufferCoverTargetNum=0, mSendBuffer size < one outputSlice send multiple times + size_t remainSendSize = outputSplitCaches.at(cacheIdx)->getSize(); + size_t needSendSize = outputSplitCaches.at(cacheIdx)->getSize(); + auto sendBufferIdx = bufferCoverTargetNum == 0 ? 0 : cacheIdx % bufferCoverTargetNum; + auto sendBufferUsed = bufferCoverTargetNum == 0 ? preAllocSendBuffer : outputSplitCaches.at(sendBufferIdx); while (remainSendSize > 0) { - TLLM_CHECK(preAllocSendBuffer != nullptr); - auto sendBufferEleSize = preAllocSendBuffer->getSize(); + TLLM_CHECK(sendBufferUsed != nullptr); + auto sendBufferEleSize = sendBufferUsed->getSize(); auto sendSize = std::min(remainSendSize, sendBufferEleSize); - auto copySlice = runtime::ITensor::slice( - outputSplitCaches.at(cacheIdx), targetBufferSize - remainSendSize, sendSize); - auto copyTargetSlice = runtime::ITensor::slice(preAllocSendBuffer, 0, sendSize); + auto copySlice + = runtime::ITensor::slice(outputSplitCaches.at(cacheIdx), needSendSize - remainSendSize, sendSize); + auto copyTargetSlice = runtime::ITensor::slice(sendBufferUsed, 0, sendSize); bufferManager.copy(*copySlice, *copyTargetSlice); bufferManager.getStream().synchronize(); session.send(processIdx, copyTargetSlice->data(), copyTargetSlice->getSizeInBytes()); @@ -236,7 +243,7 @@ void MLACacheFormatter::format(TransferSession& session) } double cacheTransferTime = std::max(0.0, std::chrono::duration(endTime - startTime).count()); - session.appendMeasure(delay, cacheTransferTime, size); + session.appendMeasure(delay, cacheTransferTime, outputSplitCaches.at(cacheIdx)->getSizeInBytes()); }; if (connections.size() > 1) @@ -360,10 +367,27 @@ void MLACacheFormatter::unformat(TransferSession& session) auto cacheBlockSize = outputBuffers.at(0)->getSize(); auto targetNum = pickUpConnections.size(); - TLLM_CHECK((cacheBlockSize * blockNum) % targetNum == 0); - auto targetBufferSize = (cacheBlockSize * blockNum) / targetNum; + auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx); + auto ppRank = selfIdx + / (selfConfig.getParallelConfig().mTensorParallelism * selfConfig.getParallelConfig().mContextParallelism); + auto selfAttentionLayerNum = selfConfig.getParallelConfig().mAttentionLayerNumPerPP.at(ppRank); + TLLM_CHECK_WITH_INFO(selfAttentionLayerNum != 0, "selfAttentionLayerNum should not be 0"); + + auto getBufferSizeForTarget = [&]() + { + std::vector bufferEleSizes(targetNum, 0); + auto cacheSizePerLayer = cacheBlockSize * blockNum / selfAttentionLayerNum; + for (size_t i = 0; i < targetNum; i++) + { + auto layerNum = targetInfo.getPeerPPDomainLayerNum(static_cast(pickUpConnections[i])); + bufferEleSizes[i] = cacheSizePerLayer * layerNum; + } + return bufferEleSizes; + }; + auto bufferEleSizes = getBufferSizeForTarget(); + auto result = mCacheTransBufferManager->getOrAllocateRecvBuffers( - cacheBufferId, targetNum, targetBufferSize, bufferManager); + cacheBufferId, static_cast(targetNum), bufferEleSizes, bufferManager); auto& recvSplitCaches = std::get<0>(result); auto& bufferCoverTargetNum = std::get<1>(result); size_t remainNoCoverTargetNum = targetNum > bufferCoverTargetNum ? targetNum - bufferCoverTargetNum : 0; @@ -394,29 +418,22 @@ void MLACacheFormatter::unformat(TransferSession& session) size = buffer->getSizeInBytes(); session.recv(pickUpConnections.at(processIdx), buffer->data(), buffer->getSizeInBytes()); } - else if (bufferCoverTargetNum > 0) - { - auto recvBufferIdx = processIdx % bufferCoverTargetNum - + remainNoCoverTargetNum; // caches.at(recvBufferIdx) is allocated by cudaMalloc - auto& buffer = recvSplitCaches.at(recvBufferIdx); - llmRequest.updateKvCacheSize(buffer->getSizeInBytes()); - size = buffer->getSizeInBytes(); - session.recv(pickUpConnections.at(processIdx), buffer->data(), buffer->getSizeInBytes()); - bufferManager.copy(*recvSplitCaches.at(recvBufferIdx), *recvSplitCaches.at(processIdx)); - bufferManager.getStream().synchronize(); - } else { + auto recvBufferIdx + = bufferCoverTargetNum == 0 ? 0 : processIdx % bufferCoverTargetNum + remainNoCoverTargetNum; + auto recvBufferUsed = bufferCoverTargetNum == 0 ? preAllocRecvBuffer : recvSplitCaches[recvBufferIdx]; // bufferCoverTargetNum==0 - size_t remainRecvSize = targetBufferSize; + size_t remainRecvSize = recvBufferUsed->getSize(); + size_t needRecvSize = recvSplitCaches.at(processIdx)->getSize(); while (remainRecvSize > 0) { - TLLM_CHECK(preAllocRecvBuffer != nullptr); - auto recvBufferEleSize = preAllocRecvBuffer->getSize(); + TLLM_CHECK(recvBufferUsed != nullptr); + auto recvBufferEleSize = recvBufferUsed->getSize(); auto recvSize = std::min(remainRecvSize, recvBufferEleSize); - auto recvSlice = runtime::ITensor::slice(preAllocRecvBuffer, 0, recvSize); + auto recvSlice = runtime::ITensor::slice(recvBufferUsed, 0, recvSize); auto copySlice = runtime::ITensor::slice( - recvSplitCaches.at(processIdx), targetBufferSize - remainRecvSize, recvSize); + recvSplitCaches.at(processIdx), needRecvSize - remainRecvSize, recvSize); llmRequest.updateKvCacheSize(recvSlice->getSizeInBytes()); size += recvSlice->getSizeInBytes(); session.recv(pickUpConnections.at(processIdx), recvSlice->data(), recvSlice->getSizeInBytes()); @@ -585,28 +602,6 @@ void MLACacheFormatter::unformat(TransferSession& session) return false; } - int selfNumLayers = selfConfig.getModelConfig().mNbKvHeadsPerLayer.size(); - int selfPPSize = selfConfig.getParallelConfig().mPipelineParallelism; - int destPPSize = destConfig.getParallelConfig().mPipelineParallelism; - int destNumLayers = destConfig.getModelConfig().mNbKvHeadsPerLayer.size(); - - if (selfPPSize == destPPSize) - { - return true; - } - if (selfNumLayers % selfPPSize != 0) - { - TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d", - selfNumLayers, selfPPSize); - return false; - } - if (destNumLayers % destPPSize != 0) - { - TLLM_LOG_WARNING("CacheFormatter::inquireSupport: layers %d must be divisible by pipeline parallelism :%d ", - destNumLayers, destPPSize); - return false; - } - return true; } } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp index c64d85e1523..a9ba23e414e 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp +++ b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.cpp @@ -17,6 +17,7 @@ #include "connection.h" #include "tensorrt_llm/common/envUtils.h" +#include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h" #include #include @@ -34,6 +35,28 @@ std::string genUniqueAgentName() return std::string(hostname) + "_" + std::to_string(pid) + "_" + std::to_string(counter++); } +// NIXL connection is specific ,and different from the UCX and mpi connection, since NIXL only support one-sided +// communication. gen send buffer metaData to context when it sending requestInfo, but don't send buffer offset, since +// unformmatter has not called yet, it didn't know the cacheSize and offset. We assume the recv_size is the same as the +// send_size. and compute the buffer offset according to the layer num of the selfPPrank ,and previous PP rank's layer +// num, since the buffer size is ratio is equal to the layer num ratio except the VSWA case. + +auto computeSendOffsetRatio( + CacheState const& peerCacheState, int peerIdx, CacheState const& selfCacheState, int validConnectionIdx) +{ + auto peerTargetInfo = targetIRanks(selfCacheState, peerCacheState, peerIdx); + // int ppRank = valideConnectionIdx % peerTargetInfo.mDomainPPSize; + size_t offsetLayer = 0; + for (int i = 0; i < validConnectionIdx; i++) + { + offsetLayer += peerTargetInfo.getPeerPPDomainLayerNum(i); + } + + size_t selfSendLayer = peerTargetInfo.getPeerPPDomainLayerNum(validConnectionIdx); + + return std::make_pair(offsetLayer, selfSendLayer); +} + AgentConnection::AgentConnection( std::string mAgentName, std::string mRemoteAgentName, AgentConnectionManager* mAgentConnectionManager) : mAgentName(mAgentName) @@ -82,7 +105,8 @@ void AgentConnection::send(DataContext const& ctx, void const* data, size_t size reinterpret_cast(data), size, static_cast(mAgentConnectionManager->getDeviceId())}; MemoryDescs srcDescs{MemoryType::kVRAM, {srcDesc}}; auto dstBaseDesc = mSenderState.mReceiverBufferDesc; - MemoryDesc dstDesc{dstBaseDesc.getAddr() + (mSenderState.validSegmentIdx * size), size, dstBaseDesc.getDeviceId()}; + auto offset = size / mSenderState.mOffsetRatio.second * mSenderState.mOffsetRatio.first; + MemoryDesc dstDesc{dstBaseDesc.getAddr() + offset, size, dstBaseDesc.getDeviceId()}; TLLM_LOG_DEBUG( "send dstDesc: %p, size: %ld ,validSegmentIdx: %ld", dstDesc.getAddr(), size, mSenderState.validSegmentIdx); MemoryDescs dstDescs{MemoryType::kVRAM, {dstDesc}}; @@ -137,10 +161,12 @@ void AgentConnection::sendRequestAndBufferInfo( mAgentConnectionManager->getAgent()->notifySyncMessage(mRemoteAgentName, ss.str()); } -void AgentConnection::setSenderState(MemoryDesc mReceiverBufferDesc, int validSegmentIdx) +void AgentConnection::setSenderState( + MemoryDesc mReceiverBufferDesc, int validSegmentIdx, std::pair offsetRatio) { mSenderState.mReceiverBufferDesc = mReceiverBufferDesc; mSenderState.validSegmentIdx = validSegmentIdx; + mSenderState.mOffsetRatio = offsetRatio; } void AgentConnection::setHasLoadRemoteAgent(bool hasLoadRemoteAgent) @@ -155,8 +181,9 @@ bool AgentConnection::hasLoadRemoteAgent() const } AgentConnectionManager::AgentConnectionManager( - batch_manager::kv_cache_manager::CacheTransBufferManager* cacheTransBufferManager) - : mRegMemDescs(MemoryType::kVRAM, {}) + batch_manager::kv_cache_manager::CacheTransBufferManager* cacheTransBufferManager, CacheState cacheState) + : mCacheState(std::move(cacheState)) + , mRegMemDescs(MemoryType::kVRAM, {}) { TLLM_CUDA_CHECK(cudaGetDevice(&mDeviceId)); TLLM_CHECK(mDeviceId != -1); @@ -260,7 +287,10 @@ AgentConnection const* AgentConnectionManager::recvConnectionAndRequestInfo(batc auto remoteAgentName = requestAndBufferInfo.mAgentName; TLLM_LOG_DEBUG(" recv Address:%s", address.c_str()); auto connection = connect(remoteAgentName, address, metadataOpt, true); - connection->setSenderState(bufferDesc, validConnectionIdx); + // to compute the offset. + auto offsetRatio = computeSendOffsetRatio(requestInfo.getTransState().getCacheState().value(), + requestInfo.getTransState().getCommState()->getSelfIdx(), mCacheState, validConnectionIdx); + connection->setSenderState(bufferDesc, validConnectionIdx, offsetRatio); it2 = notifs.erase(it2); if (notifs.empty()) { @@ -328,7 +358,7 @@ batch_manager::kv_cache_manager::CacheTransBufferManager* AgentConnectionManager return mCacheTransBufferManager; } -AgentConnection* AgentConnectionManager::connect(std::string const& remoteAgentName, std::string const& connecitonInfo, +AgentConnection* AgentConnectionManager::connect(std::string const& remoteAgentName, std::string const& connectionInfo, std::optional metadata, bool isSender) { @@ -369,7 +399,7 @@ AgentConnection* AgentConnectionManager::connect(std::string const& remoteAgentN TLLM_CHECK_WITH_INFO(!isSender, "Sender shouldn't call connectRemoteAgent"); TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "mAgentName: %s connect to %s with connectRemoteAgent", mAgentName.c_str(), remoteAgentName.c_str()); - m_Agent->connectRemoteAgent(remoteAgentName, connecitonInfo); + m_Agent->connectRemoteAgent(remoteAgentName, connectionInfo); } } else diff --git a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h index 8f73631d1e8..0ee171632f6 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h +++ b/cpp/tensorrt_llm/executor/cache_transmission/agent_utils/connection.h @@ -175,7 +175,7 @@ class AgentConnection : public Connection void recv(DataContext const& ctx, void* data, size_t size) const override; void sendRequestAndBufferInfo( batch_manager::RequestInfo& requestInfo, std::optional cacheBufferId, int validConnectionIdx); - void setSenderState(MemoryDesc mReceiverBufferDesc, int valideSegmentIdx); + void setSenderState(MemoryDesc mReceiverBufferDesc, int valideSegmentIdx, std::pair offsetRatio); [[nodiscard]] std::optional getCacheBufferId() const; void setHasLoadRemoteAgent(bool hasLoadRemoteAgent); [[nodiscard]] bool hasLoadRemoteAgent() const; @@ -188,6 +188,7 @@ class AgentConnection : public Connection { MemoryDesc mReceiverBufferDesc{nullptr, 0, 0}; int validSegmentIdx{0}; + std::pair mOffsetRatio; SenderState() = default; }; @@ -203,7 +204,8 @@ class AgentConnection : public Connection class AgentConnectionManager : public ConnectionManager { public: - AgentConnectionManager(batch_manager::kv_cache_manager::CacheTransBufferManager* cacheTransBufferManager); + AgentConnectionManager( + batch_manager::kv_cache_manager::CacheTransBufferManager* cacheTransBufferManager, CacheState cacheState); ~AgentConnectionManager(); AgentConnection* recvConnect(DataContext const& ctx, void* data, size_t size) override; [[nodiscard]] std::vector getConnections(CommState const& state) override; @@ -222,6 +224,7 @@ class AgentConnectionManager : public ConnectionManager std::map> mConnections; std::mutex mConnectionsMutex; CommState mCommState; + CacheState mCacheState; batch_manager::kv_cache_manager::CacheTransBufferManager* mCacheTransBufferManager; std::mutex mNotificationMutex; std::unordered_map> mUnhandledNotifications; diff --git a/cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu b/cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu index 73343b9d1e3..df2b38f09bc 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu +++ b/cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.cu @@ -67,22 +67,49 @@ TargetRanksInfo TargetRanksInfoForDP( int peerPPRankStart = 0; int mDomainPPSize = 1; int peerPPRankEnd = 0; - for (auto val : {peerPPNum, selfPPNum}) - { - TLLM_CHECK(isPowerOfTwo(val)); - } - if (selfPPNum <= peerPPNum) + std::vector peerNumLayerPerPP = peerParConfig.mAttentionLayerNumPerPP; + std::vector selfNumLayerPerPP = selfParConfig.mAttentionLayerNumPerPP; + TLLM_CHECK(peerNumLayerPerPP.size() == peerPPNum); + TLLM_CHECK(selfNumLayerPerPP.size() == selfPPNum); + int selfStartLayerId = 0; + // global start layer id for selfPPrank, which is the sum of the layer num of the previous PP ranks. + // compute the target PP ranks and layer num need to be fetched from each target PP rank, according to [global start + // layer id, global end layer id) + + for (int ppRank = 0; ppRank < selfPPRank; ppRank++) { - mDomainPPSize = peerPPNum / selfPPNum; - peerPPRankStart = selfPPRank * mDomainPPSize; - peerPPRankEnd = (selfPPRank + 1) * mDomainPPSize; + selfStartLayerId += selfNumLayerPerPP[ppRank]; } - else + int selfEndLayerId = selfStartLayerId + selfNumLayerPerPP[selfPPRank]; + int prePeerPPLayerId = 0; + std::vector targetPeerPPRanks; + std::vector targetPeerPPLayerNum; + for (int ppRank = 0; ppRank < peerPPNum; ppRank++) { - peerPPRankStart = selfPPRank / (selfPPNum / peerPPNum); - peerPPRankEnd = peerPPRankStart + mDomainPPSize; + int peerPPStartLayerId = prePeerPPLayerId; + int peerPPEndLayerId = peerPPStartLayerId + peerNumLayerPerPP[ppRank]; + + prePeerPPLayerId += peerNumLayerPerPP[ppRank]; + + if (selfStartLayerId < peerPPEndLayerId && selfEndLayerId > peerPPStartLayerId) + { + targetPeerPPRanks.push_back(ppRank); + int layerNumInDomainPP + = std::min(peerPPEndLayerId, selfEndLayerId) - std::max(peerPPStartLayerId, selfStartLayerId); + targetPeerPPLayerNum.push_back(layerNumInDomainPP); + } } + mDomainPPSize = static_cast(targetPeerPPRanks.size()); + peerPPRankStart = targetPeerPPRanks.front(); + peerPPRankEnd = peerPPRankStart + mDomainPPSize; + TLLM_CHECK(targetPeerPPLayerNum.size() == mDomainPPSize); + int targetPeerPpLayerNumSum = std::accumulate(targetPeerPPLayerNum.begin(), targetPeerPPLayerNum.end(), 0); + TLLM_CHECK(targetPeerPpLayerNumSum == selfNumLayerPerPP[selfPPRank]); + + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), + "selfPPRank:%d,selfPPNum:%d,peerPPNum:%d,selfTPNum:%d,peerTPNum:%d,peerPPRankStart:%d,peerPPRankEnd:%d", + selfPPRank, selfPPNum, peerPPNum, selfTPNum, peerTPNum, peerPPRankStart, peerPPRankEnd); int peerTPRankStart = 0; int mDomainTPSize = 1; int peerTPRankEnd = 0; @@ -156,7 +183,27 @@ TargetRanksInfo TargetRanksInfoForDP( = (peerNbHeadsPerLayer * peerTPSizePerDPGroup) / (selfNbHeadsPerLayer * selfTPSizePerDPGroup); } - return {mDomainPPSize, mDomainTPSize, mDomainCPSize, std::move(retRanks), mDupHeadFactor, mPeerDupHeadFactor}; + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), + "mDomainPPSize:%d, mDomainTPSize:%d, mDupHeadFactor:%d, mPeerDupHeadFactor:%d, selfPPRank:%d, selfPPNum:%d, " + "peerPPNum:%d, selfTPNum:%d, peerTPNum:%d, selfTPSizePerDPGroup:%d, peerTPSizePerDPGroup:%d, " + "selfNbHeadsPerLayer:%d, peerNbHeadsPerLayer:%d, selfTPrankInDPGroup:%d, peerDpRank:%d, selfRank:%d", + mDomainPPSize, mDomainTPSize, mDupHeadFactor, mPeerDupHeadFactor, selfPPRank, selfPPNum, peerPPNum, selfTPNum, + peerTPNum, selfTPSizePerDPGroup, peerTPSizePerDPGroup, selfNbHeadsPerLayer, peerNbHeadsPerLayer, + selfTPrankInDPGroup, peerDpRank, selfRank); + + auto vector_to_string = [](std::vector const& vec) + { + std::stringstream ss; + for (auto val : vec) + { + ss << val << ","; + } + return ss.str(); + }; + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "retRanks:%s , targetPeerPPLayerNum:%s", + vector_to_string(retRanks).c_str(), vector_to_string(targetPeerPPLayerNum).c_str()); + return {mDomainPPSize, mDomainTPSize, mDomainCPSize, std::move(retRanks), mDupHeadFactor, mPeerDupHeadFactor, + std::move(targetPeerPPLayerNum)}; } TargetRanksInfo targetIRanks( @@ -496,12 +543,37 @@ nvinfer1::Dims makeShapeFromCacheState(kv_cache::CacheState const& cacheState) cacheState.getAttentionConfig().mKvFactor, blockSize}); } +__device__ __forceinline__ void getLayerIdInDomainPPandRankInDomainPP(int layerId, int DomainPPSize, + uint64_t* prefixLayerNumDevPtr, int& layerIdInDomainPP, int& rankInDomainPP, int& layerNumInSpecPP) +{ + __shared__ int sharedLayerIdInDomainPP; + __shared__ int sharedRankInDomainPP; + __shared__ int sharedLayerNumInSpecPP; + +#pragma unroll 1 + for (int ppRank = threadIdx.x; ppRank < DomainPPSize; ppRank += blockDim.x) + { + if (layerId >= prefixLayerNumDevPtr[ppRank] && layerId < prefixLayerNumDevPtr[ppRank + 1]) + { + sharedLayerIdInDomainPP = layerId - prefixLayerNumDevPtr[ppRank]; + sharedRankInDomainPP = ppRank; + sharedLayerNumInSpecPP = prefixLayerNumDevPtr[ppRank + 1] - prefixLayerNumDevPtr[ppRank]; + break; + } + } + + __syncthreads(); + layerIdInDomainPP = sharedLayerIdInDomainPP; + rankInDomainPP = sharedRankInDomainPP; + layerNumInSpecPP = sharedLayerNumInSpecPP; +} + // MLA Head 1: One thread block per [(2), tokens, dimsPerHead] template __global__ void splitKVCacheForMLAKernel(T const** __restrict__ inputBlocks, T** __restrict__ outputCaches, int tokensPerBlock, int numLayers, int headNum, int dimsPerHead, int inputBlockNum, int DomainPPSize, - int DomainTPSize, int layerNumDomainPP, int kvFactor) + int DomainTPSize, int kvFactor, uint64_t* prefixLayerNumDevPtr) { int const subWarpId = threadIdx.x / subWarpSize; int const laneId = threadIdx.x % subWarpSize; @@ -518,19 +590,25 @@ __global__ void splitKVCacheForMLAKernel(T const** __restrict__ inputBlocks, T** for (int layerId = blockIdx.x; layerId < numLayers; layerId += gridDim.x) { + int layerIdInDomainPP{}; + int rankInDomainPP{}; + int layerNumInSpecPP{}; + getLayerIdInDomainPPandRankInDomainPP( + layerId, DomainPPSize, prefixLayerNumDevPtr, layerIdInDomainPP, rankInDomainPP, layerNumInSpecPP); + #pragma unroll 1 for (int headId = 0; headId < headNum; headId++) { T const* inputBlockPtr = inputBlocks[blockId]; T const* kInputPtr = inputBlockPtr + layerId * kvFactor * headNum * tokensPerBlock * dimsPerHead + headId * tokensPerBlock * dimsPerHead; - int const outputCacheIdx = layerId / layerNumDomainPP; + int outputCacheIdx = rankInDomainPP; T* outputCachePtr = outputCaches[outputCacheIdx]; - int const layerIdInDomainPP = layerId % layerNumDomainPP; + int const headIdInDomainTP = headId; T* kOutputPtr = outputCachePtr - + blockId * (layerNumDomainPP * kvFactor * headNum * tokensPerBlock * dimsPerHead) + + blockId * (layerNumInSpecPP * kvFactor * headNum * tokensPerBlock * dimsPerHead) + layerIdInDomainPP * kvFactor * headNum * tokensPerBlock * dimsPerHead + headIdInDomainTP * tokensPerBlock * dimsPerHead; int const kvOffset = headNum * tokensPerBlock * dimsPerHead; @@ -565,7 +643,7 @@ __global__ void splitKVCacheForMLAKernel(T const** __restrict__ inputBlocks, T** template __global__ void splitKVCacheKernel(T const** __restrict__ inputBlocks, T** __restrict__ outputCaches, int tokensPerBlock, int numLayers, int headNum, int dimsPerHead, int inputBlockNum, int DomainPPSize, - int DomainTPSize, int layerNumDomainPP, int headNumDomainTP) + int DomainTPSize, int headNumDomainTP, uint64_t* prefixLayerNumDevPtr) { // layerNumDomainPP = numLayers/DomainPPSize @@ -587,6 +665,13 @@ __global__ void splitKVCacheKernel(T const** __restrict__ inputBlocks, T** __res for (int layerId = blockIdx.x; layerId < numLayers; layerId += gridDim.x) { + + int layerIdInDomainPP{}; + int rankInDomainPP{}; + int layerNumInSpecPP{}; + getLayerIdInDomainPPandRankInDomainPP( + layerId, DomainPPSize, prefixLayerNumDevPtr, layerIdInDomainPP, rankInDomainPP, layerNumInSpecPP); + #pragma unroll 1 for (int headId = subWarpGroupId; headId < headNum; headId += subWarpGroupNum) @@ -598,13 +683,12 @@ __global__ void splitKVCacheKernel(T const** __restrict__ inputBlocks, T** __res T const* vInputPtr = inputBlockPtr + (layerId * 2 + 1) * headNum * tokensPerBlock * dimsPerHead + headId * tokensPerBlock * dimsPerHead; - int outputCacheIdx = headId / headNumDomainTP * DomainPPSize + layerId / layerNumDomainPP; + int outputCacheIdx = headId / headNumDomainTP * DomainPPSize + rankInDomainPP; T* outputCachePtr = outputCaches[outputCacheIdx]; - int layerIdInDomainPP = layerId % layerNumDomainPP; int headIdInDomainTP = headId % headNumDomainTP; T* kOutputPtr = outputCachePtr - + blockId * (layerNumDomainPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead) + + blockId * (layerNumInSpecPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead) + layerIdInDomainPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead + headIdInDomainTP * tokensPerBlock * dimsPerHead; @@ -746,7 +830,7 @@ __global__ void splitKVCacheForWindowKernel(T const** __restrict__ inputBlocks, template __global__ void concatKVCacheForMLAKernel(T const** __restrict__ inputCaches, T** __restrict__ outputBlocks, int tokensPerBlock, int numLayers, int headNum, int dimsPerHead, int outputBlockNum, int DomainPPSize, - int DomainTPSize, int layerNumDomainPP, int kvFactor) + int DomainTPSize, int kvFactor, uint64_t* prefixLayerNumDevPtr) { int const subWarpId = threadIdx.x / subWarpSize; @@ -761,7 +845,11 @@ __global__ void concatKVCacheForMLAKernel(T const** __restrict__ inputCaches, T* #pragma unroll 1 for (int layerId = blockIdx.x; layerId < numLayers; layerId += gridDim.x) { - + int layerIdInDomainPP{}; + int rankInDomainPP{}; + int layerNumInSpecPP{}; + getLayerIdInDomainPPandRankInDomainPP( + layerId, DomainPPSize, prefixLayerNumDevPtr, layerIdInDomainPP, rankInDomainPP, layerNumInSpecPP); #pragma unroll 1 for (int headId = 0; headId < headNum; headId++) @@ -769,13 +857,12 @@ __global__ void concatKVCacheForMLAKernel(T const** __restrict__ inputCaches, T* T* outputBlockPtr = outputBlocks[blockId]; T* kOutputPtr = outputBlockPtr + layerId * kvFactor * headNum * tokensPerBlock * dimsPerHead + headId * tokensPerBlock * dimsPerHead; - int inputCacheIdx = layerId / layerNumDomainPP; + int inputCacheIdx = rankInDomainPP; T const* inputCachePtr = inputCaches[inputCacheIdx]; - int layerIdInDomainPP = layerId % layerNumDomainPP; int headIdInDomainTP = headId; T const* kInputPtr = inputCachePtr - + blockId * (layerNumDomainPP * kvFactor * headNum * tokensPerBlock * dimsPerHead) + + blockId * (layerNumInSpecPP * kvFactor * headNum * tokensPerBlock * dimsPerHead) + layerIdInDomainPP * kvFactor * headNum * tokensPerBlock * dimsPerHead + headIdInDomainTP * tokensPerBlock * dimsPerHead; int const kvOffset = headNum * tokensPerBlock * dimsPerHead; @@ -804,7 +891,7 @@ __global__ void concatKVCacheForMLAKernel(T const** __restrict__ inputCaches, T* template __global__ void concatKVCacheKernel(T const** __restrict__ inputCaches, T** __restrict__ outputBlocks, int tokensPerBlock, int numLayers, int headNum, int dimsPerHead, int outputBlockNum, int DomainPPSize, - int DomainTPSize, int layerNumDomainPP, int headNumDomainTP) + int DomainTPSize, int headNumDomainTP, uint64_t* prefixLayerNumDevPtr) { int const subWarpId = threadIdx.x / subWarpSize; int const laneId = threadIdx.x % subWarpSize; @@ -821,6 +908,11 @@ __global__ void concatKVCacheKernel(T const** __restrict__ inputCaches, T** __re #pragma unroll 1 for (int layerId = blockIdx.x; layerId < numLayers; layerId += gridDim.x) { + int layerIdInDomainPP{}; + int rankInDomainPP{}; + int layerNumInSpecPP{}; + getLayerIdInDomainPPandRankInDomainPP( + layerId, DomainPPSize, prefixLayerNumDevPtr, layerIdInDomainPP, rankInDomainPP, layerNumInSpecPP); #pragma unroll 1 for (int headId = subWarpGroupId; headId < headNum; headId += subWarpGroupNum) @@ -832,13 +924,12 @@ __global__ void concatKVCacheKernel(T const** __restrict__ inputCaches, T** __re T* vOutputPtr = outputBlockPtr + (layerId * 2 + 1) * headNum * tokensPerBlock * dimsPerHead + headId * tokensPerBlock * dimsPerHead; - int inputCacheIdx = headId / headNumDomainTP * DomainPPSize + layerId / layerNumDomainPP; + int inputCacheIdx = headId / headNumDomainTP * DomainPPSize + rankInDomainPP; T const* inputCachePtr = inputCaches[inputCacheIdx]; - int layerIdInDomainPP = layerId % layerNumDomainPP; int headIdInDomainTP = headId % headNumDomainTP; T const* kInputPtr = inputCachePtr - + blockId * (layerNumDomainPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead) + + blockId * (layerNumInSpecPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead) + layerIdInDomainPP * 2 * headNumDomainTP * tokensPerBlock * dimsPerHead + headIdInDomainTP * tokensPerBlock * dimsPerHead; @@ -942,7 +1033,9 @@ void splitKVCache(std::map> } TLLM_CHECK(outputCacheNum == outputSplitBlocks.size()); TLLM_CHECK(inputBlockNumSum > 0); - std::vector cachePtrs; + // we want to reduce the call of `cudaMemcpyAsync H2D` , cachePtrs is used to store the pointers of the cache blocks + // and the values of the prefix layer num. + std::vector cachePtrs; std::vector windowSizes; std::vector blockNumInwindow; std::vector layersInWindow; @@ -965,7 +1058,7 @@ void splitKVCache(std::map> TLLM_CHECK(kvCacheBlock->getDataType() == cacheDataType); TLLM_CHECK(kvCacheBlock->getSize() == cacheBlockSize); cacheBlockSizeSum += kvCacheBlock->getSize(); - cachePtrs.push_back(static_cast(kvCacheBlock->data())); + cachePtrs.push_back(reinterpret_cast((kvCacheBlock->data()))); inputBlockLayerNumSum += layersNum; } } @@ -973,10 +1066,15 @@ void splitKVCache(std::map> for (auto&& outputSplitBlock : outputSplitBlocks) { TLLM_CHECK(outputSplitBlock->getDataType() == cacheDataType); - TLLM_CHECK(outputSplitBlock->getSize() == cacheBlockSizeSum / outputCacheNum); - cachePtrs.push_back(static_cast(outputSplitBlock->data())); + cachePtrs.push_back(reinterpret_cast(outputSplitBlock->data())); } - + std::vector prefixLayerNum(targetRankInfo.mDomainPPSize + 1, 0); + prefixLayerNum[0] = 0; + for (int i = 0; i < targetRankInfo.mDomainPPSize; i++) + { + prefixLayerNum[i + 1] = prefixLayerNum[i] + targetRankInfo.mPeerAttentionLayerNumInDomainPP[i]; + } + cachePtrs.insert(cachePtrs.end(), prefixLayerNum.begin(), prefixLayerNum.end()); bool const isWindow = windowSizes.size() > 1; runtime::BufferManager::IBufferPtr PtrsDeviceBuffer @@ -1037,23 +1135,25 @@ void splitKVCache(std::map> int const sizePerHead = selfModelConfig.mSizePerHead; T const** inputBlockPtrsDev = static_cast(PtrsDeviceBuffer->data()); T** outputCachePtrsDev = static_cast(PtrsDeviceBuffer->data()) + inputBlockNumSum; + uint64_t* prefixLayerNumDevPtr + = static_cast(PtrsDeviceBuffer->data()) + inputBlockNumSum + outputSplitBlocks.size(); + int const tokensPerBlock = selfModelConfig.mTokensPerBlock; - int const numLayers = selfModelConfig.mNbKvHeadsPerLayer.size() / oPPNum; + int selfPPRank = selfIdx / (selfParallelConfig.mTensorParallelism * selfParallelConfig.mContextParallelism); + int const numLayers = selfParallelConfig.mAttentionLayerNumPerPP.at(selfPPRank); int const headNum = selfModelConfig.mNbKvHeadsPerLayer[0]; int const dimsPerHead = selfModelConfig.mSizePerHead; int const DomainPPSize = targetRankInfo.mDomainPPSize; int const DomainTPSize = targetRankInfo.mDomainTPSize; - int const layerNumDomainPP = numLayers / DomainPPSize; - int const headNumDomainTP - = headNum / (DomainTPSize / targetRankInfo.mPeerDupHeadFactor); // TODO: duplicate head factor + int const headNumDomainTP = headNum / (DomainTPSize / targetRankInfo.mPeerDupHeadFactor); int const kvFactor = selfAttentionConfig.mKvFactor; bool const isMLA = selfAttentionConfig.mAttentionType == CacheState::AttentionType::kMLA; constexpr int mlaSubWarpSize = 16; TLLM_LOG_DEBUG( "splitKVCache - numLayers: %d, headNum: %d, domainPPSize: %d, domainTPSize: %d, " - "layersPerDomainPP: %d, headsPerDomainTP: %d", - numLayers, headNum, DomainPPSize, DomainTPSize, layerNumDomainPP, headNumDomainTP); + "headsPerDomainTP: %d", + numLayers, headNum, DomainPPSize, DomainTPSize, headNumDomainTP); int const remainder = sizePerHead * sizeof(T) % 16; switch (remainder) @@ -1064,7 +1164,7 @@ void splitKVCache(std::map> { splitKVCacheForMLAKernel<<>>( inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, - inputBlockNumSum, DomainPPSize, DomainTPSize, layerNumDomainPP, kvFactor); + inputBlockNumSum, DomainPPSize, DomainTPSize, kvFactor, prefixLayerNumDevPtr); } else if (isWindow) { @@ -1078,7 +1178,7 @@ void splitKVCache(std::map> splitKVCacheKernel <<>>(inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize, - layerNumDomainPP, headNumDomainTP); + headNumDomainTP, prefixLayerNumDevPtr); } break; } @@ -1088,7 +1188,7 @@ void splitKVCache(std::map> { splitKVCacheForMLAKernel<<>>( inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, - inputBlockNumSum, DomainPPSize, DomainTPSize, layerNumDomainPP, kvFactor); + inputBlockNumSum, DomainPPSize, DomainTPSize, kvFactor, prefixLayerNumDevPtr); } else if (isWindow) { @@ -1102,7 +1202,7 @@ void splitKVCache(std::map> splitKVCacheKernel <<>>(inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize, - layerNumDomainPP, headNumDomainTP); + headNumDomainTP, prefixLayerNumDevPtr); } break; } @@ -1116,7 +1216,7 @@ void splitKVCache(std::map> splitKVCacheForMLAKernel <<>>(inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize, - layerNumDomainPP, kvFactor); + kvFactor, prefixLayerNumDevPtr); } else if (isWindow) { @@ -1131,7 +1231,7 @@ void splitKVCache(std::map> splitKVCacheKernel <<>>(inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize, - layerNumDomainPP, headNumDomainTP); + headNumDomainTP, prefixLayerNumDevPtr); } break; } @@ -1149,7 +1249,7 @@ void splitKVCache(std::map> splitKVCacheForMLAKernel <<>>(inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize, - layerNumDomainPP, kvFactor); + kvFactor, prefixLayerNumDevPtr); } else if (isWindow) { @@ -1164,7 +1264,7 @@ void splitKVCache(std::map> splitKVCacheKernel <<>>(inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize, - layerNumDomainPP, headNumDomainTP); + headNumDomainTP, prefixLayerNumDevPtr); } break; } @@ -1178,7 +1278,7 @@ void splitKVCache(std::map> splitKVCacheForMLAKernel <<>>(inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize, - layerNumDomainPP, kvFactor); + kvFactor, prefixLayerNumDevPtr); } else if (isWindow) { @@ -1193,7 +1293,7 @@ void splitKVCache(std::map> splitKVCacheKernel <<>>(inputBlockPtrsDev, outputCachePtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, inputBlockNumSum, DomainPPSize, DomainTPSize, - layerNumDomainPP, headNumDomainTP); + headNumDomainTP, prefixLayerNumDevPtr); } break; } @@ -1274,7 +1374,7 @@ void concatKVCache(std::vector const& inputSplitBlo TLLM_CHECK(inputCacheNum == inputSplitBlocks.size()); TLLM_CHECK(outputBlockNumSum > 0); - std::vector cachePtrs; + std::vector cachePtrs; std::vector windowSizes; std::vector blockNumInwindow; std::vector layersInWindow; @@ -1294,7 +1394,7 @@ void concatKVCache(std::vector const& inputSplitBlo { TLLM_CHECK(kvCacheBlock->getDataType() == cacheDataType); TLLM_CHECK(kvCacheBlock->getSize() == cacheBlockSize); - cachePtrs.push_back(static_cast(kvCacheBlock->data())); + cachePtrs.push_back(reinterpret_cast(kvCacheBlock->data())); cacheBlockSizeSum += kvCacheBlock->getSize(); } outputBlockLayerNumSum += layersNum * blocks.size(); @@ -1302,12 +1402,23 @@ void concatKVCache(std::vector const& inputSplitBlo for (auto&& inputSplitBlock : inputSplitBlocks) { TLLM_CHECK(inputSplitBlock->getDataType() == cacheDataType); - TLLM_CHECK(inputSplitBlock->getSize() == cacheBlockSizeSum / inputCacheNum); - cachePtrs.push_back(static_cast(inputSplitBlock->data())); + cachePtrs.push_back(reinterpret_cast(inputSplitBlock->data())); } + + // the prefix layer num is used to store the layer num of the previous PP ranks. + // which is helpful for the kernel to get layer num info. refer to the function + // `getLayerIdInDomainPPandRankInDomainPP`. + + std::vector prefixLayerNum(targetRankInfo.mDomainPPSize + 1, 0); + prefixLayerNum[0] = 0; + for (int i = 0; i < targetRankInfo.mDomainPPSize; i++) + { + prefixLayerNum[i + 1] = prefixLayerNum[i] + targetRankInfo.mPeerAttentionLayerNumInDomainPP[i]; + } + cachePtrs.insert(cachePtrs.end(), prefixLayerNum.begin(), prefixLayerNum.end()); runtime::BufferManager::IBufferPtr PtrsDeviceBuffer = bufferManager.gpu(cachePtrs.size(), nvinfer1::DataType::kINT64); - TLLM_CHECK(PtrsDeviceBuffer->getSizeInBytes() == cachePtrs.size() * sizeof(T*)); + TLLM_CHECK(PtrsDeviceBuffer->getSizeInBytes() == cachePtrs.size() * sizeof(uint64_t)); bufferManager.copy(cachePtrs.data(), *PtrsDeviceBuffer, runtime::MemoryType::kCPU); bool const isWindow = windowSizes.size() > 1; runtime::BufferManager::IBufferPtr windowInfoDeviceBuffer; @@ -1350,14 +1461,17 @@ void concatKVCache(std::vector const& inputSplitBlo int const endLayerId = selfModelConfig.mNbKvHeadsPerLayer.size() / oPPNum; T** ouptutBlockPtrsDev = static_cast(PtrsDeviceBuffer->data()); T const** inputSplitBlockPtrsDev = static_cast(PtrsDeviceBuffer->data()) + outputBlockNumSum; + uint64_t* prefixLayerNumDevPtr + = static_cast(PtrsDeviceBuffer->data()) + outputBlockNumSum + inputSplitBlocks.size(); int const tokensPerBlock = selfModelConfig.mTokensPerBlock; - int const numLayers = selfModelConfig.mNbKvHeadsPerLayer.size() / oPPNum; + int selfPPRank = selfIdx / (selfParallelConfig.mTensorParallelism * selfParallelConfig.mContextParallelism); + int const numLayers = selfParallelConfig.mAttentionLayerNumPerPP.at(selfPPRank); + TLLM_LOG_DEBUG(mpi::MpiComm::world().getRank(), "concatKVCache numLayers:%d", numLayers); int const headNum = selfModelConfig.mNbKvHeadsPerLayer[0]; int const dimsPerHead = selfModelConfig.mSizePerHead; int const DomainPPSize = targetRankInfo.mDomainPPSize; int const DomainTPSize = targetRankInfo.mDomainTPSize; - int const layerNumDomainPP = numLayers / DomainPPSize; int const headNumDomainTP = headNum / (DomainTPSize / targetRankInfo.mPeerDupHeadFactor); // TODO: duplicate head factor int const kvFactor = selfAttentionConfig.mKvFactor; @@ -1365,8 +1479,8 @@ void concatKVCache(std::vector const& inputSplitBlo bool isMLA = selfAttentionConfig.mAttentionType == CacheState::AttentionType::kMLA; TLLM_LOG_DEBUG( "concatKVCache - numLayers: %d, headNum: %d, domainPPSize: %d, domainTPSize: %d, " - "layersPerDomainPP: %d, headsPerDomainTP: %d", - numLayers, headNum, DomainPPSize, DomainTPSize, layerNumDomainPP, headNumDomainTP); + "headsPerDomainTP: %d", + numLayers, headNum, DomainPPSize, DomainTPSize, headNumDomainTP); int const remainder = sizePerHead * sizeof(T) % 16; @@ -1380,7 +1494,7 @@ void concatKVCache(std::vector const& inputSplitBlo concatKVCacheForMLAKernel <<>>(inputSplitBlockPtrsDev, ouptutBlockPtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, outputBlockNumSum, DomainPPSize, DomainTPSize, - layerNumDomainPP, kvFactor); + kvFactor, prefixLayerNumDevPtr); } else if (isWindow) { @@ -1394,7 +1508,7 @@ void concatKVCache(std::vector const& inputSplitBlo concatKVCacheKernel <<>>(inputSplitBlockPtrsDev, ouptutBlockPtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, outputBlockNumSum, DomainPPSize, DomainTPSize, - layerNumDomainPP, headNumDomainTP); + headNumDomainTP, prefixLayerNumDevPtr); } break; } @@ -1404,7 +1518,7 @@ void concatKVCache(std::vector const& inputSplitBlo { concatKVCacheForMLAKernel<<>>( inputSplitBlockPtrsDev, ouptutBlockPtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, - outputBlockNumSum, DomainPPSize, DomainTPSize, layerNumDomainPP, kvFactor); + outputBlockNumSum, DomainPPSize, DomainTPSize, kvFactor, prefixLayerNumDevPtr); } else if (isWindow) { @@ -1418,7 +1532,7 @@ void concatKVCache(std::vector const& inputSplitBlo concatKVCacheKernel <<>>(inputSplitBlockPtrsDev, ouptutBlockPtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, outputBlockNumSum, DomainPPSize, DomainTPSize, - layerNumDomainPP, headNumDomainTP); + headNumDomainTP, prefixLayerNumDevPtr); } break; } @@ -1432,7 +1546,7 @@ void concatKVCache(std::vector const& inputSplitBlo concatKVCacheForMLAKernel <<>>(inputSplitBlockPtrsDev, ouptutBlockPtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, outputBlockNumSum, - DomainPPSize, DomainTPSize, layerNumDomainPP, kvFactor); + DomainPPSize, DomainTPSize, kvFactor, prefixLayerNumDevPtr); } else if (isWindow) { @@ -1447,7 +1561,7 @@ void concatKVCache(std::vector const& inputSplitBlo concatKVCacheKernel <<>>(inputSplitBlockPtrsDev, ouptutBlockPtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, outputBlockNumSum, - DomainPPSize, DomainTPSize, layerNumDomainPP, headNumDomainTP); + DomainPPSize, DomainTPSize, headNumDomainTP, prefixLayerNumDevPtr); } break; @@ -1465,7 +1579,7 @@ void concatKVCache(std::vector const& inputSplitBlo concatKVCacheForMLAKernel <<>>(inputSplitBlockPtrsDev, ouptutBlockPtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, outputBlockNumSum, - DomainPPSize, DomainTPSize, layerNumDomainPP, kvFactor); + DomainPPSize, DomainTPSize, kvFactor, prefixLayerNumDevPtr); } else if (isWindow) { @@ -1480,7 +1594,7 @@ void concatKVCache(std::vector const& inputSplitBlo concatKVCacheKernel <<>>(inputSplitBlockPtrsDev, ouptutBlockPtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, outputBlockNumSum, - DomainPPSize, DomainTPSize, layerNumDomainPP, headNumDomainTP); + DomainPPSize, DomainTPSize, headNumDomainTP, prefixLayerNumDevPtr); } break; } @@ -1494,7 +1608,7 @@ void concatKVCache(std::vector const& inputSplitBlo concatKVCacheForMLAKernel <<>>(inputSplitBlockPtrsDev, ouptutBlockPtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, outputBlockNumSum, - DomainPPSize, DomainTPSize, layerNumDomainPP, kvFactor); + DomainPPSize, DomainTPSize, kvFactor, prefixLayerNumDevPtr); } else if (isWindow) { @@ -1509,7 +1623,7 @@ void concatKVCache(std::vector const& inputSplitBlo concatKVCacheKernel <<>>(inputSplitBlockPtrsDev, ouptutBlockPtrsDev, tokensPerBlock, numLayers, headNum, dimsPerHead, outputBlockNumSum, - DomainPPSize, DomainTPSize, layerNumDomainPP, headNumDomainTP); + DomainPPSize, DomainTPSize, headNumDomainTP, prefixLayerNumDevPtr); } break; } diff --git a/cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h b/cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h index eca8c9a21a6..a54a0389e94 100644 --- a/cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h +++ b/cpp/tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h @@ -40,6 +40,18 @@ struct TargetRanksInfo std::vector mIRanks; int mDupHeadFactor; int mPeerDupHeadFactor; + + // the size of the vector is equal to the mDomainPPSize. the value of the vector is the layer num should be fetched + // from each target PP rank in domain PP. + std::vector mPeerAttentionLayerNumInDomainPP; + + int getPeerPPDomainLayerNum(int targetRankIdx) + { + //[TP,PP] + + int ppDomainRankIdx = targetRankIdx % mDomainPPSize; + return mPeerAttentionLayerNumInDomainPP[ppDomainRankIdx]; + } }; TargetRanksInfo targetIRanks( diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index bba8d19e2f6..8a6010b18bc 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -535,11 +535,13 @@ kv_cache::CacheState Serialization::deserializeCacheState(std::istream& is) auto enableAttentionDP = su::deserialize(is); auto DPrank = su::deserialize(is); auto DPsize = su::deserialize(is); + auto attentionLayerNumPerPP = su::deserialize(is); auto dataType = su::deserialize(is); auto attentionType = su::deserialize(is); auto kvFactor = su::deserialize(is); return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism, - contextParallelism, dataType, attentionType, kvFactor, enableAttentionDP, DPrank, DPsize}; + contextParallelism, attentionLayerNumPerPP, dataType, attentionType, kvFactor, enableAttentionDP, DPrank, + DPsize}; } void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& os) @@ -553,6 +555,7 @@ void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& o su::serialize(state.mParallelConfig.mEnableAttentionDP, os); su::serialize(state.mParallelConfig.mDPrank, os); su::serialize(state.mParallelConfig.mDPsize, os); + su::serialize(state.mParallelConfig.mAttentionLayerNumPerPP, os); su::serialize(state.mDataType, os); su::serialize(state.mAttentionConfig.mAttentionType, os); su::serialize(state.mAttentionConfig.mKvFactor, os); @@ -570,6 +573,7 @@ size_t Serialization::serializedSize(kv_cache::CacheState const& state) totalSize += su::serializedSize(state.mParallelConfig.mEnableAttentionDP); totalSize += su::serializedSize(state.mParallelConfig.mDPrank); totalSize += su::serializedSize(state.mParallelConfig.mDPsize); + totalSize += su::serializedSize(state.mParallelConfig.mAttentionLayerNumPerPP); totalSize += su::serializedSize(state.mDataType); totalSize += su::serializedSize(state.mAttentionConfig.mAttentionType); totalSize += su::serializedSize(state.mAttentionConfig.mKvFactor); diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp index 8a7f73f3b06..19e37ff5b19 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/cacheTransceiver.cpp @@ -25,6 +25,7 @@ #include #include #include +#include #include #include @@ -90,11 +91,11 @@ void tb::CacheTransceiverBindings::initBindings(nb::module_& m) nb::class_(m, "CacheTransceiver") .def(nb::init, SizeType32, SizeType32, - runtime::WorldConfig, nvinfer1::DataType, executor::kv_cache::CacheState::AttentionType, - std::optional>(), + runtime::WorldConfig, std::vector, nvinfer1::DataType, + executor::kv_cache::CacheState::AttentionType, std::optional>(), nb::arg("cache_manager"), nb::arg("num_kv_heads_per_layer"), nb::arg("size_per_head"), - nb::arg("tokens_per_block"), nb::arg("world_config"), nb::arg("dtype"), nb::arg("attention_type"), - nb::arg("cache_transceiver_config") = std::nullopt); + nb::arg("tokens_per_block"), nb::arg("world_config"), nb::arg("attention_layer_num_per_pp"), + nb::arg("dtype"), nb::arg("attention_type"), nb::arg("cache_transceiver_config") = std::nullopt); nb::class_(m, "CacheTransBufferManager") .def(nb::init>(), nb::arg("cache_manager"), diff --git a/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp b/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp index d92336e6bdf..8baff07809c 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/cacheTransceiver.cpp @@ -87,11 +87,11 @@ void tb::CacheTransceiverBindings::initBindings(py::module_& m) py::classh(m, "CacheTransceiver") .def(py::init, SizeType32, SizeType32, - runtime::WorldConfig, nvinfer1::DataType, executor::kv_cache::CacheState::AttentionType, - std::optional>(), + runtime::WorldConfig, std::vector, nvinfer1::DataType, + executor::kv_cache::CacheState::AttentionType, std::optional>(), py::arg("cache_manager"), py::arg("num_kv_heads_per_layer"), py::arg("size_per_head"), - py::arg("tokens_per_block"), py::arg("world_config"), py::arg("dtype"), py::arg("attention_type"), - py::arg("cache_transceiver_config") = std::nullopt); + py::arg("tokens_per_block"), py::arg("world_config"), py::arg("attention_layer_num_per_pp"), + py::arg("dtype"), py::arg("attention_type"), py::arg("cache_transceiver_config") = std::nullopt); py::class_(m, "CacheTransBufferManager") .def(py::init>(), py::arg("cache_manager"), diff --git a/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp b/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp index 27e1590e6a2..d5e57797b77 100644 --- a/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/cacheTransBufferTest.cpp @@ -346,8 +346,9 @@ TEST_F(CacheTransBufferTest, TestForNullOptAndNoneTransSize) auto bufferManager = tensorrt_llm::runtime::BufferManager{std::make_shared()}; auto targetNum = 2; auto targetSize = 1024; + std::vector targetSizeVec = std::vector(targetNum, targetSize); auto [sendBuffers, bufferCoverTargetNum, onlyUseDynamicBuffer] - = mTransBufferManager->getOrAllocateSendBuffers(bufferId3, targetNum, targetSize, bufferManager); + = mTransBufferManager->getOrAllocateSendBuffers(bufferId3, targetNum, targetSizeVec, bufferManager); EXPECT_EQ(sendBuffers.size(), targetNum); EXPECT_EQ(bufferCoverTargetNum, targetNum); EXPECT_EQ(onlyUseDynamicBuffer, true); @@ -393,8 +394,9 @@ TEST_F(CacheTransBufferTest, TestForNullOptAndDefaultTransSize) auto bufferManager = tensorrt_llm::runtime::BufferManager{std::make_shared()}; auto targetNum = 2; auto targetSize = 1024; + std::vector targetSizeVec = std::vector(targetNum, targetSize); auto [sendBuffers, bufferCoverTargetNum, onlyUseDynamicBuffer] - = mTransBufferManager->getOrAllocateSendBuffers(bufferId3, targetNum, targetSize, bufferManager); + = mTransBufferManager->getOrAllocateSendBuffers(bufferId3, targetNum, targetSizeVec, bufferManager); EXPECT_EQ(sendBuffers.size(), targetNum); EXPECT_EQ(bufferCoverTargetNum, targetNum); EXPECT_EQ(onlyUseDynamicBuffer, false); @@ -407,8 +409,9 @@ TEST_F(CacheTransBufferTest, TestForNullOptAndDefaultTransSize) auto bufferId4 = mTransBufferManager->assignBufferIndexForSend(); EXPECT_TRUE(bufferId4.has_value()); EXPECT_EQ(bufferId4.value(), 0); + targetSizeVec = std::vector(targetNum, targetSize); auto [sendBuffers2, bufferCoverTargetNum2, onlyUseDynamicBuffer2] - = mTransBufferManager->getOrAllocateSendBuffers(bufferId4, targetNum, targetSize, bufferManager); + = mTransBufferManager->getOrAllocateSendBuffers(bufferId4, targetNum, targetSizeVec, bufferManager); EXPECT_EQ(sendBuffers2.size(), targetNum); EXPECT_EQ(bufferCoverTargetNum2, targetNum / 2); EXPECT_EQ(onlyUseDynamicBuffer2, false); @@ -418,8 +421,9 @@ TEST_F(CacheTransBufferTest, TestForNullOptAndDefaultTransSize) auto bufferId5 = mTransBufferManager->assignBufferIndexForSend(); EXPECT_TRUE(bufferId5.has_value()); EXPECT_EQ(bufferId5.value(), 0); + targetSizeVec = std::vector(targetNum, targetSize); auto [sendBuffers3, bufferCoverTargetNum3, onlyUseDynamicBuffer3] - = mTransBufferManager->getOrAllocateSendBuffers(bufferId5, targetNum, targetSize, bufferManager); + = mTransBufferManager->getOrAllocateSendBuffers(bufferId5, targetNum, targetSizeVec, bufferManager); EXPECT_EQ(sendBuffers3.size(), targetNum); EXPECT_EQ(bufferCoverTargetNum3, targetNum); EXPECT_EQ(onlyUseDynamicBuffer3, false); diff --git a/cpp/tests/unit_tests/executor/agentCommTest.cpp b/cpp/tests/unit_tests/executor/agentCommTest.cpp index d9e6aaa1389..ee561ca816b 100644 --- a/cpp/tests/unit_tests/executor/agentCommTest.cpp +++ b/cpp/tests/unit_tests/executor/agentCommTest.cpp @@ -78,7 +78,7 @@ class AgentCommTest : public ::testing::Test auto constexpr dataType = nvinfer1::DataType::kFLOAT; using BlocksPerWindow = std::map>; - const BlocksPerWindow blocksPerWindow + BlocksPerWindow const blocksPerWindow = {{maxAttentionWindow, std::make_tuple(totalNumBlocks, blocksInSecondaryPool)}}; mCacheManager = std::make_unique(numLayers, numHeads, sizePerHead, tokensPerBlock, @@ -90,7 +90,8 @@ class AgentCommTest : public ::testing::Test size_t maxNumTokens = 1024; mTransBufferManager = std::make_unique(mCacheManager.get(), maxNumTokens); - mCacheState = std::make_unique(numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, 1, dataType); + mCacheState = std::make_unique( + numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, 1, std::vector{numLayers}, dataType); } void TearDown() override @@ -107,7 +108,7 @@ class AgentCommTest : public ::testing::Test TEST_F(AgentCommTest, AgentConnectionManagerBasic) { - auto connectionManager = std::make_unique(mTransBufferManager.get()); + auto connectionManager = std::make_unique(mTransBufferManager.get(), *mCacheState); ASSERT_TRUE(connectionManager != nullptr); ASSERT_TRUE(connectionManager->getCacheTransBufferManager() != nullptr); ASSERT_EQ(connectionManager->getDeviceId(), 0); @@ -120,8 +121,8 @@ TEST_F(AgentCommTest, AgentConnectionManagerBasic) TEST_F(AgentCommTest, AgentConnectionManagerConnect) { - auto connectionManager0 = std::make_unique(mTransBufferManager.get()); - auto connectionManager1 = std::make_unique(mTransBufferManager.get()); + auto connectionManager0 = std::make_unique(mTransBufferManager.get(), *mCacheState); + auto connectionManager1 = std::make_unique(mTransBufferManager.get(), *mCacheState); auto agentName0 = connectionManager0->getAgentName(); auto agentName1 = connectionManager1->getAgentName(); ASSERT_TRUE(!agentName0.empty()); diff --git a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp index 1dad1fa2bbb..1faf9540760 100644 --- a/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp +++ b/cpp/tests/unit_tests/executor/serializeUtilsTest.cpp @@ -726,7 +726,7 @@ TEST(SerializeUtilsTest, ContextPhaseParams) { auto state = std::make_unique(); state->setCommState(texec::kv_cache::CommState{12, "127.0.0.1"}); - state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 8, nvinfer1::DataType::kFLOAT}); + state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 8, {4}, nvinfer1::DataType::kFLOAT}); auto stats = texec::ContextPhaseParams({10, 20, 30, 40, 50, 60}, 0, state.release(), VecTokens{10, 20}); auto stats2 = serializeDeserialize(stats); EXPECT_EQ(stats, stats2); diff --git a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp index a3d809b8860..6b4298909ac 100644 --- a/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp +++ b/cpp/tests/unit_tests/multi_gpu/cacheTransceiverTest.cpp @@ -99,7 +99,7 @@ TEST_F(RequestInfoTest, Basic) } auto state = std::make_unique(); state->setCommState(texec::kv_cache::CommState{12, "127.0.0.1"}); - state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 8, nvinfer1::DataType::kFLOAT}); + state->setCacheState(texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 8, {10}, nvinfer1::DataType::kFLOAT}); RequestInfo info{1, *state}; auto info2 = serializeDeserialize(info); EXPECT_EQ(info, info2); @@ -141,14 +141,16 @@ TEST_F(CacheConfigTest, EqualTo) vocabSize, nbAttentionLayers + nbRnnLayers, nbAttentionLayers, nbRnnLayers, nbHeads, hiddenSize, dtype}; modelConfig.setTokensPerBlock(tokensPerBlock); tr::WorldConfig worldConfig{tensorParallelism, pipelineParallelism, contextParallelism}; + std::vector attentionLayerNumPerPP(pipelineParallelism, nbAttentionLayers / pipelineParallelism); texec::kv_cache::CacheState::ModelConfig cacheStateCfg{ modelConfig.getNumKvHeadsPerLayer(), modelConfig.getSizePerHead(), modelConfig.getTokensPerBlock()}; texec::kv_cache::CacheState state0{ - cacheStateCfg, worldConfig, modelConfig.getKvDataType(), attentionType, kvFactor}; + cacheStateCfg, worldConfig, attentionLayerNumPerPP, modelConfig.getKvDataType(), attentionType, kvFactor}; texec::kv_cache::CacheState state1{nbAttentionLayers, nbHeads, sizePerHead, tokensPerBlock, tensorParallelism, - pipelineParallelism, contextParallelism, dtype, attentionType, kvFactor, false, 0, tensorParallelism}; + pipelineParallelism, contextParallelism, attentionLayerNumPerPP, dtype, attentionType, kvFactor, false, 0, + tensorParallelism}; EXPECT_EQ(state0, state1); } @@ -165,7 +167,7 @@ class MockDataSender : public DataSender ON_CALL(*this, recvRequestInfo) .WillByDefault(Return(RequestInfo{0, texec::DataTransceiverState{ - texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, nvinfer1::DataType::kFLOAT}, + texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, {10}, nvinfer1::DataType::kFLOAT}, texec::kv_cache::CommState{std::vector{0}, 0}}})); ON_CALL(*this, getCounterpartsCount).WillByDefault(Return(1)); } @@ -218,7 +220,7 @@ TEST_F(MockTransceiverTest, MpiResponderBasic) EXPECT_CALL(*sender, recvRequestInfo) .WillOnce(Return(RequestInfo{0, texec::DataTransceiverState{ - texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, nvinfer1::DataType::kFLOAT}, + texec::kv_cache::CacheState{10, 12, 128, 128, 8, 8, 1, {4}, nvinfer1::DataType::kFLOAT}, texec::kv_cache::CommState{std::vector{0}, 0}}})); EXPECT_CALL(*sender, sendSync).WillOnce(Return()); EXPECT_CALL(*sender, getCounterpartsCount).WillOnce(Return(1)); @@ -318,8 +320,9 @@ class SymmetricalCacheTest : public ::testing::Test // NOLINT(cppcoreguidelines- mMaxNumSequences, maxBeamWidth, std::vector{maxAttentionWindow}, std::nullopt, dataType, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks, CacheType::kSELF, std::nullopt, nullptr, true); + auto attentionLayerNumPerPP = std::vector{numLayers}; mCacheState = std::make_unique( - numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, 1, dataType); + numLayers, numHeads, sizePerHead, tokensPerBlock, 1, 1, 1, attentionLayerNumPerPP, dataType); if (tensorrt_llm::common::getEnvUseUCXKvCache()) { @@ -614,7 +617,29 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam(mPpSize, 0); + for (int ppRank = 0; ppRank < mPpSize; ppRank++) + { + mAttentionLayerNumPerPP[ppRank] = getLayerNumPPRank(numLayers, ppRank, mPpSize); + } + int layerNumthisRank = getLayerNumPPRank(numLayers, mPpRank, mPpSize); + + auto contextAttentionLayerNumPerPP = std::vector(mContextPpSize, 0); + for (int ppRank = 0; ppRank < mContextPpSize; ppRank++) + { + contextAttentionLayerNumPerPP[ppRank] = getLayerNumPPRank(numLayers, ppRank, mContextPpSize); + } + if (!isMLA) { // ASSERT_EQ(numHeads % mTpSize , 0); @@ -693,19 +718,19 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam(numLayers / mPpSize, numHeadsPerRank, sizePerHead, tokensPerBlock, + mManager = std::make_unique(layerNumthisRank, numHeadsPerRank, sizePerHead, tokensPerBlock, blocksPerWindow, mMaxNumSequences, maxBeamWidth, maxAttentionWindowVec, std::nullopt, dataType, sinkTokenLength, stream, std::nullopt, enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, true); texec::kv_cache::CacheState::AttentionType attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA : texec::kv_cache::CacheState::AttentionType::kDEFAULT; - mCacheState - = std::make_unique(numLayers, numHeadsPerRank, sizePerHead, tokensPerBlock, - mTpSize, mPpSize, mCpSize, dataType, attentionType, kvFactor, enableDPAttention, DPrank, DPsize); + mCacheState = std::make_unique(numLayers, numHeadsPerRank, sizePerHead, + tokensPerBlock, mTpSize, mPpSize, mCpSize, mAttentionLayerNumPerPP, dataType, attentionType, kvFactor, + enableDPAttention, DPrank, DPsize); mContextCacheState = std::make_unique(numLayers, numHeadsPerRankForContext, - sizePerHead, tokensPerBlock, mContextTpSize, mContextPpSize, mContextCpSize, dataType, attentionType, - kvFactor, mContextDP, DPrank, mContextTpSize); + sizePerHead, tokensPerBlock, mContextTpSize, mContextPpSize, mContextCpSize, contextAttentionLayerNumPerPP, + dataType, attentionType, kvFactor, mContextDP, DPrank, mContextTpSize); // UVM seems to be incompatible with MPI, and it is continuing to investigate. bool constexpr useUvm = false; @@ -751,8 +776,8 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam(mCacheTransBufferManager.get()); + mConnectionManager = std::make_unique( + mCacheTransBufferManager.get(), *mCacheState); } else { @@ -865,7 +890,8 @@ class AsymmetricalCacheTest : public ::testing::TestWithParamgetModelConfig().mSizePerHead, mContextCacheState->getModelConfig().mTokensPerBlock, mContextCacheState->getParallelConfig().mTensorParallelism, mContextCacheState->getParallelConfig().mPipelineParallelism, - mContextCacheState->getParallelConfig().mContextParallelism, mContextCacheState->getDataType(), + mContextCacheState->getParallelConfig().mContextParallelism, + mContextCacheState->getParallelConfig().mAttentionLayerNumPerPP, mContextCacheState->getDataType(), mContextCacheState->getAttentionConfig().mAttentionType, mContextCacheState->getAttentionConfig().mKvFactor, mContextCacheState->getParallelConfig().mEnableAttentionDP, contextDpRank, mContextCacheState->getParallelConfig().mTensorParallelism}; @@ -944,8 +970,19 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam(); - int startLayerId = layerSizePerRank * mPpRank; + int layerSizeThisRank = blockData.getDimension<1>(); + int startLayerId = 0; + if (mIsWindowAttention) + { + startLayerId = layerSizeThisRank * mPpRank; + } + else + { + for (int ppRank = 0; ppRank < mPpRank; ppRank++) + { + startLayerId += mAttentionLayerNumPerPP[ppRank]; + } + } int headSizePerRank = mCacheState->getModelConfig().mNbKvHeadsPerLayer.at(0); int startHeadId = headSizePerRank * (mTpRank / mDupHeadFactor); bool enableDP = mCacheState->getParallelConfig().mEnableAttentionDP; @@ -958,7 +995,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParamgetModelConfig().mSizePerHead; auto dataTypeSize = tensorrt_llm::common::getDTypeSize(blockData.getDataType()); - for (int layerId = 0; layerId < layerSizePerRank; layerId++) + for (int layerId = 0; layerId < layerSizeThisRank; layerId++) { for (int headId = 0; headId < headSizePerRank; headId++) { @@ -1008,8 +1045,20 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam(); - int startLayerId = layerSizePerRank * mPpRank; + int layerSizethisRank = blockData.getDimension<1>(); + int startLayerId = 0; + if (mIsWindowAttention) + { + startLayerId = layerSizethisRank * mPpRank; + } + else + { + for (int ppRank = 0; ppRank < mPpRank; ppRank++) + { + startLayerId += mAttentionLayerNumPerPP[ppRank]; + } + } + int headSizePerRank = mCacheState->getModelConfig().mNbKvHeadsPerLayer.at(0); int startHeadId = headSizePerRank * (mTpRank / mDupHeadFactor); bool enableDP = mCacheState->getParallelConfig().mEnableAttentionDP; @@ -1025,7 +1074,7 @@ class AsymmetricalCacheTest : public ::testing::TestWithParam mAttentionLayerNumPerPP; SizeType32 mMaxNumSequences{}; std::unique_ptr mManager; @@ -1351,6 +1401,18 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1, AsymmetricalCacheTest, testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false, true))); +INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1EvenLayer, AsymmetricalCacheTest, + testing::Combine(testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(4), + testing::Values(1), testing::Values(10), testing::Values(4), testing::Values(4), testing::Values(8), + testing::Values(nvinfer1::DataType::kFLOAT), testing::Values(2), testing::Values(false), testing::Values(false), + testing::Values(false), testing::Values(false))); + +INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest2EvenLayer, AsymmetricalCacheTest, + testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(4), + testing::Values(1), testing::Values(10), testing::Values(4), testing::Values(4), testing::Values(8), + testing::Values(nvinfer1::DataType::kFLOAT), testing::Values(2), testing::Values(false), testing::Values(false), + testing::Values(false), testing::Values(false))); + INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest2, AsymmetricalCacheTest, testing::Combine(testing::Values(1), testing::Values(2), testing::Values(1), testing::Values(1), testing::Values(1, 4), testing::Values(1), testing::Values(16), testing::Values(16), testing::Values(4), @@ -1369,6 +1431,18 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1ForMLA, AsymmetricalCacheTest, testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), testing::Values(true), testing::Values(false), testing::Values(false), testing::Values(false))); +INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest1ForMLAEvenLayer, AsymmetricalCacheTestWithDP, + testing::Combine(testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(1), + testing::Values(1), testing::Values(10), testing::Values(1), testing::Values(4), testing::Values(8), + testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), + testing::Values(true), testing::Values(false), testing::Values(false, true), testing::Values(false))); + +INSTANTIATE_TEST_CASE_P(AsymmetricCaseTest2ForMLAEvenLayer, AsymmetricalCacheTestWithDP, + testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1), testing::Values(4), + testing::Values(1), testing::Values(10), testing::Values(1), testing::Values(4), testing::Values(8), + testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(1), + testing::Values(true), testing::Values(false), testing::Values(false, true), testing::Values(false))); + INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForMLA1, AsymmetricalCacheTestWithDP, testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), @@ -1403,12 +1477,19 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLA2, AsymmetricalCacheTest testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(4), testing::Values(4), testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(false), testing::Values(true), testing::Values(false))); + INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate0, AsymmetricalCacheTestWithDP, testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(true, false), testing::Values(false), testing::Values(false))); +INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate0EvenLayer, AsymmetricalCacheTestWithDP, + testing::Combine(testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), testing::Values(1), + testing::Values(1), testing::Values(5), testing::Values(2), testing::Values(4), testing::Values(16), + testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), + testing::Values(false), testing::Values(true, false), testing::Values(false), testing::Values(false))); + INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate1, AsymmetricalCacheTestWithDP, testing::Combine(testing::Values(1, 2), testing::Values(1, 2), testing::Values(1), testing::Values(2), testing::Values(2), testing::Values(1), testing::Values(4), testing::Values(1), testing::Values(4), @@ -1419,6 +1500,7 @@ INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate2, Asymmetrical testing::Values(1), testing::Values(1), testing::Values(4), testing::Values(2), testing::Values(4), testing::Values(16), testing::Values(nvinfer1::DataType::kFLOAT, nvinfer1::DataType::kINT8), testing::Values(2), testing::Values(false), testing::Values(false), testing::Values(false), testing::Values(false))); + INSTANTIATE_TEST_CASE_P(AsymmetricCaseTestWithDPForNoMLADuplicate4, AsymmetricalCacheTestWithDP, testing::Combine(testing::Values(4), testing::Values(1), testing::Values(1), testing::Values(1, 2), testing::Values(2), testing::Values(1), testing::Values(4), testing::Values(1, 2), testing::Values(4), @@ -1444,12 +1526,17 @@ TEST(targetTest, CacheStateNODP) { auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA : texec::kv_cache::CacheState::AttentionType::kDEFAULT; + std::vector contextAttentionLayerNumPerPP( + contextWC.getPipelineParallelism(), numLayers / contextWC.getPipelineParallelism()); + std::vector genAttentionLayerNumPerPP( + genWC.getPipelineParallelism(), numLayers / genWC.getPipelineParallelism()); auto const sharedModelConfig = texec::kv_cache::CacheState::ModelConfig{std::vector(numLayers, numHeads), sizePerHead, tokensPerBlock}; - auto const contextCache - = texec::kv_cache::CacheState(sharedModelConfig, contextWC, dataType, attentionType, kvFactor); - auto const genCache = texec::kv_cache::CacheState(sharedModelConfig, genWC, dataType, attentionType, kvFactor); + auto const contextCache = texec::kv_cache::CacheState( + sharedModelConfig, contextWC, contextAttentionLayerNumPerPP, dataType, attentionType, kvFactor); + auto const genCache = texec::kv_cache::CacheState( + sharedModelConfig, genWC, genAttentionLayerNumPerPP, dataType, attentionType, kvFactor); auto const contextTargetInfo = tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(genCache, contextCache, contextRank); @@ -1731,6 +1818,8 @@ TEST(targetTest, CacheStateContextDP) int genCP = 1; bool contextEnableDP = true; bool genEnableDP = true; + std::vector contextAttentionLayerNumPerPP(contextPP, numLayers / contextPP); + std::vector genAttentionLayerNumPerPP(genPP, numLayers / genPP); auto const verifyContext = [&](int contextRank, int generationRank, std::vector const& expectRanks, int expectPPDomain, int expectTPDomain, bool expectNeedSend) @@ -1740,13 +1829,13 @@ TEST(targetTest, CacheStateContextDP) auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA : texec::kv_cache::CacheState::AttentionType::kDEFAULT; - auto const contextCache - = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, contextTP, - contextPP, contextCP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP}; + auto const contextCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, + tokensPerBlock, contextTP, contextPP, contextCP, contextAttentionLayerNumPerPP, dataType, attentionType, + kvFactor, contextEnableDP, contextDPRank, contextTP}; - auto const genCache - = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, genTP, - genPP, genCP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP}; + auto const genCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, + tokensPerBlock, genTP, genPP, genCP, genAttentionLayerNumPerPP, dataType, attentionType, kvFactor, + genEnableDP, generationDPRank, genTP}; auto const contextTragetInfo = tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(genCache, contextCache, contextRank); @@ -1847,13 +1936,13 @@ TEST(targetTest, CacheStateContextDP) auto attentionType = isMLA ? texec::kv_cache::CacheState::AttentionType::kMLA : texec::kv_cache::CacheState::AttentionType::kDEFAULT; - auto const contextCache - = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, contextTP, - contextPP, contextCP, dataType, attentionType, kvFactor, contextEnableDP, contextDPRank, contextTP}; + auto const contextCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, + tokensPerBlock, contextTP, contextPP, contextCP, contextAttentionLayerNumPerPP, dataType, attentionType, + kvFactor, contextEnableDP, contextDPRank, contextTP}; - auto const genCache - = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, tokensPerBlock, genTP, - genPP, genCP, dataType, attentionType, kvFactor, genEnableDP, generationDPRank, genTP}; + auto const genCache = tensorrt_llm::executor::kv_cache::CacheState{numLayers, numHeads, sizePerHead, + tokensPerBlock, genTP, genPP, genCP, genAttentionLayerNumPerPP, dataType, attentionType, kvFactor, + genEnableDP, generationDPRank, genTP}; auto const contextTragetInfo = tensorrt_llm::executor::kv_cache::TargetRanksInfoForDP(contextCache, genCache, generationRank); @@ -1872,6 +1961,8 @@ TEST(targetTest, CacheStateContextDP) contextPP = 1; genTP = 1; genPP = 2; + contextAttentionLayerNumPerPP = std::vector(contextPP, numLayers / contextPP); + genAttentionLayerNumPerPP = std::vector(genPP, numLayers / genPP); verfiyGeneration( /*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1); @@ -1885,6 +1976,8 @@ TEST(targetTest, CacheStateContextDP) contextPP = 1; genTP = 1; genPP = 1; + contextAttentionLayerNumPerPP = std::vector(contextPP, numLayers / contextPP); + genAttentionLayerNumPerPP = std::vector(genPP, numLayers / genPP); verfiyGeneration( /*contextRank*/ 0, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1); diff --git a/tensorrt_llm/_torch/distributed/communicator.py b/tensorrt_llm/_torch/distributed/communicator.py index 83eb7157495..bf3ef8b1cfc 100644 --- a/tensorrt_llm/_torch/distributed/communicator.py +++ b/tensorrt_llm/_torch/distributed/communicator.py @@ -100,6 +100,7 @@ class MPIDist(Distributed): def __init__(self, mapping: Mapping): super().__init__(mapping) self.create_tp_comm() + self.create_pp_comm() def broadcast(self, obj, root=0): return mpi_broadcast(obj, root) @@ -135,6 +136,10 @@ def create_tp_comm(self): new_group = mpi_comm().group.Incl(self.mapping.tp_group) self.tp_comm = mpi_comm().Create_group(new_group) + def create_pp_comm(self): + new_group = mpi_comm().group.Incl(self.mapping.pp_group) + self.pp_comm = mpi_comm().Create_group(new_group) + def tp_allgather(self, obj): return self.tp_comm.allgather(obj) @@ -144,6 +149,15 @@ def tp_gather(self, obj): def tp_broadcast(self, obj, root=0): return self.tp_comm.bcast(obj, root) + def pp_allgather(self, obj): + return self.pp_comm.allgather(obj) + + def pp_gather(self, obj): + return self.pp_comm.gather(obj) + + def pp_broadcast(self, obj, root=0): + return self.pp_comm.bcast(obj, root) + class TorchDist(Distributed): diff --git a/tensorrt_llm/_torch/pyexecutor/_util.py b/tensorrt_llm/_torch/pyexecutor/_util.py index e3021797acf..65e1a93b54c 100644 --- a/tensorrt_llm/_torch/pyexecutor/_util.py +++ b/tensorrt_llm/_torch/pyexecutor/_util.py @@ -644,7 +644,8 @@ def create_py_executor_instance( config) else AttentionTypeCpp.DEFAULT cache_transceiver_config = executor_config.cache_transceiver_config kv_cache_transceiver = create_kv_cache_transceiver( - mapping, kv_cache_manager, attention_type, cache_transceiver_config) + mapping, dist, kv_cache_manager, attention_type, + cache_transceiver_config) return PyExecutor( resource_manager, scheduler, diff --git a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py index eb1f2019781..923f10e7086 100644 --- a/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py +++ b/tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @@ -3,6 +3,7 @@ import tensorrt_llm from tensorrt_llm import logger +from tensorrt_llm._torch.distributed.communicator import Distributed from tensorrt_llm.bindings import WorldConfig from tensorrt_llm.bindings.executor import CacheTransceiverConfig from tensorrt_llm.mapping import Mapping @@ -28,7 +29,7 @@ def mapping_to_world_config(mapping: Mapping) -> WorldConfig: def create_kv_cache_transceiver( - mapping: Mapping, kv_cache_manager: KVCacheManager, + mapping: Mapping, dist: Distributed, kv_cache_manager: KVCacheManager, attention_type: AttentionTypeCpp, cache_transceiver_config: CacheTransceiverConfig): if cache_transceiver_config is None or cache_transceiver_config.backend is None: @@ -59,8 +60,8 @@ def create_kv_cache_transceiver( f"UCX_CUDA_IPC_ENABLE_MNNVL=n, UCX_RNDV_SCHEME=put_zcopy and/or unset UCX_NET_DEVICES upon server " f"hangs or lower-than-expected performance.") - return BindKvCacheTransceiver(mapping, kv_cache_manager, attention_type, - cache_transceiver_config) + return BindKvCacheTransceiver(mapping, dist, kv_cache_manager, + attention_type, cache_transceiver_config) class KvCacheTransceiver(ABC): @@ -92,7 +93,8 @@ def check_gen_transfer_complete(self): class BindKvCacheTransceiver(KvCacheTransceiver): - def __init__(self, mapping: Mapping, kv_cache_manager: KVCacheManager, + def __init__(self, mapping: Mapping, dist: Distributed, + kv_cache_manager: KVCacheManager, attention_type: AttentionTypeCpp, cache_transceiver_config: CacheTransceiverConfig): world_config = mapping_to_world_config(mapping) @@ -100,10 +102,13 @@ def __init__(self, mapping: Mapping, kv_cache_manager: KVCacheManager, head_dim = kv_cache_manager.head_dim tokens_per_block = kv_cache_manager.tokens_per_block dtype = kv_cache_manager.dtype - + # get the layer num per pp rank, which is required by cache transceiver. + pp_layer_num = len(kv_cache_manager.pp_layers) + pp_layer_num_per_pp_rank = dist.pp_allgather(pp_layer_num) self.impl = CacheTransceiverCpp(kv_cache_manager.impl, total_num_kv_heads_per_layer, head_dim, - tokens_per_block, world_config, dtype, + tokens_per_block, world_config, + pp_layer_num_per_pp_rank, dtype, attention_type, cache_transceiver_config) diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxpp4_gentp4.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxpp4_gentp4.yaml new file mode 100644 index 00000000000..a1e4ad50a9c --- /dev/null +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxpp4_gentp4.yaml @@ -0,0 +1,36 @@ +model: TinyLlama/TinyLlama-1.1B-Chat-v1.0 +hostname: localhost +port: 8000 +backend: "pytorch" +cuda_graph_config: null +free_gpu_memory_fraction: 0.2 +context_servers: + num_instances: 1 + max_batch_size: 1 + max_num_tokens: 3000 + max_seq_len: 4096 + tensor_parallel_size: 1 + pipeline_parallel_size: 4 + kv_cache_config: + free_gpu_memory_fraction: 0.2 + enable_partial_reuse: False + disable_overlap_scheduler: True + cache_transceiver_config: + backend: DEFAULT + urls: + - "localhost:8001" +generation_servers: + num_instances: 1 + tensor_parallel_size: 4 + pipeline_parallel_size: 1 + max_batch_size: 256 + max_num_tokens: 4096 + max_seq_len: 4096 + kv_cache_config: + free_gpu_memory_fraction: 0.2 + enable_partial_reuse: False + disable_overlap_scheduler: True + cache_transceiver_config: + backend: DEFAULT + urls: + - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_ctxpp2_gentp2.yaml b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_ctxpp2_gentp2.yaml new file mode 100644 index 00000000000..4a61497e94e --- /dev/null +++ b/tests/integration/defs/disaggregated/test_configs/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_ctxpp2_gentp2.yaml @@ -0,0 +1,32 @@ +hostname: localhost +port: 8000 +model: DeepSeek-V3-Lite/fp8 +free_gpu_memory_fraction: 0.1 +backend: "pytorch" +cuda_graph_config: null +disable_overlap_scheduler: True + +context_servers: + num_instances: 1 + tensor_parallel_size: 1 + pipeline_parallel_size: 2 + enable_attention_dp: false + speculative_config: + decoding_type: MTP + num_nextn_predict_layers: 1 + cache_transceiver_config: + backend: DEFAULT + urls: + - "localhost:8001" +generation_servers: + num_instances: 1 + tensor_parallel_size: 2 + pipeline_parallel_size: 1 + enable_attention_dp: false + speculative_config: + decoding_type: MTP + num_nextn_predict_layers: 1 + cache_transceiver_config: + backend: DEFAULT + urls: + - "localhost:8002" diff --git a/tests/integration/defs/disaggregated/test_disaggregated.py b/tests/integration/defs/disaggregated/test_disaggregated.py index 46c393ab488..9b485da96b3 100644 --- a/tests/integration/defs/disaggregated/test_disaggregated.py +++ b/tests/integration/defs/disaggregated/test_disaggregated.py @@ -87,6 +87,8 @@ def get_test_config(test_desc, example_dir, test_root): (8, f"{test_configs_root}/disagg_config_ctxtp2pp2_gentp2pp2.yaml"), "ctxpp4_genpp4": (8, f"{test_configs_root}/disagg_config_ctxpp4_genpp4.yaml"), + "ctxpp4_gentp4": + (8, f"{test_configs_root}/disagg_config_ctxpp4_gentp4.yaml"), "deepseek_v3_lite_fp8_mpi": (4, f"{test_configs_root}/disagg_config_ctxtp2_gentp2_deepseek_v3_lite_mpi.yaml" @@ -149,6 +151,10 @@ def get_test_config(test_desc, example_dir, test_root): (2, f"{test_configs_root}/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_two_mtp.yaml" ), + "deepseek_v3_lite_fp8_ctxpp2_gentp2_one_mtp": + (4, + f"{test_configs_root}/disagg_config_ctxtp1_gentp1_deepseek_v3_lite_one_mtp_ctxpp2_gentp2.yaml" + ), } if test_desc not in config_map: @@ -776,6 +782,27 @@ def test_disaggregated_ctxpp4_genpp4(disaggregated_test_root, llm_venv, cwd=llm_venv.get_working_directory()) +#tiny llama pp4 will have uneven layer per pp. pp4 +@pytest.mark.skip_less_device(8) +@pytest.mark.parametrize("llama_model_root", ['TinyLlama-1.1B-Chat-v1.0'], + indirect=True) +def test_disaggregated_ctxpp4_gentp4(disaggregated_test_root, llm_venv, + disaggregated_example_root, + llama_model_root): + src_dst_dict = { + llama_model_root: + f"{llm_venv.get_working_directory()}/TinyLlama/TinyLlama-1.1B-Chat-v1.0", + } + for src, dst in src_dst_dict.items(): + if not os.path.islink(dst): + os.makedirs(os.path.dirname(dst), exist_ok=True) + os.symlink(src, dst, target_is_directory=True) + run_disaggregated_test(disaggregated_example_root, + "ctxpp4_gentp4", + env=llm_venv._new_env, + cwd=llm_venv.get_working_directory()) + + @skip_no_hopper @pytest.mark.skip_less_device(4) @pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'], @@ -842,6 +869,29 @@ def test_disaggregated_deepseek_v3_lite_fp8_tp1_single_gpu_mtp( cwd=llm_venv.get_working_directory()) +@pytest.mark.skip_less_device(4) +@skip_no_hopper +@pytest.mark.parametrize("deepseek_v3_model_root", ['DeepSeek-V3-Lite-fp8'], + indirect=True) +def test_disaggregated_deepseek_v3_lite_fp8_ctxpp2_gentp2_one_mtp( + disaggregated_test_root, disaggregated_example_root, llm_venv, + deepseek_v3_model_root): + #add one mtp layer, pp rank0 will have 15 layer, pp rank 1 will have 16 layers. + src_dst_dict = { + deepseek_v3_model_root: + f"{llm_venv.get_working_directory()}/DeepSeek-V3-Lite/fp8", + } + for src, dst in src_dst_dict.items(): + if not os.path.islink(dst): + os.makedirs(os.path.dirname(dst), exist_ok=True) + os.symlink(src, dst, target_is_directory=True) + + run_disaggregated_test(disaggregated_example_root, + "deepseek_v3_lite_fp8_ctxpp2_gentp2_one_mtp", + env=llm_venv._new_env, + cwd=llm_venv.get_working_directory()) + + @skip_no_hopper @skip_arm @pytest.mark.skip_less_device(4) diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index f0171fd2c89..52a5a5e6b94 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -33,6 +33,7 @@ l0_dgx_h100: - disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_genpp2[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_ctxtp2_genpp2[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_ctxpp2_gentp2[TinyLlama-1.1B-Chat-v1.0] + - disaggregated/test_disaggregated.py::test_disaggregated_ctxpp4_gentp4[TinyLlama-1.1B-Chat-v1.0] - disaggregated/test_disaggregated.py::test_disaggregated_genbs1[TinyLlama-1.1B-Chat-v1.0] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True] @@ -139,6 +140,7 @@ l0_dgx_h100: - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_cache_aware_balance[DeepSeek-V3-Lite-bf16] - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_bf16_conditional[DeepSeek-V3-Lite-bf16] - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2] + - disaggregated/test_disaggregated.py::test_disaggregated_deepseek_v3_lite_fp8_ctxpp2_gentp2_one_mtp[DeepSeek-V3-Lite-fp8] - disaggregated/test_workers.py::test_workers_conditional_disaggregation_deepseek_v3_lite_bf16[DeepSeek-V3-Lite-bf16] - disaggregated/test_workers.py::test_workers_kv_cache_aware_router_deepseek_v3_lite_bf16[DeepSeek-V3-Lite-bf16] - condition: