Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions cpp/include/tensorrt_llm/runtime/decoderState.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,14 @@ class DecoderState
//! @brief Cache indirection output for beam search.
[[nodiscard]] TensorPtr getCacheIndirectionOutput() const;

//! @brief Get the generation steps for all requests in the batch.
//! @returns The generation steps for all requests in the batch.
[[nodiscard]] std::optional<std::vector<SizeType32>> const& getGenerationSteps() const;

//! @brief Set the generation steps for all requests in the batch.
//! @param generationSteps The generation steps for all requests in the batch.
void setGenerationSteps(std::vector<SizeType32> const& generationSteps);

//! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots.
[[nodiscard]] DecodingInput& getJointDecodingInput() const;

Expand Down
18 changes: 0 additions & 18 deletions cpp/include/tensorrt_llm/runtime/decodingInput.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,24 +142,6 @@ class DecodingInput

struct EagleInputs
{
EagleInputs(TensorConstPtr nextDraftTokens, TensorConstPtr nextDraftLens, TensorConstPtr nextDraftPaths,
TensorConstPtr lastDraftTokens, TensorConstPtr lastDraftLens, TensorConstPtr lastDraftPaths,
TensorConstPtr acceptedTokens, TensorConstPtr acceptedLens, TensorConstPtr acceptedPathIds,
TensorConstPtr chunkedContextNextTokens, TensorConstPtr seqSlots)
: nextDraftTokens(std::move(nextDraftTokens))
, nextDraftLens(std::move(nextDraftLens))
, nextDraftPaths(std::move(nextDraftPaths))
, lastDraftTokens(std::move(lastDraftTokens))
, lastDraftLens(std::move(lastDraftLens))
, lastDraftPaths(std::move(lastDraftPaths))
, acceptedTokens(std::move(acceptedTokens))
, acceptedLens(std::move(acceptedLens))
, acceptedPathIds(std::move(acceptedPathIds))
, chunkedContextNextTokens(std::move(chunkedContextNextTokens))
, seqSlots(std::move(seqSlots))
{
}

TensorConstPtr nextDraftTokens; // [batchSize, maxDecodingDraftTokens]
TensorConstPtr nextDraftLens; // [batchSize]
TensorConstPtr nextDraftPaths; // [batchSize, maxDecodingTokens, maxPathLen]
Expand Down
23 changes: 2 additions & 21 deletions cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@

#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/eagleBuffers.h"
#include "tensorrt_llm/runtime/explicitDraftTokensBuffers.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/worldConfig.h"

#include <memory>
#include <vector>
Expand Down Expand Up @@ -72,25 +72,6 @@ class Input

//! Batch of active decoder slots, sorted by slots, [maxDecoderSteps][batchSize]
std::vector<TensorPtr> batchSlots;
//! Filled with slots in request order, [batchSize]
TensorPtr batchSlotsRequestOrder;

//! For Beam Search
//! The generation step of each request (for Variable-Beam-Width-Search), [batchSize]
std::vector<SizeType32> generationSteps;

//! For speculative decoding
//! Logits of draft
//! [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded]
std::vector<std::vector<TensorPtr>> predictedDraftLogits;

//! Explicit draft tokens data
std::optional<ExplicitDraftTokensBuffers::EngineOutputs> explicitDraftTokensInputs;
std::optional<ExplicitDraftTokensBuffers::EngineInputs> explicitDraftTokensLastInputs;

//! Eagle data
std::optional<EagleBuffers::EngineOutputs> eagleInputs;
std::optional<EagleBuffers::Inputs> eagleLastInputs;
};

} // namespace decoder_batch
Expand Down
74 changes: 66 additions & 8 deletions cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,62 @@ std::pair<std::vector<SizeType32>, std::vector<SizeType32>> getActiveSlots(
return {activeSlots, generationSteps};
}

//! @brief Sets inputs for explicit draft tokens.
void setExplicitDraftTokensInputs(tr::DecodingInput& dInput, RuntimeBuffers const& fusedRuntimeBuffers)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);

TLLM_CHECK(fusedRuntimeBuffers.mExplicitDraftTokensBuffers);
auto const& explicitDraftTokensInputs = fusedRuntimeBuffers.mExplicitDraftTokensBuffers->engineOutputs;
auto const& explicitDraftTokensLastInputs = fusedRuntimeBuffers.mExplicitDraftTokensBuffers->engineInputs;

dInput.explicitDraftTokensInputs = tr::DecodingInput::ExplicitDraftTokensInputs();
dInput.explicitDraftTokensInputs->nextDraftTokens = explicitDraftTokensInputs.nextDraftTokens;
dInput.explicitDraftTokensInputs->nextFlatTokens = explicitDraftTokensInputs.nextFlatTokens;
dInput.explicitDraftTokensInputs->nextDraftIndices = explicitDraftTokensInputs.nextDraftIndices;
dInput.explicitDraftTokensInputs->nextDraftProbs = explicitDraftTokensInputs.nextDraftProbs;
dInput.explicitDraftTokensInputs->lastDraftTokens = explicitDraftTokensLastInputs.draftTokens;
dInput.explicitDraftTokensInputs->lastDraftIndices = explicitDraftTokensLastInputs.draftIndices;
dInput.explicitDraftTokensInputs->lastPositionIdsBase = explicitDraftTokensLastInputs.positionIdsBase;
dInput.explicitDraftTokensInputs->masks = explicitDraftTokensInputs.masks;
dInput.explicitDraftTokensInputs->packedPositionIds = explicitDraftTokensInputs.packedPositionIds;
dInput.explicitDraftTokensInputs->bestPathLengths = explicitDraftTokensInputs.bestPathLengths;
dInput.explicitDraftTokensInputs->bestPathIndices = explicitDraftTokensInputs.bestPathIndices;
dInput.explicitDraftTokensInputs->nextGenerationLengths = explicitDraftTokensInputs.nextGenerationLengths;
dInput.explicitDraftTokensInputs->lastGenerationLengths = explicitDraftTokensLastInputs.generationLengths;
dInput.explicitDraftTokensInputs->maxGenLengthDevice = explicitDraftTokensInputs.maxGenToken;
// Slots in request order
dInput.explicitDraftTokensInputs->seqSlots = fusedRuntimeBuffers.seqSlots;

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

//! @brief Sets inputs for eagle decoding.
void setEagleInputs(tr::DecodingInput& dInput, RuntimeBuffers const& fusedRuntimeBuffers)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);

TLLM_CHECK(fusedRuntimeBuffers.mEagleBuffers);
auto const& eagleInputs = fusedRuntimeBuffers.mEagleBuffers->engineOutputs;
auto const& eagleLastInputs = fusedRuntimeBuffers.mEagleBuffers->engineInputs;

dInput.eagleInputs = tr::DecodingInput::EagleInputs();
dInput.eagleInputs->nextDraftTokens = eagleInputs.nextDraftTokens;
dInput.eagleInputs->nextDraftLens = eagleInputs.nextDraftLens;
dInput.eagleInputs->nextDraftPaths = eagleInputs.nextDraftPaths;
dInput.eagleInputs->lastDraftTokens = eagleLastInputs.draftTokens;
dInput.eagleInputs->lastDraftLens = eagleLastInputs.draftLens;
dInput.eagleInputs->lastDraftPaths = eagleLastInputs.draftPaths;
dInput.eagleInputs->acceptedTokens = eagleInputs.acceptedTokens;
dInput.eagleInputs->acceptedLens = eagleInputs.acceptedLens;
dInput.eagleInputs->acceptedPathIds = eagleInputs.acceptedPaths;
dInput.eagleInputs->chunkedContextNextTokens = eagleInputs.chunkedContextNextTokens;
// Slots in request order
dInput.eagleInputs->seqSlots = fusedRuntimeBuffers.seqSlots;

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

} // namespace

std::unique_ptr<tr::decoder_batch::Input> MakeDecodingBatchInputOutput::operator()(RequestVector const& contextRequests,
Expand All @@ -131,28 +187,30 @@ std::unique_ptr<tr::decoder_batch::Input> MakeDecodingBatchInputOutput::operator

auto decodingInput = createDecoderBatchInputs(
activeSlots, decoderState, inputBuffers.logits, maxNumSequences, inputBuffers.forwardBatchSlots);
decodingInput->generationSteps = generationSteps;

auto const maxBeamWidth = decoderState.getMaxBeamWidth();
if (maxBeamWidth > 1)
{
// For Variable-Beam-Width-Search
decoderState.getJointDecodingInput().generationSteps = generationSteps;
}

if (modelConfig.getSpeculativeDecodingMode().hasDraftLogits())
{
decodingInput->predictedDraftLogits = inputBuffers.predictedDraftLogits;
decoderState.getJointDecodingInput().medusaInputs->medusaLogits = inputBuffers.predictedDraftLogits;
}

if (modelConfig.getSpeculativeDecodingMode().isExplicitDraftTokens())
{
TLLM_CHECK(fusedRuntimeBuffers);
// requires mCtxGenFusion == true
decodingInput->batchSlotsRequestOrder = fusedRuntimeBuffers->seqSlots;
decodingInput->explicitDraftTokensInputs = fusedRuntimeBuffers->mExplicitDraftTokensBuffers->engineOutputs;
decodingInput->explicitDraftTokensLastInputs = fusedRuntimeBuffers->mExplicitDraftTokensBuffers->engineInputs;
setExplicitDraftTokensInputs(decoderState.getJointDecodingInput(), *fusedRuntimeBuffers);
}
else if (modelConfig.getSpeculativeDecodingMode().isEagle())
{
TLLM_CHECK(fusedRuntimeBuffers);
// requires mCtxGenFusion == true
decodingInput->batchSlotsRequestOrder = fusedRuntimeBuffers->seqSlots;
decodingInput->eagleInputs = fusedRuntimeBuffers->mEagleBuffers->engineOutputs;
decodingInput->eagleLastInputs = fusedRuntimeBuffers->mEagleBuffers->engineInputs;
setEagleInputs(decoderState.getJointDecodingInput(), *fusedRuntimeBuffers);
}

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
Expand Down
10 changes: 5 additions & 5 deletions cpp/tensorrt_llm/pybind/runtime/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "tensorrt_llm/kernels/delayStream.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/decoderState.h"
#include "tensorrt_llm/runtime/decodingInput.h"
#include "tensorrt_llm/runtime/decodingOutput.h"
#include "tensorrt_llm/runtime/gptDecoder.h"
Expand Down Expand Up @@ -273,10 +274,7 @@ void initBindings(pybind11::module_& m)
.def(py::init<std::vector<tr::ITensor::SharedConstPtr>>(), py::arg("logits"))
.def_readwrite("logits", &tr::decoder_batch::Input::logits)
.def_readwrite("max_decoder_steps", &tr::decoder_batch::Input::maxDecoderSteps)
.def_readwrite("batch_slots", &tr::decoder_batch::Input::batchSlots)
.def_readwrite("batch_slots_request_order", &tr::decoder_batch::Input::batchSlotsRequestOrder)
.def_readwrite("generation_steps", &tr::decoder_batch::Input::generationSteps)
.def_readwrite("predicted_draft_logits", &tr::decoder_batch::Input::predictedDraftLogits);
.def_readwrite("batch_slots", &tr::decoder_batch::Input::batchSlots);

py::class_<tr::LookaheadDecodingBuffers>(m, "LookaheadDecodingBuffers")
.def(py::init<tr::SizeType32, tr::SizeType32, tr::BufferManager const&>(), py::arg("max_num_sequences"),
Expand Down Expand Up @@ -380,7 +378,9 @@ void initBindings(pybind11::module_& m)
py::arg("batch_idx"))
.def("set_num_decoding_engine_tokens", &tr::decoder::DecoderState::setNumDecodingEngineTokens,
py::arg("batch_idx"), py::arg("num_tokens"))
.def_property_readonly("speculative_decoding_mode", &tr::decoder::DecoderState::getSpeculativeDecodingMode);
.def_property_readonly("speculative_decoding_mode", &tr::decoder::DecoderState::getSpeculativeDecodingMode)
.def_property("generation_steps", &tr::decoder::DecoderState::getGenerationSteps,
&tr::decoder::DecoderState::setGenerationSteps);

py::class_<tr::GptDecoderBatched>(m, "GptDecoderBatched")
.def(py::init<tr::GptDecoderBatched::CudaStreamPtr>(), py::arg("stream"))
Expand Down
10 changes: 10 additions & 0 deletions cpp/tensorrt_llm/runtime/decoderState.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,16 @@ TensorPtr DecoderState::getCacheIndirectionOutput() const
return mJointDecodingOutput->cacheIndirection;
}

std::optional<std::vector<SizeType32>> const& DecoderState::getGenerationSteps() const
{
return mJointDecodingInput->generationSteps;
}

void DecoderState::setGenerationSteps(std::vector<SizeType32> const& generationSteps)
{
mJointDecodingInput->generationSteps = generationSteps;
}

DecodingInput& DecoderState::getJointDecodingInput() const
{
return *mJointDecodingInput;
Expand Down
65 changes: 0 additions & 65 deletions cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,78 +100,18 @@ void GptDecoderBatched::setup(executor::DecodingMode const& mode, SizeType32 max

namespace
{
//! @brief Sets inputs for explicit draft tokens.
void setExplicitDraftTokensInputs(DecodingInput& dInput, decoder_batch::Input const& input)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);

dInput.explicitDraftTokensInputs = DecodingInput::ExplicitDraftTokensInputs();
TLLM_CHECK(input.explicitDraftTokensInputs.has_value());
TLLM_CHECK(input.explicitDraftTokensLastInputs.has_value());

dInput.explicitDraftTokensInputs->nextDraftTokens = input.explicitDraftTokensInputs->nextDraftTokens;
dInput.explicitDraftTokensInputs->nextFlatTokens = input.explicitDraftTokensInputs->nextFlatTokens;
dInput.explicitDraftTokensInputs->nextDraftIndices = input.explicitDraftTokensInputs->nextDraftIndices;
dInput.explicitDraftTokensInputs->nextDraftProbs = input.explicitDraftTokensInputs->nextDraftProbs;
dInput.explicitDraftTokensInputs->lastDraftTokens = input.explicitDraftTokensLastInputs->draftTokens;
dInput.explicitDraftTokensInputs->lastDraftIndices = input.explicitDraftTokensLastInputs->draftIndices;
dInput.explicitDraftTokensInputs->lastPositionIdsBase = input.explicitDraftTokensLastInputs->positionIdsBase;
dInput.explicitDraftTokensInputs->masks = input.explicitDraftTokensInputs->masks;
dInput.explicitDraftTokensInputs->packedPositionIds = input.explicitDraftTokensInputs->packedPositionIds;
dInput.explicitDraftTokensInputs->bestPathLengths = input.explicitDraftTokensInputs->bestPathLengths;
dInput.explicitDraftTokensInputs->bestPathIndices = input.explicitDraftTokensInputs->bestPathIndices;
dInput.explicitDraftTokensInputs->nextGenerationLengths = input.explicitDraftTokensInputs->nextGenerationLengths;
dInput.explicitDraftTokensInputs->lastGenerationLengths = input.explicitDraftTokensLastInputs->generationLengths;
dInput.explicitDraftTokensInputs->maxGenLengthDevice = input.explicitDraftTokensInputs->maxGenToken;
dInput.explicitDraftTokensInputs->seqSlots = input.batchSlotsRequestOrder;

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

//! @brief Sets inputs for eagle decoding.
void setEagleInputs(DecodingInput& dInput, decoder_batch::Input const& input)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);

TLLM_CHECK(input.eagleInputs.has_value());
TLLM_CHECK(input.eagleLastInputs.has_value());

dInput.eagleInputs = DecodingInput::EagleInputs(input.eagleInputs->nextDraftTokens,
input.eagleInputs->nextDraftLens, input.eagleInputs->nextDraftPaths, input.eagleLastInputs->draftTokens,
input.eagleLastInputs->draftLens, input.eagleLastInputs->draftPaths, input.eagleInputs->acceptedTokens,
input.eagleInputs->acceptedLens, input.eagleInputs->acceptedPaths, input.eagleInputs->chunkedContextNextTokens,
input.batchSlotsRequestOrder);

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

//! @brief Prepare Input and Output for decoder step.
// TODO: produce new input and output objects
void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step, decoder_batch::Input const& input,
BufferManager const& bufferManager)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);

auto const maxBeamWidth = decoderState.getMaxBeamWidth();
auto const speculativeDecodingMode = decoderState.getSpeculativeDecodingMode();

auto& dInput = decoderState.getJointDecodingInput();
auto& dOutput = decoderState.getJointDecodingOutput();

if (maxBeamWidth > 1)
{
dInput.generationSteps = input.generationSteps; // For Variable-Beam-Width-Search
}

if (speculativeDecodingMode.isExplicitDraftTokens())
{
setExplicitDraftTokensInputs(dInput, input);
}
else if (speculativeDecodingMode.isEagle())
{
setEagleInputs(dInput, input);
}

dInput.batchSlots = input.batchSlots.at(step);
dInput.batchSize = static_cast<SizeType32>(dInput.batchSlots->getSize());
dInput.logitsVec = input.logits.at(step);
Expand All @@ -186,11 +126,6 @@ void prepareForward(decoder::DecoderState const& decoderState, SizeType32 step,

dInput.finishReasons = finishedStepsInput;

if (speculativeDecodingMode.isMedusa())
{
dInput.medusaInputs->medusaLogits = input.predictedDraftLogits;
}

if (speculativeDecodingMode.isDraftTokensExternal())
{
dInput.externalDraftTokensInputs->step = step;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __call__(
max_num_sequences=max_num_sequences,
batch_slots=decoder_input_buffers.forward_batch_slots,
)
decoding_input.generation_steps = generation_steps
decoder_state.generation_steps = generation_steps

# TODO: Handle speculative decoding modes.
# fused_runtime_buffers is not created in the pytorch framework.
Expand Down