-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][refactor] Improve lookahead decoding interfaces #5576
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4a3892c
f7ab50f
d279e0c
6062c69
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1863,24 +1863,29 @@ void TrtGptModelInflightBatching::setupDecoderStep( | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (mWorldConfig.isLastPipelineParallelRank() && !contextRequests.empty()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
RequestVector finishedContextRequests; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
std::copy_if(contextRequests.begin(), contextRequests.end(), std::back_inserter(finishedContextRequests), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
[](auto const& llmReq) { return llmReq->isLastContextChunk(); }); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
= (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests, logitsType, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
inputBuffers, *mDecoderState, mRuntime->getStream(), *mDecoder->getDecoderStream(), getMaxSequenceLen(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mOperatingBeamWidth, buffers.mMedusaBuffers); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto const localBatchSize = batchSlots->getSize(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (localBatchSize > 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (!finishedContextRequests.empty()) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+1866
to
1871
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Don’t drop disaggregated-generation setup when no “last context chunk” is present. setupDecoderStep now filters to last-context chunks only. When called from prepareDistGenBufferAndDecoder (Lines 1681-1683) with generation requests, isLastContextChunk() will typically be false, so finishedContextRequests becomes empty and decoder setup is skipped, breaking the disaggregated generation init path. Fix: fall back to using the passed-in requests if no last-context chunk is found. - RequestVector finishedContextRequests;
- std::copy_if(contextRequests.begin(), contextRequests.end(), std::back_inserter(finishedContextRequests),
- [](auto const& llmReq) { return llmReq->isLastContextChunk(); });
+ RequestVector finishedContextRequests;
+ std::copy_if(contextRequests.begin(), contextRequests.end(), std::back_inserter(finishedContextRequests),
+ [](auto const& llmReq) { return llmReq->isLastContextChunk(); });
+ // Disagg generation init path calls this with generation-ready requests (no context chunk in-flight).
+ if (finishedContextRequests.empty())
+ {
+ finishedContextRequests = contextRequests;
+ } 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto samplingConfig = SamplingConfig(samplingConfigs); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto [batchSlots, samplingConfig, lookaheadPrompt, lookaheadAlgoConfigs] | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
= (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, finishedContextRequests, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
logitsType, inputBuffers, *mDecoderState, mRuntime->getStream(), *mDecoder->getDecoderStream(), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
getMaxSequenceLen(), mOperatingBeamWidth, buffers.mMedusaBuffers); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto const localBatchSize = batchSlots->getSize(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TLLM_CHECK_WITH_INFO(localBatchSize > 0, "Decoder setup should be called with at least one request"); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mDecoder->getUnderlyingDecoder().setup(samplingConfig, localBatchSize, batchSlots, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{mDecoderState->getJointDecodingOutput()}, mModelConfig.getDataType(), lookaheadPrompt, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
lookaheadAlgoConfigs); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto const& stream = mDecoder->getDecoderStream(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto const& decoderStream = mDecoder->getDecoderStream(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
CudaEvent event{}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
stream->record(event); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
decoderStream->record(event); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mRuntime->getStreamPtr()->wait(event); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -2515,6 +2520,24 @@ void TrtGptModelInflightBatching::changeBeamWidth(SizeType32 beamWidth) | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
void TrtGptModelInflightBatching::disableLookaheadDecoder( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
RequestVector const& genRequests, DecoderInputBuffers& inputBuffers) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto batchSlots = CreateNewDecoderRequests::fillBatchSlots(genRequests, inputBuffers); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto samplingConfig = CreateNewDecoderRequests::fuseSamplingConfigs(genRequests); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mDecoder->getUnderlyingDecoder().disableLookahead(samplingConfig, batchSlots->getSize(), batchSlots); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
auto const& decoderStream = mDecoder->getDecoderStream(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
CudaEvent event{}; | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
decoderStream->record(event); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mRuntime->getStreamPtr()->wait(event); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
void TrtGptModelInflightBatching::changeSpecDecMode(ScheduledRequests const& scheduledRequests) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
@@ -2602,11 +2625,16 @@ void TrtGptModelInflightBatching::changeSpecDecMode(ScheduledRequests const& sch | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mDecodingConfig.setDecodingMode(executor::DecodingMode::Auto()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mBuffers.at(bufferId)->mLookaheadBuffers->disableLookaheadDecoding(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mDecoderOutputBuffers.at(getFusedBufferId()).disableLookaheadDecoding(getMaxNumSequences()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mDecoder->disableLookahead( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
scheduledRequests.generationRequests, mDecoderInputBuffers.at(getFusedBufferId()).setupBatchSlots); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mDecoderState->disableLookahead(scheduledRequests.generationRequests); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
disableLookaheadDecoder(scheduledRequests.generationRequests, mDecoderInputBuffers.at(getFusedBufferId())); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mDecoderState->disableLookahead(); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
for (auto const& llmReq : scheduledRequests.generationRequests) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (llmReq->mSeqSlot) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
mDecoderState->setNumDecodingEngineTokens(llmReq->mSeqSlot.value(), 1); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+2628
to
+2637
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🛠️ Refactor suggestion Disable LAD for all active slots (ctx + gen), and reset engine tokens consistently. The current call disables lookahead only for scheduledRequests.generationRequests. When LAD is turned off due to constraints, there may be only context requests in-flight; underlying decoder state for those slots won’t be updated. Also, resetting “numDecodingEngineTokens = 1” should apply to all slots in this transition. - disableLookaheadDecoder(scheduledRequests.generationRequests, mDecoderInputBuffers.at(getFusedBufferId()));
- mDecoderState->disableLookahead();
-
- for (auto const& llmReq : scheduledRequests.generationRequests)
+ // Apply to all scheduled slots (both ctx and gen have seqSlot at this point)
+ RequestVector requestsForDisable;
+ requestsForDisable.reserve(
+ scheduledRequests.contextRequests.size() + scheduledRequests.generationRequests.size());
+ requestsForDisable.insert(requestsForDisable.end(),
+ scheduledRequests.contextRequests.begin(), scheduledRequests.contextRequests.end());
+ requestsForDisable.insert(requestsForDisable.end(),
+ scheduledRequests.generationRequests.begin(), scheduledRequests.generationRequests.end());
+
+ disableLookaheadDecoder(requestsForDisable, mDecoderInputBuffers.at(getFusedBufferId()));
+ mDecoderState->disableLookahead();
+
+ for (auto const& llmReq : requestsForDisable)
{
if (llmReq->mSeqSlot)
{
mDecoderState->setNumDecodingEngineTokens(llmReq->mSeqSlot.value(), 1);
} 📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
if (llmReq->getNumDraftTokens() > 0) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
llmReq->discardDraftTokens(llmReq->getNumDraftTokens()); | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Uh oh!
There was an error while loading. Please reload this page.