Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion benchmarks/cpp/disaggServerBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,8 @@ texec::Request makeExecutorContextRequest(Sample const& sample, SizeType32 const
std::nullopt, // kvCacheRetentionConfig
std::nullopt, // logitsPostProcessorName
std::nullopt, // logitsPostProcessor
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt);
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt,
std::nullopt); // cacheSaltID
request.setRequestType(tensorrt_llm::executor::RequestType::REQUEST_TYPE_CONTEXT_ONLY);
return request;
}
Expand Down
3 changes: 2 additions & 1 deletion benchmarks/cpp/gptManagerBenchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,7 +837,8 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW
std::nullopt, // kvCacheRetentionConfig
std::nullopt, // logitsPostProcessorName
std::nullopt, // logitsPostProcessor
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt);
encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt,
std::nullopt); // cacheSaltID
}

void benchmarkExecutor(std::optional<std::filesystem::path> const& decoderEngineDir,
Expand Down
9 changes: 6 additions & 3 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ using UniqueToken = tensorrt_llm::runtime::UniqueToken;
using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType;
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;
using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType;

// Type alias for multimodal hash key (hash array + start offset)
using MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>;
Expand Down Expand Up @@ -115,6 +116,7 @@ struct BlockKey
// Extra keys for multimodal data (similar to VLLM's approach)
// Each extra key is a pair of (mm_hash, start_offset_in_block)
std::vector<MmKey> extraKeys;
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt;

BlockKey() = default;

Expand All @@ -129,24 +131,25 @@ struct BlockKey
}

explicit BlockKey(bool usesExtraIds, std::optional<LoraTaskIdType> loraTaskId, VecUniqueTokens uniqueTokens,
std::vector<MmKey> extraKeys = {})
std::vector<MmKey> extraKeys = {}, std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: usesExtraIds{usesExtraIds}
, loraTaskId{loraTaskId}
, uniqueTokens{std::move(uniqueTokens)}
, extraKeys{std::move(extraKeys)}
, cacheSaltID{cacheSaltID}
{
}

bool operator==(BlockKey const& other) const noexcept
{
return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId
&& uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys);
&& uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID);
}

int partialMatch(BlockKey const& other) const noexcept
{
SizeType32 numMatched{0};
if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys)
if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID)
{
auto [matchEnd, otherMatchEnd] = std::mismatch(
uniqueTokens.begin(), uniqueTokens.end(), other.uniqueTokens.begin(), other.uniqueTokens.end());
Expand Down
37 changes: 26 additions & 11 deletions cpp/include/tensorrt_llm/batch_manager/llmRequest.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ class GenericLlmRequest
RequestIdType, TensorPtr&, BeamTokens const&, TStream const&, std::optional<RequestIdType>)>;
using RequestPtr = std::shared_ptr<GenericLlmRequest>;
using MillisecondsType = std::chrono::milliseconds;
using CacheSaltIDType = runtime::CacheSaltIDType;

// 49 parameters, 56 items in initialization list
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr<VecTokens> const& inputTokens,
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
Expand Down Expand Up @@ -134,7 +134,8 @@ class GenericLlmRequest
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: mRequestId(requestId)
, mPromptLen(inputTokens->size())
, mMaxNewTokens(maxNewTokens)
Expand Down Expand Up @@ -191,6 +192,7 @@ class GenericLlmRequest
, mGuidedDecodingParams(std::move(guidedDecodingParams))
, mLanguageAdapterUid(languageAdapterUid)
, mAllottedTimeMs(allottedTimeMs)
, mCacheSaltID(cacheSaltID)
{
if (mEncoderTokens.has_value() || encoderInputFeatures.has_value())
{
Expand All @@ -200,7 +202,6 @@ class GenericLlmRequest
initialize(*inputTokens, returnLogProbs);
}

// 32 parameters, 39 items in initialization list
GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, VecTokens const& inputTokens,
runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional<SizeType32> endId = std::nullopt,
std::optional<SizeType32> padId = std::nullopt, std::optional<TensorPtr> embeddingBias = std::nullopt,
Expand All @@ -218,7 +219,8 @@ class GenericLlmRequest
bool returnEncoderOutput = false, std::optional<RequestIdType> clientId = std::nullopt,
executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: mRequestId(requestId)
, mPromptLen(inputTokens.size())
, mMaxNewTokens(maxNewTokens)
Expand Down Expand Up @@ -258,6 +260,7 @@ class GenericLlmRequest
, mContextPhaseParams(contextPhaseParams)
, mNumReturnSequences(numReturnSequences)
, mLanguageAdapterUid(languageAdapterUid)
, mCacheSaltID(cacheSaltID)
{
if (mEncoderTokens.has_value())
{
Expand All @@ -266,7 +269,6 @@ class GenericLlmRequest
initialize(inputTokens, returnLogProbs);
}

// 29 items in initialization list
GenericLlmRequest(RequestIdType requestId, executor::Request const& req)
: mRequestId(requestId)
, mPromptLen(req.getInputTokenIds().size())
Expand Down Expand Up @@ -297,6 +299,7 @@ class GenericLlmRequest
, mGuidedDecodingParams(req.getGuidedDecodingParams())
, mLanguageAdapterUid(req.getLanguageAdapterUid())
, mAllottedTimeMs(req.getAllottedTimeMs())
, mCacheSaltID(req.getCacheSaltID())
{
if (req.getRequestType() == executor::RequestType::REQUEST_TYPE_GENERATION_ONLY)
{
Expand Down Expand Up @@ -1761,6 +1764,11 @@ class GenericLlmRequest
return mLanguageAdapterUid;
}

[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const
{
return mCacheSaltID;
}

std::vector<SizeType32> getLanguageAdapterRouting(
SizeType32 const reqNumLanguages, SizeType32 const inputLength) const
{
Expand Down Expand Up @@ -2039,6 +2047,9 @@ class GenericLlmRequest

bool mUseDraftModel{false};

// Cache salt id for each request.
std::optional<CacheSaltIDType> mCacheSaltID{std::nullopt};

private:
void initialize(VecTokens const& inputTokens, bool outputLogProbs)
{
Expand Down Expand Up @@ -2219,7 +2230,8 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId,
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
std::move(promptEmbeddingTable), promptVocabSize, std::move(multimodalHashes),
Expand All @@ -2231,7 +2243,8 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
std::move(encoderInputTokens), returnEncoderOutput, clientId, priority, std::move(encoderInputFeatures),
std::move(encoderOutputLength), std::move(crossAttentionMask), llmRequestType,
std::move(inputTokenExtraIds), numReturnSequences, std::move(eagleConfig), std::move(skipCrossAttnBlocks),
returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams)
returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams,
cacheSaltID)
{
}

Expand Down Expand Up @@ -2269,7 +2282,8 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
std::optional<executor::GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: Base(requestId, maxNewTokens, std::make_shared<std::vector<TokenIdType>>(std::move(inputTokens)),
samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList),
std::move(stopWordsList),
Expand Down Expand Up @@ -2299,7 +2313,7 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
inputTokenExtraIds ? std::make_optional(std::make_shared<VecTokenExtraIds>(std::move(*inputTokenExtraIds)))
: std::optional<std::shared_ptr<VecTokenExtraIds>>(std::nullopt),
numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks, returnPerfMetrics,
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams)
std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, cacheSaltID)
{
}

Expand All @@ -2321,14 +2335,15 @@ class LlmRequest : public GenericLlmRequest<runtime::ITensor::SharedPtr>
bool returnEncoderOutput = false, std::optional<RequestIdType> clientId = std::nullopt,
executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt)
std::optional<executor::ContextPhaseParams> const& contextPhaseParams = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt)
: Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId,
std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds),
std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig),
lookaheadConfig, returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens),
std::move(draftLogits), excludeInputFromOutput, std::move(logitsPostProcessor),
applyLogitsPostProcessorBatched, std::move(encoderInputTokens), returnEncoderOutput, clientId, priority,
numReturnSequences, languageAdapterUid, contextPhaseParams)
numReturnSequences, languageAdapterUid, contextPhaseParams, cacheSaltID)
{
}

Expand Down
7 changes: 5 additions & 2 deletions cpp/include/tensorrt_llm/executor/executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ class Request
/// @param allottedTimeMs The allotted time in milliseconds after which the request is cancelled with a timedOut
/// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism
/// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled.
// 34 parameters
/// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string.
Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false,
SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(),
std::optional<SizeType32> const& endId = std::nullopt, std::optional<SizeType32> const& padId = std::nullopt,
Expand All @@ -697,7 +697,8 @@ class Request
std::optional<EagleConfig> eagleConfig = std::nullopt, std::optional<Tensor> skipCrossAttnBlocks = std::nullopt,
std::optional<GuidedDecodingParams> guidedDecodingParams = std::nullopt,
std::optional<SizeType32> languageAdapterUid = std::nullopt,
std::optional<MillisecondsType> allottedTimeMs = std::nullopt);
std::optional<MillisecondsType> allottedTimeMs = std::nullopt,
std::optional<CacheSaltIDType> cacheSaltID = std::nullopt);

/// @brief This logits postprocessor name will dispatch to the batched logits postprocessor
static auto constexpr kBatchedPostProcessorName = "batched";
Expand Down Expand Up @@ -745,6 +746,7 @@ class Request
[[nodiscard]] std::optional<GuidedDecodingParams> getGuidedDecodingParams() const;
[[nodiscard]] std::optional<SizeType32> getLanguageAdapterUid() const;
[[nodiscard]] std::optional<MillisecondsType> getAllottedTimeMs() const;
[[nodiscard]] std::optional<CacheSaltIDType> getCacheSaltID() const;
[[nodiscard]] std::optional<std::vector<std::string>> getAdditionalOutputNames() const;

void setStreaming(bool streaming);
Expand Down Expand Up @@ -780,6 +782,7 @@ class Request
void setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams);
void setLanguageAdapterUid(SizeType32 languageAdapterUid);
void setAllottedTimeMs(MillisecondsType allottedTimeMs);
void setCacheSaltID(CacheSaltIDType cacheSaltID);

private:
friend class Serialization;
Expand Down
1 change: 1 addition & 0 deletions cpp/include/tensorrt_llm/executor/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ using RandomSeedType = std::uint64_t;
using VecLogProbs = std::vector<FloatType>;
using StreamPtr = std::shared_ptr<tensorrt_llm::runtime::CudaStream>;
using MillisecondsType = std::chrono::milliseconds;
using CacheSaltIDType = std::uint64_t;
using LogitsPostProcessor
= std::function<void(IdType, Tensor&, BeamTokens const&, StreamPtr const&, std::optional<IdType>)>;
using LogitsPostProcessorMap = std::unordered_map<std::string, LogitsPostProcessor>;
Expand Down
1 change: 1 addition & 0 deletions cpp/include/tensorrt_llm/runtime/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ using TokenIdType = std::int32_t;
using LoraTaskIdType = std::uint64_t;
using TokenExtraIdType = std::uint64_t;
using VecTokenExtraIds = std::vector<TokenExtraIdType>;
using CacheSaltIDType = std::uint64_t;

struct UniqueToken
{
Expand Down
14 changes: 12 additions & 2 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ std::vector<MmKey> generateBlockHashExtraKeys(
// Check if this multimodal content overlaps with the current block
if (endTokenIdx > startPos && startTokenIdx < startPos + length)
{
SizeType32 mmStartInBlock = (startPos >= startTokenIdx) ? 0 : startTokenIdx - startPos;
uint64_t mmStartInBlock = (startPos >= startTokenIdx) ? 0 : static_cast<uint64_t>(startTokenIdx - startPos);
extraKeys.emplace_back(mmHashArray, mmStartInBlock);
}
}
Expand All @@ -151,7 +151,7 @@ std::vector<BlockKey> buildBlockKeys(
currentTokenIdx += uniqueTokens.size();

blockKeys.emplace_back(llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(),
std::move(uniqueTokens), std::move(extraKeys));
std::move(uniqueTokens), std::move(extraKeys), llmRequest.getCacheSaltID());
}
return blockKeys;
}
Expand All @@ -167,6 +167,16 @@ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) no
// Constants provide very good distribution - each input bit affects each output bit with ~50% probability.
size_t seed = blockKey.uniqueTokens.size() ^ parentHash * UINT64_C(0xbf58476d1ce4e5b9);

if (parentHash == 0 && blockKey.cacheSaltID)
{
// Only hashing the cache salt ID for the first block in the sequence
uint64_t c = blockKey.cacheSaltID.value();
c = (c ^ (c >> 30)) * UINT64_C(0xbf58476d1ce4e5b9);
c = (c ^ (c >> 27)) * UINT64_C(0x94d049bb133111eb);
c = c ^ (c >> 31);
seed ^= c + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}

for (auto const& uniqueToken : blockKey.uniqueTokens)
{
uint32_t a = static_cast<uint32_t>(uniqueToken.tokenId);
Expand Down
15 changes: 12 additions & 3 deletions cpp/tensorrt_llm/executor/request.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@

namespace tensorrt_llm::executor
{
// 36 parameters
Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, SamplingConfig const& samplingConfig,
OutputConfig const& outputConfig, std::optional<SizeType32> const& endId, std::optional<SizeType32> const& padId,
std::optional<std::vector<SizeType32>> positionIds, std::optional<std::list<VecTokens>> badWords,
Expand All @@ -41,7 +40,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming,
std::optional<SizeType32> encoderOutputLength, std::optional<Tensor> crossAttentionMask,
SizeType32 numReturnSequences, std::optional<EagleConfig> eagleConfig, std::optional<Tensor> skipCrossAttnBlocks,
std::optional<GuidedDecodingParams> guidedDecodingParams, std::optional<SizeType32> languageAdapterUid,
std::optional<MillisecondsType> allottedTimeMs)
std::optional<MillisecondsType> allottedTimeMs, std::optional<CacheSaltIDType> cacheSaltID)
: mImpl(std::make_unique<Impl>(std::move(inputTokenIds), maxTokens, streaming, samplingConfig, outputConfig, endId,
padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias),
std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput),
Expand All @@ -50,7 +49,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming,
std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, type,
std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, crossAttentionMask,
numReturnSequences, eagleConfig, skipCrossAttnBlocks, std::move(guidedDecodingParams), languageAdapterUid,
allottedTimeMs))
allottedTimeMs, cacheSaltID))
{
}

Expand Down Expand Up @@ -249,6 +248,11 @@ std::optional<SizeType32> Request::getLanguageAdapterUid() const
return mImpl->getLanguageAdapterUid();
}

std::optional<CacheSaltIDType> Request::getCacheSaltID() const
{
return mImpl->getCacheSaltID();
}

void Request::setStreaming(bool streaming)
{
mImpl->setStreaming(streaming);
Expand Down Expand Up @@ -413,4 +417,9 @@ void Request::setLanguageAdapterUid(SizeType32 languageAdapterUid)
{
return mImpl->setLanguageAdapterUid(languageAdapterUid);
}

void Request::setCacheSaltID(CacheSaltIDType cacheSaltID)
{
return mImpl->setCacheSaltID(cacheSaltID);
}
} // namespace tensorrt_llm::executor
Loading