Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class CreateNewDecoderRequests : Algorithm
{
}

std::tuple<TensorPtr, std::vector<runtime::SamplingConfig>, std::vector<runtime::ITensor::SharedConstPtr>,
std::tuple<TensorPtr, runtime::SamplingConfig, std::vector<runtime::ITensor::SharedConstPtr>,
std::vector<executor::LookaheadDecodingConfig>>
operator()(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests,
Expand All @@ -86,6 +86,10 @@ class CreateNewDecoderRequests : Algorithm
runtime::CudaStream const& runtimeStream, runtime::CudaStream const& decoderStream,
SizeType32 maxSequenceLength, OptionalRef<MedusaBuffers const> medusaBuffers) const;

static TensorPtr fillBatchSlots(RequestVector const& requests, DecoderInputBuffers& inputBuffers);

static std::optional<SamplingConfig> fuseSamplingConfigs(RequestVector const& requests);

private:
bool mSpeculativeDecodingFastLogits;
bool mIsLeaderInOrchMode;
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/tensorrt_llm/runtime/decoderState.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class DecoderState
WorldConfig const& worldConfig, BufferManager const& bufferManager);

//! @brief Disable lookahead decoding.
void disableLookahead(RequestVector const& genRequests);
void disableLookahead();

//! @returns [batchSize], number of finished sequences per request, on gpu
[[nodiscard]] TensorPtr getFinishedSum() const;
Expand Down
14 changes: 2 additions & 12 deletions cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@
#include "tensorrt_llm/runtime/worldConfig.h"

#include <memory>
#include <vector>

namespace tensorrt_llm::batch_manager
{
class LlmRequest;
} // namespace tensorrt_llm::batch_manager

namespace tensorrt_llm::runtime
{
Expand All @@ -41,17 +35,13 @@ class GptDecoderBatched : public IGptDecoderBatched
{
public:
using CudaStreamPtr = std::shared_ptr<CudaStream>;
using LlmRequestPtr = std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>;
using RequestVector = std::vector<LlmRequestPtr>;
using TensorPtr = ITensor::SharedPtr;

explicit GptDecoderBatched(CudaStreamPtr stream);

void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) override;

void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override;

CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override;
void forward(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) override;

Expand All @@ -60,12 +50,12 @@ class GptDecoderBatched : public IGptDecoderBatched
[[nodiscard]] CudaEvent finalize(decoder::DecoderState const& decoderState, SizeType32 batchSlot,
SamplingConfig const& samplingConfig, bool streaming) const override;

CudaStreamPtr getDecoderStream() const
[[nodiscard]] CudaStreamPtr getDecoderStream() const
{
return mDecoderStream;
}

IGptDecoder& getUnderlyingDecoder() const
[[nodiscard]] IGptDecoder& getUnderlyingDecoder() const
{
return *mDecoder.get();
}
Expand Down
10 changes: 0 additions & 10 deletions cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,6 @@
#include <memory>
#include <vector>

namespace tensorrt_llm::batch_manager
{
class LlmRequest;
}

namespace tensorrt_llm::runtime
{
class SamplingConfig;
Expand Down Expand Up @@ -81,18 +76,13 @@ class IGptDecoderBatched
{
public:
using CudaStreamPtr = std::shared_ptr<CudaStream>;
using LlmRequestPtr = std::shared_ptr<tensorrt_llm::batch_manager::LlmRequest>;
using RequestVector = std::vector<LlmRequestPtr>;
using TensorPtr = std::shared_ptr<ITensor>;

//! @brief Setup the decoder before calling `forward()`
virtual void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth,
nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig)
= 0;

//! @brief Disable Lookahead decoding.
virtual void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) = 0;

//! @brief Run one step for all requests without blocking the host process and return the token for synchronization.
virtual CudaEvent forwardAsync(decoder::DecoderState const& decoderState, decoder_batch::Input const& input) = 0;

Expand Down
77 changes: 54 additions & 23 deletions cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ using OptionalRef = tensorrt_llm::common::OptionalRef<T>;
namespace
{

void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffers& inputBuffers,
//! @brief Fills the seqSlots and sequence lengths in the inputBuffers.
TensorPtr copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffers& inputBuffers,
ITensor& sequenceLengths, SizeType32 beamWidth, runtime::CudaStream const& stream)
{
auto const bufferManager = BufferManager{std::make_shared<CudaStream>(stream.get())};
Expand Down Expand Up @@ -82,6 +83,7 @@ void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffe
}

// copy sequence lengths
if (!contextRequests.empty())
{
auto batchSlotsDeviceView = tr::ITensor::slice(inputBuffers.setupBatchSlotsDevice, 0, batchSize);
auto fillValuesViewDevice = tr::ITensor::slice(inputBuffers.fillValuesDevice, 0, batchSize);
Expand All @@ -90,6 +92,8 @@ void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffe
bufferManager.copy(*fillValuesView, *fillValuesViewDevice);
tr::kernels::invokeFillBatch(sequenceLengths, *batchSlotsDeviceView, beamWidth, *fillValuesViewDevice, stream);
}

return batchSlotsView;
}

/// @brief Retrieve the embedding bias from the request. This potentially makes a copy of the tensor
Expand Down Expand Up @@ -131,7 +135,46 @@ void copySequenceLengths(RequestVector const& contextRequests, DecoderInputBuffe

} // namespace

std::tuple<TensorPtr, std::vector<runtime::SamplingConfig>, std::vector<runtime::ITensor::SharedConstPtr>,
// Similar to copySequenceLengths, but only fills the seqSlots.
TensorPtr CreateNewDecoderRequests::fillBatchSlots(RequestVector const& requests, DecoderInputBuffers& inputBuffers)
{
auto const batchSize = requests.size();
auto batchSlotsView = tr::ITensor::slice(inputBuffers.setupBatchSlots, 0, batchSize);

auto batchSlotsRange = tr::BufferRange<SizeType32>(*batchSlotsView);

// fill buffers on host
SizeType32 batchIdx{0};
for (auto const& llmReq : requests)
{
auto const seqSlot = llmReq->mSeqSlot.value();
batchSlotsRange[batchIdx] = seqSlot;
++batchIdx;
}

// TODO: copy to device and use in GptDecoder
// manager.copy(*batchSlotsView, *batchSlotsDeviceView);

return batchSlotsView;
}

std::optional<SamplingConfig> CreateNewDecoderRequests::fuseSamplingConfigs(RequestVector const& requests)
{
if (requests.empty())
{
return std::nullopt;
}

std::vector<SamplingConfig> samplingConfigs;
samplingConfigs.reserve(requests.size());
for (auto const& llmReq : requests)
{
samplingConfigs.push_back(llmReq->mSamplingConfig);
}
return SamplingConfig(samplingConfigs);
}

std::tuple<TensorPtr, runtime::SamplingConfig, std::vector<runtime::ITensor::SharedConstPtr>,
std::vector<executor::LookaheadDecodingConfig>>
CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, runtime::WorldConfig const& worldConfig,
executor::DecodingConfig const& decodingConfig, RequestVector const& contextRequests, nvinfer1::DataType logitsType,
Expand All @@ -142,33 +185,21 @@ CreateNewDecoderRequests::operator()(runtime::ModelConfig const& modelConfig, ru
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
NVTX3_SCOPED_RANGE(CreateNewDecoderRequests);

RequestVector finishedContextRequests;
std::copy_if(contextRequests.begin(), contextRequests.end(), std::back_inserter(finishedContextRequests),
[](auto const& llmReq) { return llmReq->isLastContextChunk(); });
TLLM_CHECK_WITH_INFO(
!contextRequests.empty(), "CreateNewDecoderRequests should be called with at least one request");

if (!finishedContextRequests.empty())
{
copySequenceLengths(
finishedContextRequests, inputBuffers, *decoderState.getSequenceLengths(), beamWidth, runtimeStream);
}
auto batchSlotsView = copySequenceLengths(
contextRequests, inputBuffers, *decoderState.getSequenceLengths(), beamWidth, runtimeStream);

auto [lookaheadPrompt, lookaheadAlgoConfigs]
= createDecoderRequests(finishedContextRequests, inputBuffers.inputsIds, decodingConfig, decoderState,
logitsType, modelConfig, worldConfig, runtimeStream, decoderStream, maxSequenceLength, medusaBuffers);

auto const batchSize = finishedContextRequests.size();

std::vector<SamplingConfig> samplingConfigs;
samplingConfigs.reserve(batchSize);
for (auto const& llmReq : finishedContextRequests)
{
samplingConfigs.push_back(llmReq->mSamplingConfig);
}
= createDecoderRequests(contextRequests, inputBuffers.inputsIds, decodingConfig, decoderState, logitsType,
modelConfig, worldConfig, runtimeStream, decoderStream, maxSequenceLength, medusaBuffers);

TensorPtr batchSlotsView = runtime::ITensor::slice(inputBuffers.setupBatchSlots, 0, batchSize);
auto samplingConfig = fuseSamplingConfigs(contextRequests);
TLLM_CHECK(samplingConfig.has_value());

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
return {std::move(batchSlotsView), std::move(samplingConfigs), std::move(lookaheadPrompt),
return {std::move(batchSlotsView), std::move(samplingConfig.value()), std::move(lookaheadPrompt),
std::move(lookaheadAlgoConfigs)};
}

Expand Down
56 changes: 42 additions & 14 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1863,24 +1863,29 @@ void TrtGptModelInflightBatching::setupDecoderStep(

if (mWorldConfig.isLastPipelineParallelRank() && !contextRequests.empty())
{
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits");
RequestVector finishedContextRequests;
std::copy_if(contextRequests.begin(), contextRequests.end(), std::back_inserter(finishedContextRequests),
[](auto const& llmReq) { return llmReq->isLastContextChunk(); });

auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs]
= (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests, logitsType,
inputBuffers, *mDecoderState, mRuntime->getStream(), *mDecoder->getDecoderStream(), getMaxSequenceLen(),
mOperatingBeamWidth, buffers.mMedusaBuffers);

auto const localBatchSize = batchSlots->getSize();
if (localBatchSize > 0)
if (!finishedContextRequests.empty())
{
Comment on lines +1866 to 1871
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Don’t drop disaggregated-generation setup when no “last context chunk” is present.

setupDecoderStep now filters to last-context chunks only. When called from prepareDistGenBufferAndDecoder (Lines 1681-1683) with generation requests, isLastContextChunk() will typically be false, so finishedContextRequests becomes empty and decoder setup is skipped, breaking the disaggregated generation init path.

Fix: fall back to using the passed-in requests if no last-context chunk is found.

-        RequestVector finishedContextRequests;
-        std::copy_if(contextRequests.begin(), contextRequests.end(), std::back_inserter(finishedContextRequests),
-            [](auto const& llmReq) { return llmReq->isLastContextChunk(); });
+        RequestVector finishedContextRequests;
+        std::copy_if(contextRequests.begin(), contextRequests.end(), std::back_inserter(finishedContextRequests),
+            [](auto const& llmReq) { return llmReq->isLastContextChunk(); });
+        // Disagg generation init path calls this with generation-ready requests (no context chunk in-flight).
+        if (finishedContextRequests.empty())
+        {
+            finishedContextRequests = contextRequests;
+        }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
RequestVector finishedContextRequests;
std::copy_if(contextRequests.begin(), contextRequests.end(), std::back_inserter(finishedContextRequests),
[](auto const& llmReq) { return llmReq->isLastContextChunk(); });
auto [batchSlots, samplingConfigs, lookaheadPrompt, lookaheadAlgoConfigs]
= (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, contextRequests, logitsType,
inputBuffers, *mDecoderState, mRuntime->getStream(), *mDecoder->getDecoderStream(), getMaxSequenceLen(),
mOperatingBeamWidth, buffers.mMedusaBuffers);
auto const localBatchSize = batchSlots->getSize();
if (localBatchSize > 0)
if (!finishedContextRequests.empty())
{
RequestVector finishedContextRequests;
std::copy_if(contextRequests.begin(), contextRequests.end(), std::back_inserter(finishedContextRequests),
[](auto const& llmReq) { return llmReq->isLastContextChunk(); });
// Disagg generation init path calls this with generation-ready requests (no context chunk in-flight).
if (finishedContextRequests.empty())
{
finishedContextRequests = contextRequests;
}
if (!finishedContextRequests.empty())
{
// … existing decoder-setup logic …
}
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp around lines
1866 to 1871, the code filters contextRequests to only last-context chunks and
skips decoder setup if none are found, which drops disaggregated-generation
initialization; change the logic so that after building finishedContextRequests
you check if it is empty and, if so, assign finishedContextRequests =
contextRequests (i.e., fall back to the original passed-in requests) before
calling setupDecoderStep, ensuring decoder setup still runs for generation
requests when no last-context chunk exists.

auto samplingConfig = SamplingConfig(samplingConfigs);
auto const logitsType = mRuntime->getEngine().getTensorDataType("logits");

auto [batchSlots, samplingConfig, lookaheadPrompt, lookaheadAlgoConfigs]
= (*mCreateNewDecoderRequests)(mModelConfig, mWorldConfig, mDecodingConfig, finishedContextRequests,
logitsType, inputBuffers, *mDecoderState, mRuntime->getStream(), *mDecoder->getDecoderStream(),
getMaxSequenceLen(), mOperatingBeamWidth, buffers.mMedusaBuffers);

auto const localBatchSize = batchSlots->getSize();
TLLM_CHECK_WITH_INFO(localBatchSize > 0, "Decoder setup should be called with at least one request");

mDecoder->getUnderlyingDecoder().setup(samplingConfig, localBatchSize, batchSlots,
{mDecoderState->getJointDecodingOutput()}, mModelConfig.getDataType(), lookaheadPrompt,
lookaheadAlgoConfigs);

auto const& stream = mDecoder->getDecoderStream();
auto const& decoderStream = mDecoder->getDecoderStream();
CudaEvent event{};
stream->record(event);
decoderStream->record(event);
mRuntime->getStreamPtr()->wait(event);
}
}
Expand Down Expand Up @@ -2515,6 +2520,24 @@ void TrtGptModelInflightBatching::changeBeamWidth(SizeType32 beamWidth)
TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

void TrtGptModelInflightBatching::disableLookaheadDecoder(
RequestVector const& genRequests, DecoderInputBuffers& inputBuffers)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);

auto batchSlots = CreateNewDecoderRequests::fillBatchSlots(genRequests, inputBuffers);
auto samplingConfig = CreateNewDecoderRequests::fuseSamplingConfigs(genRequests);

mDecoder->getUnderlyingDecoder().disableLookahead(samplingConfig, batchSlots->getSize(), batchSlots);

auto const& decoderStream = mDecoder->getDecoderStream();
CudaEvent event{};
decoderStream->record(event);
mRuntime->getStreamPtr()->wait(event);

TLLM_LOG_TRACE("%s stop", __PRETTY_FUNCTION__);
}

void TrtGptModelInflightBatching::changeSpecDecMode(ScheduledRequests const& scheduledRequests)
{
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
Expand Down Expand Up @@ -2602,11 +2625,16 @@ void TrtGptModelInflightBatching::changeSpecDecMode(ScheduledRequests const& sch
mDecodingConfig.setDecodingMode(executor::DecodingMode::Auto());
mBuffers.at(bufferId)->mLookaheadBuffers->disableLookaheadDecoding();
mDecoderOutputBuffers.at(getFusedBufferId()).disableLookaheadDecoding(getMaxNumSequences());
mDecoder->disableLookahead(
scheduledRequests.generationRequests, mDecoderInputBuffers.at(getFusedBufferId()).setupBatchSlots);
mDecoderState->disableLookahead(scheduledRequests.generationRequests);
disableLookaheadDecoder(scheduledRequests.generationRequests, mDecoderInputBuffers.at(getFusedBufferId()));
mDecoderState->disableLookahead();

for (auto const& llmReq : scheduledRequests.generationRequests)
{
if (llmReq->mSeqSlot)
{
mDecoderState->setNumDecodingEngineTokens(llmReq->mSeqSlot.value(), 1);
}

Comment on lines +2628 to +2637
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Disable LAD for all active slots (ctx + gen), and reset engine tokens consistently.

The current call disables lookahead only for scheduledRequests.generationRequests. When LAD is turned off due to constraints, there may be only context requests in-flight; underlying decoder state for those slots won’t be updated. Also, resetting “numDecodingEngineTokens = 1” should apply to all slots in this transition.

-        disableLookaheadDecoder(scheduledRequests.generationRequests, mDecoderInputBuffers.at(getFusedBufferId()));
-        mDecoderState->disableLookahead();
-
-        for (auto const& llmReq : scheduledRequests.generationRequests)
+        // Apply to all scheduled slots (both ctx and gen have seqSlot at this point)
+        RequestVector requestsForDisable;
+        requestsForDisable.reserve(
+            scheduledRequests.contextRequests.size() + scheduledRequests.generationRequests.size());
+        requestsForDisable.insert(requestsForDisable.end(),
+            scheduledRequests.contextRequests.begin(), scheduledRequests.contextRequests.end());
+        requestsForDisable.insert(requestsForDisable.end(),
+            scheduledRequests.generationRequests.begin(), scheduledRequests.generationRequests.end());
+
+        disableLookaheadDecoder(requestsForDisable, mDecoderInputBuffers.at(getFusedBufferId()));
+        mDecoderState->disableLookahead();
+
+        for (auto const& llmReq : requestsForDisable)
         {
             if (llmReq->mSeqSlot)
             {
                 mDecoderState->setNumDecodingEngineTokens(llmReq->mSeqSlot.value(), 1);
             }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
disableLookaheadDecoder(scheduledRequests.generationRequests, mDecoderInputBuffers.at(getFusedBufferId()));
mDecoderState->disableLookahead();
for (auto const& llmReq : scheduledRequests.generationRequests)
{
if (llmReq->mSeqSlot)
{
mDecoderState->setNumDecodingEngineTokens(llmReq->mSeqSlot.value(), 1);
}
// Apply to all scheduled slots (both ctx and gen have seqSlot at this point)
RequestVector requestsForDisable;
requestsForDisable.reserve(
scheduledRequests.contextRequests.size() + scheduledRequests.generationRequests.size());
requestsForDisable.insert(requestsForDisable.end(),
scheduledRequests.contextRequests.begin(), scheduledRequests.contextRequests.end());
requestsForDisable.insert(requestsForDisable.end(),
scheduledRequests.generationRequests.begin(), scheduledRequests.generationRequests.end());
disableLookaheadDecoder(requestsForDisable, mDecoderInputBuffers.at(getFusedBufferId()));
mDecoderState->disableLookahead();
for (auto const& llmReq : requestsForDisable)
{
if (llmReq->mSeqSlot)
{
mDecoderState->setNumDecodingEngineTokens(llmReq->mSeqSlot.value(), 1);
}
}
🤖 Prompt for AI Agents
In cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp around lines
2628 to 2637, the code only disables lookahead for
scheduledRequests.generationRequests and only resets numDecodingEngineTokens for
those gen slots; when LAD is turned off there may be only context requests
in-flight and their decoder state won't be updated. Fix by invoking
disableLookaheadDecoder over all active requests (both generationRequests and
contextRequests) using the same fused buffer, keep
mDecoderState->disableLookahead() as is, and iterate every active slot from both
lists to call mDecoderState->setNumDecodingEngineTokens(slot, 1) so all slots
are consistently reset when disabling lookahead.

if (llmReq->getNumDraftTokens() > 0)
{
llmReq->discardDraftTokens(llmReq->getNumDraftTokens());
Expand Down
2 changes: 2 additions & 0 deletions cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.h
Original file line number Diff line number Diff line change
Expand Up @@ -388,6 +388,8 @@ class TrtGptModelInflightBatching : public TrtGptModel
/// @brief Change the speculative decoding mode.
void changeSpecDecMode(ScheduledRequests const& scheduledRequests);

void disableLookaheadDecoder(RequestVector const& genRequests, DecoderInputBuffers& inputBuffers);

void prefetchNextPromptTableChunk(RequestVector const& contextRequests, bool isFirstChunk, SizeType32 bufferId);

void remapInputTokensForPromptTable(
Expand Down
12 changes: 0 additions & 12 deletions cpp/tensorrt_llm/nanobind/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,6 @@ using OptVec = std::optional<std::vector<T>>;
#error "TRTLLM_NB_MODULE must be defined"
#endif

namespace
{
tr::SamplingConfig makeSamplingConfig(std::vector<tr::SamplingConfig> const& configs)
{
return tr::SamplingConfig(configs);
}
} // namespace

NB_MODULE(TRTLLM_NB_MODULE, m)
{
m.doc() = "TensorRT-LLM Python bindings for C++ runtime";
Expand Down Expand Up @@ -425,10 +417,6 @@ NB_MODULE(TRTLLM_NB_MODULE, m)
.def("__setstate__", SamplingConfigSetState)
.def("__eq__", &tr::SamplingConfig::operator==);

nb::bind_vector<std::vector<tr::SamplingConfig>>(m, "SamplingConfigVector");

m.def("make_sampling_config", &makeSamplingConfig, nb::arg("configs"));

nb::class_<tr::GptJsonConfig>(m, "GptJsonConfig")
.def(nb::init<std::string, std::string, std::string, SizeType32, SizeType32, SizeType32, SizeType32,
tr::ModelConfig, std::optional<tr::RuntimeDefaults>>(),
Expand Down
2 changes: 0 additions & 2 deletions cpp/tensorrt_llm/nanobind/common/customCasters.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
#include "tensorrt_llm/batch_manager/decoderBuffers.h"
#include "tensorrt_llm/common/optionalRef.h"
#include "tensorrt_llm/runtime/cudaStream.h"
#include "tensorrt_llm/runtime/samplingConfig.h"
#include "tensorrt_llm/runtime/torch.h"
#include "tensorrt_llm/runtime/torchView.h"

Expand All @@ -45,7 +44,6 @@
// Opaque bindings
NB_MAKE_OPAQUE(tensorrt_llm::batch_manager::ReqIdsSet)
NB_MAKE_OPAQUE(std::vector<tensorrt_llm::batch_manager::SlotDecoderBuffers>)
NB_MAKE_OPAQUE(std::vector<tensorrt_llm::runtime::SamplingConfig>)

namespace nb = nanobind;

Expand Down
10 changes: 0 additions & 10 deletions cpp/tensorrt_llm/pybind/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,6 @@ using OptVec = std::optional<std::vector<T>>;
#error "TRTLLM_PYBIND_MODULE must be defined"
#endif

namespace
{
tr::SamplingConfig makeSamplingConfig(std::vector<tr::SamplingConfig> const& configs)
{
return tr::SamplingConfig(configs);
}
} // namespace

PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
{
m.doc() = "TensorRT-LLM Python bindings for C++ runtime";
Expand Down Expand Up @@ -415,8 +407,6 @@ PYBIND11_MODULE(TRTLLM_PYBIND_MODULE, m)
.def(py::pickle(SamplingConfigGetState, SamplingConfigSetState))
.def("__eq__", &tr::SamplingConfig::operator==);

m.def("make_sampling_config", &makeSamplingConfig, py::arg("configs"));

py::class_<tr::GptJsonConfig>(m, "GptJsonConfig")
.def(py::init<std::string, std::string, std::string, SizeType32, SizeType32, SizeType32, SizeType32,
tr::ModelConfig, std::optional<tr::RuntimeDefaults>>(),
Expand Down
Loading