Skip to content

Commit bcc7a76

Browse files
committed
revert kvcache transfer
Signed-off-by: Chuang Zhu <[email protected]>
1 parent 980929e commit bcc7a76

File tree

3 files changed

+13
-24
lines changed

3 files changed

+13
-24
lines changed

cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques
7575
bool CacheFormatter::needSendCache(
7676
CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx)
7777
{
78+
// int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
7879
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
7980
if (targetInfo.mDupHeadFactor <= 1)
8081
{
@@ -89,9 +90,8 @@ bool CacheFormatter::needSendCache(
8990
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
9091
selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup;
9192
}
92-
int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0;
9393

94-
return (destDPRank % targetInfo.mDupHeadFactor) == (selfTpRankInDpGroup % targetInfo.mDupHeadFactor);
94+
return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0;
9595
}
9696

9797
void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig,
@@ -128,12 +128,11 @@ std::vector<size_t> CacheFormatter::pickRecvConnections(
128128
return ret;
129129
}
130130
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
131-
int selfDPRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0;
132131

133132
std::vector<size_t> ret;
134133
for (int i = 0; i < targetInfo.mDomainTPSize; i++)
135134
{
136-
if ((i % targetInfo.mPeerDupHeadFactor) == (selfDPRank % targetInfo.mPeerDupHeadFactor))
135+
if (i % targetInfo.mPeerDupHeadFactor == 0)
137136
{
138137
for (int j = 0; j < targetInfo.mDomainPPSize; j++)
139138
{

cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,10 @@ std::vector<size_t> MLACacheFormatter::pickRecvConnections(
4545
auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx);
4646
TLLM_CHECK(numConnections == targetInfo.mIRanks.size());
4747
std::vector<size_t> ret;
48-
// targetInfo , mRanks [tpranks, ppranks]
49-
int dpRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0;
50-
48+
// targetInfo , mRanks [tpranks, dpranks]
5149
for (int i = 0; i < targetInfo.mDomainPPSize; i++)
5250
{
53-
ret.push_back(i + (dpRank % (targetInfo.mDomainTPSize)) * targetInfo.mDomainPPSize);
51+
ret.push_back(i);
5452
}
5553
return ret;
5654
}
@@ -60,24 +58,19 @@ bool MLACacheFormatter::needSendCache(
6058
{
6159
int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism;
6260

63-
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
64-
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
65-
: destConfig.getParallelConfig().mTensorParallelism;
66-
int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0;
67-
6861
if (selfConfig.getParallelConfig().mEnableAttentionDP)
6962
{
7063
int selfTPNumInDPGroup
7164
= selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize;
72-
65+
int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP
66+
? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize
67+
: destConfig.getParallelConfig().mTensorParallelism;
7368
int selfTPrankINDPGroup = selfTpRank % selfTPNumInDPGroup;
7469
if (selfTPNumInDPGroup <= destTPNumInDPGroup)
7570
{
7671
return true;
7772
}
78-
79-
int dupHeadFactor = selfTPNumInDPGroup / destTPNumInDPGroup;
80-
return selfTPrankINDPGroup % dupHeadFactor == destDPRank;
73+
return selfTPrankINDPGroup % (selfTPNumInDPGroup / destTPNumInDPGroup) == 0;
8174
}
8275

8376
int destTPNum = destConfig.getParallelConfig().mEnableAttentionDP
@@ -88,8 +81,7 @@ bool MLACacheFormatter::needSendCache(
8881
{
8982
return true;
9083
}
91-
int dupHeadFactor = selfTPNum / destTPNum;
92-
return selfTpRank % dupHeadFactor == destDPRank;
84+
return selfTpRank % (selfTPNum / destTPNum) == 0;
9385
}
9486

9587
void MLACacheFormatter::format(TransferSession& session)

cpp/tests/batch_manager/cacheTransceiverTest.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1457,15 +1457,12 @@ TEST(targetTest, CacheStateNODP)
14571457

14581458
verifyContext(
14591459
/*contextRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
1460-
14611460
verifyContext(
14621461
/*contextRank*/ 1, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
1463-
14641462
verifyContext(
14651463
/*contextRank*/ 2, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
14661464
verifyContext(
14671465
/*contextRank*/ 3, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ false);
1468-
14691466
verifyContext(
14701467
/*contextRank*/ 4, /*expectRanks*/ {2}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1, /*expectNeedSend*/ true);
14711468
verifyContext(
@@ -1477,6 +1474,7 @@ TEST(targetTest, CacheStateNODP)
14771474

14781475
contextTP = 2;
14791476
genTP = 4;
1477+
14801478
verifyContext(
14811479
/*contextRank*/ 0, /*expectRanks*/ {0, 1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2, /*expectNeedSend*/ true);
14821480
verifyContext(/*contextRank*/ 1, /*expectRanks*/ {2, 3}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 2,
@@ -1566,13 +1564,13 @@ TEST(targetTest, CacheStateContextDP)
15661564
/*expectNeedSend*/ true);
15671565
verifyContext(
15681566
/*contextRank*/ 0, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
1569-
/*expectNeedSend*/ false);
1567+
/*expectNeedSend*/ true);
15701568
verifyContext(
15711569
/*contextRank*/ 1, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
15721570
/*expectNeedSend*/ false);
15731571
verifyContext(
15741572
/*contextRank*/ 1, /*generationRank*/ 1, /*expectRanks*/ {1}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
1575-
/*expectNeedSend*/ true);
1573+
/*expectNeedSend*/ false);
15761574
verifyContext(
15771575
/*contextRank*/ 2, /*generationRank*/ 0, /*expectRanks*/ {0}, /*expectPPDomain*/ 1, /*expectTPDomain*/ 1,
15781576
/*expectNeedSend*/ false);

0 commit comments

Comments
 (0)