Skip to content

Commit e342b35

Browse files
lfr-0531k-l-lambda
authored andcommitted
fix: refactor and fix mtp vanilla (NVIDIA#4762)
Signed-off-by: Fanrong Li <[email protected]>
1 parent 24548d7 commit e342b35

File tree

17 files changed

+1826
-1994
lines changed

17 files changed

+1826
-1994
lines changed

cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu

Lines changed: 36 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -63,27 +63,24 @@ __device__ void copyChunkedHiddenStates(T const* srcPtr, T* dstPtr, int const nu
6363
}
6464

6565
template <typename T>
66-
__global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const curMTPLayerIdx, int const batchSize,
67-
int const numContextRequest, int const hiddenSize, int const* inputIds, int const* seqLens,
68-
T** const mtpPastHiddenStatesPtrs, int** const mtpPastTokensPtrs, T* const previousLayerHiddenStates,
69-
int* const previousLayerDraftTokens, int* returnInputIds, T* returnHiddenStates)
66+
__global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const numContextRequest,
67+
int const hiddenSize, int const* inputIds, int const* seqLens, T** const mtpPastHiddenStatesPtrs,
68+
int** const mtpPastTokensPtrs, T* const hiddenStates, int const* acceptedTokens, int const* numAcceptedTokens,
69+
int* returnInputIds, T* returnHiddenStates)
7070
{
7171
/*
7272
In a batch of request: context request (at the beginning) + generation requests
7373
numGenerationRequest = batchSize - numContextRequest
7474
7575
inputIds: [N]
76-
- When curMTPLayerIdx == 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules
77-
+ 1)
78-
- When curMTPLayerIdx > 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * numMTPModules
76+
- N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules + 1)
7977
seqLens: [batchSize]
8078
mtpPastHiddenStatesPtrs: [maxNumRequests][numMTPModules, hiddenSize]
8179
mtpPastTokensPtrs: [maxNumRequests][numMTPModules]
82-
previousLayerHiddenStates: [N, hiddenSize]
83-
- When curMTPLayerIdx == 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules
84-
+ 1) (from target model)
85-
- When curMTPLayerIdx > 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * numMTPModules
86-
previousLayerDraftTokens: [batchSize], the draft tokens generated by the previous layer
80+
hiddenStates: [N, hiddenSize]
81+
- N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules + 1) (from target model)
82+
acceptedTokens: [batchSize, numMTPModules + 1]
83+
numAcceptedTokens: [batchSize]
8784
returnInputIds: [N]
8885
- N = sum(all numContextRequest's prompts) + numGenerationRequest * numMTPModules
8986
returnHiddenStates: [N, hiddenSize]
@@ -94,6 +91,7 @@ __global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const
9491

9592
T const* curMTPPastHiddenStatesPtr = mtpPastHiddenStatesPtrs[bid];
9693
int const* curMTPPastTokensPtr = mtpPastTokensPtrs[bid];
94+
int const* curAcceptedTokensPtr = acceptedTokens + bid * (numMTPModules + 1);
9795

9896
int curSeqLen = seqLens[bid];
9997

@@ -117,63 +115,44 @@ __global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const
117115
}
118116

119117
int const* curInputIdsPtr = inputIds + inputIdsStartOffset;
120-
T const* curPreviousLayerHiddenStates = previousLayerHiddenStates + inputIdsStartOffset * hiddenSize;
118+
T const* curHiddenStates = hiddenStates + inputIdsStartOffset * hiddenSize;
121119

122120
int* curReturnInputIdsPtr = returnInputIds + returnInputIdsStartOffset;
123121
T* curReturnHiddenStatesIdsPtr = returnHiddenStates + returnInputIdsStartOffset * hiddenSize;
124122

125123
//// main logic
126-
127-
if (curMTPLayerIdx == 0)
124+
if (bid < numContextRequest)
128125
{
129-
if (bid < numContextRequest)
126+
// context requests
127+
if (tid == 0)
130128
{
131-
// context requests
132-
if (tid == 0)
129+
// 1) For the new inputIds
130+
for (int ii = 0; ii < curSeqLen - 1; ii++)
133131
{
134-
// 1) For the new inputIds
135-
for (int ii = 0; ii < curSeqLen - 1; ii++)
136-
{
137-
curReturnInputIdsPtr[ii] = curInputIdsPtr[ii + 1]; // +1 because of offset 1, prompt[1:]
138-
}
139-
// Append the latest golden token, i.e., the last one in the past tokens list
140-
curReturnInputIdsPtr[curSeqLen - 1] = curMTPPastTokensPtr[numMTPModules - 1];
132+
curReturnInputIdsPtr[ii] = curInputIdsPtr[ii + 1]; // +1 because of offset 1, prompt[1:]
141133
}
142-
143-
// 2) For the new past hidden states
144-
copyChunkedHiddenStates(curPreviousLayerHiddenStates, curReturnHiddenStatesIdsPtr, curSeqLen * hiddenSize);
134+
// Append the latest golden token, i.e., the first one in the accepted tokens list
135+
curReturnInputIdsPtr[curSeqLen - 1] = curAcceptedTokensPtr[0];
145136
}
146-
else
147-
{
148-
// generation requests
149-
if (tid == 0)
150-
{
151-
// 1) For the new inputIds
152-
for (int ii = 0; ii < numMTPModules; ii++)
153-
{
154-
curReturnInputIdsPtr[ii] = curMTPPastTokensPtr[ii];
155-
}
156-
}
157137

158-
// 2) For the new past hidden states
159-
copyChunkedHiddenStates(curMTPPastHiddenStatesPtr, curReturnHiddenStatesIdsPtr, numMTPModules * hiddenSize);
160-
}
138+
// 2) For the new past hidden states
139+
copyChunkedHiddenStates(curHiddenStates, curReturnHiddenStatesIdsPtr, curSeqLen * hiddenSize);
161140
}
162-
else // For curMTPLayerIdx > 0
141+
else
163142
{
143+
// generation requests
164144
if (tid == 0)
165145
{
166146
// 1) For the new inputIds
167-
int numPastTokens = (bid < numContextRequest) ? curSeqLen : numMTPModules;
168-
for (int ii = 0; ii < numPastTokens; ii++)
147+
for (int ii = 0; ii < numMTPModules - 1; ii++)
169148
{
170-
curReturnInputIdsPtr[ii] = curInputIdsPtr[ii + 1];
149+
curReturnInputIdsPtr[ii] = curMTPPastTokensPtr[ii + 1];
171150
}
172-
curReturnInputIdsPtr[numPastTokens - 1] = previousLayerDraftTokens[bid];
151+
curReturnInputIdsPtr[numMTPModules - 1] = curAcceptedTokensPtr[numAcceptedTokens[bid] - 1];
173152
}
174153

175154
// 2) For the new past hidden states
176-
// Directly use previous layer's output hidden states
155+
copyChunkedHiddenStates(curMTPPastHiddenStatesPtr, curReturnHiddenStatesIdsPtr, numMTPModules * hiddenSize);
177156
}
178157
}
179158

@@ -185,10 +164,10 @@ void invokeMTPPrepareDrafterInputs(MTPPrepareDrafterInputsParam& params, cudaStr
185164
params.hiddenSize * sizeof(T) % 16 == 0); // Which is because we will use float4 to copy the hidden states.
186165

187166
mtpPrepareDrafterInputsKernel<T><<<params.batchSize, BLOCK_SIZE, 0, stream>>>(params.numMTPModules,
188-
params.curMTPLayerIdx, params.batchSize, params.numContextRequest, params.hiddenSize, params.inputIds,
189-
params.seqLens, reinterpret_cast<T**>(params.mtpPastHiddenStatesPtrs), params.mtpPastTokensPtrs,
190-
reinterpret_cast<T*>(params.previousLayerHiddenStates), params.previousLayerDraftTokens, params.returnInputIds,
191-
reinterpret_cast<T*>(params.returnHiddenStates));
167+
params.numContextRequest, params.hiddenSize, params.inputIds, params.seqLens,
168+
reinterpret_cast<T**>(params.mtpPastHiddenStatesPtrs), params.mtpPastTokensPtrs,
169+
reinterpret_cast<T*>(params.hiddenStates), params.acceptedTokens, params.numAcceptedTokens,
170+
params.returnInputIds, reinterpret_cast<T*>(params.returnHiddenStates));
192171

193172
sync_check_cuda_error(stream);
194173
}
@@ -362,7 +341,7 @@ template void invokeMTPSampleAndAcceptDraftTokens<__nv_bfloat16>(
362341
template <typename T>
363342
__global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const batchSize, int const numContextRequest,
364343
int const hiddenSize, int const* inputIds, int const* seqLens, T* targetModelHiddenStates,
365-
T** mtpPastHiddenStatesPtrs, int** mtpPastTokensPtrs, int const* numAcceptedTokens, int const* acceptedTokens)
344+
T** mtpPastHiddenStatesPtrs, int** mtpPastTokensPtrs, int const* numAcceptedTokens)
366345
{
367346
/*
368347
In a batch of request: context request (at the beginning) + generation requests
@@ -374,7 +353,6 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b
374353
mtpPastHiddenStatesPtrs: [maxNumRequests][numMTPModules, hiddenSize]
375354
mtpPastTokensPtrs: [maxNumRequests][numMTPModules]
376355
numAcceptedTokens: [batchSize]
377-
acceptedTokens: [batchSize][numMTPModules + 1], flatten
378356
*/
379357

380358
int const bid = static_cast<int>(blockIdx.x); // Each block is responsible for a request.
@@ -395,7 +373,6 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b
395373

396374
auto curInputIdsPtr = inputIds + inputIdsStartOffset;
397375
auto curTargetModelHiddenStatesPtr = targetModelHiddenStates + inputIdsStartOffset * hiddenSize;
398-
auto curAcceptedTokensPtr = acceptedTokens + bid * (numMTPModules + 1);
399376

400377
// Update MTP tokens
401378
// Just use one thread to execute this copy
@@ -405,12 +382,10 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b
405382
{
406383
// Context request
407384
// Copy the end of prompt tokens
408-
for (int ii = 0; ii < numMTPModules - 1; ii++)
385+
for (int ii = 0; ii < numMTPModules; ii++)
409386
{
410-
curMTPPastTokensPtr[ii] = curInputIdsPtr[curSeqLen - numMTPModules + 1 + ii];
387+
curMTPPastTokensPtr[ii] = curInputIdsPtr[curSeqLen - numMTPModules + ii];
411388
}
412-
// Copy the new generated golden token
413-
curMTPPastTokensPtr[numMTPModules - 1] = curAcceptedTokensPtr[0];
414389
}
415390
else
416391
{
@@ -424,7 +399,7 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b
424399
int acceptedTokenStartIdx = max(0, curAcceptedLen - numMTPModules);
425400
for (; ii < numMTPModules; ii++, acceptedTokenStartIdx++)
426401
{
427-
curMTPPastTokensPtr[ii] = curAcceptedTokensPtr[acceptedTokenStartIdx];
402+
curMTPPastTokensPtr[ii] = curInputIdsPtr[acceptedTokenStartIdx];
428403
}
429404
}
430405
}
@@ -463,7 +438,7 @@ void invokeMTPUpdateHiddenStates(MTPUpdateHiddenStatesParam& params, cudaStream_
463438
mtpUpdateHiddenStatesKernel<T><<<params.batchSize, BLOCK_SIZE, 0, stream>>>(params.numMTPModules, params.batchSize,
464439
params.numContextRequest, params.hiddenSize, params.inputIds, params.seqLens,
465440
reinterpret_cast<T*>(params.targetModelHiddenStates), reinterpret_cast<T**>(params.mtpPastHiddenStatesPtrs),
466-
params.mtpPastTokensPtrs, params.numAcceptedTokens, params.acceptedTokens);
441+
params.mtpPastTokensPtrs, params.numAcceptedTokens);
467442
sync_check_cuda_error(stream);
468443
}
469444

cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,16 @@ namespace kernels
3434
struct MTPPrepareDrafterInputsParam
3535
{
3636
int numMTPModules;
37-
int curMTPLayerIdx;
3837
int batchSize;
3938
int numContextRequest;
4039
int hiddenSize;
4140
int* inputIds;
4241
int* seqLens;
4342
void** __restrict__ mtpPastHiddenStatesPtrs;
4443
int** mtpPastTokensPtrs;
45-
void* __restrict__ previousLayerHiddenStates;
46-
int* previousLayerDraftTokens;
44+
void* __restrict__ hiddenStates;
45+
int* acceptedTokens;
46+
int* numAcceptedTokens;
4747
int* returnInputIds;
4848
void* __restrict__ returnHiddenStates;
4949
};

cpp/tensorrt_llm/thop/mtpOp.cpp

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -29,35 +29,36 @@ namespace torch_ext
2929

3030
////////////////////////////////////////////////////////////////////////////////////////////////////////////
3131
std::tuple<th::Tensor, th::Tensor> mtp_prepare_drafter_inputs_op(th::Tensor& inputIds, th::Tensor& seqLens,
32-
th::Tensor& mtpPastHiddenStatesPtrs, th::Tensor& mtpPastTokensPtrs, th::Tensor& previousLayerHiddenStates,
33-
th::Tensor& previousLayerDraftTokens, th::Tensor& returnInputIds, th::Tensor& returnHiddenStates,
34-
int64_t numMTPModules, int64_t curMTPLayerIdx, int64_t batchSize, int64_t numContextRequest, int64_t hiddenSize)
32+
th::Tensor& mtpPastHiddenStatesPtrs, th::Tensor& mtpPastTokensPtrs, th::Tensor& hiddenStates,
33+
th::Tensor& acceptedTokens, th::Tensor& numAcceptedTokens, th::Tensor& returnInputIds,
34+
th::Tensor& returnHiddenStates, int64_t numMTPModules, int64_t batchSize, int64_t numContextRequest,
35+
int64_t hiddenSize)
3536
{
36-
auto dataType = previousLayerHiddenStates.scalar_type();
37+
auto dataType = hiddenStates.scalar_type();
3738

3839
// Check
3940
auto inputIdsSizes = inputIds.sizes();
40-
auto previousLayerHiddenStatesSizes = previousLayerHiddenStates.sizes();
41-
TLLM_CHECK(inputIdsSizes[0] == previousLayerHiddenStatesSizes[0]);
41+
auto hiddenStatesSizes = hiddenStates.sizes();
42+
TLLM_CHECK(inputIdsSizes[0] == hiddenStatesSizes[0]);
4243

4344
auto seqLensSizes = seqLens.sizes();
4445
TLLM_CHECK(seqLensSizes[0] == batchSize);
4546

46-
auto stream = at::cuda::getCurrentCUDAStream(previousLayerHiddenStates.get_device());
47+
auto stream = at::cuda::getCurrentCUDAStream(hiddenStates.get_device());
4748

4849
// Fill params
4950
tk::MTPPrepareDrafterInputsParam params;
5051
params.numMTPModules = numMTPModules;
51-
params.curMTPLayerIdx = curMTPLayerIdx;
5252
params.batchSize = batchSize;
5353
params.numContextRequest = numContextRequest;
5454
params.hiddenSize = hiddenSize;
5555
params.inputIds = reinterpret_cast<int*>(inputIds.data_ptr());
5656
params.seqLens = reinterpret_cast<int*>(seqLens.data_ptr());
5757
params.mtpPastHiddenStatesPtrs = reinterpret_cast<void**>(mtpPastHiddenStatesPtrs.data_ptr());
5858
params.mtpPastTokensPtrs = reinterpret_cast<int**>(mtpPastTokensPtrs.data_ptr());
59-
params.previousLayerHiddenStates = reinterpret_cast<void*>(previousLayerHiddenStates.data_ptr());
60-
params.previousLayerDraftTokens = reinterpret_cast<int*>(previousLayerDraftTokens.data_ptr());
59+
params.hiddenStates = reinterpret_cast<void*>(hiddenStates.data_ptr());
60+
params.acceptedTokens = reinterpret_cast<int*>(acceptedTokens.data_ptr());
61+
params.numAcceptedTokens = reinterpret_cast<int*>(numAcceptedTokens.data_ptr());
6162
params.returnInputIds = reinterpret_cast<int*>(returnInputIds.data_ptr());
6263
params.returnHiddenStates = reinterpret_cast<void*>(returnHiddenStates.data_ptr());
6364

@@ -81,14 +82,7 @@ std::tuple<th::Tensor, th::Tensor> mtp_prepare_drafter_inputs_op(th::Tensor& inp
8182
break;
8283
}
8384

84-
if (curMTPLayerIdx > 0)
85-
{
86-
return std::make_tuple(returnInputIds, previousLayerHiddenStates);
87-
}
88-
else
89-
{
90-
return std::make_tuple(returnInputIds, returnHiddenStates);
91-
}
85+
return std::make_tuple(returnInputIds, returnHiddenStates);
9286
}
9387

9488
////////////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -151,8 +145,8 @@ std::tuple<th::Tensor, th::Tensor> mtp_sampling_and_accepted_draft_tokens_op(th:
151145
////////////////////////////////////////////////////////////////////////////////////////////////////////////
152146
std::tuple<th::Tensor, th::Tensor> mtp_update_hidden_states_op(th::Tensor& inputIds, th::Tensor& seqLens,
153147
th::Tensor& targetModelHiddenStates, th::Tensor& mtpPastHiddenStatesPtrs, th::Tensor& mtpPastTokensPtrs,
154-
th::Tensor& numAcceptedTokens, th::Tensor& acceptedTokens, int64_t numMTPModules, int64_t batchSize,
155-
int64_t numContextRequest, int64_t hiddenSize)
148+
th::Tensor& numAcceptedTokens, int64_t numMTPModules, int64_t batchSize, int64_t numContextRequest,
149+
int64_t hiddenSize)
156150
{
157151
auto dataType = targetModelHiddenStates.scalar_type();
158152

@@ -178,7 +172,6 @@ std::tuple<th::Tensor, th::Tensor> mtp_update_hidden_states_op(th::Tensor& input
178172
params.mtpPastHiddenStatesPtrs = reinterpret_cast<void**>(mtpPastHiddenStatesPtrs.data_ptr());
179173
params.mtpPastTokensPtrs = reinterpret_cast<int**>(mtpPastTokensPtrs.data_ptr());
180174
params.numAcceptedTokens = reinterpret_cast<int*>(numAcceptedTokens.data_ptr());
181-
params.acceptedTokens = reinterpret_cast<int*>(acceptedTokens.data_ptr());
182175

183176
switch (dataType)
184177
{
@@ -274,9 +267,9 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
274267
{
275268
m.def(
276269
"mtp_prepare_drafter_inputs_op(Tensor inputIds, Tensor seqLens, Tensor "
277-
"mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor previousLayerHiddenStates, "
278-
"Tensor previousLayerDraftTokens, Tensor returnInputIds, Tensor returnHiddenStates, "
279-
"int numMTPModules, int curMTPLayerIdx, int batchSize, int numContextRequest,"
270+
"mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor hiddenStates, "
271+
"Tensor acceptedTokens, Tensor numAcceptedTokens, Tensor returnInputIds, Tensor returnHiddenStates, "
272+
"int numMTPModules, int batchSize, int numContextRequest,"
280273
"int hiddenSize) -> (Tensor, Tensor)");
281274
}
282275

@@ -306,7 +299,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
306299
{
307300
m.def(
308301
"mtp_update_hidden_states_op(Tensor inputIds, Tensor seqLens, Tensor targetModelHiddenStates, "
309-
"Tensor mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor numAcceptedTokens, Tensor acceptedTokens, "
302+
"Tensor mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor numAcceptedTokens, "
310303
"int numMTPModules, int batchSize, int numContextRequest, int hiddenSize) -> (Tensor, Tensor)");
311304
}
312305

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,15 +1003,13 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
10031003
ckpt_nextn = self.config.num_nextn_predict_layers
10041004
self.num_hidden_layers = self.config.num_hidden_layers
10051005
assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
1006-
if ckpt_nextn == 1:
1006+
if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla:
10071007
mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers,
10081008
self.model.aux_stream_dict)
10091009
self.model.layers.append(mtp_layer)
10101010
self.epilogue.append(mtp_layer)
10111011
self.mtp_worker = MTPEagleWorker(model_config.spec_config)
10121012
else:
1013-
# TODO: fix the accuracy issue and remove this assert.
1014-
assert False, "Cannot support num_nextn_predict_layers>1 in checkpoint now. Will fix it soon"
10151013
mtp_layers = nn.ModuleList([
10161014
DeepseekV3MTP(model_config,
10171015
layer_idx + self.num_hidden_layers,

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ def __post_init__(self):
2121

2222
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
2323
self.spec_dec_name)
24-
self.num_extra_kv_tokens = 0
2524

2625
def update_from_model_config(self, model_config):
2726
self.num_layers = model_config.num_hidden_layers

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ class SpecConfig:
6464
spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE
6565
# The max number of draft tokens
6666
max_draft_tokens: int = 1024
67+
# The number of extra kv tokens
68+
num_extra_kv_tokens: int = 0
6769

6870
def __post_init__(self) -> None:
6971
self.spec_dec_mode = SpeculativeDecodingMode.from_string(

0 commit comments

Comments
 (0)