-
Notifications
You must be signed in to change notification settings - Fork 1.8k
refactor: Remove IGptDecoderBatched interface #5577
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,13 +19,13 @@ | |
#include "tensorrt_llm/runtime/bufferManager.h" | ||
#include "tensorrt_llm/runtime/cudaEvent.h" | ||
#include "tensorrt_llm/runtime/cudaStream.h" | ||
#include "tensorrt_llm/runtime/decoderState.h" | ||
#include "tensorrt_llm/runtime/gptDecoder.h" | ||
#include "tensorrt_llm/runtime/iGptDecoderBatched.h" | ||
#include "tensorrt_llm/runtime/eagleBuffers.h" | ||
#include "tensorrt_llm/runtime/explicitDraftTokensBuffers.h" | ||
#include "tensorrt_llm/runtime/iTensor.h" | ||
#include "tensorrt_llm/runtime/worldConfig.h" | ||
|
||
#include <memory> | ||
#include <optional> | ||
#include <vector> | ||
|
||
Funatiq marked this conversation as resolved.
Show resolved
Hide resolved
|
||
namespace tensorrt_llm::batch_manager | ||
|
@@ -35,9 +35,72 @@ class LlmRequest; | |
|
||
namespace tensorrt_llm::runtime | ||
{ | ||
class SamplingConfig; | ||
class IGptDecoder; | ||
|
||
namespace decoder | ||
{ | ||
class DecoderState; | ||
} | ||
|
||
namespace decoder_batch | ||
{ | ||
|
||
class Input | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. [nitpick] Consider extracting Copilot uses AI. Check for mistakes. Positive FeedbackNegative Feedback |
||
{ | ||
public: | ||
using TensorConstPtr = ITensor::SharedConstPtr; | ||
using TensorPtr = ITensor::SharedPtr; | ||
|
||
explicit Input(std::vector<std::vector<TensorConstPtr>> const& logits, SizeType32 maxDecoderSteps) | ||
: logits{logits} | ||
, maxDecoderSteps{maxDecoderSteps} | ||
{ | ||
TLLM_CHECK_WITH_INFO( | ||
logits.size() == static_cast<size_t>(maxDecoderSteps), "logits vector size does not match maxDecoderSteps"); | ||
} | ||
|
||
explicit Input(std::vector<TensorConstPtr> const& logits) | ||
: Input{{logits}, 1} | ||
{ | ||
} | ||
|
||
//! Mandatory parameters | ||
//! Logits | ||
// FIXME: remove first dimension of tensors | ||
//! [maxDecoderSteps][batchSize][1, beamWidth, vocabSizePadded], on gpu | ||
std::vector<std::vector<TensorConstPtr>> logits; | ||
|
||
//! Maximum number of decoding tokens of active slots | ||
SizeType32 maxDecoderSteps; | ||
|
||
//! Batch of active decoder slots, sorted by slots, [maxDecoderSteps][batchSize] | ||
std::vector<TensorPtr> batchSlots; | ||
//! Filled with slots in request order, [batchSize] | ||
TensorPtr batchSlotsRequestOrder; | ||
|
||
//! For Beam Search | ||
//! The generation step of each request (for Variable-Beam-Width-Search), [batchSize] | ||
std::vector<SizeType32> generationSteps; | ||
|
||
//! For speculative decoding | ||
//! Logits of draft | ||
//! [maxBatchSize][maxAcceptedDraftTokensPerStep][maxDraftTokens + 1, vocabSizePadded] | ||
std::vector<std::vector<TensorPtr>> predictedDraftLogits; | ||
|
||
//! Explicit draft tokens data | ||
std::optional<ExplicitDraftTokensBuffers::EngineOutputs> explicitDraftTokensInputs; | ||
std::optional<ExplicitDraftTokensBuffers::EngineInputs> explicitDraftTokensLastInputs; | ||
|
||
//! Eagle data | ||
std::optional<EagleBuffers::EngineOutputs> eagleInputs; | ||
std::optional<EagleBuffers::Inputs> eagleLastInputs; | ||
}; | ||
|
||
} // namespace decoder_batch | ||
|
||
//! GPT decoder class with support for in-flight batching | ||
class GptDecoderBatched : public IGptDecoderBatched | ||
class GptDecoderBatched | ||
{ | ||
public: | ||
using CudaStreamPtr = std::shared_ptr<CudaStream>; | ||
|
@@ -47,25 +110,29 @@ class GptDecoderBatched : public IGptDecoderBatched | |
|
||
explicit GptDecoderBatched(CudaStreamPtr stream); | ||
|
||
//! @brief Setup the decoder before calling `forward()` | ||
void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, | ||
nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) override; | ||
nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig); | ||
|
||
void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override; | ||
//! @brief Disable Lookahead decoding. | ||
void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots); | ||
|
||
CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override; | ||
void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override; | ||
//! @brief Run one step for all requests without blocking the host process and return the token for synchronization. | ||
CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input); | ||
//! @brief Run one step for all requests and wait for completion on the host. | ||
void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input); | ||
|
||
//! @brief Gather final beam search results for request `batchSlot`. | ||
//! Result will only be available after event returned. | ||
[[nodiscard]] CudaEvent finalize(decoder::DecoderState const& decoderState, SizeType32 batchSlot, | ||
SamplingConfig const& samplingConfig, bool streaming) const override; | ||
SamplingConfig const& samplingConfig, bool streaming) const; | ||
|
||
CudaStreamPtr getDecoderStream() const | ||
[[nodiscard]] CudaStreamPtr getDecoderStream() const | ||
{ | ||
return mDecoderStream; | ||
} | ||
|
||
IGptDecoder& getUnderlyingDecoder() const | ||
[[nodiscard]] IGptDecoder& getUnderlyingDecoder() const | ||
{ | ||
return *mDecoder.get(); | ||
} | ||
|
@@ -87,4 +154,5 @@ class GptDecoderBatched : public IGptDecoderBatched | |
using GptDecoderPtr = std::unique_ptr<IGptDecoder>; | ||
GptDecoderPtr mDecoder; | ||
}; | ||
|
||
} // namespace tensorrt_llm::runtime |
This file was deleted.
Uh oh!
There was an error while loading. Please reload this page.