Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 36 additions & 61 deletions cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,27 +63,24 @@ __device__ void copyChunkedHiddenStates(T const* srcPtr, T* dstPtr, int const nu
}

template <typename T>
__global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const curMTPLayerIdx, int const batchSize,
int const numContextRequest, int const hiddenSize, int const* inputIds, int const* seqLens,
T** const mtpPastHiddenStatesPtrs, int** const mtpPastTokensPtrs, T* const previousLayerHiddenStates,
int* const previousLayerDraftTokens, int* returnInputIds, T* returnHiddenStates)
__global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const numContextRequest,
int const hiddenSize, int const* inputIds, int const* seqLens, T** const mtpPastHiddenStatesPtrs,
int** const mtpPastTokensPtrs, T* const hiddenStates, int const* acceptedTokens, int const* numAcceptedTokens,
int* returnInputIds, T* returnHiddenStates)
{
/*
In a batch of request: context request (at the beginning) + generation requests
numGenerationRequest = batchSize - numContextRequest

inputIds: [N]
- When curMTPLayerIdx == 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules
+ 1)
- When curMTPLayerIdx > 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * numMTPModules
- N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules + 1)
seqLens: [batchSize]
mtpPastHiddenStatesPtrs: [maxNumRequests][numMTPModules, hiddenSize]
mtpPastTokensPtrs: [maxNumRequests][numMTPModules]
previousLayerHiddenStates: [N, hiddenSize]
- When curMTPLayerIdx == 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules
+ 1) (from target model)
- When curMTPLayerIdx > 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * numMTPModules
previousLayerDraftTokens: [batchSize], the draft tokens generated by the previous layer
hiddenStates: [N, hiddenSize]
- N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules + 1) (from target model)
acceptedTokens: [batchSize, numMTPModules + 1]
numAcceptedTokens: [batchSize]
returnInputIds: [N]
- N = sum(all numContextRequest's prompts) + numGenerationRequest * numMTPModules
returnHiddenStates: [N, hiddenSize]
Expand All @@ -94,6 +91,7 @@ __global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const

T const* curMTPPastHiddenStatesPtr = mtpPastHiddenStatesPtrs[bid];
int const* curMTPPastTokensPtr = mtpPastTokensPtrs[bid];
int const* curAcceptedTokensPtr = acceptedTokens + bid * (numMTPModules + 1);

int curSeqLen = seqLens[bid];

Expand All @@ -117,63 +115,44 @@ __global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const
}

int const* curInputIdsPtr = inputIds + inputIdsStartOffset;
T const* curPreviousLayerHiddenStates = previousLayerHiddenStates + inputIdsStartOffset * hiddenSize;
T const* curHiddenStates = hiddenStates + inputIdsStartOffset * hiddenSize;

int* curReturnInputIdsPtr = returnInputIds + returnInputIdsStartOffset;
T* curReturnHiddenStatesIdsPtr = returnHiddenStates + returnInputIdsStartOffset * hiddenSize;

//// main logic

if (curMTPLayerIdx == 0)
if (bid < numContextRequest)
{
if (bid < numContextRequest)
// context requests
if (tid == 0)
{
// context requests
if (tid == 0)
// 1) For the new inputIds
for (int ii = 0; ii < curSeqLen - 1; ii++)
{
// 1) For the new inputIds
for (int ii = 0; ii < curSeqLen - 1; ii++)
{
curReturnInputIdsPtr[ii] = curInputIdsPtr[ii + 1]; // +1 because of offset 1, prompt[1:]
}
// Append the latest golden token, i.e., the last one in the past tokens list
curReturnInputIdsPtr[curSeqLen - 1] = curMTPPastTokensPtr[numMTPModules - 1];
curReturnInputIdsPtr[ii] = curInputIdsPtr[ii + 1]; // +1 because of offset 1, prompt[1:]
}

// 2) For the new past hidden states
copyChunkedHiddenStates(curPreviousLayerHiddenStates, curReturnHiddenStatesIdsPtr, curSeqLen * hiddenSize);
// Append the latest golden token, i.e., the first one in the accepted tokens list
curReturnInputIdsPtr[curSeqLen - 1] = curAcceptedTokensPtr[0];
}
else
{
// generation requests
if (tid == 0)
{
// 1) For the new inputIds
for (int ii = 0; ii < numMTPModules; ii++)
{
curReturnInputIdsPtr[ii] = curMTPPastTokensPtr[ii];
}
}

// 2) For the new past hidden states
copyChunkedHiddenStates(curMTPPastHiddenStatesPtr, curReturnHiddenStatesIdsPtr, numMTPModules * hiddenSize);
}
// 2) For the new past hidden states
copyChunkedHiddenStates(curHiddenStates, curReturnHiddenStatesIdsPtr, curSeqLen * hiddenSize);
}
else // For curMTPLayerIdx > 0
else
{
// generation requests
if (tid == 0)
{
// 1) For the new inputIds
int numPastTokens = (bid < numContextRequest) ? curSeqLen : numMTPModules;
for (int ii = 0; ii < numPastTokens; ii++)
for (int ii = 0; ii < numMTPModules - 1; ii++)
{
curReturnInputIdsPtr[ii] = curInputIdsPtr[ii + 1];
curReturnInputIdsPtr[ii] = curMTPPastTokensPtr[ii + 1];
}
curReturnInputIdsPtr[numPastTokens - 1] = previousLayerDraftTokens[bid];
curReturnInputIdsPtr[numMTPModules - 1] = curAcceptedTokensPtr[numAcceptedTokens[bid] - 1];
}

// 2) For the new past hidden states
// Directly use previous layer's output hidden states
copyChunkedHiddenStates(curMTPPastHiddenStatesPtr, curReturnHiddenStatesIdsPtr, numMTPModules * hiddenSize);
}
}

Expand All @@ -185,10 +164,10 @@ void invokeMTPPrepareDrafterInputs(MTPPrepareDrafterInputsParam& params, cudaStr
params.hiddenSize * sizeof(T) % 16 == 0); // Which is because we will use float4 to copy the hidden states.

mtpPrepareDrafterInputsKernel<T><<<params.batchSize, BLOCK_SIZE, 0, stream>>>(params.numMTPModules,
params.curMTPLayerIdx, params.batchSize, params.numContextRequest, params.hiddenSize, params.inputIds,
params.seqLens, reinterpret_cast<T**>(params.mtpPastHiddenStatesPtrs), params.mtpPastTokensPtrs,
reinterpret_cast<T*>(params.previousLayerHiddenStates), params.previousLayerDraftTokens, params.returnInputIds,
reinterpret_cast<T*>(params.returnHiddenStates));
params.numContextRequest, params.hiddenSize, params.inputIds, params.seqLens,
reinterpret_cast<T**>(params.mtpPastHiddenStatesPtrs), params.mtpPastTokensPtrs,
reinterpret_cast<T*>(params.hiddenStates), params.acceptedTokens, params.numAcceptedTokens,
params.returnInputIds, reinterpret_cast<T*>(params.returnHiddenStates));

sync_check_cuda_error(stream);
}
Expand Down Expand Up @@ -362,7 +341,7 @@ template void invokeMTPSampleAndAcceptDraftTokens<__nv_bfloat16>(
template <typename T>
__global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const batchSize, int const numContextRequest,
int const hiddenSize, int const* inputIds, int const* seqLens, T* targetModelHiddenStates,
T** mtpPastHiddenStatesPtrs, int** mtpPastTokensPtrs, int const* numAcceptedTokens, int const* acceptedTokens)
T** mtpPastHiddenStatesPtrs, int** mtpPastTokensPtrs, int const* numAcceptedTokens)
{
/*
In a batch of request: context request (at the beginning) + generation requests
Expand All @@ -374,7 +353,6 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b
mtpPastHiddenStatesPtrs: [maxNumRequests][numMTPModules, hiddenSize]
mtpPastTokensPtrs: [maxNumRequests][numMTPModules]
numAcceptedTokens: [batchSize]
acceptedTokens: [batchSize][numMTPModules + 1], flatten
*/

int const bid = static_cast<int>(blockIdx.x); // Each block is responsible for a request.
Expand All @@ -395,7 +373,6 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b

auto curInputIdsPtr = inputIds + inputIdsStartOffset;
auto curTargetModelHiddenStatesPtr = targetModelHiddenStates + inputIdsStartOffset * hiddenSize;
auto curAcceptedTokensPtr = acceptedTokens + bid * (numMTPModules + 1);

// Update MTP tokens
// Just use one thread to execute this copy
Expand All @@ -405,12 +382,10 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b
{
// Context request
// Copy the end of prompt tokens
for (int ii = 0; ii < numMTPModules - 1; ii++)
for (int ii = 0; ii < numMTPModules; ii++)
{
curMTPPastTokensPtr[ii] = curInputIdsPtr[curSeqLen - numMTPModules + 1 + ii];
curMTPPastTokensPtr[ii] = curInputIdsPtr[curSeqLen - numMTPModules + ii];
}
// Copy the new generated golden token
curMTPPastTokensPtr[numMTPModules - 1] = curAcceptedTokensPtr[0];
}
else
{
Expand All @@ -424,7 +399,7 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b
int acceptedTokenStartIdx = max(0, curAcceptedLen - numMTPModules);
for (; ii < numMTPModules; ii++, acceptedTokenStartIdx++)
{
curMTPPastTokensPtr[ii] = curAcceptedTokensPtr[acceptedTokenStartIdx];
curMTPPastTokensPtr[ii] = curInputIdsPtr[acceptedTokenStartIdx];
}
}
}
Expand Down Expand Up @@ -463,7 +438,7 @@ void invokeMTPUpdateHiddenStates(MTPUpdateHiddenStatesParam& params, cudaStream_
mtpUpdateHiddenStatesKernel<T><<<params.batchSize, BLOCK_SIZE, 0, stream>>>(params.numMTPModules, params.batchSize,
params.numContextRequest, params.hiddenSize, params.inputIds, params.seqLens,
reinterpret_cast<T*>(params.targetModelHiddenStates), reinterpret_cast<T**>(params.mtpPastHiddenStatesPtrs),
params.mtpPastTokensPtrs, params.numAcceptedTokens, params.acceptedTokens);
params.mtpPastTokensPtrs, params.numAcceptedTokens);
sync_check_cuda_error(stream);
}

Expand Down
6 changes: 3 additions & 3 deletions cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,16 +34,16 @@ namespace kernels
struct MTPPrepareDrafterInputsParam
{
int numMTPModules;
int curMTPLayerIdx;
int batchSize;
int numContextRequest;
int hiddenSize;
int* inputIds;
int* seqLens;
void** __restrict__ mtpPastHiddenStatesPtrs;
int** mtpPastTokensPtrs;
void* __restrict__ previousLayerHiddenStates;
int* previousLayerDraftTokens;
void* __restrict__ hiddenStates;
int* acceptedTokens;
int* numAcceptedTokens;
int* returnInputIds;
void* __restrict__ returnHiddenStates;
};
Expand Down
43 changes: 18 additions & 25 deletions cpp/tensorrt_llm/thop/mtpOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,35 +29,36 @@ namespace torch_ext

////////////////////////////////////////////////////////////////////////////////////////////////////////////
std::tuple<th::Tensor, th::Tensor> mtp_prepare_drafter_inputs_op(th::Tensor& inputIds, th::Tensor& seqLens,
th::Tensor& mtpPastHiddenStatesPtrs, th::Tensor& mtpPastTokensPtrs, th::Tensor& previousLayerHiddenStates,
th::Tensor& previousLayerDraftTokens, th::Tensor& returnInputIds, th::Tensor& returnHiddenStates,
int64_t numMTPModules, int64_t curMTPLayerIdx, int64_t batchSize, int64_t numContextRequest, int64_t hiddenSize)
th::Tensor& mtpPastHiddenStatesPtrs, th::Tensor& mtpPastTokensPtrs, th::Tensor& hiddenStates,
th::Tensor& acceptedTokens, th::Tensor& numAcceptedTokens, th::Tensor& returnInputIds,
th::Tensor& returnHiddenStates, int64_t numMTPModules, int64_t batchSize, int64_t numContextRequest,
int64_t hiddenSize)
{
auto dataType = previousLayerHiddenStates.scalar_type();
auto dataType = hiddenStates.scalar_type();

// Check
auto inputIdsSizes = inputIds.sizes();
auto previousLayerHiddenStatesSizes = previousLayerHiddenStates.sizes();
TLLM_CHECK(inputIdsSizes[0] == previousLayerHiddenStatesSizes[0]);
auto hiddenStatesSizes = hiddenStates.sizes();
TLLM_CHECK(inputIdsSizes[0] == hiddenStatesSizes[0]);

auto seqLensSizes = seqLens.sizes();
TLLM_CHECK(seqLensSizes[0] == batchSize);

auto stream = at::cuda::getCurrentCUDAStream(previousLayerHiddenStates.get_device());
auto stream = at::cuda::getCurrentCUDAStream(hiddenStates.get_device());

// Fill params
tk::MTPPrepareDrafterInputsParam params;
params.numMTPModules = numMTPModules;
params.curMTPLayerIdx = curMTPLayerIdx;
params.batchSize = batchSize;
params.numContextRequest = numContextRequest;
params.hiddenSize = hiddenSize;
params.inputIds = reinterpret_cast<int*>(inputIds.data_ptr());
params.seqLens = reinterpret_cast<int*>(seqLens.data_ptr());
params.mtpPastHiddenStatesPtrs = reinterpret_cast<void**>(mtpPastHiddenStatesPtrs.data_ptr());
params.mtpPastTokensPtrs = reinterpret_cast<int**>(mtpPastTokensPtrs.data_ptr());
params.previousLayerHiddenStates = reinterpret_cast<void*>(previousLayerHiddenStates.data_ptr());
params.previousLayerDraftTokens = reinterpret_cast<int*>(previousLayerDraftTokens.data_ptr());
params.hiddenStates = reinterpret_cast<void*>(hiddenStates.data_ptr());
params.acceptedTokens = reinterpret_cast<int*>(acceptedTokens.data_ptr());
params.numAcceptedTokens = reinterpret_cast<int*>(numAcceptedTokens.data_ptr());
params.returnInputIds = reinterpret_cast<int*>(returnInputIds.data_ptr());
params.returnHiddenStates = reinterpret_cast<void*>(returnHiddenStates.data_ptr());

Expand All @@ -81,14 +82,7 @@ std::tuple<th::Tensor, th::Tensor> mtp_prepare_drafter_inputs_op(th::Tensor& inp
break;
}

if (curMTPLayerIdx > 0)
{
return std::make_tuple(returnInputIds, previousLayerHiddenStates);
}
else
{
return std::make_tuple(returnInputIds, returnHiddenStates);
}
return std::make_tuple(returnInputIds, returnHiddenStates);
}

////////////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -151,8 +145,8 @@ std::tuple<th::Tensor, th::Tensor> mtp_sampling_and_accepted_draft_tokens_op(th:
////////////////////////////////////////////////////////////////////////////////////////////////////////////
std::tuple<th::Tensor, th::Tensor> mtp_update_hidden_states_op(th::Tensor& inputIds, th::Tensor& seqLens,
th::Tensor& targetModelHiddenStates, th::Tensor& mtpPastHiddenStatesPtrs, th::Tensor& mtpPastTokensPtrs,
th::Tensor& numAcceptedTokens, th::Tensor& acceptedTokens, int64_t numMTPModules, int64_t batchSize,
int64_t numContextRequest, int64_t hiddenSize)
th::Tensor& numAcceptedTokens, int64_t numMTPModules, int64_t batchSize, int64_t numContextRequest,
int64_t hiddenSize)
{
auto dataType = targetModelHiddenStates.scalar_type();

Expand All @@ -178,7 +172,6 @@ std::tuple<th::Tensor, th::Tensor> mtp_update_hidden_states_op(th::Tensor& input
params.mtpPastHiddenStatesPtrs = reinterpret_cast<void**>(mtpPastHiddenStatesPtrs.data_ptr());
params.mtpPastTokensPtrs = reinterpret_cast<int**>(mtpPastTokensPtrs.data_ptr());
params.numAcceptedTokens = reinterpret_cast<int*>(numAcceptedTokens.data_ptr());
params.acceptedTokens = reinterpret_cast<int*>(acceptedTokens.data_ptr());

switch (dataType)
{
Expand Down Expand Up @@ -274,9 +267,9 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"mtp_prepare_drafter_inputs_op(Tensor inputIds, Tensor seqLens, Tensor "
"mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor previousLayerHiddenStates, "
"Tensor previousLayerDraftTokens, Tensor returnInputIds, Tensor returnHiddenStates, "
"int numMTPModules, int curMTPLayerIdx, int batchSize, int numContextRequest,"
"mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor hiddenStates, "
"Tensor acceptedTokens, Tensor numAcceptedTokens, Tensor returnInputIds, Tensor returnHiddenStates, "
"int numMTPModules, int batchSize, int numContextRequest,"
"int hiddenSize) -> (Tensor, Tensor)");
}

Expand Down Expand Up @@ -306,7 +299,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
{
m.def(
"mtp_update_hidden_states_op(Tensor inputIds, Tensor seqLens, Tensor targetModelHiddenStates, "
"Tensor mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor numAcceptedTokens, Tensor acceptedTokens, "
"Tensor mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor numAcceptedTokens, "
"int numMTPModules, int batchSize, int numContextRequest, int hiddenSize) -> (Tensor, Tensor)");
}

Expand Down
4 changes: 1 addition & 3 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -1068,15 +1068,13 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
ckpt_nextn = self.config.num_nextn_predict_layers
self.num_hidden_layers = self.config.num_hidden_layers
assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
if ckpt_nextn == 1:
if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla:
mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers,
self.model.aux_stream_dict)
self.model.layers.append(mtp_layer)
self.epilogue.append(mtp_layer)
self.mtp_worker = MTPEagleWorker(model_config.spec_config)
else:
# TODO: fix the accuracy issue and remove this assert.
assert False, "Cannot support num_nextn_predict_layers>1 in checkpoint now. Will fix it soon"
mtp_layers = nn.ModuleList([
DeepseekV3MTP(model_config,
layer_idx + self.num_hidden_layers,
Expand Down
1 change: 0 additions & 1 deletion tensorrt_llm/_torch/speculative/eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ def __post_init__(self):
else:
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
self.spec_dec_name)
self.num_extra_kv_tokens = 0
logger.info(f"EAGLE3 Config: {self}")

def update_from_model_config(self, model_config):
Expand Down
2 changes: 2 additions & 0 deletions tensorrt_llm/_torch/speculative/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ class SpecConfig:
max_draft_tokens: int = 1024
# The path to the draft model
draft_model_path: Optional[str] = None
# The number of extra kv tokens
num_extra_kv_tokens: int = 0

def __post_init__(self) -> None:
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
Expand Down
Loading