diff --git a/cpp/include/tensorrt_llm/runtime/decoderState.h b/cpp/include/tensorrt_llm/runtime/decoderState.h index 6c5ab093e9a..6781e177766 100644 --- a/cpp/include/tensorrt_llm/runtime/decoderState.h +++ b/cpp/include/tensorrt_llm/runtime/decoderState.h @@ -182,6 +182,14 @@ class DecoderState //! @brief Cache indirection output for beam search. [[nodiscard]] TensorPtr getCacheIndirectionOutput() const; + //! @brief Get the generation steps for all requests in the batch. + //! @returns The generation steps for all requests in the batch. + [[nodiscard]] std::optional> const& getGenerationSteps() const; + + //! @brief Set the generation steps for all requests in the batch. + //! @param generationSteps The generation steps for all requests in the batch. + void setGenerationSteps(std::vector const& generationSteps); + //! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots. [[nodiscard]] DecodingInput& getJointDecodingInput() const; diff --git a/cpp/include/tensorrt_llm/runtime/decodingInput.h b/cpp/include/tensorrt_llm/runtime/decodingInput.h index 40c8a757be5..deeb0fa0af4 100644 --- a/cpp/include/tensorrt_llm/runtime/decodingInput.h +++ b/cpp/include/tensorrt_llm/runtime/decodingInput.h @@ -142,24 +142,6 @@ class DecodingInput struct EagleInputs { - EagleInputs(TensorConstPtr nextDraftTokens, TensorConstPtr nextDraftLens, TensorConstPtr nextDraftPaths, - TensorConstPtr lastDraftTokens, TensorConstPtr lastDraftLens, TensorConstPtr lastDraftPaths, - TensorConstPtr acceptedTokens, TensorConstPtr acceptedLens, TensorConstPtr acceptedPathIds, - TensorConstPtr chunkedContextNextTokens, TensorConstPtr seqSlots) - : nextDraftTokens(std::move(nextDraftTokens)) - , nextDraftLens(std::move(nextDraftLens)) - , nextDraftPaths(std::move(nextDraftPaths)) - , lastDraftTokens(std::move(lastDraftTokens)) - , lastDraftLens(std::move(lastDraftLens)) - , lastDraftPaths(std::move(lastDraftPaths)) - , acceptedTokens(std::move(acceptedTokens)) - , acceptedLens(std::move(acceptedLens)) - , acceptedPathIds(std::move(acceptedPathIds)) - , chunkedContextNextTokens(std::move(chunkedContextNextTokens)) - , seqSlots(std::move(seqSlots)) - { - } - TensorConstPtr nextDraftTokens; // [batchSize, maxDecodingDraftTokens] TensorConstPtr nextDraftLens; // [batchSize] TensorConstPtr nextDraftPaths; // [batchSize, maxDecodingTokens, maxPathLen] diff --git a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h index ed37c1260e9..327af71f8a7 100644 --- a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h @@ -18,9 +18,9 @@ #include "tensorrt_llm/runtime/cudaEvent.h" #include "tensorrt_llm/runtime/cudaStream.h" -#include "tensorrt_llm/runtime/eagleBuffers.h" -#include "tensorrt_llm/runtime/explicitDraftTokensBuffers.h" #include "tensorrt_llm/runtime/iTensor.h" +#include "tensorrt_llm/runtime/modelConfig.h" +#include "tensorrt_llm/runtime/worldConfig.h" #include #include @@ -72,25 +72,6 @@ class Input //! Batch of active decoder slots, sorted by slots, [maxDecoderSteps][batchSize] std::vector batchSlots; - //! Filled with slots in request order, [batchSize] - TensorPtr batchSlotsRequestOrder; - - //! For Beam Search - //! The generation step of each request (for Variable-Beam-Width-Search), [batchSize] - std::vector generationSteps; - - //! For speculative decoding - //! Logits of draft - //! [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded] - std::vector> predictedDraftLogits; - - //! Explicit draft tokens data - std::optional explicitDraftTokensInputs; - std::optional explicitDraftTokensLastInputs; - - //! Eagle data - std::optional eagleInputs; - std::optional eagleLastInputs; }; } // namespace decoder_batch diff --git a/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp b/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp index 3b65ceaf376..f659720e377 100644 --- a/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp +++ b/cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp @@ -118,6 +118,62 @@ std::pair, std::vector> getActiveSlots( return {activeSlots, generationSteps}; } +//! @brief Sets inputs for explicit draft tokens. +void setExplicitDraftTokensInputs(tr::DecodingInput& dInput, RuntimeBuffers const& fusedRuntimeBuffers) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + TLLM_CHECK(fusedRuntimeBuffers.mExplicitDraftTokensBuffers); + auto const& explicitDraftTokensInputs = fusedRuntimeBuffers.mExplicitDraftTokensBuffers->engineOutputs; + auto const& explicitDraftTokensLastInputs = fusedRuntimeBuffers.mExplicitDraftTokensBuffers->engineInputs; + + dInput.explicitDraftTokensInputs = tr::DecodingInput::ExplicitDraftTokensInputs(); + dInput.explicitDraftTokensInputs->nextDraftTokens = explicitDraftTokensInputs.nextDraftTokens; + dInput.explicitDraftTokensInputs->nextFlatTokens = explicitDraftTokensInputs.nextFlatTokens; + dInput.explicitDraftTokensInputs->nextDraftIndices = explicitDraftTokensInputs.nextDraftIndices; + dInput.explicitDraftTokensInputs->nextDraftProbs = explicitDraftTokensInputs.nextDraftProbs; + dInput.explicitDraftTokensInputs->lastDraftTokens = explicitDraftTokensLastInputs.draftTokens; + dInput.explicitDraftTokensInputs->lastDraftIndices = explicitDraftTokensLastInputs.draftIndices; + dInput.explicitDraftTokensInputs->lastPositionIdsBase = explicitDraftTokensLastInputs.positionIdsBase; + dInput.explicitDraftTokensInputs->masks = explicitDraftTokensInputs.masks; + dInput.explicitDraftTokensInputs->packedPositionIds = explicitDraftTokensInputs.packedPositionIds; + dInput.explicitDraftTokensInputs->bestPathLengths = explicitDraftTokensInputs.bestPathLengths; + dInput.explicitDraftTokensInputs->bestPathIndices = explicitDraftTokensInputs.bestPathIndices; + dInput.explicitDraftTokensInputs->nextGenerationLengths = explicitDraftTokensInputs.nextGenerationLengths; + dInput.explicitDraftTokensInputs->lastGenerationLengths = explicitDraftTokensLastInputs.generationLengths; + dInput.explicitDraftTokensInputs->maxGenLengthDevice = explicitDraftTokensInputs.maxGenToken; + // Slots in request order + dInput.explicitDraftTokensInputs->seqSlots = fusedRuntimeBuffers.seqSlots; + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + +//! @brief Sets inputs for eagle decoding. +void setEagleInputs(tr::DecodingInput& dInput, RuntimeBuffers const& fusedRuntimeBuffers) +{ + TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); + + TLLM_CHECK(fusedRuntimeBuffers.mEagleBuffers); + auto const& eagleInputs = fusedRuntimeBuffers.mEagleBuffers->engineOutputs; + auto const& eagleLastInputs = fusedRuntimeBuffers.mEagleBuffers->engineInputs; + + dInput.eagleInputs = tr::DecodingInput::EagleInputs(); + dInput.eagleInputs->nextDraftTokens = eagleInputs.nextDraftTokens; + dInput.eagleInputs->nextDraftLens = eagleInputs.nextDraftLens; + dInput.eagleInputs->nextDraftPaths = eagleInputs.nextDraftPaths; + dInput.eagleInputs->lastDraftTokens = eagleLastInputs.draftTokens; + dInput.eagleInputs->lastDraftLens = eagleLastInputs.draftLens; + dInput.eagleInputs->lastDraftPaths = eagleLastInputs.draftPaths; + dInput.eagleInputs->acceptedTokens = eagleInputs.acceptedTokens; + dInput.eagleInputs->acceptedLens = eagleInputs.acceptedLens; + dInput.eagleInputs->acceptedPathIds = eagleInputs.acceptedPaths; + dInput.eagleInputs->chunkedContextNextTokens = eagleInputs.chunkedContextNextTokens; + // Slots in request order + dInput.eagleInputs->seqSlots = fusedRuntimeBuffers.seqSlots; + + TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); +} + } // namespace std::unique_ptr MakeDecodingBatchInputOutput::operator()(RequestVector const& contextRequests, @@ -131,28 +187,30 @@ std::unique_ptr MakeDecodingBatchInputOutput::operator auto decodingInput = createDecoderBatchInputs( activeSlots, decoderState, inputBuffers.logits, maxNumSequences, inputBuffers.forwardBatchSlots); - decodingInput->generationSteps = generationSteps; + + auto const maxBeamWidth = decoderState.getMaxBeamWidth(); + if (maxBeamWidth > 1) + { + // For Variable-Beam-Width-Search + decoderState.getJointDecodingInput().generationSteps = generationSteps; + } if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits()) { - decodingInput->predictedDraftLogits = inputBuffers.predictedDraftLogits; + decoderState.getJointDecodingInput().medusaInputs->medusaLogits = inputBuffers.predictedDraftLogits; } if (modelConfig.getSpeculativeDecodingMode().isExplicitDraftTokens()) { TLLM_CHECK(fusedRuntimeBuffers); // requires mCtxGenFusion == true - decodingInput->batchSlotsRequestOrder = fusedRuntimeBuffers->seqSlots; - decodingInput->explicitDraftTokensInputs = fusedRuntimeBuffers->mExplicitDraftTokensBuffers->engineOutputs; - decodingInput->explicitDraftTokensLastInputs = fusedRuntimeBuffers->mExplicitDraftTokensBuffers->engineInputs; + setExplicitDraftTokensInputs(decoderState.getJointDecodingInput(), *fusedRuntimeBuffers); } else if (modelConfig.getSpeculativeDecodingMode().isEagle()) { TLLM_CHECK(fusedRuntimeBuffers); // requires mCtxGenFusion == true - decodingInput->batchSlotsRequestOrder = fusedRuntimeBuffers->seqSlots; - decodingInput->eagleInputs = fusedRuntimeBuffers->mEagleBuffers->engineOutputs; - decodingInput->eagleLastInputs = fusedRuntimeBuffers->mEagleBuffers->engineInputs; + setEagleInputs(decoderState.getJointDecodingInput(), *fusedRuntimeBuffers); } TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); diff --git a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp index 318cb2cef97..b6303a14311 100644 --- a/cpp/tensorrt_llm/pybind/runtime/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/runtime/bindings.cpp @@ -23,6 +23,7 @@ #include "tensorrt_llm/kernels/delayStream.h" #include "tensorrt_llm/runtime/cudaEvent.h" #include "tensorrt_llm/runtime/cudaStream.h" +#include "tensorrt_llm/runtime/decoderState.h" #include "tensorrt_llm/runtime/decodingInput.h" #include "tensorrt_llm/runtime/decodingOutput.h" #include "tensorrt_llm/runtime/gptDecoder.h" @@ -273,10 +274,7 @@ void initBindings(pybind11::module_& m) .def(py::init>(), py::arg("logits")) .def_readwrite("logits", &tr::decoder_batch::Input::logits) .def_readwrite("max_decoder_steps", &tr::decoder_batch::Input::maxDecoderSteps) - .def_readwrite("batch_slots", &tr::decoder_batch::Input::batchSlots) - .def_readwrite("batch_slots_request_order", &tr::decoder_batch::Input::batchSlotsRequestOrder) - .def_readwrite("generation_steps", &tr::decoder_batch::Input::generationSteps) - .def_readwrite("predicted_draft_logits", &tr::decoder_batch::Input::predictedDraftLogits); + .def_readwrite("batch_slots", &tr::decoder_batch::Input::batchSlots); py::class_(m, "LookaheadDecodingBuffers") .def(py::init(), py::arg("max_num_sequences"), @@ -380,7 +378,9 @@ void initBindings(pybind11::module_& m) py::arg("batch_idx")) .def("set_num_decoding_engine_tokens", &tr::decoder::DecoderState::setNumDecodingEngineTokens, py::arg("batch_idx"), py::arg("num_tokens")) - .def_property_readonly("speculative_decoding_mode", &tr::decoder::DecoderState::getSpeculativeDecodingMode); + .def_property_readonly("speculative_decoding_mode", &tr::decoder::DecoderState::getSpeculativeDecodingMode) + .def_property("generation_steps", &tr::decoder::DecoderState::getGenerationSteps, + &tr::decoder::DecoderState::setGenerationSteps); py::class_(m, "GptDecoderBatched") .def(py::init(), py::arg("stream")) diff --git a/cpp/tensorrt_llm/runtime/decoderState.cpp b/cpp/tensorrt_llm/runtime/decoderState.cpp index 76cb54cea4c..d59a989bb6b 100644 --- a/cpp/tensorrt_llm/runtime/decoderState.cpp +++ b/cpp/tensorrt_llm/runtime/decoderState.cpp @@ -644,6 +644,16 @@ TensorPtr DecoderState::getCacheIndirectionOutput() const return mJointDecodingOutput->cacheIndirection; } +std::optional> const& DecoderState::getGenerationSteps() const +{ + return mJointDecodingInput->generationSteps; +} + +void DecoderState::setGenerationSteps(std::vector const& generationSteps) +{ + mJointDecodingInput->generationSteps = generationSteps; +} + DecodingInput& DecoderState::getJointDecodingInput() const { return *mJointDecodingInput; diff --git a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp index 0cdf72f3980..6e22b8f2f49 100644 --- a/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp +++ b/cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp @@ -100,51 +100,6 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max namespace { -//! @brief Sets inputs for explicit draft tokens. -void setExplicitDraftTokensInputs(DecodingInput& dInput, decoder_batch::Input const& input) -{ - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - - dInput.explicitDraftTokensInputs = DecodingInput::ExplicitDraftTokensInputs(); - TLLM_CHECK(input.explicitDraftTokensInputs.has_value()); - TLLM_CHECK(input.explicitDraftTokensLastInputs.has_value()); - - dInput.explicitDraftTokensInputs->nextDraftTokens = input.explicitDraftTokensInputs->nextDraftTokens; - dInput.explicitDraftTokensInputs->nextFlatTokens = input.explicitDraftTokensInputs->nextFlatTokens; - dInput.explicitDraftTokensInputs->nextDraftIndices = input.explicitDraftTokensInputs->nextDraftIndices; - dInput.explicitDraftTokensInputs->nextDraftProbs = input.explicitDraftTokensInputs->nextDraftProbs; - dInput.explicitDraftTokensInputs->lastDraftTokens = input.explicitDraftTokensLastInputs->draftTokens; - dInput.explicitDraftTokensInputs->lastDraftIndices = input.explicitDraftTokensLastInputs->draftIndices; - dInput.explicitDraftTokensInputs->lastPositionIdsBase = input.explicitDraftTokensLastInputs->positionIdsBase; - dInput.explicitDraftTokensInputs->masks = input.explicitDraftTokensInputs->masks; - dInput.explicitDraftTokensInputs->packedPositionIds = input.explicitDraftTokensInputs->packedPositionIds; - dInput.explicitDraftTokensInputs->bestPathLengths = input.explicitDraftTokensInputs->bestPathLengths; - dInput.explicitDraftTokensInputs->bestPathIndices = input.explicitDraftTokensInputs->bestPathIndices; - dInput.explicitDraftTokensInputs->nextGenerationLengths = input.explicitDraftTokensInputs->nextGenerationLengths; - dInput.explicitDraftTokensInputs->lastGenerationLengths = input.explicitDraftTokensLastInputs->generationLengths; - dInput.explicitDraftTokensInputs->maxGenLengthDevice = input.explicitDraftTokensInputs->maxGenToken; - dInput.explicitDraftTokensInputs->seqSlots = input.batchSlotsRequestOrder; - - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); -} - -//! @brief Sets inputs for eagle decoding. -void setEagleInputs(DecodingInput& dInput, decoder_batch::Input const& input) -{ - TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - - TLLM_CHECK(input.eagleInputs.has_value()); - TLLM_CHECK(input.eagleLastInputs.has_value()); - - dInput.eagleInputs = DecodingInput::EagleInputs(input.eagleInputs->nextDraftTokens, - input.eagleInputs->nextDraftLens, input.eagleInputs->nextDraftPaths, input.eagleLastInputs->draftTokens, - input.eagleLastInputs->draftLens, input.eagleLastInputs->draftPaths, input.eagleInputs->acceptedTokens, - input.eagleInputs->acceptedLens, input.eagleInputs->acceptedPaths, input.eagleInputs->chunkedContextNextTokens, - input.batchSlotsRequestOrder); - - TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__); -} - //! @brief Prepare Input and Output for decoder step. // TODO: produce new input and output objects void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step, decoder_batch::Input const& input, @@ -152,26 +107,11 @@ void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step, { TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__); - auto const maxBeamWidth = decoderState.getMaxBeamWidth(); auto const speculativeDecodingMode = decoderState.getSpeculativeDecodingMode(); auto& dInput = decoderState.getJointDecodingInput(); auto& dOutput = decoderState.getJointDecodingOutput(); - if (maxBeamWidth > 1) - { - dInput.generationSteps = input.generationSteps; // For Variable-Beam-Width-Search - } - - if (speculativeDecodingMode.isExplicitDraftTokens()) - { - setExplicitDraftTokensInputs(dInput, input); - } - else if (speculativeDecodingMode.isEagle()) - { - setEagleInputs(dInput, input); - } - dInput.batchSlots = input.batchSlots.at(step); dInput.batchSize = static_cast(dInput.batchSlots->getSize()); dInput.logitsVec = input.logits.at(step); @@ -186,11 +126,6 @@ void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step, dInput.finishReasons = finishedStepsInput; - if (speculativeDecodingMode.isMedusa()) - { - dInput.medusaInputs->medusaLogits = input.predictedDraftLogits; - } - if (speculativeDecodingMode.isDraftTokensExternal()) { dInput.externalDraftTokensInputs->step = step; diff --git a/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py b/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py index 5dab55a7c78..28f87919f07 100644 --- a/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py +++ b/tensorrt_llm/_torch/pyexecutor/make_decoding_batch_input_output.py @@ -131,7 +131,7 @@ def __call__( max_num_sequences=max_num_sequences, batch_slots=decoder_input_buffers.forward_batch_slots, ) - decoding_input.generation_steps = generation_steps + decoder_state.generation_steps = generation_steps # TODO: Handle speculative decoding modes. # fused_runtime_buffers is not created in the pytorch framework.