Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,20 @@
#include "tensorrt_llm/common/algorithm.h"
#include "tensorrt_llm/common/optionalRef.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/iGptDecoderBatched.h"
#include "tensorrt_llm/runtime/modelConfig.h"

namespace tensorrt_llm::runtime::decoder
namespace tensorrt_llm::runtime
{
namespace decoder
{
class DecoderState;
} // namespace tensorrt_llm::runtime::decoder
} // namespace decoder

namespace decoder_batch
{
class Input;
} // namespace decoder_batch
} // namespace tensorrt_llm::runtime

namespace tensorrt_llm::batch_manager
{
Expand All @@ -40,7 +47,7 @@ class MakeDecodingBatchInputOutput : Algorithm
constexpr static auto name{"MakeDecodingBatchInputOutput"};

using SizeType32 = tensorrt_llm::runtime::SizeType32;
using TensorPtr = runtime::decoder_batch::Input::TensorPtr;
using TensorPtr = runtime::ITensor::SharedPtr;
template <typename T>
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;

Expand Down
90 changes: 79 additions & 11 deletions cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/decoderState.h"
#include "tensorrt_llm/runtime/gptDecoder.h"
#include "tensorrt_llm/runtime/iGptDecoderBatched.h"
#include "tensorrt_llm/runtime/eagleBuffers.h"
#include "tensorrt_llm/runtime/explicitDraftTokensBuffers.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/worldConfig.h"

#include <memory>
#include <optional>
#include <vector>

namespace tensorrt_llm::batch_manager
Expand All @@ -35,9 +35,72 @@ class LlmRequest;

namespace tensorrt_llm::runtime
{
class SamplingConfig;
class IGptDecoder;

namespace decoder
{
class DecoderState;
}

namespace decoder_batch
{

class Input
Copy link
Preview

Copilot AI Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Consider extracting decoder_batch::Input into its own header (e.g., gptDecoderBatchedInput.h) to keep gptDecoderBatched.h focused on the decoder class and improve compilation modularity.

Copilot uses AI. Check for mistakes.

{
public:
using TensorConstPtr = ITensor::SharedConstPtr;
using TensorPtr = ITensor::SharedPtr;

explicit Input(std::vector<std::vector<TensorConstPtr>> const& logits, SizeType32 maxDecoderSteps)
: logits{logits}
, maxDecoderSteps{maxDecoderSteps}
{
TLLM_CHECK_WITH_INFO(
logits.size() == static_cast<size_t>(maxDecoderSteps), "logits vector size does not match maxDecoderSteps");
}

explicit Input(std::vector<TensorConstPtr> const& logits)
: Input{{logits}, 1}
{
}

//! Mandatory parameters
//! Logits
// FIXME: remove first dimension of tensors
//! [maxDecoderSteps][batchSize][1, beamWidth, vocabSizePadded], on gpu
std::vector<std::vector<TensorConstPtr>> logits;

//! Maximum number of decoding tokens of active slots
SizeType32 maxDecoderSteps;

//! Batch of active decoder slots, sorted by slots, [maxDecoderSteps][batchSize]
std::vector<TensorPtr> batchSlots;
//! Filled with slots in request order, [batchSize]
TensorPtr batchSlotsRequestOrder;

//! For Beam Search
//! The generation step of each request (for Variable-Beam-Width-Search), [batchSize]
std::vector<SizeType32> generationSteps;

//! For speculative decoding
//! Logits of draft
//! [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded]
std::vector<std::vector<TensorPtr>> predictedDraftLogits;

//! Explicit draft tokens data
std::optional<ExplicitDraftTokensBuffers::EngineOutputs> explicitDraftTokensInputs;
std::optional<ExplicitDraftTokensBuffers::EngineInputs> explicitDraftTokensLastInputs;

//! Eagle data
std::optional<EagleBuffers::EngineOutputs> eagleInputs;
std::optional<EagleBuffers::Inputs> eagleLastInputs;
};

} // namespace decoder_batch

//! GPT decoder class with support for in-flight batching
class GptDecoderBatched : public IGptDecoderBatched
class GptDecoderBatched
{
public:
using CudaStreamPtr = std::shared_ptr<CudaStream>;
Expand All @@ -47,25 +110,29 @@ class GptDecoderBatched : public IGptDecoderBatched

explicit GptDecoderBatched(CudaStreamPtr stream);

//! @brief Setup the decoder before calling `forward()`
void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) override;
nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig);

void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override;
//! @brief Disable Lookahead decoding.
void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots);

CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override;
void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override;
//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input);
//! @brief Run one step for all requests and wait for completion on the host.
void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input);

//! @brief Gather final beam search results for request `batchSlot`.
//! Result will only be available after event returned.
[[nodiscard]] CudaEvent finalize(decoder::DecoderState const& decoderState, SizeType32 batchSlot,
SamplingConfig const& samplingConfig, bool streaming) const override;
SamplingConfig const& samplingConfig, bool streaming) const;

CudaStreamPtr getDecoderStream() const
[[nodiscard]] CudaStreamPtr getDecoderStream() const
{
return mDecoderStream;
}

IGptDecoder& getUnderlyingDecoder() const
[[nodiscard]] IGptDecoder& getUnderlyingDecoder() const
{
return *mDecoder.get();
}
Expand All @@ -87,4 +154,5 @@ class GptDecoderBatched : public IGptDecoderBatched
using GptDecoderPtr = std::unique_ptr<IGptDecoder>;
GptDecoderPtr mDecoder;
};

} // namespace tensorrt_llm::runtime
113 changes: 0 additions & 113 deletions cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
#include "tensorrt_llm/common/cudaUtils.h"
#include "tensorrt_llm/common/logger.h"
#include "tensorrt_llm/runtime/decoderState.h"
#include "tensorrt_llm/runtime/iGptDecoderBatched.h"
#include "tensorrt_llm/runtime/gptDecoderBatched.h"

namespace tr = tensorrt_llm::runtime;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@
#include "tensorrt_llm/kernels/decodingCommon.h"
#include "tensorrt_llm/layers/defaultDecodingParams.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/decoderState.h"
#include "tensorrt_llm/runtime/gptDecoder.h"
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "tensorrt_llm/batch_manager/rnnStateManager.h"
#include "tensorrt_llm/batch_manager/sequenceSlotManager.h"
#include "tensorrt_llm/nanobind/common/bindTypes.h"
#include "tensorrt_llm/runtime/decoderState.h"
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include "tensorrt_llm/runtime/torch.h"
Expand Down
1 change: 0 additions & 1 deletion cpp/tensorrt_llm/nanobind/runtime/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
#include "tensorrt_llm/runtime/gptDecoder.h"
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iGptDecoderBatched.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/ipcUtils.h"
#include "tensorrt_llm/runtime/lookaheadBuffers.h"
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/pybind/batch_manager/algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "tensorrt_llm/batch_manager/pauseRequests.h"
#include "tensorrt_llm/batch_manager/peftCacheManager.h"
#include "tensorrt_llm/runtime/decoderState.h"
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
#include "tensorrt_llm/runtime/torch.h"
#include "tensorrt_llm/runtime/torchView.h"

Expand Down
1 change: 0 additions & 1 deletion cpp/tensorrt_llm/pybind/runtime/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include "tensorrt_llm/runtime/gptDecoder.h"
#include "tensorrt_llm/runtime/gptDecoderBatched.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iGptDecoderBatched.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/ipcUtils.h"
#include "tensorrt_llm/runtime/lookaheadBuffers.h"
Expand Down
1 change: 1 addition & 0 deletions cpp/tensorrt_llm/runtime/gptDecoderBatched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "tensorrt_llm/kernels/decodingKernels.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/cudaEvent.h"
#include "tensorrt_llm/runtime/gptDecoder.h"

#include <algorithm>
#include <cassert>
Expand Down
3 changes: 2 additions & 1 deletion cpp/tests/runtime/gptDecoderBatchedTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
#include "tensorrt_llm/executor/types.h"
#include "tensorrt_llm/runtime/bufferManager.h"
#include "tensorrt_llm/runtime/common.h"
#include "tensorrt_llm/runtime/decoderState.h"
#include "tensorrt_llm/runtime/gptDecoder.h"
#include "tensorrt_llm/runtime/iBuffer.h"
#include "tensorrt_llm/runtime/iTensor.h"
#include "tensorrt_llm/runtime/modelConfig.h"
#include "tensorrt_llm/runtime/runtimeKernels.h"
#include "tensorrt_llm/runtime/worldConfig.h"

#include <gmock/gmock-matchers.h>
Expand Down
Loading