diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu b/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu index 469e64a159c..2e370a4900b 100644 --- a/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu +++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.cu @@ -63,27 +63,24 @@ __device__ void copyChunkedHiddenStates(T const* srcPtr, T* dstPtr, int const nu } template -__global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const curMTPLayerIdx, int const batchSize, - int const numContextRequest, int const hiddenSize, int const* inputIds, int const* seqLens, - T** const mtpPastHiddenStatesPtrs, int** const mtpPastTokensPtrs, T* const previousLayerHiddenStates, - int* const previousLayerDraftTokens, int* returnInputIds, T* returnHiddenStates) +__global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const numContextRequest, + int const hiddenSize, int const* inputIds, int const* seqLens, T** const mtpPastHiddenStatesPtrs, + int** const mtpPastTokensPtrs, T* const hiddenStates, int const* acceptedTokens, int const* numAcceptedTokens, + int* returnInputIds, T* returnHiddenStates) { /* In a batch of request: context request (at the beginning) + generation requests numGenerationRequest = batchSize - numContextRequest inputIds: [N] - - When curMTPLayerIdx == 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules - + 1) - - When curMTPLayerIdx > 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * numMTPModules + - N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules + 1) seqLens: [batchSize] mtpPastHiddenStatesPtrs: [maxNumRequests][numMTPModules, hiddenSize] mtpPastTokensPtrs: [maxNumRequests][numMTPModules] - previousLayerHiddenStates: [N, hiddenSize] - - When curMTPLayerIdx == 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules - + 1) (from target model) - - When curMTPLayerIdx > 0: N = sum(all numContextRequest's prompts) + numGenerationRequest * numMTPModules - previousLayerDraftTokens: [batchSize], the draft tokens generated by the previous layer + hiddenStates: [N, hiddenSize] + - N = sum(all numContextRequest's prompts) + numGenerationRequest * (numMTPModules + 1) (from target model) + acceptedTokens: [batchSize, numMTPModules + 1] + numAcceptedTokens: [batchSize] returnInputIds: [N] - N = sum(all numContextRequest's prompts) + numGenerationRequest * numMTPModules returnHiddenStates: [N, hiddenSize] @@ -94,6 +91,7 @@ __global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const T const* curMTPPastHiddenStatesPtr = mtpPastHiddenStatesPtrs[bid]; int const* curMTPPastTokensPtr = mtpPastTokensPtrs[bid]; + int const* curAcceptedTokensPtr = acceptedTokens + bid * (numMTPModules + 1); int curSeqLen = seqLens[bid]; @@ -117,63 +115,44 @@ __global__ void mtpPrepareDrafterInputsKernel(int const numMTPModules, int const } int const* curInputIdsPtr = inputIds + inputIdsStartOffset; - T const* curPreviousLayerHiddenStates = previousLayerHiddenStates + inputIdsStartOffset * hiddenSize; + T const* curHiddenStates = hiddenStates + inputIdsStartOffset * hiddenSize; int* curReturnInputIdsPtr = returnInputIds + returnInputIdsStartOffset; T* curReturnHiddenStatesIdsPtr = returnHiddenStates + returnInputIdsStartOffset * hiddenSize; //// main logic - - if (curMTPLayerIdx == 0) + if (bid < numContextRequest) { - if (bid < numContextRequest) + // context requests + if (tid == 0) { - // context requests - if (tid == 0) + // 1) For the new inputIds + for (int ii = 0; ii < curSeqLen - 1; ii++) { - // 1) For the new inputIds - for (int ii = 0; ii < curSeqLen - 1; ii++) - { - curReturnInputIdsPtr[ii] = curInputIdsPtr[ii + 1]; // +1 because of offset 1, prompt[1:] - } - // Append the latest golden token, i.e., the last one in the past tokens list - curReturnInputIdsPtr[curSeqLen - 1] = curMTPPastTokensPtr[numMTPModules - 1]; + curReturnInputIdsPtr[ii] = curInputIdsPtr[ii + 1]; // +1 because of offset 1, prompt[1:] } - - // 2) For the new past hidden states - copyChunkedHiddenStates(curPreviousLayerHiddenStates, curReturnHiddenStatesIdsPtr, curSeqLen * hiddenSize); + // Append the latest golden token, i.e., the first one in the accepted tokens list + curReturnInputIdsPtr[curSeqLen - 1] = curAcceptedTokensPtr[0]; } - else - { - // generation requests - if (tid == 0) - { - // 1) For the new inputIds - for (int ii = 0; ii < numMTPModules; ii++) - { - curReturnInputIdsPtr[ii] = curMTPPastTokensPtr[ii]; - } - } - // 2) For the new past hidden states - copyChunkedHiddenStates(curMTPPastHiddenStatesPtr, curReturnHiddenStatesIdsPtr, numMTPModules * hiddenSize); - } + // 2) For the new past hidden states + copyChunkedHiddenStates(curHiddenStates, curReturnHiddenStatesIdsPtr, curSeqLen * hiddenSize); } - else // For curMTPLayerIdx > 0 + else { + // generation requests if (tid == 0) { // 1) For the new inputIds - int numPastTokens = (bid < numContextRequest) ? curSeqLen : numMTPModules; - for (int ii = 0; ii < numPastTokens; ii++) + for (int ii = 0; ii < numMTPModules - 1; ii++) { - curReturnInputIdsPtr[ii] = curInputIdsPtr[ii + 1]; + curReturnInputIdsPtr[ii] = curMTPPastTokensPtr[ii + 1]; } - curReturnInputIdsPtr[numPastTokens - 1] = previousLayerDraftTokens[bid]; + curReturnInputIdsPtr[numMTPModules - 1] = curAcceptedTokensPtr[numAcceptedTokens[bid] - 1]; } // 2) For the new past hidden states - // Directly use previous layer's output hidden states + copyChunkedHiddenStates(curMTPPastHiddenStatesPtr, curReturnHiddenStatesIdsPtr, numMTPModules * hiddenSize); } } @@ -185,10 +164,10 @@ void invokeMTPPrepareDrafterInputs(MTPPrepareDrafterInputsParam& params, cudaStr params.hiddenSize * sizeof(T) % 16 == 0); // Which is because we will use float4 to copy the hidden states. mtpPrepareDrafterInputsKernel<<>>(params.numMTPModules, - params.curMTPLayerIdx, params.batchSize, params.numContextRequest, params.hiddenSize, params.inputIds, - params.seqLens, reinterpret_cast(params.mtpPastHiddenStatesPtrs), params.mtpPastTokensPtrs, - reinterpret_cast(params.previousLayerHiddenStates), params.previousLayerDraftTokens, params.returnInputIds, - reinterpret_cast(params.returnHiddenStates)); + params.numContextRequest, params.hiddenSize, params.inputIds, params.seqLens, + reinterpret_cast(params.mtpPastHiddenStatesPtrs), params.mtpPastTokensPtrs, + reinterpret_cast(params.hiddenStates), params.acceptedTokens, params.numAcceptedTokens, + params.returnInputIds, reinterpret_cast(params.returnHiddenStates)); sync_check_cuda_error(stream); } @@ -362,7 +341,7 @@ template void invokeMTPSampleAndAcceptDraftTokens<__nv_bfloat16>( template __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const batchSize, int const numContextRequest, int const hiddenSize, int const* inputIds, int const* seqLens, T* targetModelHiddenStates, - T** mtpPastHiddenStatesPtrs, int** mtpPastTokensPtrs, int const* numAcceptedTokens, int const* acceptedTokens) + T** mtpPastHiddenStatesPtrs, int** mtpPastTokensPtrs, int const* numAcceptedTokens) { /* In a batch of request: context request (at the beginning) + generation requests @@ -374,7 +353,6 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b mtpPastHiddenStatesPtrs: [maxNumRequests][numMTPModules, hiddenSize] mtpPastTokensPtrs: [maxNumRequests][numMTPModules] numAcceptedTokens: [batchSize] - acceptedTokens: [batchSize][numMTPModules + 1], flatten */ int const bid = static_cast(blockIdx.x); // Each block is responsible for a request. @@ -395,7 +373,6 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b auto curInputIdsPtr = inputIds + inputIdsStartOffset; auto curTargetModelHiddenStatesPtr = targetModelHiddenStates + inputIdsStartOffset * hiddenSize; - auto curAcceptedTokensPtr = acceptedTokens + bid * (numMTPModules + 1); // Update MTP tokens // Just use one thread to execute this copy @@ -405,12 +382,10 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b { // Context request // Copy the end of prompt tokens - for (int ii = 0; ii < numMTPModules - 1; ii++) + for (int ii = 0; ii < numMTPModules; ii++) { - curMTPPastTokensPtr[ii] = curInputIdsPtr[curSeqLen - numMTPModules + 1 + ii]; + curMTPPastTokensPtr[ii] = curInputIdsPtr[curSeqLen - numMTPModules + ii]; } - // Copy the new generated golden token - curMTPPastTokensPtr[numMTPModules - 1] = curAcceptedTokensPtr[0]; } else { @@ -424,7 +399,7 @@ __global__ void mtpUpdateHiddenStatesKernel(int const numMTPModules, int const b int acceptedTokenStartIdx = max(0, curAcceptedLen - numMTPModules); for (; ii < numMTPModules; ii++, acceptedTokenStartIdx++) { - curMTPPastTokensPtr[ii] = curAcceptedTokensPtr[acceptedTokenStartIdx]; + curMTPPastTokensPtr[ii] = curInputIdsPtr[acceptedTokenStartIdx]; } } } @@ -463,7 +438,7 @@ void invokeMTPUpdateHiddenStates(MTPUpdateHiddenStatesParam& params, cudaStream_ mtpUpdateHiddenStatesKernel<<>>(params.numMTPModules, params.batchSize, params.numContextRequest, params.hiddenSize, params.inputIds, params.seqLens, reinterpret_cast(params.targetModelHiddenStates), reinterpret_cast(params.mtpPastHiddenStatesPtrs), - params.mtpPastTokensPtrs, params.numAcceptedTokens, params.acceptedTokens); + params.mtpPastTokensPtrs, params.numAcceptedTokens); sync_check_cuda_error(stream); } diff --git a/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h b/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h index a930f12356b..e19908101f1 100644 --- a/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h +++ b/cpp/tensorrt_llm/kernels/speculativeDecoding/mtpKernels.h @@ -34,7 +34,6 @@ namespace kernels struct MTPPrepareDrafterInputsParam { int numMTPModules; - int curMTPLayerIdx; int batchSize; int numContextRequest; int hiddenSize; @@ -42,8 +41,9 @@ struct MTPPrepareDrafterInputsParam int* seqLens; void** __restrict__ mtpPastHiddenStatesPtrs; int** mtpPastTokensPtrs; - void* __restrict__ previousLayerHiddenStates; - int* previousLayerDraftTokens; + void* __restrict__ hiddenStates; + int* acceptedTokens; + int* numAcceptedTokens; int* returnInputIds; void* __restrict__ returnHiddenStates; }; diff --git a/cpp/tensorrt_llm/thop/mtpOp.cpp b/cpp/tensorrt_llm/thop/mtpOp.cpp index d16efedef64..4f926cfe6f6 100644 --- a/cpp/tensorrt_llm/thop/mtpOp.cpp +++ b/cpp/tensorrt_llm/thop/mtpOp.cpp @@ -29,26 +29,26 @@ namespace torch_ext //////////////////////////////////////////////////////////////////////////////////////////////////////////// std::tuple mtp_prepare_drafter_inputs_op(th::Tensor& inputIds, th::Tensor& seqLens, - th::Tensor& mtpPastHiddenStatesPtrs, th::Tensor& mtpPastTokensPtrs, th::Tensor& previousLayerHiddenStates, - th::Tensor& previousLayerDraftTokens, th::Tensor& returnInputIds, th::Tensor& returnHiddenStates, - int64_t numMTPModules, int64_t curMTPLayerIdx, int64_t batchSize, int64_t numContextRequest, int64_t hiddenSize) + th::Tensor& mtpPastHiddenStatesPtrs, th::Tensor& mtpPastTokensPtrs, th::Tensor& hiddenStates, + th::Tensor& acceptedTokens, th::Tensor& numAcceptedTokens, th::Tensor& returnInputIds, + th::Tensor& returnHiddenStates, int64_t numMTPModules, int64_t batchSize, int64_t numContextRequest, + int64_t hiddenSize) { - auto dataType = previousLayerHiddenStates.scalar_type(); + auto dataType = hiddenStates.scalar_type(); // Check auto inputIdsSizes = inputIds.sizes(); - auto previousLayerHiddenStatesSizes = previousLayerHiddenStates.sizes(); - TLLM_CHECK(inputIdsSizes[0] == previousLayerHiddenStatesSizes[0]); + auto hiddenStatesSizes = hiddenStates.sizes(); + TLLM_CHECK(inputIdsSizes[0] == hiddenStatesSizes[0]); auto seqLensSizes = seqLens.sizes(); TLLM_CHECK(seqLensSizes[0] == batchSize); - auto stream = at::cuda::getCurrentCUDAStream(previousLayerHiddenStates.get_device()); + auto stream = at::cuda::getCurrentCUDAStream(hiddenStates.get_device()); // Fill params tk::MTPPrepareDrafterInputsParam params; params.numMTPModules = numMTPModules; - params.curMTPLayerIdx = curMTPLayerIdx; params.batchSize = batchSize; params.numContextRequest = numContextRequest; params.hiddenSize = hiddenSize; @@ -56,8 +56,9 @@ std::tuple mtp_prepare_drafter_inputs_op(th::Tensor& inp params.seqLens = reinterpret_cast(seqLens.data_ptr()); params.mtpPastHiddenStatesPtrs = reinterpret_cast(mtpPastHiddenStatesPtrs.data_ptr()); params.mtpPastTokensPtrs = reinterpret_cast(mtpPastTokensPtrs.data_ptr()); - params.previousLayerHiddenStates = reinterpret_cast(previousLayerHiddenStates.data_ptr()); - params.previousLayerDraftTokens = reinterpret_cast(previousLayerDraftTokens.data_ptr()); + params.hiddenStates = reinterpret_cast(hiddenStates.data_ptr()); + params.acceptedTokens = reinterpret_cast(acceptedTokens.data_ptr()); + params.numAcceptedTokens = reinterpret_cast(numAcceptedTokens.data_ptr()); params.returnInputIds = reinterpret_cast(returnInputIds.data_ptr()); params.returnHiddenStates = reinterpret_cast(returnHiddenStates.data_ptr()); @@ -81,14 +82,7 @@ std::tuple mtp_prepare_drafter_inputs_op(th::Tensor& inp break; } - if (curMTPLayerIdx > 0) - { - return std::make_tuple(returnInputIds, previousLayerHiddenStates); - } - else - { - return std::make_tuple(returnInputIds, returnHiddenStates); - } + return std::make_tuple(returnInputIds, returnHiddenStates); } //////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -151,8 +145,8 @@ std::tuple mtp_sampling_and_accepted_draft_tokens_op(th: //////////////////////////////////////////////////////////////////////////////////////////////////////////// std::tuple mtp_update_hidden_states_op(th::Tensor& inputIds, th::Tensor& seqLens, th::Tensor& targetModelHiddenStates, th::Tensor& mtpPastHiddenStatesPtrs, th::Tensor& mtpPastTokensPtrs, - th::Tensor& numAcceptedTokens, th::Tensor& acceptedTokens, int64_t numMTPModules, int64_t batchSize, - int64_t numContextRequest, int64_t hiddenSize) + th::Tensor& numAcceptedTokens, int64_t numMTPModules, int64_t batchSize, int64_t numContextRequest, + int64_t hiddenSize) { auto dataType = targetModelHiddenStates.scalar_type(); @@ -178,7 +172,6 @@ std::tuple mtp_update_hidden_states_op(th::Tensor& input params.mtpPastHiddenStatesPtrs = reinterpret_cast(mtpPastHiddenStatesPtrs.data_ptr()); params.mtpPastTokensPtrs = reinterpret_cast(mtpPastTokensPtrs.data_ptr()); params.numAcceptedTokens = reinterpret_cast(numAcceptedTokens.data_ptr()); - params.acceptedTokens = reinterpret_cast(acceptedTokens.data_ptr()); switch (dataType) { @@ -274,9 +267,9 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "mtp_prepare_drafter_inputs_op(Tensor inputIds, Tensor seqLens, Tensor " - "mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor previousLayerHiddenStates, " - "Tensor previousLayerDraftTokens, Tensor returnInputIds, Tensor returnHiddenStates, " - "int numMTPModules, int curMTPLayerIdx, int batchSize, int numContextRequest," + "mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor hiddenStates, " + "Tensor acceptedTokens, Tensor numAcceptedTokens, Tensor returnInputIds, Tensor returnHiddenStates, " + "int numMTPModules, int batchSize, int numContextRequest," "int hiddenSize) -> (Tensor, Tensor)"); } @@ -306,7 +299,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m) { m.def( "mtp_update_hidden_states_op(Tensor inputIds, Tensor seqLens, Tensor targetModelHiddenStates, " - "Tensor mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor numAcceptedTokens, Tensor acceptedTokens, " + "Tensor mtpPastHiddenStatesPtrs, Tensor mtpPastTokensPtrs, Tensor numAcceptedTokens, " "int numMTPModules, int batchSize, int numContextRequest, int hiddenSize) -> (Tensor, Tensor)"); } diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index d8fdabb8e62..f66663e9203 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -1068,15 +1068,13 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): ckpt_nextn = self.config.num_nextn_predict_layers self.num_hidden_layers = self.config.num_hidden_layers assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint." - if ckpt_nextn == 1: + if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla: mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers, self.model.aux_stream_dict) self.model.layers.append(mtp_layer) self.epilogue.append(mtp_layer) self.mtp_worker = MTPEagleWorker(model_config.spec_config) else: - # TODO: fix the accuracy issue and remove this assert. - assert False, "Cannot support num_nextn_predict_layers>1 in checkpoint now. Will fix it soon" mtp_layers = nn.ModuleList([ DeepseekV3MTP(model_config, layer_idx + self.num_hidden_layers, diff --git a/tensorrt_llm/_torch/speculative/eagle3.py b/tensorrt_llm/_torch/speculative/eagle3.py index a76a6792f82..e6183cc1528 100644 --- a/tensorrt_llm/_torch/speculative/eagle3.py +++ b/tensorrt_llm/_torch/speculative/eagle3.py @@ -33,7 +33,6 @@ def __post_init__(self): else: self.spec_dec_mode = SpeculativeDecodingMode.from_string( self.spec_dec_name) - self.num_extra_kv_tokens = 0 logger.info(f"EAGLE3 Config: {self}") def update_from_model_config(self, model_config): diff --git a/tensorrt_llm/_torch/speculative/interface.py b/tensorrt_llm/_torch/speculative/interface.py index 85630335b7f..7c7a5318ed7 100644 --- a/tensorrt_llm/_torch/speculative/interface.py +++ b/tensorrt_llm/_torch/speculative/interface.py @@ -105,6 +105,8 @@ class SpecConfig: max_draft_tokens: int = 1024 # The path to the draft model draft_model_path: Optional[str] = None + # The number of extra kv tokens + num_extra_kv_tokens: int = 0 def __post_init__(self) -> None: self.spec_dec_mode = SpeculativeDecodingMode.from_string( diff --git a/tensorrt_llm/_torch/speculative/mtp.py b/tensorrt_llm/_torch/speculative/mtp.py index 2b6c580411b..25edbdae363 100644 --- a/tensorrt_llm/_torch/speculative/mtp.py +++ b/tensorrt_llm/_torch/speculative/mtp.py @@ -46,6 +46,9 @@ class MTPConfig(SpecConfig): # Filter out tokens with a large probability gap between the top-1 token's log probability. relaxed_delta: float = 0. + # Whether to use vanilla MTP + use_mtp_vanilla: bool = False + # TODO: Hard code for DeepSeek R1 # When encounter , start thinking phase. # When encounter , end thinking phase. @@ -60,7 +63,7 @@ def __post_init__(self) -> None: def update_from_model_config(self, model_config): assert self.num_nextn_predict_layers > 0 - if model_config.num_nextn_predict_layers == 1: + if model_config.num_nextn_predict_layers == 1 and not self.use_mtp_vanilla: self.spec_dec_mode = SpeculativeDecodingMode.MTP_EAGLE self.num_extra_kv_tokens = self.num_nextn_predict_layers - 1 @@ -172,6 +175,10 @@ def __post_init__(self) -> None: dtype=torch.int, device='cuda', ) + self.draft_token_indices_cuda = torch.arange( + self.mtp_num_modules, + device='cuda', + ) @property def all_rank_num_seqs(self): @@ -348,7 +355,7 @@ class MTPWorker(nn.Module): def __init__(self, spec_config: MTPConfig): super().__init__() self.spec_config = spec_config - self.is_thop = True + self.is_thop = False def forward( self, @@ -362,6 +369,106 @@ def forward( spec_metadata, mtp_layers, ): + ''' + Example: + Assume there are 3 MTP layers + Notation: + - H_t: token t's hidden state, generated by the target model + - h_t: token t's hidden state, generated by the draft model + + Prompt: ABCD + + Context phase: + Target model: + - input tokens: ABCD + [] + - sampling tokens: E + - accepted tokens: E + - KV cache: ABCD + - hidden states: H_A, H_B, H_C, H_D + Draft model: + MTP1: + # For context request, prompt[1:] + new generated goloden token is the input. + - input tokens: BCDE + - input hidden states: H_A, H_B, H_C, H_D + # '()' means historical KV cache + - KV cache: () + BCDE + - output hidden states: h_B, h_C, h_D, h_E + - output next draft token: F + MTP2: + - input token: CDEF + - input hidden states: H_B, H_C, H_D, h_E + - KV cache: () + CDEF + - output hidden states: h_C, h_D, h_E, h_F + - output next draft token: G + MTP3: + - input tokens: DEFG + - input hidden states: H_C, H_D, h_E, h_F + - KV cache: () + DEFG + - output hidden states: h_D, h_E, h_F, h_G + - output next draft token: H + After 3 MTP layers: + - new generated draft tokens: FGH + + Generation phase 1: accept partial draft tokens + Target model: + - input tokens: E + FGH + - sampling tokens: FGXY + - accepted tokens: FGX + - KV cache: (ABCD) + EFGH (H's KV cache is invalid) + - hidden states: H_E, H_F, H_G, H_H (H_H is invalid) + Draft model: + MPT1: + # For generation request, `mtp_num_modules` of tokens will be used as input. + - input tokens: FGX + - input hidden states: H_E, H_F, H_G + - KV cache: (BCDE) + FGX + - output hidden states: h_F, h_G, h_X + - output next draft token: N + MPT2: + - input tokens: GXN + - input hidden states: H_F, H_G, h_X + - KV cache: (CDEF) + GXN + - output hidden states: h_G, h_X, h_N + - output next draft token: O + MPT3: + - input tokens: XNO + - input hidden states: H_G, H_X, h_N + - KV cache: (DEFG) + XNO + - output hidden states: h_X, h_N, h_O + - output next draft token: P + After 3 MTP layers: + - new generated draft tokens: NOP + + Generation 2: accept none draft tokens + Target model: + - input tokens: X + NOP + - sampling tokens: KMZY + - accepted tokens: K + - KV cache: (ABCDEFG) + NOP (NOP's KV cache is invalid) + - hidden states: H_X, H_N, H_O, H_P (H_N, H_O, H_P is invalid) + Draft model: + MTP1: + - input tokens: GXK + - input hidden states: H_F, H_G, H_X + - KV cache: (BCDE + F) + GXK + - output hidden states: h_G, h_X, h_K + - output next draft token: U + MTP2: + - input tokens: XKU + - input hidden states: H_G, H_X, h_K + - KV cache: (CDEF + G) + XKU + - output hidden states: h_X, h_K, h_U + - output next draft token: V + MTP3: + - input tokens: KUV + - input hidden states: H_X, h_K, h_U + - KV cache: (DEFG + X) + KUV + - output hidden states: h_K, h_U, h_V + - output next draft token: Q + After 3 MTP layers: + - new generated draft tokens: UVQ + ''' + batch_size = attn_metadata.num_seqs # Sample and verify draft tokens @@ -371,38 +478,55 @@ def forward( # Update MTP past hidden states self.update_mtp_hidden_states(input_ids=input_ids, - target_model_hidden_states=hidden_states, + hidden_states=hidden_states, num_accepted_tokens=num_accepted_tokens, - accepted_tokens=accepted_tokens, spec_metadata=spec_metadata, attn_metadata=attn_metadata) - # Predict draft tokens + # prepare draft layer inputs + position_ids = position_ids.squeeze(0) + draft_inputs = self.prepare_drafter_inputs( + input_ids=input_ids, + position_ids=position_ids, + hidden_states=hidden_states, + accepted_tokens=accepted_tokens, + num_accepted_tokens=num_accepted_tokens, + spec_metadata=spec_metadata, + attn_metadata=attn_metadata) + + # update attn metadata + if attn_metadata is not None: + self.change_attn_metadata(num_accepted_tokens, attn_metadata) + draft_inputs.update(attn_metadata=attn_metadata) + + # Run MTP layers to predict draft tokens next_draft_tokens = [] - # will not be used in the first MTP, just a placeholder to avoid Nonetype - previous_layer_draft_tokens = torch.empty(1, - dtype=torch.int, - device='cpu') - for mtp_layer_idx, mtp_layer in enumerate(mtp_layers): - draft_inputs = self.prepare_drafter_inputs( - mtp_layer_idx=mtp_layer_idx, - input_ids=input_ids, - position_ids=position_ids, - previous_layer_hidden_states=hidden_states, - previous_layer_draft_tokens=previous_layer_draft_tokens, - num_accepted_tokens=num_accepted_tokens, - spec_metadata=spec_metadata, - attn_metadata=attn_metadata) + last_tokens_idx = torch.cumsum( + attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 + for _, mtp_layer in enumerate(mtp_layers): hidden_states = mtp_layer(embed_tokens=embed_tokens, **draft_inputs) logits = mtp_layer.shared_head(hidden_states, lm_head, attn_metadata).float() - previous_layer_draft_tokens = self.draft_sampler(logits) - next_draft_tokens.append(previous_layer_draft_tokens) - + new_draft_token = self.draft_sampler(logits) + next_draft_tokens.append(new_draft_token) + # shift input_ids and hidden_states input_ids = draft_inputs["input_ids"] - position_ids = draft_inputs["position_ids"] - attn_metadata = draft_inputs["attn_metadata"] + input_ids[:-1] = input_ids[1:].clone() + input_ids[last_tokens_idx] = new_draft_token + draft_hidden_states = draft_inputs["hidden_states"] + draft_hidden_states[:-1] = draft_hidden_states[1:].clone() + draft_hidden_states[last_tokens_idx] = hidden_states[ + last_tokens_idx, :] + draft_inputs = { + "input_ids": input_ids, + "position_ids": draft_inputs["position_ids"], + "hidden_states": draft_hidden_states, + "attn_metadata": draft_inputs["attn_metadata"], + "spec_metadata": draft_inputs["spec_metadata"], + } next_draft_tokens = torch.stack(next_draft_tokens, dim=1) + + # restore attn metadata if attn_metadata.is_cuda_graph and attn_metadata is not None: self.restore_attn_metadata(attn_metadata=attn_metadata) @@ -458,9 +582,8 @@ def skip_forward( def update_mtp_hidden_states( self, input_ids: torch.IntTensor, - target_model_hidden_states: torch.Tensor, + hidden_states: torch.Tensor, num_accepted_tokens: torch.Tensor, - accepted_tokens: torch.Tensor, spec_metadata: MTPSpecMetadata, attn_metadata: AttentionMetadata, ): @@ -468,14 +591,13 @@ def update_mtp_hidden_states( Update the past hidden states and past tokens in spec_metadata base on the newly accepted tokens and historical hidden states. These past hidden states and past tokens will be use in MTP module. - Also update the seq_len and kv_lens in attention metadata. Args: input_ids: torch.IntTensor [num_tokens] The input ids of all requests. Flatten. - target_model_hidden_states: torch.Tensor + hidden_states: torch.Tensor [num_tokens, hidden_size] Target model's hidden states. @@ -483,10 +605,6 @@ def update_mtp_hidden_states( [batch_size] Number of accepted tokens per request. - accepted_tokens: torch.Tensor - [batch_size, max_draft_tokens + 1] - Accepted token ids. Flattened. - spec_metadata: MTPSpecMetadata MTP speculative decoding metadata @@ -495,173 +613,104 @@ def update_mtp_hidden_states( Returns: None + ''' - Example: - Assume there are 3 MTP layers - Notation: - - H_t: token t's hidden state, generated by the target model - - h_t: token t's hidden state, generated by the draft model - - Prompt: ABCD - - Context phase: - Target model: - - input tokens: ABCD + [] - - sampling tokens: E - - accepted tokens: E - - KV cache: ABCD - - hidden states: H_A, H_B, H_C, H_D - Draft model: - MTP1: - - input tokens: BCDE # For context request, prompt[2: -1] + new generated goloden token is the input. - - input hidden states: H_A, H_B, H_C, H_D # Therefore, the input and output tokens/hidden_states will have the dimension of `prompt_length`. - - KV cache: () + BCDE '()' means historical KV cache - - output hidden states: h_B, h_C, h_D, h_E - - output next draft token: F - MTP2: - - input token: CDEF # Also with the `prompt_length`. - - input hidden states: h_B, h_C, h_D, h_E - - KV cache: () + CDEF - - output hidden states: h_C, h_D, h_E, h_F - - output next draft token: G - MTP3: - - input tokens: DEFG - - input hidden states: h_C, h_D, h_E, h_F - - KV cache: () + DEFG - - output hidden states: h_D, h_E, h_F, h_G - - output next draft token: H - After 3 MTP layers: - - input tokens: BCDE - - new generated draft tokens: FGH + def unpack_sequence(packed_seq, seq_len): + max_length = seq_len.max() + num_sequences = seq_len.shape[0] + # initialize a zero tensor to store the result + result = torch.zeros( + (num_sequences, max_length, packed_seq.shape[1]), + dtype=packed_seq.dtype, + device=packed_seq.device) + # get mask + seq_indices = torch.arange( + max_length, + device=seq_len.device).unsqueeze(0).expand(num_sequences, -1) + mask = seq_indices < seq_len.unsqueeze(1) + # unpack + result[mask] = packed_seq + return result - Generation phase 1: accept partial draft tokens - Target model: - - input tokens: E + FGH - - sampling tokens: FGXY - - accepted tokens: FGX - - KV cache: (ABCD) + EFGH (H's KV cache is useless) - - hidden states: H_E, H_F, H_G, H_H (H_H is useless) - Draft model: - MPT1: - - input tokens: FGX # For generation request, `mtp_num_modules` + 1 of tokens will be used as input. - - input hidden states: H_E, H_F, H_G - - KV cache: (BCDE) + FGX - - output hidden states: h_F, h_G, h_X - - output next draft token: N - MPT2: - - input tokens: GXN - - input hidden states: h_F, h_G, h_X - - KV cache: (CDEF) + GXN - - output hidden states: h_G, h_X, h_N - - output next draft token: O - MPT3: - - input tokens: XNO - - input hidden states: h_G, h_X, h_N - - KV cache: (DEFG) + XNO - - output hidden states: h_X, h_N, h_O - - output next draft token: P - After 3 MTP layers: - - input tokens: FGX - - new generated draft tokens: NOP - - Generation 2: accept none draft tokens - Target model: - - input tokens: X + NOP - - sampling tokens: KMZY - - accepted tokens: K - - KV cache: (ABCDEFG) + NOP (NOP's KV cache is useless) - - hidden states: H_X, H_N, H_O, H_P (H_N, H_O, H_P is useless) - Draft model: - MTP1: - - input tokens: GXK - - input hidden states: H_F, H_G, H_X - - KV cache: (BCDE + FGX) + FGX - - output hidden states: h_G, h_X, h_K - - output next draft token: U - MTP2: - - input tokens: XKU - - input hidden states: h_G, h_X, h_K - - KV cache: (CDEF + GXN) + XKU - - output hidden states: h_X, h_K, h_U - - output next draft token: V - MTP3: - - input tokens: KUV - - input hidden states: h_X, h_K, h_U - - KV cache: (DEFG + XNO) + KUV - - output hidden states: h_K, h_U, h_V - - output next draft token: Q - After 3 MTP layers: - - input tokens: GXK - - new generated draft tokens: UVQ - ''' batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts + num_ctx_tokens = attn_metadata.num_ctx_tokens + num_gens = batch_size - num_contexts seq_lens = attn_metadata.seq_lens_cuda - hidden_size = target_model_hidden_states.shape[-1] + hidden_size = hidden_states.shape[-1] mtp_num_modules = self.spec_config.num_nextn_predict_layers if self.is_thop: _, _ = torch.ops.trtllm.mtp_update_hidden_states_op( - input_ids, seq_lens, target_model_hidden_states, + input_ids, seq_lens, hidden_states, spec_metadata.mtp_hidden_states_ptrs, spec_metadata.mtp_past_tokens_ptrs, num_accepted_tokens, - accepted_tokens, mtp_num_modules, batch_size, num_contexts, - hidden_size) + mtp_num_modules, batch_size, num_contexts, hidden_size) else: assert len(spec_metadata.request_ids) == batch_size mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool mtp_past_tokens_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool - input_ids_offset = 0 - - for bix in range(batch_size): - slot_id_cuda = spec_metadata.slot_ids[bix] - # [num_nextn_predict_layers - 1, hidden_states] - mtp_hidden_states = torch.index_select( - mtp_past_hidden_states_pool, 0, slot_id_cuda).squeeze(0) - # [num_nextn_predict_layers] - mtp_tokens = torch.index_select(mtp_past_tokens_pool, 0, - slot_id_cuda).squeeze(0) - cur_accepted_len = num_accepted_tokens[bix] - cur_seq_len = seq_lens[bix] - - # Update MTP tokens - cur_accepted_tokens = accepted_tokens[bix, 0:cur_accepted_len] - if bix < num_contexts: - past_input_ids = input_ids[ - input_ids_offset:input_ids_offset + cur_seq_len] - else: - past_input_ids = mtp_tokens - - cat_past_tokens = torch.cat( - (past_input_ids, cur_accepted_tokens), dim=0) - # shape: [mtp_num_modules] - new_mtp_past_tokens = cat_past_tokens[ - -mtp_num_modules:, - ] - # Update the buffer, but keep the pointer unchanged. - mtp_past_tokens_pool.index_copy_( - 0, slot_id_cuda, new_mtp_past_tokens.unsqueeze(0)) - - # Update MTP hidden states - past_hidden_states = mtp_hidden_states - # For context, we need to slice prompt length - # For generation, we only need to slice accepted length - num_slice_tokens = cur_seq_len if bix < num_contexts else cur_accepted_len - accepted_tokens_hidden_states = target_model_hidden_states[ - input_ids_offset:input_ids_offset + num_slice_tokens] - cat_hidden_states = torch.cat( - (past_hidden_states, accepted_tokens_hidden_states), dim=0) - # shape: [mtp_num_modules, hidden_states] - new_mtp_hidden_states = cat_hidden_states[ - -mtp_num_modules:, - ] - # Update the buffer, but keep the pointer unchanged. - mtp_past_hidden_states_pool.index_copy_( - 0, slot_id_cuda, new_mtp_hidden_states.unsqueeze(0)) - - # Update offset - input_ids_offset += cur_seq_len + slot_ids = spec_metadata.slot_ids[:batch_size] + mtp_tokens = mtp_past_tokens_pool[slot_ids] + mtp_hidden_states = mtp_past_hidden_states_pool[slot_ids] + + new_mtp_past_tokens, new_mtp_past_hidden_states = [], [] + # context + if num_contexts > 0: + seq_lens_ctx = seq_lens[:num_contexts] + unpacked_input_ids_ctx = unpack_sequence( + input_ids[:num_ctx_tokens].unsqueeze(1), + seq_lens_ctx).squeeze(2) + unpacked_hidden_states_ctx = unpack_sequence( + hidden_states[:num_ctx_tokens], seq_lens_ctx) + cat_tokens_ctx = torch.cat( + (mtp_tokens[:num_contexts], unpacked_input_ids_ctx), dim=1) + cat_hidden_states_ctx = torch.cat( + (mtp_hidden_states[:num_contexts], + unpacked_hidden_states_ctx), + dim=1) + ctx_batch_idx = spec_metadata.batch_indices_cuda[:num_contexts] + row_indices_ctx = ctx_batch_idx.unsqueeze(1).expand( + -1, mtp_num_modules) + col_indices_ctx = (seq_lens_ctx.unsqueeze(1) + + spec_metadata.draft_token_indices_cuda) + new_mtp_past_tokens.append(cat_tokens_ctx[row_indices_ctx, + col_indices_ctx]) + new_mtp_past_hidden_states.append( + cat_hidden_states_ctx[row_indices_ctx, col_indices_ctx, :]) + + # generation + if num_gens > 0: + unpacked_input_ids_gen = input_ids[num_ctx_tokens:].reshape( + num_gens, mtp_num_modules + 1).int() + hidden_states_gen = hidden_states[num_ctx_tokens:, :] + unpacked_hidden_states_gen = hidden_states_gen.reshape( + num_gens, mtp_num_modules + 1, hidden_size) + cat_tokens_gen = torch.cat( + (mtp_tokens[num_contexts:], unpacked_input_ids_gen), dim=1) + cat_hidden_states_gen = torch.cat( + (mtp_hidden_states[num_contexts:], + unpacked_hidden_states_gen), + dim=1) + gen_batch_idx = spec_metadata.batch_indices_cuda[:num_gens] + row_indices_gen = gen_batch_idx.unsqueeze(1).expand( + -1, mtp_num_modules) + col_indices_gen = ( + num_accepted_tokens[num_contexts:].unsqueeze(1) + + spec_metadata.draft_token_indices_cuda) + new_mtp_past_tokens.append(cat_tokens_gen[row_indices_gen, + col_indices_gen]) + new_mtp_past_hidden_states.append( + cat_hidden_states_gen[row_indices_gen, col_indices_gen, :]) + + # update past tokens and hidden states + new_mtp_past_tokens = torch.cat(new_mtp_past_tokens, dim=0) + new_mtp_past_hidden_states = torch.cat(new_mtp_past_hidden_states, + dim=0) + mtp_past_tokens_pool.index_copy_(0, slot_ids, new_mtp_past_tokens) + mtp_past_hidden_states_pool.index_copy_(0, slot_ids, + new_mtp_past_hidden_states) def sample_and_accept_draft_tokens( self, @@ -867,11 +916,10 @@ def restore_attn_metadata(self, attn_metadata: AttentionMetadata): def prepare_drafter_inputs( self, - mtp_layer_idx: int, input_ids: torch.IntTensor, position_ids: torch.IntTensor, - previous_layer_hidden_states: torch.Tensor, - previous_layer_draft_tokens: torch.Tensor, + hidden_states: torch.Tensor, + accepted_tokens: torch.Tensor, num_accepted_tokens: torch.Tensor, spec_metadata: MTPSpecMetadata, attn_metadata: AttentionMetadata, @@ -880,26 +928,22 @@ def prepare_drafter_inputs( Parepare the input of the draft model. Args: - mtp_layer_idx: int - The index number of the current MTP layer. - input_ids: torch.IntTensor [num_tokens] The input ids of all requests. Flatten. - When mtp_layer_idx == 0: num_tokens = sum(all prompts) + num_generation * (mtp_num_modules + 1) - When mtp_layer_idx > 0: num_tokens = sum(all prompts) + num_generation * (mtp_num_modules) + num_tokens = sum(all prompts) + num_generation * (mtp_num_modules + 1) position_ids: torch.IntTensor [1][num_tokens] The position id of all requests. Flatten. - previous_layer_hidden_states: torch.Tensor + hidden_states: torch.Tensor [num_tokens, hidden_size] Target model's hidden states. - previous_layer_draft_tokens: torch.Tensor - [batch_size] - Privous layer's draft tokens. + accepted_tokens: torch.Tensor + [batch_size, max_draft_tokens + 1] + Accepted token ids. Flattened. num_accepted_tokens: torch.Tensor [batch_size] @@ -929,9 +973,13 @@ def prepare_drafter_inputs( attn_metadata: AttentionMetadata Attention metadata + spec_metadata: MTPSpecMetadata + MTP speculative decoding metadata + ''' batch_size = attn_metadata.num_seqs num_contexts = attn_metadata.num_contexts + num_ctx_tokens = attn_metadata.num_ctx_tokens num_gens = batch_size - num_contexts mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool mtp_past_tokens_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool @@ -939,105 +987,81 @@ def prepare_drafter_inputs( if self.is_thop: # Temporary buffer - hidden_size = previous_layer_hidden_states.shape[-1] + hidden_size = hidden_states.shape[-1] # generation requests' golden tokens - num_tokens = input_ids.shape[ - 0] - num_gens if mtp_layer_idx == 0 else input_ids.shape[0] + num_tokens = input_ids.shape[0] - num_gens return_input_ids = torch.empty(num_tokens, dtype=torch.int, device="cuda") - if (mtp_layer_idx == 0): - return_hidden_states = torch.empty( - (num_tokens, hidden_size), - dtype=previous_layer_hidden_states.dtype, - device="cuda") - else: - return_hidden_states = torch.empty( - 1, dtype=previous_layer_hidden_states.dtype, - device="cuda") # Useless, placeholder + return_hidden_states = torch.empty((num_tokens, hidden_size), + dtype=hidden_states.dtype, + device="cuda") (return_input_ids, return_hidden_states ) = torch.ops.trtllm.mtp_prepare_drafter_inputs_op( input_ids, attn_metadata.seq_lens_cuda, spec_metadata.mtp_hidden_states_ptrs, - spec_metadata.mtp_past_tokens_ptrs, - previous_layer_hidden_states, previous_layer_draft_tokens, - return_input_ids, return_hidden_states, mtp_num_modules, - mtp_layer_idx, batch_size, num_contexts, hidden_size) + spec_metadata.mtp_past_tokens_ptrs, hidden_states, + accepted_tokens, num_accepted_tokens, return_input_ids, + return_hidden_states, mtp_num_modules, batch_size, + num_contexts, hidden_size) else: return_input_ids_list = [] return_hidden_states_list = [] - if mtp_layer_idx == 0: # The first MTP layer - input_ids_offset = 0 - for bix in range(batch_size): - slot_id_cuda = spec_metadata.slot_ids[bix] - cur_seq_len = attn_metadata.seq_lens_cuda[bix] - past_tokens = torch.index_select(mtp_past_tokens_pool, 0, - slot_id_cuda).squeeze(0) - past_hidden_states = torch.index_select( - mtp_past_hidden_states_pool, 0, slot_id_cuda).squeeze(0) - - if bix < num_contexts: - # Context request - # MTP past tokens - # cuda Graph should not run this part since has context request - prompt_tokens = input_ids[ - input_ids_offset:input_ids_offset + cur_seq_len] - cat_tensor = torch.cat( - (prompt_tokens[1:], past_tokens[-1:]), dim=0) - return_input_ids_list.append(cat_tensor) - - # MTP past hidden states - prompt_hidden_states = previous_layer_hidden_states[ - input_ids_offset:input_ids_offset + cur_seq_len] - return_hidden_states_list.append(prompt_hidden_states) - else: - # Generation request - # Directly append - return_input_ids_list.append(past_tokens) - return_hidden_states_list.append(past_hidden_states) - - # Update offset - input_ids_offset += cur_seq_len - - # Concat into a continuous buffer - return_input_ids = torch.cat(return_input_ids_list, dim=0) - return_hidden_states = torch.cat(return_hidden_states_list, - dim=0) - else: - # this else part should be CUDA Graph supported - input_ids_offset = 0 - for bix in range(batch_size): - # For the generation request, the 'cur_seq_len' already been update to 'num_nextn_predict_layers'. - cur_seq_len = attn_metadata.seq_lens_cuda[bix] - - # The 'input_ids' come from the prvious layer - previous_layer_tokens = input_ids[ - input_ids_offset:input_ids_offset + cur_seq_len] - - # MTP past tokens - previous_draft_tokens = previous_layer_draft_tokens[bix:( - bix + 1)] - cat_tensor = torch.cat( - (previous_layer_tokens, previous_draft_tokens), dim=0) - return_input_ids_list.append(cat_tensor[1:]) - - # Update offset - input_ids_offset += cur_seq_len - - return_input_ids = torch.cat(return_input_ids_list, dim=0) - # Directly use previous_layer_hidden_states as this layer's input hidden states - return_hidden_states = previous_layer_hidden_states - - if mtp_layer_idx == 0 and attn_metadata is not None: - self.change_attn_metadata(num_accepted_tokens, attn_metadata) + # Calculate cumulative sequence lengths for indexing + last_tokens_idx = torch.cumsum( + attn_metadata.seq_lens_cuda, dim=0, dtype=torch.long) - 1 + # context + if num_contexts > 0: + hidden_states_ctx = hidden_states[:num_ctx_tokens, :] + input_prompt_ids = input_ids[:num_ctx_tokens] + input_ids_ctx = torch.empty_like(input_prompt_ids, + dtype=torch.int32, + device="cuda") + input_ids_ctx[:-1].copy_(input_prompt_ids[1:]) + input_ids_ctx[last_tokens_idx[:num_contexts]] = \ + accepted_tokens[:num_contexts, 0] + return_input_ids_list.append(input_ids_ctx) + return_hidden_states_list.append(hidden_states_ctx) + # generation + if num_gens > 0: + slot_ids = spec_metadata.slot_ids[num_contexts:batch_size] + gen_batch_idx = spec_metadata.batch_indices_cuda[:num_gens] + gen_token_idx = num_accepted_tokens[num_contexts:] - 1 + accepted_tokens_gen = accepted_tokens[num_contexts:, :] + input_ids_gen = accepted_tokens_gen[gen_batch_idx, + gen_token_idx].unsqueeze(1) + input_ids_gen = torch.concat( + [mtp_past_tokens_pool[slot_ids][:, 1:], input_ids_gen], + dim=1) + hidden_states_gen = mtp_past_hidden_states_pool[ + slot_ids].flatten(0, 1) + return_input_ids_list.append(input_ids_gen.flatten(0, 1)) + return_hidden_states_list.append(hidden_states_gen) + # Concatenate into continuous buffers + return_input_ids = torch.concat(return_input_ids_list, dim=0) + return_hidden_states = torch.concat(return_hidden_states_list, + dim=0) + + # update position_ids + position_ids_list = [] + if num_contexts > 0: + position_ids_list.append(position_ids[:num_ctx_tokens]) + if num_gens > 0: + position_ids_gen = position_ids[num_ctx_tokens:].reshape( + num_gens, mtp_num_modules + 1)[:, -mtp_num_modules:] + position_ids_gen = position_ids_gen - ( + 1 + mtp_num_modules - + num_accepted_tokens[num_contexts:].unsqueeze(1)) + position_ids_list.append(position_ids_gen.flatten()) + return_position_ids = torch.concat(position_ids_list, dim=-1) return { "input_ids": return_input_ids, - "position_ids": position_ids, + "position_ids": return_position_ids, "hidden_states": return_hidden_states, "attn_metadata": attn_metadata, } @@ -1208,11 +1232,11 @@ def prepare_drafter_inputs( num_contexts = attn_metadata.num_contexts # context - input_ctx_ids = input_ids[:attn_metadata.num_ctx_tokens] - input_ids_ctx = torch.empty_like(input_ctx_ids, + input_prompt_ids = input_ids[:attn_metadata.num_ctx_tokens] + input_ids_ctx = torch.empty_like(input_prompt_ids, dtype=torch.int32, device="cuda") - input_ids_ctx[:-1].copy_(input_ctx_ids[1:]) + input_ids_ctx[:-1].copy_(input_prompt_ids[1:]) input_ids_ctx[ last_tokens_idx[:num_contexts]] = accepted_tokens[:num_contexts, 0] diff --git a/tensorrt_llm/llmapi/llm_args.py b/tensorrt_llm/llmapi/llm_args.py index de4aa4365ae..98701436476 100644 --- a/tensorrt_llm/llmapi/llm_args.py +++ b/tensorrt_llm/llmapi/llm_args.py @@ -298,6 +298,7 @@ class MTPDecodingConfig(DecodingBaseConfig): use_relaxed_acceptance_for_thinking: Optional[bool] = False relaxed_topk: Optional[int] = 1 relaxed_delta: Optional[float] = 0. + use_mtp_vanilla: Optional[bool] = False @classmethod def from_dict(cls, data: dict): @@ -1351,7 +1352,8 @@ def validate_speculative_config(self): use_relaxed_acceptance_for_thinking=self.speculative_config. use_relaxed_acceptance_for_thinking, relaxed_topk=self.speculative_config.relaxed_topk, - relaxed_delta=self.speculative_config.relaxed_delta) + relaxed_delta=self.speculative_config.relaxed_delta, + use_mtp_vanilla=self.speculative_config.use_mtp_vanilla) else: raise ValueError( f"Speculative config type not recognized: {self.speculative_config}" diff --git a/tests/integration/defs/.test_durations b/tests/integration/defs/.test_durations index c06064a5c71..8a9b53d819a 100644 --- a/tests/integration/defs/.test_durations +++ b/tests/integration/defs/.test_durations @@ -131,8 +131,8 @@ "examples/test_multimodal.py::test_llm_multimodal_general[llava-v1.6-mistral-7b-hf-vision-trtllm-pp:1-tp:1-float16-bs:1-cpp_e2e:False-nb:1]": 306.38610201328993, "examples/test_prompt_lookup.py::test_llm_prompt_lookup_1gpu[streaming-gpt2-use_cpp_session-use_tokens-max_matching_ngram_size_2-prompt_lookup_num_tokens_8-float16-bs2]": 195.90045699477196, "test_unittests.py::test_unittests_v2[unittest/trt/model/test_gpt.py -k \"partition2\"]": 357.6496359631419, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 413.903915906325, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]": 143.841789112892, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 413.903915906325, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]": 143.841789112892, "accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=FLASHINFER-torch_compile=False]": 307.12596721109, "accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=True]": 166.85348949534819, "accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=FLASHINFER-torch_compile=True]": 226.39608797896653, @@ -228,10 +228,10 @@ "test_e2e.py::test_llmapi_server_example": 112.925546400249, "test_unittests.py::test_unittests_v2[unittest/trt/functional]": 778.6451135131065, "test_unittests.py::test_unittests_v2[unittest/trt/model/test_mamba.py]": 76.84791256207973, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 506.1045090719126, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]": 184.20976317999884, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 202.37037238897756, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]": 246.64391099987552, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 506.1045090719126, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False]": 184.20976317999884, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 202.37037238897756, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]": 246.64391099987552, "accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=TRTLLM-torch_compile=True]": 313.69273760309443, "accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=True-attn_backend=FLASHINFER-torch_compile=False]": 409.8932851999998, "accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=True-attn_backend=FLASHINFER-torch_compile=True]": 344.8807112099603, @@ -362,10 +362,10 @@ "examples/test_multimodal.py::test_llm_fp8_multimodal_general[fp8-fp8-cnn_dailymail-Qwen2-VL-7B-Instruct-pp:1-tp:1-bfloat16-bs:1-cpp_e2e:False]": 332.0248579243198, "test_e2e.py::test_mistral_large_hidden_vocab_size": 81.36711680702865, "test_e2e.py::test_trtllm_bench_iteration_log[TRT-non-streaming-meta-llama/Llama-3.1-8B-llama-3.1-model/Meta-Llama-3.1-8B]": 285.3362849447876, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 647.6109309499152, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False]": 647.6109309499152, "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[-attention_dp-cuda_graph-overlap_scheduler-torch_compile=False]": 326.1317654890008, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]": 226.01353620411828, - "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-attention_dp-cuda_graph-overlap_scheduler-torch_compile=False]": 336.02580665098503, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False]": 226.01353620411828, + "accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-attention_dp-cuda_graph-overlap_scheduler-torch_compile=False]": 336.02580665098503, "accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=FLASHINFER-torch_compile=True]": 443.91388061689213, "accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_bfloat16[attn_backend=TRTLLM-torch_compile=False]": 191.10617867391557, "accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=False-attn_backend=FLASHINFER-torch_compile=False]": 237.24446990108117, diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 55375156298..02560c75318 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -586,10 +586,10 @@ def test_bfloat16_4gpus(self, tp_size, pp_size, ep_size, mtp_nextn, (False, False, True, False), (False, False, False, True), (True, False, True, True), (True, True, True, True)]) - @parametrize_with_ids("mtp_nextn", [0, 2]) - def test_fp8_block_scales(self, mtp_nextn, fp8kv, attention_dp, cuda_graph, + @parametrize_with_ids("mtp", ["disable", "eagle", "vanilla"]) + def test_fp8_block_scales(self, mtp, fp8kv, attention_dp, cuda_graph, overlap_scheduler, torch_compile): - if torch_compile and mtp_nextn > 0: + if torch_compile and mtp != "disable": pytest.skip("https://nvbugs/5252313") if torch_compile and attention_dp: pytest.skip("https://nvbugs/5252559") @@ -610,8 +610,12 @@ def test_fp8_block_scales(self, mtp_nextn, fp8kv, attention_dp, cuda_graph, pytorch_config["kv_cache_dtype"] = "fp8" mtp_config = None - if mtp_nextn > 0: + mtp_nextn = 2 + if mtp == "eagle": mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn) + elif mtp == "vanilla": + mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn, + use_mtp_vanilla=True) llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8", kv_cache_config=kv_cache_config, diff --git a/tests/integration/test_lists/qa/examples_test_list.txt b/tests/integration/test_lists/qa/examples_test_list.txt index b2b56fb824f..e5416a1052b 100644 --- a/tests/integration/test_lists/qa/examples_test_list.txt +++ b/tests/integration/test_lists/qa/examples_test_list.txt @@ -458,7 +458,7 @@ accuracy/test_llm_api_pytorch.py::TestLlama4ScoutInstruct::test_auto_dtype[tp8ep accuracy/test_llm_api_pytorch.py::TestMixtral8x7B::test_fp8_tp2 accuracy/test_llm_api_pytorch.py::TestMixtral8x7B::test_nvfp4_tp2 accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestMinitron4BBaseInstruct::test_fp8_prequantized diff --git a/tests/integration/test_lists/qa/llm_sanity_test.txt b/tests/integration/test_lists/qa/llm_sanity_test.txt index 1d5f225b94c..7da833b3159 100644 --- a/tests/integration/test_lists/qa/llm_sanity_test.txt +++ b/tests/integration/test_lists/qa/llm_sanity_test.txt @@ -127,7 +127,7 @@ accuracy/test_llm_api_pytorch.py::TestLlama3_1NemotronNano8Bv1::test_fp8_prequan accuracy/test_llm_api_pytorch.py::TestNemotronUltra::test_fp8_prequantized[tp8-cuda_graph=True] accuracy/test_llm_api_pytorch.py::TestNemotronH::test_reasoning_fp8_prequantized accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=0-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] -accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] +accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4[moe_backend=CUTLASS-mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] diff --git a/tests/integration/test_lists/test-db/l0_h100.yml b/tests/integration/test_lists/test-db/l0_h100.yml index 324a74e3fb5..ada1ff80d83 100644 --- a/tests/integration/test_lists/test-db/l0_h100.yml +++ b/tests/integration/test_lists/test-db/l0_h100.yml @@ -38,8 +38,9 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=True-attn_backend=TRTLLM-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=True-attn_backend=FLASHINFER-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8[fp8kv=True-attn_backend=FLASHINFER-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=True-attention_dp=False-cuda_graph=True-overlap_scheduler=True-torch_compile=True] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=fp8-mtp_nextn=2-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestQwen3_8B::test_fp8_block_scales[latency] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8[latency] @@ -164,20 +165,21 @@ l0_h100: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16[mtp_nextn=2-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=0-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] - - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp_nextn=2-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=disable-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=True-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=False-attention_dp=True-cuda_graph=False-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=False-attention_dp=False-cuda_graph=True-overlap_scheduler=False-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=eagle-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=True-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_fp8_block_scales[mtp=vanilla-fp8kv=False-attention_dp=False-cuda_graph=False-overlap_scheduler=False-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_no_kv_cache_reuse[quant_dtype=none-mtp_nextn=2-fp8kv=False-attention_dp=True-cuda_graph=True-overlap_scheduler=True] - accuracy/test_llm_api_pytorch.py::TestQwen3_30B_A3B::test_fp8_block_scales[latency] - condition: diff --git a/tests/unittest/_torch/speculative/test_mtp.py b/tests/unittest/_torch/speculative/test_mtp.py new file mode 100644 index 00000000000..61dcf2a3ce6 --- /dev/null +++ b/tests/unittest/_torch/speculative/test_mtp.py @@ -0,0 +1,1414 @@ +import unittest + +import torch +from parameterized import parameterized + +import tensorrt_llm +from tensorrt_llm._torch.attention_backend import TrtllmAttentionMetadata +from tensorrt_llm._torch.metadata import KVCacheParams +from tensorrt_llm._torch.speculative.mtp import (MTPConfig, + MTPHiddenStatesManager, + MTPSpecMetadata, MTPWorker) + + +def unittest_name_func(testcase_func, param_num, param): + name = param.args[0] + return "%s_%s" % ( + testcase_func.__name__, + parameterized.to_safe_name(name), + ) + + +class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase): + + def setUp(self): + tensorrt_llm.logger.set_level('warning') + + def load_sample_and_accept_draft_tokens_test_cases(): + test_cases = [] + + # ''' + ################# CASE 0 ########################## + # BS=1, 1 context request + mtp_num_modules = 1 + num_context_requests = 1 + logits = torch.tensor( + [ + [-100, 0, -100, -100, -100, -100, -100, -100], # Top1 id = 1 + ], + dtype=torch.float32, + device="cuda") # [num_tokens, vocab_size] + + draft_tokens = torch.tensor( + [], dtype=torch.int, + device="cuda") # [batch_size * max_draft_tokens] + + draft_len = torch.tensor([0], dtype=torch.int, + device="cuda") # [batch_size] + + ref_accepted_tokens = torch.tensor( + [[1, 0]], dtype=torch.int, + device="cuda") # [batch_size * max_draft_tokens] + + ref_num_accepted_tokens = torch.tensor([1], + dtype=torch.int, + device="cuda") # [batch_size] + + test_cases += [[ + "case0", mtp_num_modules, logits, draft_tokens, draft_len, + num_context_requests, ref_accepted_tokens, ref_num_accepted_tokens + ]] + + ################# CASE 1 ########################## + # BS=4, 4 context request + mtp_num_modules = 1 + num_context_requests = 4 + logits = torch.tensor( + [ + [-100, 0, -100, -100, -100, -100, -100, -100], # Top1 id = 1 + [-100, -100, -100, 0, -100, -100, -100, -100], # Top1 id = 3 + [-100, -100, -100, 0, -100, -100, -100, -100], # Top1 id = 3 + [-100, -100, -100, -100, -100, -100, 0, -100], # Top1 id = 6 + ], + dtype=torch.float32, + device="cuda") # [num_tokens, vocab_size] + + draft_tokens = torch.tensor( + [], dtype=torch.int, + device="cuda") # [batch_size * max_draft_tokens] + + draft_len = torch.tensor([0, 0, 0, 0], dtype=torch.int, + device="cuda") # [batch_size] + + ref_accepted_tokens = torch.tensor( + [[1, 0], [3, 0], [3, 0], [6, 0]], dtype=torch.int, + device="cuda") # [batch_size * max_draft_tokens] + + ref_num_accepted_tokens = torch.tensor([1, 1, 1, 1], + dtype=torch.int, + device="cuda") # [batch_size] + + test_cases += [[ + "case1", mtp_num_modules, logits, draft_tokens, draft_len, + num_context_requests, ref_accepted_tokens, ref_num_accepted_tokens + ]] + + ################# CASE 2 ########################## + # BS=1, 1 generation request + # Assume there are 3 MTP layers + # For each generation request, there are four logits here: one from golden token + three draft tokens + mtp_num_modules = 3 + num_context_requests = 0 + logits = torch.tensor( + [ + [-100, 0, -100, -100, -100, -100, -100, -100], # Top1 id = 1 + [-100, -100, -100, 0, -100, -100, -100, -100], # Top1 id = 3 + [-100, -100, 0, -100, -100, -100, -100, -100], # Top1 id = 2 + [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 + ], + dtype=torch.float32, + device="cuda") # [num_tokens, vocab_size] + + draft_tokens = torch.tensor( + [1, 3, 4], dtype=torch.int, + device="cuda") # [batch_size * max_draft_tokens] + + draft_len = torch.tensor([3], dtype=torch.int, + device="cuda") # [batch_size] + + ref_accepted_tokens = torch.tensor( + [[1, 3, 2, 0]], dtype=torch.int, + device="cuda") # [batch_size * max_draft_tokens] + + ref_num_accepted_tokens = torch.tensor([3], + dtype=torch.int, + device="cuda") # [batch_size] + + test_cases += [[ + "case2", mtp_num_modules, logits, draft_tokens, draft_len, + num_context_requests, ref_accepted_tokens, ref_num_accepted_tokens + ]] + + ################# CASE 3 ########################## + # BS=2, 2 generation request + # Assume there are 1 MTP layers + # For each generation request, there are two logits here: one golden token + one draft token + mtp_num_modules = 1 + num_context_requests = 0 + logits = torch.tensor( + [ + [-100, 0, -100, -100, -100, -100, -100, -100], # Top1 id = 1 + [-100, -100, -100, 0, -100, -100, -100, -100], # Top1 id = 3 + [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 + [-100, -100, -100, -100, -100, -100, 0, -100], # Top1 id = 6 + ], + dtype=torch.float32, + device="cuda") # [num_tokens, vocab_size] + + draft_tokens = torch.tensor( + [1, 5], dtype=torch.int, + device="cuda") # [batch_size * max_draft_tokens] + + draft_len = torch.tensor([1, 1], dtype=torch.int, + device="cuda") # [batch_size] + + ref_accepted_tokens = torch.tensor( + [[1, 3], [4, 0]], dtype=torch.int, + device="cuda") # [batch_size * max_draft_tokens] + + ref_num_accepted_tokens = torch.tensor([2, 1], + dtype=torch.int, + device="cuda") # [batch_size] + + test_cases += [[ + "case3", mtp_num_modules, logits, draft_tokens, draft_len, + num_context_requests, ref_accepted_tokens, ref_num_accepted_tokens + ]] + + ################# CASE 4 ########################## + # BS=2, 2 generation request + # Assume there are 3 MTP layers + # For each generation request, there are four logits here: one golden token + three draft token + mtp_num_modules = 3 + num_context_requests = 0 + logits = torch.tensor( + [ + [-100, 0, -100, -100, -100, -100, -100, -100], # Top1 id = 1 + [-100, -100, -100, 0, -100, -100, -100, -100], # Top1 id = 3 + [-100, -100, 0, -100, -100, -100, -100, -100], # Top1 id = 2 + [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 + [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 + [-100, -100, -100, -100, -100, -100, 0, -100], # Top1 id = 6 + [-100, -100, -100, -100, -100, 0, -100, -100], # Top1 id = 5 + [-100, -100, 0, -100, -100, -100, -100, -100], # Top1 id = 2 + ], + dtype=torch.float32, + device="cuda") # [num_tokens, vocab_size] + + draft_tokens = torch.tensor( + [1, 3, 4, 4, 7, 3], dtype=torch.int, + device="cuda") # [batch_size * max_draft_tokens] + + draft_len = torch.tensor([3, 3], dtype=torch.int, + device="cuda") # [batch_size] + + ref_accepted_tokens = torch.tensor( + [[1, 3, 2, 0], [4, 6, 0, 0]], dtype=torch.int, + device="cuda") # [batch_size * max_draft_tokens] + + ref_num_accepted_tokens = torch.tensor([3, 2], + dtype=torch.int, + device="cuda") # [batch_size] + + test_cases += [[ + "case4", mtp_num_modules, logits, draft_tokens, draft_len, + num_context_requests, ref_accepted_tokens, ref_num_accepted_tokens + ]] + + ################# CASE 5 ########################## + # BS=3, 3 generation request, 2 accept partial, 1 accept all, 1 accept none + # Assume there are 3 MTP layers + # For each generation request, there are four logits here: one golden token + three draft token + mtp_num_modules = 3 + num_context_requests = 0 + logits = torch.tensor( + [ + [-100, 0, -100, -100, -100, -100, -100, -100], # Top1 id = 1 + [-100, -100, -100, 0, -100, -100, -100, -100], # Top1 id = 3 + [-100, -100, 0, -100, -100, -100, -100, -100], # Top1 id = 2 + [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 + [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 + [-100, -100, -100, -100, -100, -100, 0, -100], # Top1 id = 6 + [-100, -100, -100, -100, -100, 0, -100, -100], # Top1 id = 5 + [-100, -100, 0, -100, -100, -100, -100, -100], # Top1 id = 2 + [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 + [-100, -100, -100, 0, -100, -100, -100, -100], # Top1 id = 3 + [-100, -100, -100, -100, -100, 0, -100, -100], # Top1 id = 5 + [-100, -100, -100, -100, -100, -100, -100, 0], # Top1 id = 7 + ], + dtype=torch.float32, + device="cuda") # [num_tokens, vocab_size] + + draft_tokens = torch.tensor( + [1, 3, 5, 4, 6, 5, 5, 7, 4], dtype=torch.int, + device="cuda") # [batch_size * max_draft_tokens] + + draft_len = torch.tensor([3, 3, 3], dtype=torch.int, + device="cuda") # [batch_size] + + ref_accepted_tokens = torch.tensor( + [[1, 3, 2, 0], [4, 6, 5, 2], [4, 0, 0, 0]], + dtype=torch.int, + device="cuda") # [batch_size * max_draft_tokens] + + ref_num_accepted_tokens = torch.tensor([3, 4, 1], + dtype=torch.int, + device="cuda") # [batch_size] + + test_cases += [[ + "case5", mtp_num_modules, logits, draft_tokens, draft_len, + num_context_requests, ref_accepted_tokens, ref_num_accepted_tokens + ]] + + ################# CASE 6 ########################## + # BS=2, 1 context request, 1 generation request + # request0 is context request, and request1 is generation request + # Assume there are 1 MTP layers + # For each generation request, there are two logits here: one golden token + one draft token + mtp_num_modules = 1 + num_context_requests = 1 + logits = torch.tensor( + [ + [-100, 0, -100, -100, -100, -100, -100, -100], # Top1 id = 1 + [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 + [-100, -100, -100, -100, -100, -100, 0, -100], # Top1 id = 6 + ], + dtype=torch.float32, + device="cuda") # [num_tokens, vocab_size] + + draft_tokens = torch.tensor( + [4], dtype=torch.int, + device="cuda") # [batch_size * max_draft_tokens] + + draft_len = torch.tensor([0, 1], dtype=torch.int, + device="cuda") # [batch_size] + + ref_accepted_tokens = torch.tensor( + [[1, 0], [4, 6]], dtype=torch.int, + device="cuda") # [batch_size * max_draft_tokens] + + ref_num_accepted_tokens = torch.tensor([1, 2], + dtype=torch.int, + device="cuda") # [batch_size] + + test_cases += [[ + "case6", mtp_num_modules, logits, draft_tokens, draft_len, + num_context_requests, ref_accepted_tokens, ref_num_accepted_tokens + ]] + + return test_cases + + @parameterized.expand(load_sample_and_accept_draft_tokens_test_cases, + name_func=unittest_name_func) + def test_sample_and_accept_draft_tokens(self, test_case_name, + mtp_num_modules, logits, + draft_tokens, draft_len, + num_context_requests, + ref_accepted_tokens, + ref_num_accepted_tokens): + batch_size = len(draft_len) + spec_config = MTPConfig(num_nextn_predict_layers=mtp_num_modules) + + # attention metedata + attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size, + max_num_tokens=1024, + kv_cache_manager=None) + attn_metadata.seq_lens = torch.tensor( + [1] * batch_size, dtype=torch.int) # dummy sequence length + attn_metadata.num_contexts = num_context_requests + + # speculative decoding metadata + spec_metadata = MTPSpecMetadata(max_num_requests=32, + spec_dec_mode=spec_config.spec_dec_mode, + max_draft_tokens=mtp_num_modules, + mtp_num_modules=mtp_num_modules) + spec_metadata.draft_tokens = draft_tokens + + # mtp worker + mtpworker = MTPWorker(spec_config) + + # Test thop kernel + # Test native torch op + for is_thop in [True, False]: + mtpworker.is_thop = is_thop + # TODO: add unit tests for relaxed acceptance + accepted_tokens, num_accepted_tokens = mtpworker.sample_and_accept_draft_tokens( + None, logits, spec_metadata, attn_metadata) + + torch.testing.assert_close(num_accepted_tokens, + ref_num_accepted_tokens) + for i in range(len(draft_len)): + torch.testing.assert_close( + accepted_tokens[i][0:ref_num_accepted_tokens[i]], + ref_accepted_tokens[i][0:ref_num_accepted_tokens[i]]) + + +class TestMTPUpdateMTPHiddenStates(unittest.TestCase): + + def setUp(self): + tensorrt_llm.logger.set_level('warning') + + def load_update_mtp_hidden_states_test_cases(): + + def gen_data(batch_size, num_nextn_predict_layers, hidden_size): + mtp_past_hidden_states_ptrs = [] + mtp_past_tokens_ptrs = [] + mtp_hidden_states_tensor_pool = torch.ones( + (batch_size, num_nextn_predict_layers, hidden_size), + device='cuda', + dtype=torch.float32) + mtp_tokens_tensor_pool = torch.ones( + (batch_size, num_nextn_predict_layers), + device='cuda', + dtype=torch.int) + + for bix in range(batch_size): + mtp_hidden_states_tensor_pool[ + bix] = mtp_hidden_states_tensor_pool[ + bix] * bix # be different + mtp_past_hidden_states_ptrs.append( + mtp_hidden_states_tensor_pool[bix].data_ptr()) + + mtp_tokens_tensor_pool[ + bix] = mtp_tokens_tensor_pool[bix] * bix # be different + mtp_past_tokens_ptrs.append( + mtp_tokens_tensor_pool[bix].data_ptr()) + return mtp_past_hidden_states_ptrs, mtp_past_tokens_ptrs, mtp_hidden_states_tensor_pool, mtp_tokens_tensor_pool + + test_cases = [] + + ################# CASE 0 ########################## + # BS=1, 1 context request + batch_size = 1 + num_context_request = 1 + num_nextn_predict_layers = 1 + hidden_size = 12 + request_ids = range(batch_size) + + # Since we will update the data inplace, so we need two data, + # One for is_thop=False, one for is_thop=True + mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( + batch_size, num_nextn_predict_layers, hidden_size) + mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( + batch_size, num_nextn_predict_layers, hidden_size) + + input_ids = torch.tensor([5, 6, 7, 8, 9], + dtype=torch.int, + device="cuda") + + seq_lens = torch.tensor([5], dtype=torch.int, + device="cuda") # [batch_size] + + hidden_states = torch.tensor( + [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], + [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], + [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], + [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], + ], + dtype=torch.float32, + device="cuda") # [prompt_length, hidden_size] + + num_accepted_tokens = torch.tensor([1], dtype=torch.int, + device="cuda") # [batch_size] + + ref_mtp_tokens_dict = dict() + ref_mtp_tokens_dict[0] = torch.tensor([9], + dtype=torch.int, + device="cuda") + + ref_mtp_hidden_state_dict = dict() + ref_mtp_hidden_state_dict[0] = torch.tensor([ + [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], + ], + dtype=torch.float32, + device="cuda") + + test_cases += [[ + 'case0_is_thop_false', num_nextn_predict_layers, + num_context_request, input_ids, seq_lens, hidden_states, + torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v1, + device="cuda"), mtp_hidden_states_tensor_pool_v1, + mtp_tokens_tensor_pool_v1, request_ids, num_accepted_tokens, + ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, False + ]] + + test_cases += [[ + 'case0_is_thop_true', num_nextn_predict_layers, num_context_request, + input_ids, seq_lens, hidden_states, + torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v2, + device="cuda"), mtp_hidden_states_tensor_pool_v2, + mtp_tokens_tensor_pool_v2, request_ids, num_accepted_tokens, + ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, True + ]] + + ################## CASE 1 ########################## + # BS=3, 3 context request, num_nextn_predict_layers = 2 + batch_size = 3 + num_context_request = 3 + num_nextn_predict_layers = 2 + hidden_size = 12 + request_ids = range(batch_size) + + # Since we will update the data inplace, so we need two data, + # One for is_thop=False, one for is_thop=True + mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( + batch_size, num_nextn_predict_layers, hidden_size) + mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( + batch_size, num_nextn_predict_layers, hidden_size) + + input_ids = torch.tensor( + [ + 5, + 6, + 7, + 8, + 9, # request 1 + 11, + 12, + 13, + 14, # request 2 + 21, + 22, + 23, + 24, + 25, + 26 # request 3 + ], + dtype=torch.int, + device="cuda") + + seq_lens = torch.tensor([5, 4, 6], dtype=torch.int, + device="cuda") # [batch_size] + + hidden_states = torch.tensor( + [ + # request 1 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], + [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], + [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], + [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], + + # request 2 + [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 1, 1], + [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 1, 1], + [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], + [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 1, 1], + + # request 3 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], + [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], + [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], + [90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 1, 1], + [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 1, 1], + ], + dtype=torch.float32, + device="cuda") # [prompt_length, hidden_size] + + num_accepted_tokens = torch.tensor([1, 1, 1], + dtype=torch.int, + device="cuda") # [batch_size] + + ref_mtp_tokens_dict = dict() + ref_mtp_tokens_dict[0] = torch.tensor([8, 9], + dtype=torch.int, + device="cuda") + ref_mtp_tokens_dict[1] = torch.tensor([13, 14], + dtype=torch.int, + device="cuda") + ref_mtp_tokens_dict[2] = torch.tensor([25, 26], + dtype=torch.int, + device="cuda") + + ref_mtp_hidden_state_dict = dict() + ref_mtp_hidden_state_dict[0] = torch.tensor([ + [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], + [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], + ], + dtype=torch.float32, + device="cuda") + ref_mtp_hidden_state_dict[1] = torch.tensor([ + [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], + [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 1, 1], + ], + dtype=torch.float32, + device="cuda") + ref_mtp_hidden_state_dict[2] = torch.tensor([ + [90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 1, 1], + [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 1, 1], + ], + dtype=torch.float32, + device="cuda") + + test_cases += [[ + 'case1_is_thop_false', num_nextn_predict_layers, + num_context_request, input_ids, seq_lens, hidden_states, + torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v1, + device="cuda"), mtp_hidden_states_tensor_pool_v1, + mtp_tokens_tensor_pool_v1, request_ids, num_accepted_tokens, + ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, False + ]] + + test_cases += [[ + 'case1_is_thop_true', num_nextn_predict_layers, num_context_request, + input_ids, seq_lens, hidden_states, + torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v2, + device="cuda"), mtp_hidden_states_tensor_pool_v2, + mtp_tokens_tensor_pool_v2, request_ids, num_accepted_tokens, + ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, True + ]] + + ################## CASE 2 ########################## + # BS=1, 1 generation request, num_nextn_predict_layers = 1 + batch_size = 1 + num_context_request = 0 + num_nextn_predict_layers = 1 + hidden_size = 12 + request_ids = range(batch_size) + + # Since we will update the data inplace, so we need two data, + # One for is_thop=False, one for is_thop=True + mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( + batch_size, num_nextn_predict_layers, hidden_size) + mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( + batch_size, num_nextn_predict_layers, hidden_size) + + input_ids = torch.tensor([5, 42], dtype=torch.int, device="cuda") + + seq_lens = torch.tensor([2], dtype=torch.int, + device="cuda") # [batch_size] + + hidden_states = torch.tensor( + [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], + ], + dtype=torch.float32, + device="cuda") # [prompt_length, hidden_size] + + num_accepted_tokens = torch.tensor([2], dtype=torch.int, + device="cuda") # [batch_size] + + ref_mtp_tokens_dict = dict() + ref_mtp_tokens_dict[0] = torch.tensor([42], + dtype=torch.int, + device="cuda") + + ref_mtp_hidden_state_dict = dict() + ref_mtp_hidden_state_dict[0] = torch.tensor([ + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], + ], + dtype=torch.float32, + device="cuda") + + test_cases += [[ + 'case2_is_thop_false', num_nextn_predict_layers, + num_context_request, input_ids, seq_lens, hidden_states, + torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v1, + device="cuda"), mtp_hidden_states_tensor_pool_v1, + mtp_tokens_tensor_pool_v1, request_ids, num_accepted_tokens, + ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, False + ]] + + test_cases += [[ + 'case2_is_thop_true', num_nextn_predict_layers, num_context_request, + input_ids, seq_lens, hidden_states, + torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v2, + device="cuda"), mtp_hidden_states_tensor_pool_v2, + mtp_tokens_tensor_pool_v2, request_ids, num_accepted_tokens, + ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, True + ]] + + ################## CASE 3 ########################## + # BS=4, 4 generation request, num_nextn_predict_layers = 1 + batch_size = 4 + num_context_request = 0 + num_nextn_predict_layers = 1 + hidden_size = 12 + request_ids = range(batch_size) + + # Since we will update the data inplace, so we need two data, + # One for is_thop=False, one for is_thop=True + mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( + batch_size, num_nextn_predict_layers, hidden_size) + mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( + batch_size, num_nextn_predict_layers, hidden_size) + + input_ids = torch.tensor( + [ + 5, + 6, # request 1 + 7, + 8, # request 2 + 9, + 10, # request 3 + 11, + 12 # request 4 + ], + dtype=torch.int, + device="cuda") + + seq_lens = torch.tensor([2, 2, 2, 2], dtype=torch.int, + device="cuda") # [batch_size] + + hidden_states = torch.tensor( + [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], + [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], + [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], + [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], + [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 1, 1], + [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 1, 1], + [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], + ], + dtype=torch.float32, + device="cuda") # [prompt_length, hidden_size] + + num_accepted_tokens = torch.tensor([2, 1, 1, 2], + dtype=torch.int, + device="cuda") # [batch_size] + + ref_mtp_tokens_dict = dict() + ref_mtp_tokens_dict[0] = torch.tensor([6], + dtype=torch.int, + device="cuda") + ref_mtp_tokens_dict[1] = torch.tensor([7], + dtype=torch.int, + device="cuda") + ref_mtp_tokens_dict[2] = torch.tensor([9], + dtype=torch.int, + device="cuda") + ref_mtp_tokens_dict[3] = torch.tensor([12], + dtype=torch.int, + device="cuda") + + ref_mtp_hidden_state_dict = dict() + ref_mtp_hidden_state_dict[0] = torch.tensor([ + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], + ], + dtype=torch.float32, + device="cuda") + ref_mtp_hidden_state_dict[1] = torch.tensor([ + [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], + ], + dtype=torch.float32, + device="cuda") + ref_mtp_hidden_state_dict[2] = torch.tensor([ + [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], + ], + dtype=torch.float32, + device="cuda") + ref_mtp_hidden_state_dict[3] = torch.tensor([ + [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], + ], + dtype=torch.float32, + device="cuda") + + test_cases += [[ + 'case3_is_thop_false', num_nextn_predict_layers, + num_context_request, input_ids, seq_lens, hidden_states, + torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v1, + device="cuda"), mtp_hidden_states_tensor_pool_v1, + mtp_tokens_tensor_pool_v1, request_ids, num_accepted_tokens, + ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, False + ]] + + test_cases += [[ + 'case3_is_thop_true', num_nextn_predict_layers, num_context_request, + input_ids, seq_lens, hidden_states, + torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v2, + device="cuda"), mtp_hidden_states_tensor_pool_v2, + mtp_tokens_tensor_pool_v2, request_ids, num_accepted_tokens, + ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, True + ]] + + ################## CASE 4 ########################## + # BS=4, 2 context request, 2 generation request, num_nextn_predict_layers = 2 + num_context = 2 + num_generation = 2 + num_context_request = num_context + batch_size = num_context + num_generation + num_nextn_predict_layers = 2 + hidden_size = 12 + request_ids = range(batch_size) + + # Since we will update the data inplace, so we need two data, + # One for is_thop=False, one for is_thop=True + mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( + batch_size, num_nextn_predict_layers, hidden_size) + mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( + batch_size, num_nextn_predict_layers, hidden_size) + + input_ids = torch.tensor( + [ + 5, + 6, + 7, + 8, + 9, # request 1 + 10, + 11, + 12, + 13, # request 2 + 26, + 27, + 28, # request 3 + 31, + 32, + 33 # request 4 + ], + dtype=torch.int, + device="cuda") + + seq_lens = torch.tensor([5, 4, 3, 3], dtype=torch.int, + device="cuda") # [batch_size] + + hidden_states = torch.tensor( + [ + # request 1 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], + [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], + [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], + [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], + + # request 2 + [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 1, 1], + [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 1, 1], + [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], + [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 1, 1], + + # request 3 + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], + [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], + + # request 4 + [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 1, 1], + [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], + [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 1, 1], + ], + dtype=torch.float32, + device="cuda") # [prompt_length, hidden_size] + + num_accepted_tokens = torch.tensor([1, 1, 3, 1], + dtype=torch.int, + device="cuda") # [batch_size] + + ref_mtp_tokens_dict = dict() + ref_mtp_tokens_dict[0] = torch.tensor([8, 9], + dtype=torch.int, + device="cuda") + ref_mtp_tokens_dict[1] = torch.tensor([12, 13], + dtype=torch.int, + device="cuda") + ref_mtp_tokens_dict[2] = torch.tensor([27, 28], + dtype=torch.int, + device="cuda") + ref_mtp_tokens_dict[3] = torch.tensor([3, 31], + dtype=torch.int, + device="cuda") + + ref_mtp_hidden_state_dict = dict() + ref_mtp_hidden_state_dict[0] = torch.tensor([ + [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], + [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], + ], + dtype=torch.float32, + device="cuda") + ref_mtp_hidden_state_dict[1] = torch.tensor([ + [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], + [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 1, 1], + ], + dtype=torch.float32, + device="cuda") + ref_mtp_hidden_state_dict[2] = torch.tensor([ + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], + [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], + ], + dtype=torch.float32, + device="cuda") + ref_mtp_hidden_state_dict[3] = torch.tensor([ + [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], + [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 1, 1], + ], + dtype=torch.float32, + device="cuda") + + test_cases += [[ + 'case4_is_thop_false', num_nextn_predict_layers, + num_context_request, input_ids, seq_lens, hidden_states, + torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v1, + device="cuda"), mtp_hidden_states_tensor_pool_v1, + mtp_tokens_tensor_pool_v1, request_ids, num_accepted_tokens, + ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, False + ]] + + test_cases += [[ + 'case4_is_thop_true', num_nextn_predict_layers, num_context_request, + input_ids, seq_lens, hidden_states, + torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v2, + device="cuda"), mtp_hidden_states_tensor_pool_v2, + mtp_tokens_tensor_pool_v2, request_ids, num_accepted_tokens, + ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, True + ]] + + return test_cases + + @parameterized.expand(load_update_mtp_hidden_states_test_cases, + name_func=unittest_name_func) + def test_mtp_update_mtp_hidden_states( + self, test_case_name, num_nextn_predict_layers, num_context_request, + input_ids, seq_lens, hidden_states, mtp_hidden_states_ptrs, + mtp_past_tokens_ptrs, mtp_hidden_states_tensor_pool, + mtp_tokens_tensor_pool, request_ids, num_accepted_tokens, + ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, is_thop): + + batch_size = len(request_ids) + batch_size - num_context_request + hidden_size = hidden_states.shape[1] + spec_config = MTPConfig( + num_nextn_predict_layers=num_nextn_predict_layers) + + attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size, + max_num_tokens=1024, + kv_cache_manager=None) + attn_metadata.seq_lens = seq_lens.to('cpu') + attn_metadata.num_contexts = num_context_request + + spec_manager = MTPHiddenStatesManager(config=spec_config, + dtype=torch.float32, + hidden_size=hidden_size, + max_num_requests=batch_size) + for i in range(batch_size): + # for the generation requests, we also need to manually add slot + # because these generation requests are also first use + spec_manager.slot_manager.add_slot(request_ids[i]) + + spec_metadata = MTPSpecMetadata( + max_num_requests=32, + spec_dec_mode=spec_config.spec_dec_mode, + max_draft_tokens=num_nextn_predict_layers, + mtp_num_modules=num_nextn_predict_layers, + mtp_hidden_states_manager=spec_manager) + spec_metadata.request_ids = request_ids + spec_metadata.mtp_hidden_states_ptrs = mtp_hidden_states_ptrs + spec_metadata.mtp_past_tokens_ptrs = mtp_past_tokens_ptrs + + spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool = mtp_hidden_states_tensor_pool + spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool = mtp_tokens_tensor_pool + spec_metadata.prepare() + + mtpworker = MTPWorker(spec_config) + mtpworker.is_thop = is_thop + + mtpworker.update_mtp_hidden_states( + input_ids=input_ids, + hidden_states=hidden_states, + num_accepted_tokens=num_accepted_tokens, + spec_metadata=spec_metadata, + attn_metadata=attn_metadata) + + # Verify + mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool + mtp_past_tokens_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool + + for bix in range(batch_size): + torch.testing.assert_close(mtp_past_tokens_pool[bix], + ref_mtp_tokens_dict[bix]) + torch.testing.assert_close(mtp_past_hidden_states_pool[bix], + ref_mtp_hidden_state_dict[bix]) + + +class TestMTPPrepareDrafterInputs(unittest.TestCase): + + def setUp(self): + tensorrt_llm.logger.set_level('warning') + + def load_prepare_drafter_inputs_test_cases(): + + def gen_data(batch_size, num_nextn_predict_layers, hidden_size): + mtp_past_hidden_states_ptrs = [] + mtp_past_tokens_ptrs = [] + mtp_hidden_states_tensor_pool = torch.ones( + (batch_size, num_nextn_predict_layers, hidden_size), + device='cuda', + dtype=torch.float32) + mtp_tokens_tensor_pool = torch.ones( + (batch_size, num_nextn_predict_layers), + device='cuda', + dtype=torch.int) + + for bix in range(batch_size): + mtp_hidden_states_tensor_pool[ + bix] = mtp_hidden_states_tensor_pool[ + bix] * bix # be different + mtp_past_hidden_states_ptrs.append( + mtp_hidden_states_tensor_pool[bix].data_ptr()) + + mtp_tokens_tensor_pool[ + bix] = mtp_tokens_tensor_pool[bix] * bix # be different + mtp_past_tokens_ptrs.append( + mtp_tokens_tensor_pool[bix].data_ptr()) + return mtp_past_hidden_states_ptrs, mtp_past_tokens_ptrs, mtp_hidden_states_tensor_pool, mtp_tokens_tensor_pool + + test_cases = [] + + ################# CASE 0 ########################## + # MTP0, BS=1, 1 context request + batch_size = 1 + num_nextn_predict_layers = 1 + num_contexts = 1 + hidden_size = 12 + request_ids = range(batch_size) + + # Since we will update the data inplace, so we need two da ta, + # One for is_thop=False, one for is_thop=True + mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( + batch_size, num_nextn_predict_layers, hidden_size) + mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( + batch_size, num_nextn_predict_layers, hidden_size) + + input_ids = torch.tensor([5, 6, 7, 8, 9], + dtype=torch.int, + device="cuda") + position_ids = torch.tensor([0, 1, 2, 3, 4], + dtype=torch.int, + device="cuda") + + seq_lens = torch.tensor([5], dtype=torch.int, + device="cuda") # [batch_size] + + previous_layer_hidden_states = torch.tensor( + [ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], + [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], + [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], + [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], + ], + dtype=torch.float32, + device="cuda") # [prompt_length, hidden_size] + + accepted_tokens = torch.tensor([[0, -1]], + dtype=torch.int, + device="cuda") + num_accepted_tokens = torch.tensor([1], dtype=torch.int, device="cuda") + + ref_input_ids = torch.tensor([6, 7, 8, 9, 0], + dtype=torch.int, + device="cuda") + ref_previous_hidden_states = previous_layer_hidden_states + attn_metadata = None + + test_cases += [[ + 'case0_is_thop_false', + num_nextn_predict_layers, + input_ids, + position_ids, + seq_lens, + torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v1, device="cuda"), + mtp_hidden_states_tensor_pool_v1, + mtp_tokens_tensor_pool_v1, + previous_layer_hidden_states, + request_ids, + num_contexts, + accepted_tokens, + num_accepted_tokens, + attn_metadata, + ref_input_ids, + ref_previous_hidden_states, + False, + ]] + + test_cases += [[ + 'case0_is_thop_true', num_nextn_predict_layers, input_ids, + position_ids, seq_lens, + torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v2, + device="cuda"), mtp_hidden_states_tensor_pool_v2, + mtp_tokens_tensor_pool_v2, previous_layer_hidden_states, + request_ids, num_contexts, accepted_tokens, num_accepted_tokens, + attn_metadata, ref_input_ids, ref_previous_hidden_states, True + ]] + + ################# CASE 1 ########################## + # MTP0, BS=3, 3 context request + batch_size = 3 + num_contexts = 3 + num_nextn_predict_layers = 1 + hidden_size = 16 + request_ids = range(batch_size) + + # Since we will update the data inplace, so we need two data, + # One for is_thop=False, one for is_thop=True + mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( + batch_size, num_nextn_predict_layers, hidden_size) + mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( + batch_size, num_nextn_predict_layers, hidden_size) + + input_ids = torch.tensor( + [ + 5, + 6, + 7, + 8, + 9, # request0 + 10, + 11, + 12, + 13, # request1 + 20, + 21, + 22, + 23, + 24, + 25 # request2 + ], + dtype=torch.int, + device="cuda") + position_ids = torch.tensor( + [[ + 0, + 1, + 2, + 3, + 4, # request0 + 0, + 1, + 2, + 3, # request1 + 0, + 1, + 2, + 3, + 4, + 5 # request2 + ]], + dtype=torch.int, + device="cuda") + + seq_lens = torch.tensor([5, 4, 6], dtype=torch.int, + device="cuda") # [batch_size] + + previous_layer_hidden_states = torch.randn( + (len(input_ids), hidden_size), dtype=torch.float32, + device="cuda") # [prompt_length, hidden_size] + + accepted_tokens = torch.tensor([[0, -1], [1, -1], [2, -1]], + dtype=torch.int, + device="cuda") + num_accepted_tokens = torch.tensor([1, 1, 1], + dtype=torch.int, + device="cuda") + + ref_input_ids = torch.tensor( + [6, 7, 8, 9, 0, 11, 12, 13, 1, 21, 22, 23, 24, 25, 2], + dtype=torch.int, + device="cuda") + ref_previous_hidden_states = previous_layer_hidden_states + + attn_metadata = None + + test_cases += [[ + 'case1_is_thop_false', num_nextn_predict_layers, input_ids, + position_ids, seq_lens, + torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v1, + device="cuda"), mtp_hidden_states_tensor_pool_v1, + mtp_tokens_tensor_pool_v1, previous_layer_hidden_states, + request_ids, num_contexts, accepted_tokens, num_accepted_tokens, + attn_metadata, ref_input_ids, ref_previous_hidden_states, False + ]] + + test_cases += [[ + 'case1_is_thop_true', num_nextn_predict_layers, input_ids, + position_ids, seq_lens, + torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v2, + device="cuda"), mtp_hidden_states_tensor_pool_v2, + mtp_tokens_tensor_pool_v2, previous_layer_hidden_states, + request_ids, num_contexts, accepted_tokens, num_accepted_tokens, + attn_metadata, ref_input_ids, ref_previous_hidden_states, True + ]] + + ################## CASE 2 ########################## + # BS=1, 1 generation request, num_nextn_predict_layers = 1 + batch_size = 1 + num_contexts = 0 + num_nextn_predict_layers = 1 + hidden_size = 12 + request_ids = range(batch_size) + + mtp_past_hidden_states_ptrs_v1 = [] + mtp_past_tokens_ptrs_v1 = [] + mtp_hidden_states_tensor_pool_v1 = torch.ones( + (batch_size, num_nextn_predict_layers, hidden_size), + device='cuda', + dtype=torch.float32) + mtp_tokens_tensor_pool_v1 = torch.ones( + (batch_size, num_nextn_predict_layers), + device='cuda', + dtype=torch.int) + + mtp_hidden_states_tensor_pool_v1[0] = torch.tensor( + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]], + device='cuda', + dtype=torch.float32) + mtp_past_hidden_states_ptrs_v1.append( + mtp_hidden_states_tensor_pool_v1[0].data_ptr()) + mtp_tokens_tensor_pool_v1[0] = torch.tensor([42], + device='cuda', + dtype=torch.int) + mtp_past_tokens_ptrs_v1.append(mtp_tokens_tensor_pool_v1[0].data_ptr()) + + input_ids = torch.tensor([6, 42], dtype=torch.int, device="cuda") + + position_ids = torch.tensor([10, 11], dtype=torch.int, device="cuda") + + # already '-1' in 'update_mtp_hidden_states' + seq_lens = torch.tensor([2], dtype=torch.int, + device="cuda") # [batch_size] + + previous_layer_hidden_states = torch.randn( + (len(input_ids), hidden_size), dtype=torch.float32, + device="cuda") # [prompt_length, hidden_size] + + accepted_tokens = torch.tensor([[43, -1]], + dtype=torch.int, + device="cuda") + + num_accepted_tokens = torch.tensor([1], dtype=torch.int, device="cuda") + + ref_input_ids = torch.tensor([43], dtype=torch.int, device="cuda") + + ref_previous_hidden_states = torch.tensor( + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]], + device='cuda', + dtype=torch.float32) + + attn_metadata = None + + test_cases += [[ + 'case2_is_thop_false', num_nextn_predict_layers, input_ids, + position_ids, seq_lens, + torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v1, + device="cuda"), mtp_hidden_states_tensor_pool_v1, + mtp_tokens_tensor_pool_v1, previous_layer_hidden_states, + request_ids, num_contexts, accepted_tokens, num_accepted_tokens, + attn_metadata, ref_input_ids, ref_previous_hidden_states, False + ]] + + test_cases += [[ + 'case2_is_thop_true', num_nextn_predict_layers, input_ids, + position_ids, seq_lens, + torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v1, + device="cuda"), mtp_hidden_states_tensor_pool_v1, + mtp_tokens_tensor_pool_v1, previous_layer_hidden_states, + request_ids, num_contexts, accepted_tokens, num_accepted_tokens, + attn_metadata, ref_input_ids, ref_previous_hidden_states, True + ]] + + ################## CASE 3 ########################## + # BS=3, 3 generation request, num_nextn_predict_layers = 3 + batch_size = 3 + num_contexts = 0 + num_nextn_predict_layers = 3 + hidden_size = 12 + request_ids = range(batch_size) + + mtp_past_hidden_states_ptrs_v1 = [] + mtp_past_tokens_ptrs_v1 = [] + mtp_hidden_states_tensor_pool_v1 = torch.ones( + (batch_size, num_nextn_predict_layers, hidden_size), + device='cuda', + dtype=torch.float32) + mtp_tokens_tensor_pool_v1 = torch.ones( + (batch_size, num_nextn_predict_layers), + device='cuda', + dtype=torch.int) + + mtp_hidden_states_tensor_pool_v1[0] = torch.tensor([ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], + [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], + ], + device='cuda', + dtype=torch.float32) + mtp_past_hidden_states_ptrs_v1.append( + mtp_hidden_states_tensor_pool_v1[0].data_ptr()) + + mtp_hidden_states_tensor_pool_v1[1] = torch.tensor([ + [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], + [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], + [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 1, 1], + ], + device='cuda', + dtype=torch.float32) + mtp_past_hidden_states_ptrs_v1.append( + mtp_hidden_states_tensor_pool_v1[1].data_ptr()) + + mtp_hidden_states_tensor_pool_v1[2] = torch.tensor([ + [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 1, 1], + [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], + [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 1, 1], + ], + device='cuda', + dtype=torch.float32) + mtp_past_hidden_states_ptrs_v1.append( + mtp_hidden_states_tensor_pool_v1[2].data_ptr()) + + mtp_tokens_tensor_pool_v1[0] = torch.tensor([19, 20, 21], + device='cuda', + dtype=torch.int) + mtp_past_tokens_ptrs_v1.append(mtp_tokens_tensor_pool_v1[0].data_ptr()) + + mtp_tokens_tensor_pool_v1[1] = torch.tensor([29, 30, 31], + device='cuda', + dtype=torch.int) + mtp_past_tokens_ptrs_v1.append(mtp_tokens_tensor_pool_v1[1].data_ptr()) + + mtp_tokens_tensor_pool_v1[2] = torch.tensor([39, 40, 41], + device='cuda', + dtype=torch.int) + mtp_past_tokens_ptrs_v1.append(mtp_tokens_tensor_pool_v1[2].data_ptr()) + + input_ids = torch.tensor( + [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], + dtype=torch.int, + device="cuda") # useless + position_ids = torch.tensor( + [10, 11, 12, 13, 21, 22, 23, 24, 32, 33, 34, 35], + dtype=torch.int, + device="cuda") + + seq_lens = torch.tensor([4, 4, 4], dtype=torch.int, + device="cuda") # [batch_size] + + previous_layer_hidden_states = torch.randn( + (12, hidden_size), device='cuda', dtype=torch.float32) # useless + + accepted_tokens = torch.tensor( + [[22, -1, -1, -1], [31, 32, -1, -1], [0, 40, 41, 42]], + dtype=torch.int, + device="cuda") + + num_accepted_tokens = torch.tensor([1, 2, 4], + dtype=torch.int, + device="cuda") + + ref_input_ids = torch.tensor([20, 21, 22, 30, 31, 32, 40, 41, 42], + dtype=torch.int, + device="cuda") + + ref_previous_hidden_states = torch.tensor([ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], + [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], + [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], + [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], + [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], + [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 1, 1], + [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 1, 1], + [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], + [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 1, 1], + ], + device='cuda', + dtype=torch.float32) + attn_metadata = None + + test_cases += [[ + 'case3_is_thop_false', num_nextn_predict_layers, input_ids, + position_ids, seq_lens, + torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v1, + device="cuda"), mtp_hidden_states_tensor_pool_v1, + mtp_tokens_tensor_pool_v1, previous_layer_hidden_states, + request_ids, num_contexts, accepted_tokens, num_accepted_tokens, + attn_metadata, ref_input_ids, ref_previous_hidden_states, False + ]] + + test_cases += [[ + 'case3_is_thop_true', num_nextn_predict_layers, input_ids, + position_ids, seq_lens, + torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), + torch.tensor(mtp_past_tokens_ptrs_v1, + device="cuda"), mtp_hidden_states_tensor_pool_v1, + mtp_tokens_tensor_pool_v1, previous_layer_hidden_states, + request_ids, num_contexts, accepted_tokens, num_accepted_tokens, + attn_metadata, ref_input_ids, ref_previous_hidden_states, True + ]] + + return test_cases + + @parameterized.expand(load_prepare_drafter_inputs_test_cases, + name_func=unittest_name_func) + def test_prepare_drafter_inputs( + self, test_case_name, num_nextn_predict_layers, input_ids, + position_ids, seq_lens, mtp_past_hidden_states_ptrs, + mtp_past_tokens_ptrs, mtp_hidden_states_tensor_pool, + mtp_tokens_tensor_pool, previous_layer_hidden_states, request_ids, + num_contexts, accepted_tokens, num_accepted_tokens, attn_metadata, + ref_input_ids, ref_previous_hidden_states, is_thop): + + batch_size = len(request_ids) + if previous_layer_hidden_states is not None: + hidden_size = previous_layer_hidden_states.shape[1] + else: + hidden_size = 10 + spec_config = MTPConfig( + num_nextn_predict_layers=num_nextn_predict_layers) + + if attn_metadata is None: + attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size, + max_num_tokens=1024, + kv_cache_manager=None) + attn_metadata.seq_lens = seq_lens.to('cpu') + attn_metadata.num_contexts = num_contexts + # dummy kv cache param + attn_metadata.kv_cache_params = KVCacheParams( + use_cache=True, num_cached_tokens_per_seq=[10] * batch_size) + + spec_manager = MTPHiddenStatesManager(config=spec_config, + dtype=torch.float32, + hidden_size=hidden_size, + max_num_requests=batch_size) + for i in range(batch_size): + # for the generation requests, we also need to manually add slot + # because these generation requests are also first use + spec_manager.slot_manager.add_slot(request_ids[i]) + + spec_metadata = MTPSpecMetadata( + max_num_requests=32, + spec_dec_mode=spec_config.spec_dec_mode, + max_draft_tokens=num_nextn_predict_layers, + mtp_num_modules=num_nextn_predict_layers, + mtp_hidden_states_manager=spec_manager) + spec_metadata.request_ids = request_ids + spec_metadata.mtp_hidden_states_ptrs = mtp_past_hidden_states_ptrs + spec_metadata.mtp_past_tokens_ptrs = mtp_past_tokens_ptrs + + spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool = mtp_hidden_states_tensor_pool + spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool = mtp_tokens_tensor_pool + spec_metadata.prepare() + + mtpworker = MTPWorker(spec_config) + mtpworker.is_thop = is_thop + draft_inputs = mtpworker.prepare_drafter_inputs( + input_ids=input_ids, + position_ids=position_ids, + hidden_states=previous_layer_hidden_states, + accepted_tokens=accepted_tokens, + num_accepted_tokens=num_accepted_tokens, + spec_metadata=spec_metadata, + attn_metadata=attn_metadata) + + torch.testing.assert_close(draft_inputs["input_ids"], ref_input_ids) + torch.testing.assert_close(draft_inputs["hidden_states"], + ref_previous_hidden_states) diff --git a/tests/unittest/_torch/speculative/test_mtp_prepare_drafter_inputs.py b/tests/unittest/_torch/speculative/test_mtp_prepare_drafter_inputs.py deleted file mode 100644 index a09ac2bedb1..00000000000 --- a/tests/unittest/_torch/speculative/test_mtp_prepare_drafter_inputs.py +++ /dev/null @@ -1,637 +0,0 @@ -import unittest - -import torch -from parameterized import parameterized -from utils.util import unittest_name_func - -import tensorrt_llm -from tensorrt_llm._torch.attention_backend import TrtllmAttentionMetadata -from tensorrt_llm._torch.metadata import KVCacheParams -from tensorrt_llm._torch.speculative.mtp import (MTPConfig, - MTPHiddenStatesManager, - MTPSpecMetadata, MTPWorker) - - -class TestMTPPrepareDrafterInputs(unittest.TestCase): - - def setUp(self): - tensorrt_llm.logger.set_level('warning') - - def load_prepare_drafter_inputs_test_cases(): - - def gen_data(batch_size, num_nextn_predict_layers, hidden_size): - mtp_past_hidden_states_ptrs = [] - mtp_past_tokens_ptrs = [] - mtp_hidden_states_tensor_pool = torch.ones( - (batch_size, num_nextn_predict_layers, hidden_size), - device='cuda', - dtype=torch.float32) - mtp_tokens_tensor_pool = torch.ones( - (batch_size, num_nextn_predict_layers), - device='cuda', - dtype=torch.int) - - for bix in range(batch_size): - mtp_hidden_states_tensor_pool[ - bix] = mtp_hidden_states_tensor_pool[ - bix] * bix # be different - mtp_past_hidden_states_ptrs.append( - mtp_hidden_states_tensor_pool[bix].data_ptr()) - - mtp_tokens_tensor_pool[ - bix] = mtp_tokens_tensor_pool[bix] * bix # be different - mtp_past_tokens_ptrs.append( - mtp_tokens_tensor_pool[bix].data_ptr()) - return mtp_past_hidden_states_ptrs, mtp_past_tokens_ptrs, mtp_hidden_states_tensor_pool, mtp_tokens_tensor_pool - - test_cases = [] - - ################# CASE 0 ########################## - # MTP0, BS=1, 1 context request - mtp_layer_idx = 0 - batch_size = 1 - num_nextn_predict_layers = 1 - num_contexts = 1 - hidden_size = 12 - request_ids = range(batch_size) - - # Since we will update the data inplace, so we need two da ta, - # One for is_thop=False, one for is_thop=True - mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - - input_ids = torch.tensor([5, 6, 7, 8, 9], - dtype=torch.int, - device="cuda") - position_ids = torch.tensor([0, 1, 2, 3, 4], - dtype=torch.int, - device="cuda") - - seq_lens = torch.tensor([5], dtype=torch.int, - device="cuda") # [batch_size] - - previous_layer_hidden_states = torch.tensor( - [ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], - [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], - [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], - [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], - [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], - ], - dtype=torch.float32, - device="cuda") # [prompt_length, hidden_size] - - previous_layer_draft_tokens = torch.tensor([-1], - dtype=torch.int, - device="cuda") # useless - num_accepted_tokens = torch.tensor([-1], dtype=torch.int, - device="cuda") # useless - - ref_input_ids = torch.tensor([6, 7, 8, 9, 0], - dtype=torch.int, - device="cuda") - ref_previous_hidden_states = previous_layer_hidden_states - attn_metadata = None - - test_cases += [[ - num_nextn_predict_layers, mtp_layer_idx, input_ids, position_ids, - seq_lens, - torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v1, device="cuda"), - mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1, - previous_layer_hidden_states, previous_layer_draft_tokens, - request_ids, num_contexts, num_accepted_tokens, attn_metadata, - ref_input_ids, ref_previous_hidden_states, False - ]] - - test_cases += [[ - num_nextn_predict_layers, mtp_layer_idx, input_ids, position_ids, - seq_lens, - torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v2, device="cuda"), - mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2, - previous_layer_hidden_states, previous_layer_draft_tokens, - request_ids, num_contexts, num_accepted_tokens, attn_metadata, - ref_input_ids, ref_previous_hidden_states, True - ]] - - ################# CASE 1 ########################## - # MTP1, BS=1, 1 context request - mtp_layer_idx = 1 - batch_size = 1 - num_contexts = 1 - num_nextn_predict_layers = 2 - hidden_size = 12 - request_ids = range(batch_size) - - # Since we will update the data inplace, so we need two data, - # One for is_thop=False, one for is_thop=True - mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - - input_ids = torch.tensor([5, 6, 7, 8, 9], - dtype=torch.int, - device="cuda") - position_ids = torch.tensor([2, 3, 4, 5, 6], - dtype=torch.int, - device="cuda") - - seq_lens = torch.tensor([5], dtype=torch.int, - device="cuda") # [batch_size] - - previous_layer_hidden_states = torch.tensor( - [ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], - [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], - [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], - [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], - [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], - ], - dtype=torch.float32, - device="cuda") # [prompt_length, hidden_size] - - previous_layer_draft_tokens = torch.tensor([10], - dtype=torch.int, - device="cuda") - num_accepted_tokens = torch.tensor([-1], dtype=torch.int, - device="cuda") # useless - - ref_input_ids = torch.tensor([6, 7, 8, 9, 10], - dtype=torch.int, - device="cuda") - ref_previous_hidden_states = previous_layer_hidden_states - attn_metadata = None - - test_cases += [[ - num_nextn_predict_layers, mtp_layer_idx, input_ids, position_ids, - seq_lens, - torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v1, device="cuda"), - mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1, - previous_layer_hidden_states, previous_layer_draft_tokens, - request_ids, num_contexts, num_accepted_tokens, attn_metadata, - ref_input_ids, ref_previous_hidden_states, False - ]] - - test_cases += [[ - num_nextn_predict_layers, mtp_layer_idx, input_ids, position_ids, - seq_lens, - torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v2, device="cuda"), - mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2, - previous_layer_hidden_states, previous_layer_draft_tokens, - request_ids, num_contexts, num_accepted_tokens, attn_metadata, - ref_input_ids, ref_previous_hidden_states, True - ]] - - ################# CASE 2 ########################## - # MTP0, BS=3, 3 context request - mtp_layer_idx = 0 - batch_size = 3 - num_contexts = 3 - num_nextn_predict_layers = 1 - hidden_size = 16 - request_ids = range(batch_size) - - # Since we will update the data inplace, so we need two data, - # One for is_thop=False, one for is_thop=True - mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - - input_ids = torch.tensor( - [ - 5, - 6, - 7, - 8, - 9, # request0 - 10, - 11, - 12, - 13, # request1 - 20, - 21, - 22, - 23, - 24, - 25 # request2 - ], - dtype=torch.int, - device="cuda") - position_ids = torch.tensor( - [[ - 0, - 1, - 2, - 3, - 4, # request0 - 0, - 1, - 2, - 3, # request1 - 0, - 1, - 2, - 3, - 4, - 5 # request2 - ]], - dtype=torch.int, - device="cuda") - - seq_lens = torch.tensor([5, 4, 6], dtype=torch.int, - device="cuda") # [batch_size] - - previous_layer_hidden_states = torch.randn( - (len(input_ids), hidden_size), dtype=torch.float32, - device="cuda") # [prompt_length, hidden_size] - - previous_layer_draft_tokens = torch.tensor([-1, -1, -1], - dtype=torch.int, - device="cuda") # useless - num_accepted_tokens = torch.tensor([-1, -1, -1], - dtype=torch.int, - device="cuda") # useless - - ref_input_ids = torch.tensor( - [6, 7, 8, 9, 0, 11, 12, 13, 1, 21, 22, 23, 24, 25, 2], - dtype=torch.int, - device="cuda") - ref_previous_hidden_states = previous_layer_hidden_states - - attn_metadata = None - - test_cases += [[ - num_nextn_predict_layers, mtp_layer_idx, input_ids, position_ids, - seq_lens, - torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v1, device="cuda"), - mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1, - previous_layer_hidden_states, previous_layer_draft_tokens, - request_ids, num_contexts, num_accepted_tokens, attn_metadata, - ref_input_ids, ref_previous_hidden_states, False - ]] - - test_cases += [[ - num_nextn_predict_layers, mtp_layer_idx, input_ids, position_ids, - seq_lens, - torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v2, device="cuda"), - mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2, - previous_layer_hidden_states, previous_layer_draft_tokens, - request_ids, num_contexts, num_accepted_tokens, attn_metadata, - ref_input_ids, ref_previous_hidden_states, True - ]] - - ################## CASE 3 ########################## - # BS=1, 1 generation request, num_nextn_predict_layers = 1 - mtp_layer_idx = 0 - batch_size = 1 - num_contexts = 0 - num_nextn_predict_layers = 1 - hidden_size = 12 - request_ids = range(batch_size) - - mtp_past_hidden_states_ptrs_v1 = [] - mtp_past_tokens_ptrs_v1 = [] - mtp_hidden_states_tensor_pool_v1 = torch.ones( - (batch_size, num_nextn_predict_layers, hidden_size), - device='cuda', - dtype=torch.float32) - mtp_tokens_tensor_pool_v1 = torch.ones( - (batch_size, num_nextn_predict_layers), - device='cuda', - dtype=torch.int) - - mtp_hidden_states_tensor_pool_v1[0] = torch.tensor( - [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]], - device='cuda', - dtype=torch.float32) - mtp_past_hidden_states_ptrs_v1.append( - mtp_hidden_states_tensor_pool_v1[0].data_ptr()) - mtp_tokens_tensor_pool_v1[0] = torch.tensor([42], - device='cuda', - dtype=torch.int) - mtp_past_tokens_ptrs_v1.append(mtp_tokens_tensor_pool_v1[0].data_ptr()) - - input_ids = torch.tensor([6, 42], dtype=torch.int, device="cuda") - - position_ids = torch.tensor([10, 11], dtype=torch.int, device="cuda") - - # already '-1' in 'update_mtp_hidden_states' - seq_lens = torch.tensor([2], dtype=torch.int, - device="cuda") # [batch_size] - - previous_layer_hidden_states = torch.randn( - (len(input_ids), hidden_size), dtype=torch.float32, - device="cuda") # [prompt_length, hidden_size] - - num_accepted_tokens = torch.tensor([1], dtype=torch.int, device="cuda") - - previous_layer_draft_tokens = torch.tensor([-1], - dtype=torch.int, - device="cuda") # useless - - ref_input_ids = torch.tensor([42], dtype=torch.int, device="cuda") - - ref_previous_hidden_states = torch.tensor( - [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]], - device='cuda', - dtype=torch.float32) - - attn_metadata = None - - test_cases += [[ - num_nextn_predict_layers, mtp_layer_idx, input_ids, position_ids, - seq_lens, - torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v1, device="cuda"), - mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1, - previous_layer_hidden_states, previous_layer_draft_tokens, - request_ids, num_contexts, num_accepted_tokens, attn_metadata, - ref_input_ids, ref_previous_hidden_states, False - ]] - - test_cases += [[ - num_nextn_predict_layers, mtp_layer_idx, input_ids, position_ids, - seq_lens, - torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v1, device="cuda"), - mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1, - previous_layer_hidden_states, previous_layer_draft_tokens, - request_ids, num_contexts, num_accepted_tokens, attn_metadata, - ref_input_ids, ref_previous_hidden_states, True - ]] - - ################## CASE 4 ########################## - # BS=3, 3 generation request, num_nextn_predict_layers = 3 - # mtp_layer_idx = 0 - mtp_layer_idx = 0 - batch_size = 3 - num_contexts = 0 - num_nextn_predict_layers = 3 - hidden_size = 12 - request_ids = range(batch_size) - - mtp_past_hidden_states_ptrs_v1 = [] - mtp_past_tokens_ptrs_v1 = [] - mtp_hidden_states_tensor_pool_v1 = torch.ones( - (batch_size, num_nextn_predict_layers, hidden_size), - device='cuda', - dtype=torch.float32) - mtp_tokens_tensor_pool_v1 = torch.ones( - (batch_size, num_nextn_predict_layers), - device='cuda', - dtype=torch.int) - - mtp_hidden_states_tensor_pool_v1[0] = torch.tensor([ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], - [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], - [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], - ], - device='cuda', - dtype=torch.float32) - mtp_past_hidden_states_ptrs_v1.append( - mtp_hidden_states_tensor_pool_v1[0].data_ptr()) - - mtp_hidden_states_tensor_pool_v1[1] = torch.tensor([ - [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], - [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], - [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 1, 1], - ], - device='cuda', - dtype=torch.float32) - mtp_past_hidden_states_ptrs_v1.append( - mtp_hidden_states_tensor_pool_v1[1].data_ptr()) - - mtp_hidden_states_tensor_pool_v1[2] = torch.tensor([ - [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 1, 1], - [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], - [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 1, 1], - ], - device='cuda', - dtype=torch.float32) - mtp_past_hidden_states_ptrs_v1.append( - mtp_hidden_states_tensor_pool_v1[2].data_ptr()) - - mtp_tokens_tensor_pool_v1[0] = torch.tensor([20, 21, 22], - device='cuda', - dtype=torch.int) - mtp_past_tokens_ptrs_v1.append(mtp_tokens_tensor_pool_v1[0].data_ptr()) - - mtp_tokens_tensor_pool_v1[1] = torch.tensor([30, 31, 32], - device='cuda', - dtype=torch.int) - mtp_past_tokens_ptrs_v1.append(mtp_tokens_tensor_pool_v1[1].data_ptr()) - - mtp_tokens_tensor_pool_v1[2] = torch.tensor([40, 41, 42], - device='cuda', - dtype=torch.int) - mtp_past_tokens_ptrs_v1.append(mtp_tokens_tensor_pool_v1[2].data_ptr()) - - input_ids = torch.tensor( - [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], - dtype=torch.int, - device="cuda") # useless - position_ids = torch.tensor( - [10, 11, 12, 13, 21, 22, 23, 24, 32, 33, 34, 35], - dtype=torch.int, - device="cuda") - - seq_lens = torch.tensor([4, 4, 4], dtype=torch.int, - device="cuda") # [batch_size] - - previous_layer_hidden_states = torch.randn( - (12, hidden_size), device='cuda', dtype=torch.float32) # useless - - num_accepted_tokens = torch.tensor([1, 2, 4], - dtype=torch.int, - device="cuda") - - previous_layer_draft_tokens = torch.tensor([-1], - dtype=torch.int, - device="cuda") # useless - - ref_input_ids = torch.tensor([20, 21, 22, 30, 31, 32, 40, 41, 42], - dtype=torch.int, - device="cuda") - - ref_previous_hidden_states = torch.tensor([ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], - [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], - [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], - [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], - [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], - [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 1, 1], - [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 1, 1], - [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], - [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 1, 1], - ], - device='cuda', - dtype=torch.float32) - attn_metadata = None - - test_cases += [[ - num_nextn_predict_layers, mtp_layer_idx, input_ids, position_ids, - seq_lens, - torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v1, device="cuda"), - mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1, - previous_layer_hidden_states, previous_layer_draft_tokens, - request_ids, num_contexts, num_accepted_tokens, attn_metadata, - ref_input_ids, ref_previous_hidden_states, False - ]] - - test_cases += [[ - num_nextn_predict_layers, mtp_layer_idx, input_ids, position_ids, - seq_lens, - torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v1, device="cuda"), - mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1, - previous_layer_hidden_states, previous_layer_draft_tokens, - request_ids, num_contexts, num_accepted_tokens, attn_metadata, - ref_input_ids, ref_previous_hidden_states, True - ]] - - ################## CASE 5 ########################## - # BS=3, 3 generation request, num_nextn_predict_layers = 3 - # mtp_layer_idx = 1 - mtp_layer_idx = 1 - batch_size = 3 - num_contexts = 0 - num_nextn_predict_layers = 3 - hidden_size = 12 - request_ids = range(batch_size) - - # Since we will update the data inplace, so we need two data, - # One for is_thop=False, one for is_thop=True - mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - - input_ids = torch.tensor([20, 21, 22, 30, 31, 32, 40, 41, 42], - dtype=torch.int, - device="cuda") - position_ids = torch.tensor([9, 10, 11, 21, 22, 23, 33, 34, 35], - dtype=torch.int, - device="cuda") - - seq_lens = torch.tensor([3, 3, 3], dtype=torch.int, - device="cuda") # [batch_size] - - previous_layer_hidden_states = torch.randn( - (len(input_ids), hidden_size), dtype=torch.float32, device="cuda") - - num_accepted_tokens = torch.tensor([-1, -1, -1], - dtype=torch.int, - device="cuda") # useless - - previous_layer_draft_tokens = torch.tensor([11, 12, 14], - dtype=torch.int, - device="cuda") - - ref_input_ids = torch.tensor([21, 22, 11, 31, 32, 12, 41, 42, 14], - dtype=torch.int, - device="cuda") - - ref_previous_hidden_states = previous_layer_hidden_states - attn_metadata = None - - test_cases += [[ - num_nextn_predict_layers, mtp_layer_idx, input_ids, position_ids, - seq_lens, - torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v1, device="cuda"), - mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1, - previous_layer_hidden_states, previous_layer_draft_tokens, - request_ids, num_contexts, num_accepted_tokens, attn_metadata, - ref_input_ids, ref_previous_hidden_states, False - ]] - - test_cases += [[ - num_nextn_predict_layers, mtp_layer_idx, input_ids, position_ids, - seq_lens, - torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v2, device="cuda"), - mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2, - previous_layer_hidden_states, previous_layer_draft_tokens, - request_ids, num_contexts, num_accepted_tokens, attn_metadata, - ref_input_ids, ref_previous_hidden_states, True - ]] - - return test_cases - - @parameterized.expand(load_prepare_drafter_inputs_test_cases, - name_func=unittest_name_func) - def test_prepare_drafter_inputs( - self, num_nextn_predict_layers, mtp_layer_idx, input_ids, - position_ids, seq_lens, mtp_past_hidden_states_ptrs, - mtp_past_tokens_ptrs, mtp_hidden_states_tensor_pool, - mtp_tokens_tensor_pool, previous_layer_hidden_states, - previous_layer_draft_tokens, request_ids, num_contexts, - num_accepted_tokens, attn_metadata, ref_input_ids, - ref_previous_hidden_states, is_thop): - - batch_size = len(request_ids) - if previous_layer_hidden_states is not None: - hidden_size = previous_layer_hidden_states.shape[1] - else: - hidden_size = 10 - spec_config = MTPConfig( - num_nextn_predict_layers=num_nextn_predict_layers) - - if attn_metadata is None: - attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size, - max_num_tokens=1024, - kv_cache_manager=None) - attn_metadata.seq_lens = seq_lens.to('cpu') - attn_metadata.num_contexts = num_contexts - # dummy kv cache param - attn_metadata.kv_cache_params = KVCacheParams( - use_cache=True, num_cached_tokens_per_seq=[10] * batch_size) - - spec_manager = MTPHiddenStatesManager(config=spec_config, - dtype=torch.float32, - hidden_size=hidden_size, - max_num_requests=batch_size) - for i in range(batch_size): - # for the generation requests, we also need to manually add slot - # because these generation requests are also first use - spec_manager.slot_manager.add_slot(request_ids[i]) - - spec_metadata = MTPSpecMetadata( - max_num_requests=32, - spec_dec_mode=spec_config.spec_dec_mode, - max_draft_tokens=num_nextn_predict_layers, - mtp_num_modules=num_nextn_predict_layers, - mtp_hidden_states_manager=spec_manager) - spec_metadata.request_ids = request_ids - spec_metadata.mtp_hidden_states_ptrs = mtp_past_hidden_states_ptrs - spec_metadata.mtp_past_tokens_ptrs = mtp_past_tokens_ptrs - - spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool = mtp_hidden_states_tensor_pool - spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool = mtp_tokens_tensor_pool - spec_metadata.prepare() - - mtpworker = MTPWorker(spec_config) - mtpworker.is_thop = is_thop - draft_inputs = mtpworker.prepare_drafter_inputs( - mtp_layer_idx=mtp_layer_idx, - input_ids=input_ids, - position_ids=position_ids, - previous_layer_hidden_states=previous_layer_hidden_states, - previous_layer_draft_tokens=previous_layer_draft_tokens, - num_accepted_tokens=num_accepted_tokens, - spec_metadata=spec_metadata, - attn_metadata=attn_metadata) - - torch.testing.assert_close(draft_inputs["input_ids"], ref_input_ids) - torch.testing.assert_close(draft_inputs["hidden_states"], - ref_previous_hidden_states) diff --git a/tests/unittest/_torch/speculative/test_mtp_sample_and_accept_draft_tokens.py b/tests/unittest/_torch/speculative/test_mtp_sample_and_accept_draft_tokens.py deleted file mode 100644 index 954d5476507..00000000000 --- a/tests/unittest/_torch/speculative/test_mtp_sample_and_accept_draft_tokens.py +++ /dev/null @@ -1,323 +0,0 @@ -import unittest - -import torch -from parameterized import parameterized -from utils.util import unittest_name_func - -import tensorrt_llm -from tensorrt_llm._torch.attention_backend import TrtllmAttentionMetadata -from tensorrt_llm._torch.speculative.mtp import (MTPConfig, MTPSpecMetadata, - MTPWorker) - - -class TestMTPSampleAndAcceptDraftTokens(unittest.TestCase): - - def setUp(self): - tensorrt_llm.logger.set_level('warning') - - def load_sample_and_accept_draft_tokens_test_cases(): - test_cases = [] - - # ''' - ################# CASE 0 ########################## - # BS=1, 1 context request - mtp_num_modules = 1 - num_context_requests = 1 - logits = torch.tensor( - [ - [-100, 0, -100, -100, -100, -100, -100, -100], # Top1 id = 1 - ], - dtype=torch.float32, - device="cuda") # [num_tokens, vocab_size] - - draft_tokens = torch.tensor( - [], dtype=torch.int, - device="cuda") # [batch_size * max_draft_tokens] - - draft_len = torch.tensor([0], dtype=torch.int, - device="cuda") # [batch_size] - - ref_accepted_tokens = torch.tensor( - [[1, 0]], dtype=torch.int, - device="cuda") # [batch_size * max_draft_tokens] - - ref_num_accepted_tokens = torch.tensor([1], - dtype=torch.int, - device="cuda") # [batch_size] - - test_cases += [[ - mtp_num_modules, logits, draft_tokens, draft_len, - num_context_requests, ref_accepted_tokens, ref_num_accepted_tokens - ]] - - ################# CASE 1 ########################## - # BS=4, 4 context request - mtp_num_modules = 1 - num_context_requests = 4 - logits = torch.tensor( - [ - [-100, 0, -100, -100, -100, -100, -100, -100], # Top1 id = 1 - [-100, -100, -100, 0, -100, -100, -100, -100], # Top1 id = 3 - [-100, -100, -100, 0, -100, -100, -100, -100], # Top1 id = 3 - [-100, -100, -100, -100, -100, -100, 0, -100], # Top1 id = 6 - ], - dtype=torch.float32, - device="cuda") # [num_tokens, vocab_size] - - draft_tokens = torch.tensor( - [], dtype=torch.int, - device="cuda") # [batch_size * max_draft_tokens] - - draft_len = torch.tensor([0, 0, 0, 0], dtype=torch.int, - device="cuda") # [batch_size] - - ref_accepted_tokens = torch.tensor( - [[1, 0], [3, 0], [3, 0], [6, 0]], dtype=torch.int, - device="cuda") # [batch_size * max_draft_tokens] - - ref_num_accepted_tokens = torch.tensor([1, 1, 1, 1], - dtype=torch.int, - device="cuda") # [batch_size] - - test_cases += [[ - mtp_num_modules, logits, draft_tokens, draft_len, - num_context_requests, ref_accepted_tokens, ref_num_accepted_tokens - ]] - - ################# CASE 2 ########################## - # BS=1, 1 generation request - # Assume there are 3 MTP layers - # For each generation request, there are four logits here: one from golden token + three draft tokens - mtp_num_modules = 3 - num_context_requests = 0 - logits = torch.tensor( - [ - [-100, 0, -100, -100, -100, -100, -100, -100], # Top1 id = 1 - [-100, -100, -100, 0, -100, -100, -100, -100], # Top1 id = 3 - [-100, -100, 0, -100, -100, -100, -100, -100], # Top1 id = 2 - [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 - ], - dtype=torch.float32, - device="cuda") # [num_tokens, vocab_size] - - draft_tokens = torch.tensor( - [1, 3, 4], dtype=torch.int, - device="cuda") # [batch_size * max_draft_tokens] - - draft_len = torch.tensor([3], dtype=torch.int, - device="cuda") # [batch_size] - - ref_accepted_tokens = torch.tensor( - [[1, 3, 2, 0]], dtype=torch.int, - device="cuda") # [batch_size * max_draft_tokens] - - ref_num_accepted_tokens = torch.tensor([3], - dtype=torch.int, - device="cuda") # [batch_size] - - test_cases += [[ - mtp_num_modules, logits, draft_tokens, draft_len, - num_context_requests, ref_accepted_tokens, ref_num_accepted_tokens - ]] - - ################# CASE 3 ########################## - # BS=2, 2 generation request - # Assume there are 1 MTP layers - # For each generation request, there are two logits here: one golden token + one draft token - mtp_num_modules = 1 - num_context_requests = 0 - logits = torch.tensor( - [ - [-100, 0, -100, -100, -100, -100, -100, -100], # Top1 id = 1 - [-100, -100, -100, 0, -100, -100, -100, -100], # Top1 id = 3 - [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 - [-100, -100, -100, -100, -100, -100, 0, -100], # Top1 id = 6 - ], - dtype=torch.float32, - device="cuda") # [num_tokens, vocab_size] - - draft_tokens = torch.tensor( - [1, 5], dtype=torch.int, - device="cuda") # [batch_size * max_draft_tokens] - - draft_len = torch.tensor([1, 1], dtype=torch.int, - device="cuda") # [batch_size] - - ref_accepted_tokens = torch.tensor( - [[1, 3], [4, 0]], dtype=torch.int, - device="cuda") # [batch_size * max_draft_tokens] - - ref_num_accepted_tokens = torch.tensor([2, 1], - dtype=torch.int, - device="cuda") # [batch_size] - - test_cases += [[ - mtp_num_modules, logits, draft_tokens, draft_len, - num_context_requests, ref_accepted_tokens, ref_num_accepted_tokens - ]] - - ################# CASE 4 ########################## - # BS=2, 2 generation request - # Assume there are 3 MTP layers - # For each generation request, there are four logits here: one golden token + three draft token - mtp_num_modules = 3 - num_context_requests = 0 - logits = torch.tensor( - [ - [-100, 0, -100, -100, -100, -100, -100, -100], # Top1 id = 1 - [-100, -100, -100, 0, -100, -100, -100, -100], # Top1 id = 3 - [-100, -100, 0, -100, -100, -100, -100, -100], # Top1 id = 2 - [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 - [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 - [-100, -100, -100, -100, -100, -100, 0, -100], # Top1 id = 6 - [-100, -100, -100, -100, -100, 0, -100, -100], # Top1 id = 5 - [-100, -100, 0, -100, -100, -100, -100, -100], # Top1 id = 2 - ], - dtype=torch.float32, - device="cuda") # [num_tokens, vocab_size] - - draft_tokens = torch.tensor( - [1, 3, 4, 4, 7, 3], dtype=torch.int, - device="cuda") # [batch_size * max_draft_tokens] - - draft_len = torch.tensor([3, 3], dtype=torch.int, - device="cuda") # [batch_size] - - ref_accepted_tokens = torch.tensor( - [[1, 3, 2, 0], [4, 6, 0, 0]], dtype=torch.int, - device="cuda") # [batch_size * max_draft_tokens] - - ref_num_accepted_tokens = torch.tensor([3, 2], - dtype=torch.int, - device="cuda") # [batch_size] - - test_cases += [[ - mtp_num_modules, logits, draft_tokens, draft_len, - num_context_requests, ref_accepted_tokens, ref_num_accepted_tokens - ]] - - ################# CASE 5 ########################## - # BS=3, 3 generation request, 2 accept partial, 1 accept all, 1 accept none - # Assume there are 3 MTP layers - # For each generation request, there are four logits here: one golden token + three draft token - mtp_num_modules = 3 - num_context_requests = 0 - logits = torch.tensor( - [ - [-100, 0, -100, -100, -100, -100, -100, -100], # Top1 id = 1 - [-100, -100, -100, 0, -100, -100, -100, -100], # Top1 id = 3 - [-100, -100, 0, -100, -100, -100, -100, -100], # Top1 id = 2 - [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 - [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 - [-100, -100, -100, -100, -100, -100, 0, -100], # Top1 id = 6 - [-100, -100, -100, -100, -100, 0, -100, -100], # Top1 id = 5 - [-100, -100, 0, -100, -100, -100, -100, -100], # Top1 id = 2 - [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 - [-100, -100, -100, 0, -100, -100, -100, -100], # Top1 id = 3 - [-100, -100, -100, -100, -100, 0, -100, -100], # Top1 id = 5 - [-100, -100, -100, -100, -100, -100, -100, 0], # Top1 id = 7 - ], - dtype=torch.float32, - device="cuda") # [num_tokens, vocab_size] - - draft_tokens = torch.tensor( - [1, 3, 5, 4, 6, 5, 5, 7, 4], dtype=torch.int, - device="cuda") # [batch_size * max_draft_tokens] - - draft_len = torch.tensor([3, 3, 3], dtype=torch.int, - device="cuda") # [batch_size] - - ref_accepted_tokens = torch.tensor( - [[1, 3, 2, 0], [4, 6, 5, 2], [4, 0, 0, 0]], - dtype=torch.int, - device="cuda") # [batch_size * max_draft_tokens] - - ref_num_accepted_tokens = torch.tensor([3, 4, 1], - dtype=torch.int, - device="cuda") # [batch_size] - - test_cases += [[ - mtp_num_modules, logits, draft_tokens, draft_len, - num_context_requests, ref_accepted_tokens, ref_num_accepted_tokens - ]] - - ################# CASE 6 ########################## - # BS=2, 1 context request, 1 generation request - # request0 is context request, and request1 is generation request - # Assume there are 1 MTP layers - # For each generation request, there are two logits here: one golden token + one draft token - mtp_num_modules = 1 - num_context_requests = 1 - logits = torch.tensor( - [ - [-100, 0, -100, -100, -100, -100, -100, -100], # Top1 id = 1 - [-100, -100, -100, -100, 0, -100, -100, -100], # Top1 id = 4 - [-100, -100, -100, -100, -100, -100, 0, -100], # Top1 id = 6 - ], - dtype=torch.float32, - device="cuda") # [num_tokens, vocab_size] - - draft_tokens = torch.tensor( - [4], dtype=torch.int, - device="cuda") # [batch_size * max_draft_tokens] - - draft_len = torch.tensor([0, 1], dtype=torch.int, - device="cuda") # [batch_size] - - ref_accepted_tokens = torch.tensor( - [[1, 0], [4, 6]], dtype=torch.int, - device="cuda") # [batch_size * max_draft_tokens] - - ref_num_accepted_tokens = torch.tensor([1, 2], - dtype=torch.int, - device="cuda") # [batch_size] - - test_cases += [[ - mtp_num_modules, logits, draft_tokens, draft_len, - num_context_requests, ref_accepted_tokens, ref_num_accepted_tokens - ]] - - return test_cases - - @parameterized.expand(load_sample_and_accept_draft_tokens_test_cases, - name_func=unittest_name_func) - def test_sample_and_accept_draft_tokens(self, mtp_num_modules, logits, - draft_tokens, draft_len, - num_context_requests, - ref_accepted_tokens, - ref_num_accepted_tokens): - batch_size = len(draft_len) - spec_config = MTPConfig(num_nextn_predict_layers=mtp_num_modules) - - # attention metedata - attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size, - max_num_tokens=1024, - kv_cache_manager=None) - attn_metadata.seq_lens = torch.tensor( - [1] * batch_size, dtype=torch.int) # dummy sequence length - attn_metadata.num_contexts = num_context_requests - - # speculative decoding metadata - spec_metadata = MTPSpecMetadata(max_num_requests=32, - spec_dec_mode=spec_config.spec_dec_mode, - max_draft_tokens=mtp_num_modules, - mtp_num_modules=mtp_num_modules) - spec_metadata.draft_tokens = draft_tokens - - # mtp worker - mtpworker = MTPWorker(spec_config) - - # Test thop kernel - # Test native torch op - for is_thop in [True, False]: - mtpworker.is_thop = is_thop - # TODO: add unit tests for relaxed acceptance - accepted_tokens, num_accepted_tokens = mtpworker.sample_and_accept_draft_tokens( - None, logits, spec_metadata, attn_metadata) - - torch.testing.assert_close(num_accepted_tokens, - ref_num_accepted_tokens) - for i in range(len(draft_len)): - torch.testing.assert_close( - accepted_tokens[i][0:ref_num_accepted_tokens[i]], - ref_accepted_tokens[i][0:ref_num_accepted_tokens[i]]) diff --git a/tests/unittest/_torch/speculative/test_mtp_update_mtp_hidden_states.py b/tests/unittest/_torch/speculative/test_mtp_update_mtp_hidden_states.py deleted file mode 100644 index c5765ec473e..00000000000 --- a/tests/unittest/_torch/speculative/test_mtp_update_mtp_hidden_states.py +++ /dev/null @@ -1,625 +0,0 @@ -import unittest - -import torch -from parameterized import parameterized -from utils.util import unittest_name_func - -import tensorrt_llm -from tensorrt_llm._torch.attention_backend import TrtllmAttentionMetadata -from tensorrt_llm._torch.speculative.mtp import (MTPConfig, - MTPHiddenStatesManager, - MTPSpecMetadata, MTPWorker) - - -class TestMTPUpdateMTPHiddenStates(unittest.TestCase): - - def setUp(self): - tensorrt_llm.logger.set_level('warning') - - def load_update_mtp_hidden_states_test_cases(): - - def gen_data(batch_size, num_nextn_predict_layers, hidden_size): - mtp_past_hidden_states_ptrs = [] - mtp_past_tokens_ptrs = [] - mtp_hidden_states_tensor_pool = torch.ones( - (batch_size, num_nextn_predict_layers, hidden_size), - device='cuda', - dtype=torch.float32) - mtp_tokens_tensor_pool = torch.ones( - (batch_size, num_nextn_predict_layers), - device='cuda', - dtype=torch.int) - - for bix in range(batch_size): - mtp_hidden_states_tensor_pool[ - bix] = mtp_hidden_states_tensor_pool[ - bix] * bix # be different - mtp_past_hidden_states_ptrs.append( - mtp_hidden_states_tensor_pool[bix].data_ptr()) - - mtp_tokens_tensor_pool[ - bix] = mtp_tokens_tensor_pool[bix] * bix # be different - mtp_past_tokens_ptrs.append( - mtp_tokens_tensor_pool[bix].data_ptr()) - return mtp_past_hidden_states_ptrs, mtp_past_tokens_ptrs, mtp_hidden_states_tensor_pool, mtp_tokens_tensor_pool - - test_cases = [] - - ################# CASE 0 ########################## - # BS=1, 1 context request - batch_size = 1 - num_context_request = 1 - num_nextn_predict_layers = 1 - hidden_size = 12 - request_ids = range(batch_size) - - # Since we will update the data inplace, so we need two data, - # One for is_thop=False, one for is_thop=True - mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - - input_ids = torch.tensor([5, 6, 7, 8, 9], - dtype=torch.int, - device="cuda") - - seq_lens = torch.tensor([5], dtype=torch.int, - device="cuda") # [batch_size] - - target_model_hidden_states = torch.tensor( - [ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], - [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], - [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], - [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], - [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], - ], - dtype=torch.float32, - device="cuda") # [prompt_length, hidden_size] - - num_accepted_tokens = torch.tensor([1], dtype=torch.int, - device="cuda") # [batch_size] - - accepted_tokens = torch.tensor([[42, 0]], - dtype=torch.int, - device="cuda") # [batch_size] - - ref_mtp_tokens_dict = dict() - ref_mtp_tokens_dict[0] = torch.tensor([42], - dtype=torch.int, - device="cuda") - - ref_mtp_hidden_state_dict = dict() - ref_mtp_hidden_state_dict[0] = torch.tensor([ - [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], - ], - dtype=torch.float32, - device="cuda") - - test_cases += [[ - num_nextn_predict_layers, num_context_request, input_ids, seq_lens, - target_model_hidden_states, - torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v1, device="cuda"), - mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1, - request_ids, num_accepted_tokens, accepted_tokens, - ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, False - ]] - - test_cases += [[ - num_nextn_predict_layers, num_context_request, input_ids, seq_lens, - target_model_hidden_states, - torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v2, device="cuda"), - mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2, - request_ids, num_accepted_tokens, accepted_tokens, - ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, True - ]] - - ################## CASE 1 ########################## - # BS=3, 3 context request, num_nextn_predict_layers = 2 - batch_size = 3 - num_context_request = 3 - num_nextn_predict_layers = 2 - hidden_size = 12 - request_ids = range(batch_size) - - # Since we will update the data inplace, so we need two data, - # One for is_thop=False, one for is_thop=True - mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - - input_ids = torch.tensor( - [ - 5, - 6, - 7, - 8, - 9, # request 1 - 11, - 12, - 13, - 14, # request 2 - 21, - 22, - 23, - 24, - 25, - 26 # request 3 - ], - dtype=torch.int, - device="cuda") - - seq_lens = torch.tensor([5, 4, 6], dtype=torch.int, - device="cuda") # [batch_size] - - target_model_hidden_states = torch.tensor( - [ - # request 1 - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], - [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], - [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], - [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], - [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], - - # request 2 - [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 1, 1], - [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 1, 1], - [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], - [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 1, 1], - - # request 3 - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], - [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], - [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], - [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], - [90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 1, 1], - [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 1, 1], - ], - dtype=torch.float32, - device="cuda") # [prompt_length, hidden_size] - - num_accepted_tokens = torch.tensor([1, 1, 1], - dtype=torch.int, - device="cuda") # [batch_size] - - accepted_tokens = torch.tensor([[42, 0, 0], [26, 0, 0], [33, 0, 0]], - dtype=torch.int, - device="cuda") # [batch_size] - - ref_mtp_tokens_dict = dict() - ref_mtp_tokens_dict[0] = torch.tensor([9, 42], - dtype=torch.int, - device="cuda") - ref_mtp_tokens_dict[1] = torch.tensor([14, 26], - dtype=torch.int, - device="cuda") - ref_mtp_tokens_dict[2] = torch.tensor([26, 33], - dtype=torch.int, - device="cuda") - - ref_mtp_hidden_state_dict = dict() - ref_mtp_hidden_state_dict[0] = torch.tensor([ - [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], - [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], - ], - dtype=torch.float32, - device="cuda") - ref_mtp_hidden_state_dict[1] = torch.tensor([ - [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], - [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 1, 1], - ], - dtype=torch.float32, - device="cuda") - ref_mtp_hidden_state_dict[2] = torch.tensor([ - [90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 1, 1], - [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 1, 1], - ], - dtype=torch.float32, - device="cuda") - - test_cases += [[ - num_nextn_predict_layers, num_context_request, input_ids, seq_lens, - target_model_hidden_states, - torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v1, device="cuda"), - mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1, - request_ids, num_accepted_tokens, accepted_tokens, - ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, False - ]] - - test_cases += [[ - num_nextn_predict_layers, num_context_request, input_ids, seq_lens, - target_model_hidden_states, - torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v2, device="cuda"), - mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2, - request_ids, num_accepted_tokens, accepted_tokens, - ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, True - ]] - - ################## CASE 2 ########################## - # BS=1, 1 generation request, num_nextn_predict_layers = 1 - batch_size = 1 - num_context_request = 0 - num_nextn_predict_layers = 1 - hidden_size = 12 - request_ids = range(batch_size) - - # Since we will update the data inplace, so we need two data, - # One for is_thop=False, one for is_thop=True - mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - - input_ids = torch.tensor([5, 42], dtype=torch.int, device="cuda") - - seq_lens = torch.tensor([2], dtype=torch.int, - device="cuda") # [batch_size] - - target_model_hidden_states = torch.tensor( - [ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], - [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], - ], - dtype=torch.float32, - device="cuda") # [prompt_length, hidden_size] - - num_accepted_tokens = torch.tensor([2], dtype=torch.int, - device="cuda") # [batch_size] - - accepted_tokens = torch.tensor([[42, 43]], - dtype=torch.int, - device="cuda") # [batch_size] - - ref_mtp_tokens_dict = dict() - ref_mtp_tokens_dict[0] = torch.tensor([43], - dtype=torch.int, - device="cuda") - - ref_mtp_hidden_state_dict = dict() - ref_mtp_hidden_state_dict[0] = torch.tensor([ - [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], - ], - dtype=torch.float32, - device="cuda") - - test_cases += [[ - num_nextn_predict_layers, num_context_request, input_ids, seq_lens, - target_model_hidden_states, - torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v1, device="cuda"), - mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1, - request_ids, num_accepted_tokens, accepted_tokens, - ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, False - ]] - - test_cases += [[ - num_nextn_predict_layers, num_context_request, input_ids, seq_lens, - target_model_hidden_states, - torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v2, device="cuda"), - mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2, - request_ids, num_accepted_tokens, accepted_tokens, - ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, True - ]] - - ################## CASE 3 ########################## - # BS=4, 4 generation request, num_nextn_predict_layers = 1 - batch_size = 4 - num_context_request = 0 - num_nextn_predict_layers = 1 - hidden_size = 12 - request_ids = range(batch_size) - - # Since we will update the data inplace, so we need two data, - # One for is_thop=False, one for is_thop=True - mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - - input_ids = torch.tensor( - [ - 5, - 6, # request 1 - 7, - 8, # request 2 - 9, - 10, # request 3 - 11, - 12 # request 4 - ], - dtype=torch.int, - device="cuda") - - seq_lens = torch.tensor([2, 2, 2, 2], dtype=torch.int, - device="cuda") # [batch_size] - - target_model_hidden_states = torch.tensor( - [ - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], - [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], - [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], - [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], - [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], - [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 1, 1], - [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 1, 1], - [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], - ], - dtype=torch.float32, - device="cuda") # [prompt_length, hidden_size] - - num_accepted_tokens = torch.tensor([2, 1, 1, 2], - dtype=torch.int, - device="cuda") # [batch_size] - - accepted_tokens = torch.tensor([[42, 41], [25, 0], [27, 0], [30, 33]], - dtype=torch.int, - device="cuda") # [batch_size] - - ref_mtp_tokens_dict = dict() - ref_mtp_tokens_dict[0] = torch.tensor([41], - dtype=torch.int, - device="cuda") - ref_mtp_tokens_dict[1] = torch.tensor([25], - dtype=torch.int, - device="cuda") - ref_mtp_tokens_dict[2] = torch.tensor([27], - dtype=torch.int, - device="cuda") - ref_mtp_tokens_dict[3] = torch.tensor([33], - dtype=torch.int, - device="cuda") - - ref_mtp_hidden_state_dict = dict() - ref_mtp_hidden_state_dict[0] = torch.tensor([ - [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], - ], - dtype=torch.float32, - device="cuda") - ref_mtp_hidden_state_dict[1] = torch.tensor([ - [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], - ], - dtype=torch.float32, - device="cuda") - ref_mtp_hidden_state_dict[2] = torch.tensor([ - [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], - ], - dtype=torch.float32, - device="cuda") - ref_mtp_hidden_state_dict[3] = torch.tensor([ - [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], - ], - dtype=torch.float32, - device="cuda") - - test_cases += [[ - num_nextn_predict_layers, num_context_request, input_ids, seq_lens, - target_model_hidden_states, - torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v1, device="cuda"), - mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1, - request_ids, num_accepted_tokens, accepted_tokens, - ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, False - ]] - - test_cases += [[ - num_nextn_predict_layers, num_context_request, input_ids, seq_lens, - target_model_hidden_states, - torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v2, device="cuda"), - mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2, - request_ids, num_accepted_tokens, accepted_tokens, - ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, True - ]] - - ################## CASE 4 ########################## - # BS=4, 2 context request, 2 generation request, num_nextn_predict_layers = 2 - num_context = 2 - num_generation = 2 - num_context_request = num_context - batch_size = num_context + num_generation - num_nextn_predict_layers = 2 - hidden_size = 12 - request_ids = range(batch_size) - - # Since we will update the data inplace, so we need two data, - # One for is_thop=False, one for is_thop=True - mtp_past_hidden_states_ptrs_v1, mtp_past_tokens_ptrs_v1, mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - mtp_past_hidden_states_ptrs_v2, mtp_past_tokens_ptrs_v2, mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2 = gen_data( - batch_size, num_nextn_predict_layers, hidden_size) - - input_ids = torch.tensor( - [ - 5, - 6, - 7, - 8, - 9, # request 1 - 10, - 11, - 12, - 13, # request 2 - 26, - 27, - 28, # request 3 - 31, - 32, - 33 # request 4 - ], - dtype=torch.int, - device="cuda") - - seq_lens = torch.tensor([5, 4, 3, 3], dtype=torch.int, - device="cuda") # [batch_size] - - target_model_hidden_states = torch.tensor( - [ - # request 1 - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], - [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], - [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], - [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], - [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], - - # request 2 - [50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 1, 1], - [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 1, 1], - [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], - [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 1, 1], - - # request 3 - [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1], - [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], - [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], - - # request 4 - [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 1, 1], - [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], - [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 1, 1], - ], - dtype=torch.float32, - device="cuda") # [prompt_length, hidden_size] - - num_accepted_tokens = torch.tensor([1, 1, 3, 1], - dtype=torch.int, - device="cuda") # [batch_size] - - accepted_tokens = torch.tensor( - [[40, 0, 0], [41, 0, 0], [42, 43, 44], [45, 0, 0]], - dtype=torch.int, - device="cuda") # [batch_size] - - ref_mtp_tokens_dict = dict() - ref_mtp_tokens_dict[0] = torch.tensor([9, 40], - dtype=torch.int, - device="cuda") - ref_mtp_tokens_dict[1] = torch.tensor([13, 41], - dtype=torch.int, - device="cuda") - ref_mtp_tokens_dict[2] = torch.tensor([43, 44], - dtype=torch.int, - device="cuda") - ref_mtp_tokens_dict[3] = torch.tensor([3, 45], - dtype=torch.int, - device="cuda") - - ref_mtp_hidden_state_dict = dict() - ref_mtp_hidden_state_dict[0] = torch.tensor([ - [30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 1, 1], - [40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 1, 1], - ], - dtype=torch.float32, - device="cuda") - ref_mtp_hidden_state_dict[1] = torch.tensor([ - [70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 1, 1], - [80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 1, 1], - ], - dtype=torch.float32, - device="cuda") - ref_mtp_hidden_state_dict[2] = torch.tensor([ - [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 1, 1], - [20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 1, 1], - ], - dtype=torch.float32, - device="cuda") - ref_mtp_hidden_state_dict[3] = torch.tensor([ - [3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], - [60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 1, 1], - ], - dtype=torch.float32, - device="cuda") - - test_cases += [[ - num_nextn_predict_layers, num_context_request, input_ids, seq_lens, - target_model_hidden_states, - torch.tensor(mtp_past_hidden_states_ptrs_v1, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v1, device="cuda"), - mtp_hidden_states_tensor_pool_v1, mtp_tokens_tensor_pool_v1, - request_ids, num_accepted_tokens, accepted_tokens, - ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, False - ]] - - test_cases += [[ - num_nextn_predict_layers, num_context_request, input_ids, seq_lens, - target_model_hidden_states, - torch.tensor(mtp_past_hidden_states_ptrs_v2, device="cuda"), - torch.tensor(mtp_past_tokens_ptrs_v2, device="cuda"), - mtp_hidden_states_tensor_pool_v2, mtp_tokens_tensor_pool_v2, - request_ids, num_accepted_tokens, accepted_tokens, - ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, True - ]] - - return test_cases - - @parameterized.expand(load_update_mtp_hidden_states_test_cases, - name_func=unittest_name_func) - def test_mtp_update_mtp_hidden_states( - self, num_nextn_predict_layers, num_context_request, input_ids, - seq_lens, target_model_hidden_states, mtp_hidden_states_ptrs, - mtp_past_tokens_ptrs, mtp_hidden_states_tensor_pool, - mtp_tokens_tensor_pool, request_ids, num_accepted_tokens, - accepted_tokens, ref_mtp_tokens_dict, ref_mtp_hidden_state_dict, - is_thop): - - batch_size = len(request_ids) - batch_size - num_context_request - hidden_size = target_model_hidden_states.shape[1] - spec_config = MTPConfig( - num_nextn_predict_layers=num_nextn_predict_layers) - - attn_metadata = TrtllmAttentionMetadata(max_num_requests=batch_size, - max_num_tokens=1024, - kv_cache_manager=None) - attn_metadata.seq_lens = seq_lens.to('cpu') - attn_metadata.num_contexts = num_context_request - - spec_manager = MTPHiddenStatesManager(config=spec_config, - dtype=torch.float32, - hidden_size=hidden_size, - max_num_requests=batch_size) - for i in range(batch_size): - # for the generation requests, we also need to manually add slot - # because these generation requests are also first use - spec_manager.slot_manager.add_slot(request_ids[i]) - - spec_metadata = MTPSpecMetadata( - max_num_requests=32, - spec_dec_mode=spec_config.spec_dec_mode, - max_draft_tokens=num_nextn_predict_layers, - mtp_num_modules=num_nextn_predict_layers, - mtp_hidden_states_manager=spec_manager) - spec_metadata.request_ids = request_ids - spec_metadata.mtp_hidden_states_ptrs = mtp_hidden_states_ptrs - spec_metadata.mtp_past_tokens_ptrs = mtp_past_tokens_ptrs - - spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool = mtp_hidden_states_tensor_pool - spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool = mtp_tokens_tensor_pool - spec_metadata.prepare() - - mtpworker = MTPWorker(spec_config) - mtpworker.is_thop = is_thop - - mtpworker.update_mtp_hidden_states( - input_ids=input_ids, - target_model_hidden_states=target_model_hidden_states, - num_accepted_tokens=num_accepted_tokens, - accepted_tokens=accepted_tokens, - spec_metadata=spec_metadata, - attn_metadata=attn_metadata) - - # Verify - mtp_past_hidden_states_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_hidden_states_pool - mtp_past_tokens_pool = spec_metadata.mtp_hidden_states_manager.mtp_past_tokens_pool - - for bix in range(batch_size): - torch.testing.assert_close(mtp_past_tokens_pool[bix], - ref_mtp_tokens_dict[bix]) - torch.testing.assert_close(mtp_past_hidden_states_pool[bix], - ref_mtp_hidden_state_dict[bix])