Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 20 additions & 12 deletions cpp/include/tensorrt_llm/executor/dataTransceiverState.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,29 +52,30 @@ class CacheState final
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2)
: mModelConfig(std::move(modelConfig))
, mParallelConfig{worldConfig.getTensorParallelism(), worldConfig.getPipelineParallelism(),
worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(), worldConfig.getTensorParallelism()}
worldConfig.getContextParallelism(), worldConfig.enableAttentionDP(), worldConfig.getTensorParallelRank(),
worldConfig.getTensorParallelism()}
, mDataType{dataType}
, mAttentionConfig(attentionType, kvFactor)
{
}

CacheState(std::vector<SizeType32> nbKvHeadPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType,
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
int DPrank = 0, int DPsize = 0)
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2,
bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0)
: mModelConfig{std::move(nbKvHeadPerLayer), sizePerHead, tokensPerBlock}
, mParallelConfig{tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize}
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
, mDataType{dataType}
, mAttentionConfig(attentionType, kvFactor)
{
}

CacheState(SizeType32 nbAttentionLayers, SizeType32 nbKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock,
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, nvinfer1::DataType dataType,
AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2, bool enableAttentionDP = false,
int DPrank = 0, int DPsize = 0)
SizeType32 tensorParallelism, SizeType32 pipelineParallelism, SizeType32 contextParallelism,
nvinfer1::DataType dataType, AttentionType attentionType = AttentionType::kDEFAULT, int kvFactor = 2,
bool enableAttentionDP = false, int DPrank = 0, int DPsize = 0)
: mModelConfig{std::vector(nbAttentionLayers, nbKvHeads), sizePerHead, tokensPerBlock}
, mParallelConfig{tensorParallelism, pipelineParallelism, enableAttentionDP, DPrank, DPsize}
, mParallelConfig{tensorParallelism, pipelineParallelism, contextParallelism, enableAttentionDP, DPrank, DPsize}
, mDataType{dataType}
, mAttentionConfig(attentionType, kvFactor)
{
Expand All @@ -83,7 +84,7 @@ class CacheState final
[[nodiscard]] bool operator==(kv_cache::CacheState const& other) const noexcept
{
return mModelConfig == other.mModelConfig && mParallelConfig == other.mParallelConfig
&& mDataType == other.mDataType;
&& mAttentionConfig == other.mAttentionConfig && mDataType == other.mDataType;
}

struct ModelConfig
Expand All @@ -103,15 +104,16 @@ class CacheState final
{
SizeType32 mTensorParallelism;
SizeType32 mPipelineParallelism;
SizeType32 mContextParallelism;
bool mEnableAttentionDP;
SizeType32 mDPrank;
SizeType32 mDPsize;

[[nodiscard]] bool operator==(ParallelConfig const& other) const noexcept
{
return mTensorParallelism == other.mTensorParallelism && mPipelineParallelism == other.mPipelineParallelism
&& mEnableAttentionDP == other.mEnableAttentionDP && mDPrank == other.mDPrank
&& mDPsize == other.mDPsize;
&& mContextParallelism == other.mContextParallelism && mEnableAttentionDP == other.mEnableAttentionDP
&& mDPrank == other.mDPrank && mDPsize == other.mDPsize;
}
};

Expand All @@ -125,6 +127,11 @@ class CacheState final
{
}

[[nodiscard]] bool operator==(AttentionConfig const& other) const noexcept
{
return mAttentionType == other.mAttentionType && mKvFactor == other.mKvFactor;
}

// attentionType ;
AttentionType mAttentionType;
int mKvFactor;
Expand Down Expand Up @@ -162,6 +169,7 @@ class CacheState final
sstring << "mTokensPerBlock:" << mModelConfig.mTokensPerBlock << "\n";
sstring << "tp:" << mParallelConfig.mTensorParallelism << "\n";
sstring << "pp:" << mParallelConfig.mPipelineParallelism << "\n";
sstring << "cp:" << mParallelConfig.mContextParallelism << "\n";
sstring << "enableAttentionDP:" << mParallelConfig.mEnableAttentionDP << "\n";
sstring << "datatype:" << static_cast<int32_t>(mDataType) << "\n";
sstring << "attentionType:" << static_cast<int32_t>(mAttentionConfig.mAttentionType) << "\n";
Expand Down
8 changes: 8 additions & 0 deletions cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,14 @@ void CacheFormatter::unformat(TransferSession& session)
TLLM_LOG_WARNING("CacheFormatter::inquireSupport: only support non-MLA");
return false;
}
if (selfConfig.getParallelConfig().mContextParallelism != 1
|| destConfig.getParallelConfig().mContextParallelism != 1)
{
TLLM_LOG_WARNING(
"CacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).",
selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism);
return false;
}

std::unordered_set<int> setVecDest{
destConfig.getModelConfig().mNbKvHeadsPerLayer.begin(), destConfig.getModelConfig().mNbKvHeadsPerLayer.end()};
Expand Down
14 changes: 8 additions & 6 deletions cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,18 +558,20 @@ void MLACacheFormatter::unformat(TransferSession& session)
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support MLA");
return false;
}

if (selfConfig.getAttentionConfig().mKvFactor != destConfig.getAttentionConfig().mKvFactor)
{
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: only support same kv factor");
return false;
}
if (selfConfig.getParallelConfig().mEnableAttentionDP
&& (selfConfig.getParallelConfig().mTensorParallelism % selfConfig.getParallelConfig().mDPsize != 0))
{
TLLM_LOG_WARNING("MLACacheFormatter::inquireSupport: TP size must be divisible by DP size");
return false;
}
if (selfConfig.getParallelConfig().mContextParallelism != 1
|| destConfig.getParallelConfig().mContextParallelism != 1)
{
TLLM_LOG_WARNING(
"MLACacheFormatter::inquireSupport: context parallelism is not currently supported (selfCP=%d, destCP=%d).",
selfConfig.getParallelConfig().mContextParallelism, destConfig.getParallelConfig().mContextParallelism);
return false;
}
if (destConfig.getParallelConfig().mEnableAttentionDP
&& (destConfig.getParallelConfig().mTensorParallelism % destConfig.getParallelConfig().mDPsize != 0))
{
Expand Down
7 changes: 5 additions & 2 deletions cpp/tensorrt_llm/executor/serialization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -531,14 +531,15 @@ kv_cache::CacheState Serialization::deserializeCacheState(std::istream& is)
auto tokensPerBlock = su::deserialize<decltype(CacheState::ModelConfig::mTokensPerBlock)>(is);
auto tensorParallelism = su::deserialize<decltype(CacheState::ParallelConfig::mTensorParallelism)>(is);
auto pipelineParallelism = su::deserialize<decltype(CacheState::ParallelConfig::mPipelineParallelism)>(is);
auto contextParallelism = su::deserialize<decltype(CacheState::ParallelConfig::mContextParallelism)>(is);
auto enableAttentionDP = su::deserialize<decltype(CacheState::ParallelConfig::mEnableAttentionDP)>(is);
auto DPrank = su::deserialize<decltype(CacheState::ParallelConfig::mDPrank)>(is);
auto DPsize = su::deserialize<decltype(CacheState::ParallelConfig::mDPsize)>(is);
auto dataType = su::deserialize<decltype(CacheState::mDataType)>(is);
auto attentionType = su::deserialize<decltype(CacheState::AttentionConfig::mAttentionType)>(is);
auto kvFactor = su::deserialize<decltype(CacheState::AttentionConfig::mKvFactor)>(is);
return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism, dataType,
attentionType, kvFactor, enableAttentionDP, DPrank, DPsize};
return CacheState{nbKvHeadsPerLayer, sizePerHead, tokensPerBlock, tensorParallelism, pipelineParallelism,
contextParallelism, dataType, attentionType, kvFactor, enableAttentionDP, DPrank, DPsize};
}

void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& os)
Expand All @@ -548,6 +549,7 @@ void Serialization::serialize(kv_cache::CacheState const& state, std::ostream& o
su::serialize(state.mModelConfig.mTokensPerBlock, os);
su::serialize(state.mParallelConfig.mTensorParallelism, os);
su::serialize(state.mParallelConfig.mPipelineParallelism, os);
su::serialize(state.mParallelConfig.mContextParallelism, os);
su::serialize(state.mParallelConfig.mEnableAttentionDP, os);
su::serialize(state.mParallelConfig.mDPrank, os);
su::serialize(state.mParallelConfig.mDPsize, os);
Expand All @@ -564,6 +566,7 @@ size_t Serialization::serializedSize(kv_cache::CacheState const& state)
totalSize += su::serializedSize(state.mModelConfig.mTokensPerBlock);
totalSize += su::serializedSize(state.mParallelConfig.mTensorParallelism);
totalSize += su::serializedSize(state.mParallelConfig.mPipelineParallelism);
totalSize += su::serializedSize(state.mParallelConfig.mContextParallelism);
totalSize += su::serializedSize(state.mParallelConfig.mEnableAttentionDP);
totalSize += su::serializedSize(state.mParallelConfig.mDPrank);
totalSize += su::serializedSize(state.mParallelConfig.mDPsize);
Expand Down
Loading