Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions cpp/include/tensorrt_llm/runtime/decoderState.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ class DecoderState
DecoderState();

//! @brief Setup buffers for the decoder excluding speculative decoding.
void setup(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
void setup(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, nvinfer1::DataType dtype,
ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager);

//! @brief Setup buffers for the cache indirection.
//! @details This is used for beam search on pipeline parallel ranks without a decoder.
void setupCacheIndirection(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
void setupCacheIndirection(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow,
BufferManager const& bufferManager);

//! @brief Setup buffers for speculative decoding.
Expand Down Expand Up @@ -137,7 +137,7 @@ class DecoderState
//! @returns [maxTokensPerStep, batchSize, beamWidth], finished states of type FinishedState, on gpu
[[nodiscard]] TensorPtr getFinishedSteps() const;

[[nodiscard]] SizeType32 getMaxBatchSize() const;
[[nodiscard]] SizeType32 getMaxNumSequences() const;

[[nodiscard]] SizeType32 getMaxBeamWidth() const;

Expand Down Expand Up @@ -190,10 +190,10 @@ class DecoderState
//! @param generationSteps The generation steps for all requests in the batch.
void setGenerationSteps(std::vector<SizeType32> const& generationSteps);

//! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots.
//! @brief Stateful inputs for the decoder. Allocated for maxNumSequences slots.
[[nodiscard]] DecodingInput& getJointDecodingInput() const;

//! @brief Stateful outputs for the decoder. Allocated for maxBatchSize slots.
//! @brief Stateful outputs for the decoder. Allocated for maxNumSequences slots.
[[nodiscard]] DecodingOutput& getJointDecodingOutput() const;

private:
Expand All @@ -212,13 +212,13 @@ class DecoderState
SizeType32 maxTokensPerEngineStep, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
BufferManager const& bufferManager);

SizeType32 mMaxBatchSize{};
SizeType32 mMaxNumSequences{};
SizeType32 mMaxBeamWidth{};
SizeType32 mMaxSequenceLength{};

//! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots.
//! @brief Stateful inputs for the decoder. Allocated for maxNumSequences slots.
DecodingInputPtr mJointDecodingInput;
//! @brief Stateful outputs for the decoder. Allocated for maxBatchSize slots.
//! @brief Stateful outputs for the decoder. Allocated for maxNumSequences slots.
DecodingOutputPtr mJointDecodingOutput;

//! @brief [maxTokensPerStep, batchSize, beamWidth] finished states of type FinishedState for each generated token
Expand Down
12 changes: 6 additions & 6 deletions cpp/include/tensorrt_llm/runtime/gptDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ class IGptDecoder
= 0;

static std::unique_ptr<IGptDecoder> create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded,
size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded,
BufferManager::CudaStreamPtr const& stream,
std::shared_ptr<SpeculativeDecodingModule const> const& speculativeDecodingModule = nullptr);
};
Expand All @@ -85,7 +85,7 @@ class GptDecoder : public virtual IGptDecoder
using CudaStreamPtr = BufferManager::CudaStreamPtr;
using TensorPtr = std::shared_ptr<ITensor>;

GptDecoder(executor::DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize,
GptDecoder(executor::DecodingMode const& mode, size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize,
size_t vocabSizePadded, CudaStreamPtr const& stream,
std::shared_ptr<SpeculativeDecodingModule const> speculativeDecodingModule = nullptr);

Expand Down Expand Up @@ -115,26 +115,26 @@ class GptDecoder : public virtual IGptDecoder

SamplingConfig mSamplingConfig;

size_t mMaxBatchSize;
size_t mMaxNumSequences;
size_t mVocabSize;
size_t mVocabSizePadded;

executor::DecodingMode mDecodingMode;
};

inline std::unique_ptr<IGptDecoder> IGptDecoder::create(executor::DecodingMode const& mode, nvinfer1::DataType dtype,
size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded,
size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded,
BufferManager::CudaStreamPtr const& stream,
std::shared_ptr<SpeculativeDecodingModule const> const& speculativeDecodingModule)
{
switch (dtype)
{
case nvinfer1::DataType::kFLOAT:
return std::make_unique<GptDecoder<float>>(
mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule);
mode, maxNumSequences, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule);
case nvinfer1::DataType::kHALF:
return std::make_unique<GptDecoder<half>>(
mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule);
mode, maxNumSequences, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule);
default:
TLLM_THROW("Unsupported decoder data type: %d. Use either kFLOAT or kHALF.", static_cast<int>(dtype));
return nullptr;
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class GptDecoderBatched : public IGptDecoderBatched

explicit GptDecoderBatched(CudaStreamPtr stream);

void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) override;

void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override;
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ class IGptDecoderBatched
using TensorPtr = std::shared_ptr<ITensor>;

//! @brief Setup the decoder before calling `forward()`
virtual void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
virtual void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig)
= 0;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder

BufferManager manager{std::make_shared<CudaStream>(decoderStream.get())};

auto const batchSize = decoderState.getMaxBatchSize();
auto const batchSize = decoderState.getMaxNumSequences();
TLLM_CHECK(0 <= batchSize && batchSlot < batchSize);
auto const maxBeamWidth = decoderState.getMaxBeamWidth();
auto const beamWidth = samplingConfig.beamWidth;
Expand Down
6 changes: 3 additions & 3 deletions cpp/tensorrt_llm/pybind/runtime/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,10 +326,10 @@ void initBindings(pybind11::module_& m)

py::class_<tr::decoder::DecoderState>(m, "DecoderState")
.def(py::init<>())
.def("setup", &tr::decoder::DecoderState::setup, py::arg("max_batch_size"), py::arg("max_beam_width"),
.def("setup", &tr::decoder::DecoderState::setup, py::arg("max_num_sequences"), py::arg("max_beam_width"),
py::arg("max_attention_window"), py::arg("sink_token_length"), py::arg("max_sequence_length"),
py::arg("dtype"), py::arg("model_config"), py::arg("world_config"), py::arg("buffer_manager"))
.def("setup_cache_indirection", &tr::decoder::DecoderState::setupCacheIndirection, py::arg("max_batch_size"),
.def("setup_cache_indirection", &tr::decoder::DecoderState::setupCacheIndirection, py::arg("max_num_sequences"),
py::arg("max_beam_width"), py::arg("max_attention_window"), py::arg("buffer_manager"))
.def("setup_speculative_decoding", &tr::decoder::DecoderState::setupSpeculativeDecoding,
py::arg("speculative_decoding_mode"), py::arg("max_tokens_per_engine_step"), py::arg("dtype"),
Expand Down Expand Up @@ -386,7 +386,7 @@ void initBindings(pybind11::module_& m)

py::class_<tr::GptDecoderBatched>(m, "GptDecoderBatched")
.def(py::init<tr::GptDecoderBatched::CudaStreamPtr>(), py::arg("stream"))
.def("setup", &tr::GptDecoderBatched::setup, py::arg("mode"), py::arg("max_batch_size"),
.def("setup", &tr::GptDecoderBatched::setup, py::arg("mode"), py::arg("max_num_sequences"),
py::arg("max_beam_width"), py::arg("dtype"), py::arg("model_config"), py::arg("world_config"))
.def("forward_async", &tr::GptDecoderBatched::forwardAsync, py::arg("decoder_state"), py::arg("input"))
.def("underlying_decoder", &tr::GptDecoderBatched::getUnderlyingDecoder, py::return_value_policy::reference)
Expand Down
Loading
Loading