diff --git a/benchmarks/cpp/disaggServerBenchmark.cpp b/benchmarks/cpp/disaggServerBenchmark.cpp index ab009802757..37394869a35 100644 --- a/benchmarks/cpp/disaggServerBenchmark.cpp +++ b/benchmarks/cpp/disaggServerBenchmark.cpp @@ -542,7 +542,8 @@ texec::Request makeExecutorContextRequest(Sample const& sample, SizeType32 const std::nullopt, // kvCacheRetentionConfig std::nullopt, // logitsPostProcessorName std::nullopt, // logitsPostProcessor - encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt); + encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt, + std::nullopt); // cacheSaltID request.setRequestType(tensorrt_llm::executor::RequestType::REQUEST_TYPE_CONTEXT_ONLY); return request; } diff --git a/benchmarks/cpp/gptManagerBenchmark.cpp b/benchmarks/cpp/gptManagerBenchmark.cpp index a586610f154..553c98bb78b 100644 --- a/benchmarks/cpp/gptManagerBenchmark.cpp +++ b/benchmarks/cpp/gptManagerBenchmark.cpp @@ -837,7 +837,8 @@ texec::Request makeExecutorRequest(Sample const& sample, SizeType32 const& beamW std::nullopt, // kvCacheRetentionConfig std::nullopt, // logitsPostProcessorName std::nullopt, // logitsPostProcessor - encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt); + encoderInputTokenIds.has_value() ? encoderInputTokenIds : std::nullopt, + std::nullopt); // cacheSaltID } void benchmarkExecutor(std::optional const& decoderEngineDir, diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index 8940b160a15..1c6ed78b3e9 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -69,6 +69,7 @@ using UniqueToken = tensorrt_llm::runtime::UniqueToken; using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens; using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType; using BlocksPerWindow = std::map>; +using CacheSaltIDType = tensorrt_llm::runtime::CacheSaltIDType; // Type alias for multimodal hash key (hash array + start offset) using MmKey = std::pair, SizeType32>; @@ -115,6 +116,7 @@ struct BlockKey // Extra keys for multimodal data (similar to VLLM's approach) // Each extra key is a pair of (mm_hash, start_offset_in_block) std::vector extraKeys; + std::optional cacheSaltID = std::nullopt; BlockKey() = default; @@ -129,24 +131,25 @@ struct BlockKey } explicit BlockKey(bool usesExtraIds, std::optional loraTaskId, VecUniqueTokens uniqueTokens, - std::vector extraKeys = {}) + std::vector extraKeys = {}, std::optional cacheSaltID = std::nullopt) : usesExtraIds{usesExtraIds} , loraTaskId{loraTaskId} , uniqueTokens{std::move(uniqueTokens)} , extraKeys{std::move(extraKeys)} + , cacheSaltID{cacheSaltID} { } bool operator==(BlockKey const& other) const noexcept { return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId - && uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys); + && uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID); } int partialMatch(BlockKey const& other) const noexcept { SizeType32 numMatched{0}; - if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys) + if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys && cacheSaltID == other.cacheSaltID) { auto [matchEnd, otherMatchEnd] = std::mismatch( uniqueTokens.begin(), uniqueTokens.end(), other.uniqueTokens.begin(), other.uniqueTokens.end()); diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index f069e3ac7f5..d818149d734 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -97,8 +97,8 @@ class GenericLlmRequest RequestIdType, TensorPtr&, BeamTokens const&, TStream const&, std::optional)>; using RequestPtr = std::shared_ptr; using MillisecondsType = std::chrono::milliseconds; + using CacheSaltIDType = runtime::CacheSaltIDType; - // 49 parameters, 56 items in initialization list GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::shared_ptr const& inputTokens, runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional endId = std::nullopt, std::optional padId = std::nullopt, std::optional embeddingBias = std::nullopt, @@ -134,7 +134,8 @@ class GenericLlmRequest std::optional guidedDecodingParams = std::nullopt, std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, - std::optional const& contextPhaseParams = std::nullopt) + std::optional const& contextPhaseParams = std::nullopt, + std::optional cacheSaltID = std::nullopt) : mRequestId(requestId) , mPromptLen(inputTokens->size()) , mMaxNewTokens(maxNewTokens) @@ -191,6 +192,7 @@ class GenericLlmRequest , mGuidedDecodingParams(std::move(guidedDecodingParams)) , mLanguageAdapterUid(languageAdapterUid) , mAllottedTimeMs(allottedTimeMs) + , mCacheSaltID(cacheSaltID) { if (mEncoderTokens.has_value() || encoderInputFeatures.has_value()) { @@ -200,7 +202,6 @@ class GenericLlmRequest initialize(*inputTokens, returnLogProbs); } - // 32 parameters, 39 items in initialization list GenericLlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, VecTokens const& inputTokens, runtime::SamplingConfig const& samplingConfig, bool isStreaming, std::optional endId = std::nullopt, std::optional padId = std::nullopt, std::optional embeddingBias = std::nullopt, @@ -218,7 +219,8 @@ class GenericLlmRequest bool returnEncoderOutput = false, std::optional clientId = std::nullopt, executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1, std::optional languageAdapterUid = std::nullopt, - std::optional const& contextPhaseParams = std::nullopt) + std::optional const& contextPhaseParams = std::nullopt, + std::optional cacheSaltID = std::nullopt) : mRequestId(requestId) , mPromptLen(inputTokens.size()) , mMaxNewTokens(maxNewTokens) @@ -258,6 +260,7 @@ class GenericLlmRequest , mContextPhaseParams(contextPhaseParams) , mNumReturnSequences(numReturnSequences) , mLanguageAdapterUid(languageAdapterUid) + , mCacheSaltID(cacheSaltID) { if (mEncoderTokens.has_value()) { @@ -266,7 +269,6 @@ class GenericLlmRequest initialize(inputTokens, returnLogProbs); } - // 29 items in initialization list GenericLlmRequest(RequestIdType requestId, executor::Request const& req) : mRequestId(requestId) , mPromptLen(req.getInputTokenIds().size()) @@ -297,6 +299,7 @@ class GenericLlmRequest , mGuidedDecodingParams(req.getGuidedDecodingParams()) , mLanguageAdapterUid(req.getLanguageAdapterUid()) , mAllottedTimeMs(req.getAllottedTimeMs()) + , mCacheSaltID(req.getCacheSaltID()) { if (req.getRequestType() == executor::RequestType::REQUEST_TYPE_GENERATION_ONLY) { @@ -1761,6 +1764,11 @@ class GenericLlmRequest return mLanguageAdapterUid; } + [[nodiscard]] std::optional getCacheSaltID() const + { + return mCacheSaltID; + } + std::vector getLanguageAdapterRouting( SizeType32 const reqNumLanguages, SizeType32 const inputLength) const { @@ -2039,6 +2047,9 @@ class GenericLlmRequest bool mUseDraftModel{false}; + // Cache salt id for each request. + std::optional mCacheSaltID{std::nullopt}; + private: void initialize(VecTokens const& inputTokens, bool outputLogProbs) { @@ -2219,7 +2230,8 @@ class LlmRequest : public GenericLlmRequest std::optional guidedDecodingParams = std::nullopt, std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, - std::optional const& contextPhaseParams = std::nullopt) + std::optional const& contextPhaseParams = std::nullopt, + std::optional cacheSaltID = std::nullopt) : Base(requestId, maxNewTokens, std::move(inputTokens), samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds), std::move(promptEmbeddingTable), promptVocabSize, std::move(multimodalHashes), @@ -2231,7 +2243,8 @@ class LlmRequest : public GenericLlmRequest std::move(encoderInputTokens), returnEncoderOutput, clientId, priority, std::move(encoderInputFeatures), std::move(encoderOutputLength), std::move(crossAttentionMask), llmRequestType, std::move(inputTokenExtraIds), numReturnSequences, std::move(eagleConfig), std::move(skipCrossAttnBlocks), - returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams) + returnPerfMetrics, std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, + cacheSaltID) { } @@ -2269,7 +2282,8 @@ class LlmRequest : public GenericLlmRequest std::optional guidedDecodingParams = std::nullopt, std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, - std::optional const& contextPhaseParams = std::nullopt) + std::optional const& contextPhaseParams = std::nullopt, + std::optional cacheSaltID = std::nullopt) : Base(requestId, maxNewTokens, std::make_shared>(std::move(inputTokens)), samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), @@ -2299,7 +2313,7 @@ class LlmRequest : public GenericLlmRequest inputTokenExtraIds ? std::make_optional(std::make_shared(std::move(*inputTokenExtraIds))) : std::optional>(std::nullopt), numReturnSequences, std::move(eagleConfig), skipCrossAttnBlocks, returnPerfMetrics, - std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams) + std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, contextPhaseParams, cacheSaltID) { } @@ -2321,14 +2335,15 @@ class LlmRequest : public GenericLlmRequest bool returnEncoderOutput = false, std::optional clientId = std::nullopt, executor::PriorityType priority = executor::Request::kDefaultPriority, SizeType32 numReturnSequences = 1, std::optional languageAdapterUid = std::nullopt, - std::optional const& contextPhaseParams = std::nullopt) + std::optional const& contextPhaseParams = std::nullopt, + std::optional cacheSaltID = std::nullopt) : Base(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, endId, padId, std::move(embeddingBias), std::move(badWordsList), std::move(stopWordsList), std::move(positionIds), std::move(promptEmbeddingTable), promptVocabSize, loraTaskId, std::move(loraWeights), std::move(loraConfig), lookaheadConfig, returnLogProbs, returnContextLogits, returnGenerationLogits, std::move(draftTokens), std::move(draftLogits), excludeInputFromOutput, std::move(logitsPostProcessor), applyLogitsPostProcessorBatched, std::move(encoderInputTokens), returnEncoderOutput, clientId, priority, - numReturnSequences, languageAdapterUid, contextPhaseParams) + numReturnSequences, languageAdapterUid, contextPhaseParams, cacheSaltID) { } diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 28c69074a3c..d9b115bf8fe 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -670,7 +670,7 @@ class Request /// @param allottedTimeMs The allotted time in milliseconds after which the request is cancelled with a timedOut /// finish reason. The request may exceed this time slightly, but at most by 1 forward pass (in pipeline parallelism /// that may involve multiple micro-batches). A request can be timed-out before ever being scheduled. - // 34 parameters + /// @param cacheSaltID Salt ID for KV cache blocks to limit the kv cache reuse to the requests with the same string. Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming = false, SamplingConfig const& samplingConfig = SamplingConfig(), OutputConfig const& outputConfig = OutputConfig(), std::optional const& endId = std::nullopt, std::optional const& padId = std::nullopt, @@ -697,7 +697,8 @@ class Request std::optional eagleConfig = std::nullopt, std::optional skipCrossAttnBlocks = std::nullopt, std::optional guidedDecodingParams = std::nullopt, std::optional languageAdapterUid = std::nullopt, - std::optional allottedTimeMs = std::nullopt); + std::optional allottedTimeMs = std::nullopt, + std::optional cacheSaltID = std::nullopt); /// @brief This logits postprocessor name will dispatch to the batched logits postprocessor static auto constexpr kBatchedPostProcessorName = "batched"; @@ -745,6 +746,7 @@ class Request [[nodiscard]] std::optional getGuidedDecodingParams() const; [[nodiscard]] std::optional getLanguageAdapterUid() const; [[nodiscard]] std::optional getAllottedTimeMs() const; + [[nodiscard]] std::optional getCacheSaltID() const; [[nodiscard]] std::optional> getAdditionalOutputNames() const; void setStreaming(bool streaming); @@ -780,6 +782,7 @@ class Request void setGuidedDecodingParams(GuidedDecodingParams const& guidedDecodingParams); void setLanguageAdapterUid(SizeType32 languageAdapterUid); void setAllottedTimeMs(MillisecondsType allottedTimeMs); + void setCacheSaltID(CacheSaltIDType cacheSaltID); private: friend class Serialization; diff --git a/cpp/include/tensorrt_llm/executor/types.h b/cpp/include/tensorrt_llm/executor/types.h index 217b0260df2..41df1c9c7a4 100644 --- a/cpp/include/tensorrt_llm/executor/types.h +++ b/cpp/include/tensorrt_llm/executor/types.h @@ -58,6 +58,7 @@ using RandomSeedType = std::uint64_t; using VecLogProbs = std::vector; using StreamPtr = std::shared_ptr; using MillisecondsType = std::chrono::milliseconds; +using CacheSaltIDType = std::uint64_t; using LogitsPostProcessor = std::function)>; using LogitsPostProcessorMap = std::unordered_map; diff --git a/cpp/include/tensorrt_llm/runtime/common.h b/cpp/include/tensorrt_llm/runtime/common.h index 2cda8821c13..7a3079d0bd7 100644 --- a/cpp/include/tensorrt_llm/runtime/common.h +++ b/cpp/include/tensorrt_llm/runtime/common.h @@ -44,6 +44,7 @@ using TokenIdType = std::int32_t; using LoraTaskIdType = std::uint64_t; using TokenExtraIdType = std::uint64_t; using VecTokenExtraIds = std::vector; +using CacheSaltIDType = std::uint64_t; struct UniqueToken { diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 748fcbbe09d..da6315bde4b 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -131,7 +131,7 @@ std::vector generateBlockHashExtraKeys( // Check if this multimodal content overlaps with the current block if (endTokenIdx > startPos && startTokenIdx < startPos + length) { - SizeType32 mmStartInBlock = (startPos >= startTokenIdx) ? 0 : startTokenIdx - startPos; + uint64_t mmStartInBlock = (startPos >= startTokenIdx) ? 0 : static_cast(startTokenIdx - startPos); extraKeys.emplace_back(mmHashArray, mmStartInBlock); } } @@ -151,7 +151,7 @@ std::vector buildBlockKeys( currentTokenIdx += uniqueTokens.size(); blockKeys.emplace_back(llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(), - std::move(uniqueTokens), std::move(extraKeys)); + std::move(uniqueTokens), std::move(extraKeys), llmRequest.getCacheSaltID()); } return blockKeys; } @@ -167,6 +167,16 @@ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) no // Constants provide very good distribution - each input bit affects each output bit with ~50% probability. size_t seed = blockKey.uniqueTokens.size() ^ parentHash * UINT64_C(0xbf58476d1ce4e5b9); + if (parentHash == 0 && blockKey.cacheSaltID) + { + // Only hashing the cache salt ID for the first block in the sequence + uint64_t c = blockKey.cacheSaltID.value(); + c = (c ^ (c >> 30)) * UINT64_C(0xbf58476d1ce4e5b9); + c = (c ^ (c >> 27)) * UINT64_C(0x94d049bb133111eb); + c = c ^ (c >> 31); + seed ^= c + 0x9e3779b9 + (seed << 6) + (seed >> 2); + } + for (auto const& uniqueToken : blockKey.uniqueTokens) { uint32_t a = static_cast(uniqueToken.tokenId); diff --git a/cpp/tensorrt_llm/executor/request.cpp b/cpp/tensorrt_llm/executor/request.cpp index 2c2a9c26fc2..987eeef894e 100644 --- a/cpp/tensorrt_llm/executor/request.cpp +++ b/cpp/tensorrt_llm/executor/request.cpp @@ -25,7 +25,6 @@ namespace tensorrt_llm::executor { -// 36 parameters Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, SamplingConfig const& samplingConfig, OutputConfig const& outputConfig, std::optional const& endId, std::optional const& padId, std::optional> positionIds, std::optional> badWords, @@ -41,7 +40,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, std::optional encoderOutputLength, std::optional crossAttentionMask, SizeType32 numReturnSequences, std::optional eagleConfig, std::optional skipCrossAttnBlocks, std::optional guidedDecodingParams, std::optional languageAdapterUid, - std::optional allottedTimeMs) + std::optional allottedTimeMs, std::optional cacheSaltID) : mImpl(std::make_unique(std::move(inputTokenIds), maxTokens, streaming, samplingConfig, outputConfig, endId, padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias), std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput), @@ -50,7 +49,7 @@ Request::Request(VecTokens inputTokenIds, SizeType32 maxTokens, bool streaming, std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, type, std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, crossAttentionMask, numReturnSequences, eagleConfig, skipCrossAttnBlocks, std::move(guidedDecodingParams), languageAdapterUid, - allottedTimeMs)) + allottedTimeMs, cacheSaltID)) { } @@ -249,6 +248,11 @@ std::optional Request::getLanguageAdapterUid() const return mImpl->getLanguageAdapterUid(); } +std::optional Request::getCacheSaltID() const +{ + return mImpl->getCacheSaltID(); +} + void Request::setStreaming(bool streaming) { mImpl->setStreaming(streaming); @@ -413,4 +417,9 @@ void Request::setLanguageAdapterUid(SizeType32 languageAdapterUid) { return mImpl->setLanguageAdapterUid(languageAdapterUid); } + +void Request::setCacheSaltID(CacheSaltIDType cacheSaltID) +{ + return mImpl->setCacheSaltID(cacheSaltID); +} } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/executor/requestImpl.h b/cpp/tensorrt_llm/executor/requestImpl.h index 111927b07ca..94de53a7817 100644 --- a/cpp/tensorrt_llm/executor/requestImpl.h +++ b/cpp/tensorrt_llm/executor/requestImpl.h @@ -32,7 +32,6 @@ class Request::Impl { public: - // 36 parameters, 36 items in initialization list Impl(VecTokens inputTokenIds, SizeType32 maxNewTokens, bool streaming, SamplingConfig const& samplingConfig, OutputConfig outputConfig, std::optional const& endId, std::optional const& padId, std::optional> positionIds, std::optional> badWords, @@ -48,7 +47,8 @@ class Request::Impl std::optional encoderInputFeatures, std::optional encoderOutputLength, std::optional crossAttentionMask, SizeType32 numReturnSequences, std::optional eagleConfig, std::optional skipCrossAttnBlocks, std::optional guidedDecodingParams, - std::optional languageAdapterUid, std::optional allottedTimeMs) + std::optional languageAdapterUid, std::optional allottedTimeMs, + std::optional cacheSaltID) : mInputTokenIds(std::move(inputTokenIds)) , mMaxNewTokens(maxNewTokens) , mStreaming(streaming) @@ -85,6 +85,7 @@ class Request::Impl , mGuidedDecodingParams(std::move(guidedDecodingParams)) , mLanguageAdapterUid(languageAdapterUid) , mAllottedTimeMs(allottedTimeMs) + , mCacheSaltID(cacheSaltID) { validate(); } @@ -296,6 +297,11 @@ class Request::Impl return mLanguageAdapterUid; } + [[nodiscard]] std::optional getCacheSaltID() const + { + return mCacheSaltID; + } + void setStreaming(bool streaming) { mStreaming = streaming; @@ -470,6 +476,11 @@ class Request::Impl mLanguageAdapterUid = languageAdapterUid; } + void setCacheSaltID(CacheSaltIDType cacheSaltID) + { + mCacheSaltID = cacheSaltID; + } + private: void validate() { @@ -543,6 +554,7 @@ class Request::Impl lambda(mGuidedDecodingParams); lambda(mLanguageAdapterUid); lambda(mAllottedTimeMs ? std::make_optional(mAllottedTimeMs->count()) : std::nullopt); + lambda(mCacheSaltID); } VecTokens mInputTokenIds; @@ -581,6 +593,7 @@ class Request::Impl std::optional mGuidedDecodingParams; std::optional mLanguageAdapterUid; std::optional mAllottedTimeMs; + std::optional mCacheSaltID; }; } // namespace tensorrt_llm::executor diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index bba8d19e2f6..428c00bdcdb 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -711,8 +711,8 @@ Request Serialization::deserializeRequest(std::istream& is) auto allottedTimeMs = allottedTimeInt ? std::optional(std::chrono::milliseconds(*allottedTimeInt)) : std::nullopt; + auto cacheSaltID = su::deserialize>(is); - // 35 parameters return Request(std::move(inputTokenIds), maxNewTokens, streaming, samplingConfig, outputConfig, endId, padId, std::move(positionIds), std::move(badWords), std::move(stopWords), std::move(embeddingBias), std::move(externalDraftTokensConfig), std::move(pTuningConfig), std::move(multimodalInput), @@ -721,7 +721,7 @@ Request Serialization::deserializeRequest(std::istream& is) std::move(encoderInputTokenIds), clientId, returnAllGeneratedTokens, priority, requestType, std::move(contextPhaseParams), std::move(encoderInputFeatures), encoderOutputLength, std::move(crossAttentionMask), numReturnSequences, std::move(eagleConfig), std::move(skipCrossAttnBlocks), - std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs); + std::move(guidedDecodingParams), languageAdapterUid, allottedTimeMs, cacheSaltID); } void Serialization::serialize(Request const& request, std::ostream& os) diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp index c170ca81015..94f15939f02 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/bindings.cpp @@ -189,6 +189,7 @@ void initBindings(nb::module_& m) .def_prop_ro("llm_request_type", &GenLlmReq::getLlmRequestType) .def_prop_ro("parent_request_id", &GenLlmReq::getParentRequestId) .def_prop_ro("is_child", &GenLlmReq::isChild) + .def_prop_ro("cache_salt_id", &GenLlmReq::getCacheSaltID) .def_prop_ro("multimodal_hashes", [](GenLlmReq& self) { @@ -287,7 +288,8 @@ void initBindings(nb::module_& m) std::optional guided_decoding_params, std::optional language_adapter_uid, std::optional allotted_time_ms, - std::optional context_phase_params) + std::optional context_phase_params, + std::optional cache_salt_id) { auto makeOptionalTensor = [](std::optional const& atTensor, bool unsqueeze = false) { @@ -316,7 +318,6 @@ void initBindings(nb::module_& m) auto cross_attention_mask_tensor_ptr = makeOptionalTensor(cross_attention_mask); auto skip_cross_attn_blocks_tensor_ptr = makeOptionalTensor(skip_cross_attn_blocks); - // 49 parameters new (self) tb::LlmRequest{request_id, max_new_tokens, input_tokens, sampling_config, is_streaming, end_id, pad_id, embedding_bias_tensor_ptr, bad_words_list_tensor_ptr, stop_words_list_tensor_ptr, position_ids, prompt_embedding_table_tensor_ptr, prompt_vocab_size, multimodal_hashes, @@ -328,7 +329,8 @@ void initBindings(nb::module_& m) encoder_input_tokens, return_encoder_output, client_id, priority, encoder_input_features_tensor_ptr, encoder_output_length, cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids, num_return_sequences, eagle_config, skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, - guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params}; + guided_decoding_params, language_adapter_uid, allotted_time_ms, context_phase_params, + cache_salt_id}; }, nb::arg("request_id"), nb::arg("max_new_tokens"), nb::arg("input_tokens"), nb::arg("sampling_config"), nb::arg("is_streaming"), nb::arg("end_id") = std::nullopt, nb::arg("pad_id") = std::nullopt, @@ -353,7 +355,7 @@ void initBindings(nb::module_& m) nb::arg("eagle_config") = std::nullopt, nb::arg("skip_cross_attn_blocks") = std::nullopt, nb::arg("return_perf_metrics") = false, nb::arg("guided_decoding_params") = std::nullopt, nb::arg("language_adapter_uid") = std::nullopt, nb::arg("allotted_time_ms") = std::nullopt, - nb::arg("context_phase_params") = std::nullopt) + nb::arg("context_phase_params") = std::nullopt, nb::arg("cache_salt_id") = std::nullopt) .def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, nb::arg("vocab_size")) .def(nb::init()) .def("validate", &tb::LlmRequest::validate, nb::arg("max_input_len"), nb::arg("max_seq_len"), diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp index d8f45cb865f..46bfa0de64a 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.cpp @@ -76,7 +76,6 @@ std::shared_ptr LlmRequest::toTrtLlm() const ? std::make_shared>(*mEncoderTokens.value().get()) : nullptr; auto const optEncoderInputTokens = std::optional>>(encoderInputTokens); - // 49 parameters return std::make_shared( // mRequestId, // mMaxNewTokens, // @@ -126,6 +125,7 @@ std::shared_ptr LlmRequest::toTrtLlm() const mGuidedDecodingParams, // mLanguageAdapterUid, // mAllottedTimeMs, // - mContextPhaseParams // + mContextPhaseParams, // + mCacheSaltID // ); } diff --git a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h index 624dc55112d..b3d6f04aef8 100644 --- a/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h +++ b/cpp/tensorrt_llm/nanobind/batch_manager/llmRequest.h @@ -51,7 +51,6 @@ class LlmRequest : public tb::GenericLlmRequest using VecTokenExtraIds = Base::VecTokenExtraIds; using LogitsPostProcessor = Base::LogitsPostProcessor; - // 49 parameters LlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::vector inputTokens, runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional endId = std::nullopt, std::optional padId = std::nullopt, std::optional embeddingBias = std::nullopt, @@ -85,7 +84,8 @@ class LlmRequest : public tb::GenericLlmRequest std::optional guidedDecodingParams = std::nullopt, std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, - std::optional const& contextPhaseParams = std::nullopt) + std::optional const& contextPhaseParams = std::nullopt, + std::optional cacheSaltID = std::nullopt) : Base(requestId, // maxNewTokens, // std::make_shared>(std::move(inputTokens)), // @@ -146,7 +146,8 @@ class LlmRequest : public tb::GenericLlmRequest guidedDecodingParams, // languageAdapterUid, // allottedTimeMs, // - contextPhaseParams // + contextPhaseParams, // + cacheSaltID // ) { } diff --git a/cpp/tensorrt_llm/nanobind/executor/request.cpp b/cpp/tensorrt_llm/nanobind/executor/request.cpp index e56341b53e2..d26c8dd70e0 100644 --- a/cpp/tensorrt_llm/nanobind/executor/request.cpp +++ b/cpp/tensorrt_llm/nanobind/executor/request.cpp @@ -573,11 +573,11 @@ void initRequestBindings(nb::module_& m) self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(), self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(), self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(), - self.getGuidedDecodingParams()); + self.getGuidedDecodingParams(), self.getCacheSaltID()); }; auto requestSetstate = [](tle::Request& self, nb::tuple const& state) { - if (state.size() != 33) + if (state.size() != 34) { throw std::runtime_error("Invalid Request state!"); } @@ -601,7 +601,8 @@ void initRequestBindings(nb::module_& m) nb::cast>(state[27]), nb::cast>(state[28]), nb::cast>(state[29]), 1, nb::cast>(state[30]), nb::cast>(state[31]), - nb::cast>(state[32])); + nb::cast>(state[32]), + nb::cast>(state[33])); }; nb::class_ request(m, "Request", nb::dynamic_attr()); @@ -641,7 +642,8 @@ void initRequestBindings(nb::module_& m) std::optional, // skipCrossAttnBlocks std::optional, // guidedDecodingParams std::optional, // languageAdapterUid - std::optional // allottedTimeMs + std::optional, // allottedTimeMs + std::optional // cacheSaltID >(), // clang-format off nb::arg("input_token_ids"), @@ -680,8 +682,9 @@ void initRequestBindings(nb::module_& m) nb::arg("skip_cross_attn_blocks") = nb::none(), nb::arg("guided_decoding_params") = nb::none(), nb::arg("language_adapter_uid") = nb::none(), - nb::arg("allotted_time_ms") = nb::none() - ) // clang-format on + nb::arg("allotted_time_ms") = nb::none(), + nb::arg("cache_salt_id") = nb::none() + ) // clang-format on .def_prop_ro("input_token_ids", &tle::Request::getInputTokenIds) .def_prop_ro("max_tokens", &tle::Request::getMaxTokens) .def_prop_rw("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming) @@ -723,6 +726,7 @@ void initRequestBindings(nb::module_& m) .def_prop_rw( "guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams) .def_prop_rw("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs) + .def_prop_rw("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID) .def_prop_rw("context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams) .def("__getstate__", requestGetstate) .def("__setstate__", requestSetstate); diff --git a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp index 53c9ec7ef6d..dffe8ad1977 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/bindings.cpp @@ -196,6 +196,7 @@ void initBindings(pybind11::module_& m) .def_property_readonly("llm_request_type", &GenLlmReq::getLlmRequestType) .def_property_readonly("parent_request_id", &GenLlmReq::getParentRequestId) .def_property_readonly("is_child", &GenLlmReq::isChild) + .def_property_readonly("cache_salt_id", &GenLlmReq::getCacheSaltID) .def_property_readonly("multimodal_hashes", [](GenLlmReq& self) { @@ -293,7 +294,8 @@ void initBindings(pybind11::module_& m) std::optional guided_decoding_params, std::optional language_adapter_uid, std::optional allotted_time_ms, - std::optional context_phase_params) + std::optional context_phase_params, + std::optional cache_salt_id) { auto makeOptionalTensor = [](std::optional const& atTensor, bool unsqueeze = false) { @@ -322,7 +324,6 @@ void initBindings(pybind11::module_& m) auto cross_attention_mask_tensor_ptr = makeOptionalTensor(cross_attention_mask); auto skip_cross_attn_blocks_tensor_ptr = makeOptionalTensor(skip_cross_attn_blocks); - // 49 parameters return tb::LlmRequest{request_id, max_new_tokens, input_tokens, sampling_config, is_streaming, end_id, pad_id, embedding_bias_tensor_ptr, bad_words_list_tensor_ptr, stop_words_list_tensor_ptr, position_ids, prompt_embedding_table_tensor_ptr, prompt_vocab_size, @@ -335,7 +336,7 @@ void initBindings(pybind11::module_& m) encoder_input_features_tensor_ptr, encoder_output_length, cross_attention_mask_tensor_ptr, llm_request_type, input_token_extra_ids, num_return_sequences, eagle_config, skip_cross_attn_blocks_tensor_ptr, return_perf_metrics, guided_decoding_params, - language_adapter_uid, allotted_time_ms, context_phase_params}; + language_adapter_uid, allotted_time_ms, context_phase_params, cache_salt_id}; }), py::arg("request_id"), py::arg("max_new_tokens"), py::arg("input_tokens"), py::arg("sampling_config"), py::arg("is_streaming"), py::arg("end_id") = std::nullopt, py::arg("pad_id") = std::nullopt, @@ -361,7 +362,7 @@ void initBindings(pybind11::module_& m) py::arg("eagle_config") = std::nullopt, py::arg("skip_cross_attn_blocks") = std::nullopt, py::arg("return_perf_metrics") = false, py::arg("guided_decoding_params") = std::nullopt, py::arg("language_adapter_uid") = std::nullopt, py::arg("allotted_time_ms") = std::nullopt, - py::arg("context_phase_params") = std::nullopt) + py::arg("context_phase_params") = std::nullopt, py::arg("cache_salt_id") = std::nullopt) .def("check_token_id_range", &tb::LlmRequest::checkTokenIdRange, py::arg("vocab_size")) .def(py::init()) .def("validate", &tb::LlmRequest::validate, py::arg("max_input_len"), py::arg("max_seq_len"), diff --git a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp index bce35ed5ee0..9b5c4bc1298 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.cpp @@ -75,7 +75,6 @@ std::shared_ptr LlmRequest::toTrtLlm() const ? std::make_shared>(*mEncoderTokens.value().get()) : nullptr; auto const optEncoderInputTokens = std::optional>>(encoderInputTokens); - // 49 parameters return std::make_shared( // mRequestId, // mMaxNewTokens, // @@ -125,6 +124,7 @@ std::shared_ptr LlmRequest::toTrtLlm() const mGuidedDecodingParams, // mLanguageAdapterUid, // mAllottedTimeMs, // - mContextPhaseParams // + mContextPhaseParams, // + mCacheSaltID // ); } diff --git a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h index 3cc12f9e889..8d004cb304f 100644 --- a/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h +++ b/cpp/tensorrt_llm/pybind/batch_manager/llmRequest.h @@ -49,8 +49,8 @@ class LlmRequest : public tb::GenericLlmRequest using VecTokens = Base::VecTokens; using VecTokenExtraIds = Base::VecTokenExtraIds; using LogitsPostProcessor = Base::LogitsPostProcessor; + using CacheSaltIDType = Base::CacheSaltIDType; - // 49 parameters LlmRequest(RequestIdType requestId, SizeType32 maxNewTokens, std::vector inputTokens, runtime::SamplingConfig samplingConfig, bool isStreaming, std::optional endId = std::nullopt, std::optional padId = std::nullopt, std::optional embeddingBias = std::nullopt, @@ -84,7 +84,8 @@ class LlmRequest : public tb::GenericLlmRequest std::optional guidedDecodingParams = std::nullopt, std::optional languageAdapterUid = std::nullopt, std::optional allottedTimeMs = std::nullopt, - std::optional const& contextPhaseParams = std::nullopt) + std::optional const& contextPhaseParams = std::nullopt, + std::optional cacheSaltID = std::nullopt) : Base(requestId, // maxNewTokens, // std::make_shared>(std::move(inputTokens)), // @@ -145,7 +146,8 @@ class LlmRequest : public tb::GenericLlmRequest guidedDecodingParams, // languageAdapterUid, // allottedTimeMs, // - contextPhaseParams // + contextPhaseParams, // + cacheSaltID // ) { } diff --git a/cpp/tensorrt_llm/pybind/executor/request.cpp b/cpp/tensorrt_llm/pybind/executor/request.cpp index 904410c253b..4eb61ecde98 100644 --- a/cpp/tensorrt_llm/pybind/executor/request.cpp +++ b/cpp/tensorrt_llm/pybind/executor/request.cpp @@ -526,11 +526,11 @@ void initRequestBindings(pybind11::module_& m) self.getClientId(), self.getReturnAllGeneratedTokens(), self.getPriority(), self.getRequestType(), self.getContextPhaseParams(), self.getEncoderInputFeatures(), self.getEncoderOutputLength(), self.getCrossAttentionMask(), self.getEagleConfig(), self.getSkipCrossAttnBlocks(), - self.getGuidedDecodingParams()); + self.getGuidedDecodingParams(), self.getCacheSaltID()); }; auto requestSetstate = [](py::tuple const& state) { - if (state.size() != 33) + if (state.size() != 34) { throw std::runtime_error("Invalid Request state!"); } @@ -550,7 +550,8 @@ void initRequestBindings(pybind11::module_& m) state[25].cast(), state[26].cast>(), state[27].cast>(), state[28].cast>(), state[29].cast>(), 1, state[30].cast>(), - state[31].cast>(), state[32].cast>()); + state[31].cast>(), state[32].cast>(), + state[33].cast>()); }; py::class_ request(m, "Request", pybind11::dynamic_attr()); @@ -590,7 +591,8 @@ void initRequestBindings(pybind11::module_& m) std::optional, // skipCrossAttnBlocks std::optional, // guidedDecodingParams std::optional, // languageAdapterUid - std::optional // allottedTimeMs + std::optional, // allottedTimeMs + std::optional // cacheSaltID >(), // clang-format off py::arg("input_token_ids"), @@ -630,8 +632,9 @@ void initRequestBindings(pybind11::module_& m) py::arg("skip_cross_attn_blocks") = py::none(), py::arg("guided_decoding_params") = py::none(), py::arg("language_adapter_uid") = py::none(), - py::arg("allotted_time_ms") = py::none() - ) // clang-format on + py::arg("allotted_time_ms") = py::none(), + py::arg("cache_salt_id") = py::none() + ) // clang-format on .def_property_readonly("input_token_ids", &tle::Request::getInputTokenIds) .def_property_readonly("max_tokens", &tle::Request::getMaxTokens) .def_property("streaming", &tle::Request::getStreaming, &tle::Request::setStreaming) @@ -675,6 +678,7 @@ void initRequestBindings(pybind11::module_& m) .def_property( "guided_decoding_params", &tle::Request::getGuidedDecodingParams, &tle::Request::setGuidedDecodingParams) .def_property("allotted_time_ms", &tle::Request::getAllottedTimeMs, &tle::Request::setAllottedTimeMs) + .def_property("cache_salt_id", &tle::Request::getCacheSaltID, &tle::Request::setCacheSaltID) .def_property( "context_phase_params", &tle::Request::getContextPhaseParams, &tle::Request::setContextPhaseParams) .def(py::pickle(requestGetstate, requestSetstate)); diff --git a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp index 0a52ae84852..5dc36b703d5 100644 --- a/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp +++ b/cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp @@ -1686,6 +1686,207 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdAndLoraTaskIdTest) EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); } +TEST_F(KVCacheManagerTest, BlockManagerReuseWithCacheSaltIdTest) +{ + // Test that cache_salt_id prevents KV cache reuse between requests with same tokens + // but different cache_salt_id values. + using VecTokenExtraIds = LlmRequest::VecTokenExtraIds; + using CacheSaltIDType = LlmRequest::CacheSaltIDType; + + auto constexpr numLayers = 12; + auto constexpr numKvHeads = 6; + auto constexpr sizePerHead = 16; + auto constexpr tokensPerBlock = 4; + auto constexpr maxBlocksPerSeq = 4; + auto constexpr blocksInPrimaryPool = 16; + auto constexpr blocksInSecondaryPool = 0; + auto constexpr maxNumSequences = 8; + auto const stream = std::make_shared(); + auto constexpr onboardBlocks = true; + auto constexpr numReturnSequences = 1; + auto constexpr maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq; + auto constexpr beamWidth = 1; + + auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}}; + + BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, + maxNumSequences, stream, maxAttentionWindow, beamWidth, + std::vector{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0, + onboardBlocks); + blockManager.allocatePools(false); + + EXPECT_EQ(blockManager.getTokensPerBlock(), tokensPerBlock); + EXPECT_EQ(blockManager.getMaxNumBlocks(), blocksInPrimaryPool); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); + + SizeType32 constexpr maxNewTokens{0}; + tr::SamplingConfig const samplingConfig{beamWidth}; + bool constexpr isStreaming{false}; + + // Create shared input tokens + auto inputTokens = std::make_shared(VecTokens{100, 101, 102, 103, 104, 105, 106, 107, 108}); + auto const inputLength = static_cast(inputTokens->size()); + + /////////////////////////////////////////////////////////////////////////// + // Test Case 1: Request without cache_salt_id + LlmRequest::RequestIdType requestId{0}; + auto llmRequest0 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, + false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, + std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt); // No cache_salt_id + + GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; + + // Add first request and get blocks 0, 1, 2 + auto constexpr beamIdx = 0; + auto promptLen0 = llmRequest0->getNumTokens(beamIdx); + auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock()); + blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow); + EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0); + EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2})); + + // Add generated tokens + llmRequest0->addNewToken(3, beamIdx); + llmRequest0->addNewToken(4, beamIdx); + auto numTokens = llmRequest0->getNumTokens(beamIdx); + auto numBlocks = tc::ceilDiv(numTokens, tokensPerBlock); + EXPECT_EQ(numBlocks, 3); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); + + // Release blocks to make them available for reuse + blockManager.releaseBlocks(seq0, llmRequest0); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); + + /////////////////////////////////////////////////////////////////////////// + // Test Case 2: Request with same tokens but with cache_salt_id = 12345 + requestId = 1; + CacheSaltIDType cacheSaltId1{12345}; + auto llmRequest1 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, + false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, + std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + cacheSaltId1); // With cache_salt_id = 12345 + + GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; + + // Should NOT reuse blocks despite same tokens, because cache_salt_id is different + auto promptLen1 = llmRequest1->getNumTokens(beamIdx); + auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock()); + blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow); + EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 0); // No reuse, starts from scratch + EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 5})); + + llmRequest1->addNewToken(3, beamIdx); + llmRequest1->addNewToken(4, beamIdx); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); + + // Release blocks + blockManager.releaseBlocks(seq1, llmRequest1); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); + + /////////////////////////////////////////////////////////////////////////// + // Test Case 3: Request with same tokens and same cache_salt_id = 12345 + requestId = 2; + auto llmRequest2 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, + false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, + std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + cacheSaltId1); // Same cache_salt_id = 12345 + + GenerationRequest seq2{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; + + // SHOULD reuse blocks because both tokens and cache_salt_id match + auto promptLen2 = llmRequest2->getNumTokens(beamIdx); + auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock()); + blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow); + EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 3,4 + EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({3, 4, 6})); + + llmRequest2->addNewToken(3, beamIdx); + llmRequest2->addNewToken(4, beamIdx); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); + + // Release blocks + blockManager.releaseBlocks(seq2, llmRequest2); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); + + /////////////////////////////////////////////////////////////////////////// + // Test Case 4: Request with same tokens but different cache_salt_id = 67890 + requestId = 3; + CacheSaltIDType cacheSaltId2{67890}; + auto llmRequest3 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, + false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, + std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + cacheSaltId2); // Different cache_salt_id = 67890 + + GenerationRequest seq3{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; + + // Should NOT reuse blocks from any previous request because cache_salt_id is different + auto promptLen3 = llmRequest3->getNumTokens(beamIdx); + auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock()); + blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow); + EXPECT_EQ(llmRequest3->getContextCurrentPosition(), 0); // No reuse + EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({7, 8, 9})); + + llmRequest3->addNewToken(5, beamIdx); + llmRequest3->addNewToken(6, beamIdx); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks); + + /////////////////////////////////////////////////////////////////////////// + // Test Case 5: Request without cache_salt_id again + requestId = 4; + auto llmRequest4 = std::make_shared(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt, std::nullopt, false, std::nullopt, + false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt, std::nullopt, + LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences, std::nullopt, + std::nullopt, false, std::nullopt, std::nullopt, std::nullopt, std::nullopt, + std::nullopt); // No cache_salt_id + + GenerationRequest seq4{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()}; + + // Should reuse blocks from request0 (blocks 0,1) because both have no cache_salt_id + auto promptLen4 = llmRequest4->getNumTokens(beamIdx); + auto numContextBlocks4 = tc::ceilDiv(promptLen4, blockManager.getTokensPerBlock()); + blockManager.addSequence(seq4, promptLen4, numContextBlocks4, *llmRequest4, maxAttentionWindow); + EXPECT_EQ(llmRequest4->getContextCurrentPosition(), 2 * tokensPerBlock); // Reuse blocks 0,1 + EXPECT_THAT(seq4.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 10})); + + llmRequest4->addNewToken(7, beamIdx); + numTokens = llmRequest4->getNumTokens(beamIdx); + numBlocks = tc::ceilDiv(numTokens, tokensPerBlock); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks * 2); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks * 2); + + // Clean up + blockManager.releaseBlocks(seq3, llmRequest3); + blockManager.releaseBlocks(seq4, llmRequest4); + EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0); + EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool); +} + TEST_F(KVCacheManagerTest, KVCacheManagerPerRequestStatsTest) { auto constexpr numLayers = 12; diff --git a/tensorrt_llm/_torch/pyexecutor/llm_request.py b/tensorrt_llm/_torch/pyexecutor/llm_request.py index fb0670e3756..2bddb6fd585 100644 --- a/tensorrt_llm/_torch/pyexecutor/llm_request.py +++ b/tensorrt_llm/_torch/pyexecutor/llm_request.py @@ -546,6 +546,7 @@ def executor_request_to_llm_request( priority=0.5, llm_request_type=llm_request_type, context_phase_params=executor_request.context_phase_params, + cache_salt_id=executor_request.cache_salt_id, py_multimodal_data=getattr(executor_request, "py_multimodal_data", None)) if child_req_ids: diff --git a/tensorrt_llm/executor/executor.py b/tensorrt_llm/executor/executor.py index 6fd70d0a01b..c9d55a7cfc1 100644 --- a/tensorrt_llm/executor/executor.py +++ b/tensorrt_llm/executor/executor.py @@ -124,6 +124,7 @@ def generate_async( postproc_params: Optional[PostprocParams] = None, multimodal_params: Optional[MultimodalParams] = None, scheduling_params: Optional[SchedulingParams] = None, + cache_salt_id: Optional[int] = None, ) -> GenerationResult: """Generate output for the given prompt token ids in the asynchronous mode. Asynchronous generation accepts single prompt only. @@ -147,7 +148,8 @@ def generate_async( kv_cache_retention_config=kv_cache_retention_config, disaggregated_params=disaggregated_params, multimodal_params=multimodal_params, - scheduling_params=scheduling_params) + scheduling_params=scheduling_params, + cache_salt_id=cache_salt_id) result = self.submit(request) # release memory in time if hasattr(request, "multimodal_params"): diff --git a/tensorrt_llm/executor/request.py b/tensorrt_llm/executor/request.py index 00b5deb2eed..1030e57f091 100644 --- a/tensorrt_llm/executor/request.py +++ b/tensorrt_llm/executor/request.py @@ -97,6 +97,7 @@ def __init__( postproc_params: Optional[PostprocParams] = None, multimodal_params: Optional[MultimodalParams] = None, scheduling_params: Optional[SchedulingParams] = None, + cache_salt_id: Optional[int] = None, ): if isinstance(prompt_token_ids, list): self.prompt_token_ids = prompt_token_ids @@ -122,6 +123,7 @@ def __init__( self.id: Optional[int] = None self.disaggregated_params = disaggregated_params self.scheduling_params = scheduling_params + self.cache_salt_id = cache_salt_id def set_id(self, id): assert self.id is None, f"Request ID is already set: {self.id}" diff --git a/tensorrt_llm/executor/worker.py b/tensorrt_llm/executor/worker.py index f4cd66d6f7d..af8b7a8ac5d 100644 --- a/tensorrt_llm/executor/worker.py +++ b/tensorrt_llm/executor/worker.py @@ -572,7 +572,8 @@ def _deduce_max_tokens(request: GenerationRequest, request.sampling_params.logits_processor, kv_cache_retention_config=request.kv_cache_retention_config, context_phase_params=context_phase_params, - type=request_type) + type=request_type, + cache_salt_id=request.cache_salt_id) executor_request.py_lora_path = py_lora_path if self._is_pytorch_backend and request.multimodal_params is not None: diff --git a/tensorrt_llm/inputs/__init__.py b/tensorrt_llm/inputs/__init__.py index 070b8449cee..e7d47b98797 100644 --- a/tensorrt_llm/inputs/__init__.py +++ b/tensorrt_llm/inputs/__init__.py @@ -11,7 +11,8 @@ add_multimodal_placeholders, apply_chat_template, async_load_audio, async_load_image, async_load_video, convert_image_mode, default_multimodal_input_loader, - encode_base64_content_from_url, load_image, load_video) + encode_base64_content_from_url, get_cache_salt_id, + load_image, load_video) __all__ = [ "ALL_SUPPORTED_MULTIMODAL_MODELS", @@ -43,4 +44,5 @@ "encode_base64_content_from_url", "load_image", "load_video", + "get_cache_salt_id", ] diff --git a/tensorrt_llm/inputs/utils.py b/tensorrt_llm/inputs/utils.py index 458b0a11d88..f935d2ffe05 100644 --- a/tensorrt_llm/inputs/utils.py +++ b/tensorrt_llm/inputs/utils.py @@ -17,6 +17,7 @@ from transformers import AutoProcessor, ProcessorMixin from transformers.utils import logging +from tensorrt_llm.inputs.multimodal import default_hasher from tensorrt_llm.inputs.registry import (MULTIMODAL_PLACEHOLDER_REGISTRY, MultimodalPlaceholderPlacement) from tensorrt_llm.llmapi.llm_utils import ModelLoader @@ -610,3 +611,14 @@ def convert_to_conversation_message( inputs.append(input) return inputs + + +def get_cache_salt_id(cache_salt: str) -> int: + b = cache_salt.encode("utf-8") + h = default_hasher(b).digest(length=8) + cache_salt_id = int.from_bytes(h, "little", signed=False) + if cache_salt_id < 0 or cache_salt_id >= (1 << 64): + raise ValueError( + f"cache_salt_id must be in [0, 2**64 - 1], got {cache_salt_id}.") + + return cache_salt_id diff --git a/tensorrt_llm/llmapi/llm.py b/tensorrt_llm/llmapi/llm.py index edb16cac001..b4c7b4cfbe0 100644 --- a/tensorrt_llm/llmapi/llm.py +++ b/tensorrt_llm/llmapi/llm.py @@ -27,7 +27,8 @@ from ..executor.utils import (create_mpi_comm_session, get_spawn_proxy_process_env) from ..inputs import (PromptInputs, create_input_processor, - create_input_processor_with_hash, prompt_inputs) + create_input_processor_with_hash, get_cache_salt_id, + prompt_inputs) from ..logger import logger from ..sampling_params import SamplingParams from ..scheduling_params import SchedulingParams @@ -325,6 +326,7 @@ def generate_async( disaggregated_params: Optional[DisaggregatedParams] = None, _postproc_params: Optional[PostprocParams] = None, scheduling_params: Optional[SchedulingParams] = None, + cache_salt: Optional[str] = None, ) -> RequestOutput: """Generate output for the given prompt in the asynchronous mode. Asynchronous generation accepts single prompt only. @@ -339,7 +341,7 @@ def generate_async( kv_cache_retention_config (tensorrt_llm.bindings.executor.KvCacheRetentionConfig, optional): Configuration for the request's retention in the KV Cache. Defaults to None. disaggregated_params (tensorrt_llm.disaggregated_params.DisaggregatedParams, optional): Disaggregated parameters. Defaults to None. scheduling_params (tensorrt_llm.scheduling_params.SchedulingParams, optional): Scheduling parameters. Defaults to None. - + cache_salt (str, optional): If specified, KV cache will be salted with the provided string to limit the kv cache reuse to the requests with the same string. Defaults to None. Returns: tensorrt_llm.llmapi.RequestOutput: The output data of the completion request to the LLM. """ @@ -349,7 +351,8 @@ def generate_async( raise RuntimeError("LLM is shutting down") sampling_params = self._prepare_sampling_params(sampling_params) - + cache_salt_id = get_cache_salt_id( + cache_salt) if cache_salt is not None else None # With pytorch backend, py_executor has logic to handle max_tokens of 1, # so set to 1 to avoid allocating unnecessary KV cache blocks for single request # TODO: Also support for trt backend @@ -444,6 +447,7 @@ def generate_async( postproc_params=_postproc_params, multimodal_params=multimodal_params, scheduling_params=scheduling_params, + cache_salt_id=cache_salt_id, ) return RequestOutput._from_generation_result(result, prompt, diff --git a/tensorrt_llm/serve/openai_protocol.py b/tensorrt_llm/serve/openai_protocol.py index acfbff14d23..b482818475c 100644 --- a/tensorrt_llm/serve/openai_protocol.py +++ b/tensorrt_llm/serve/openai_protocol.py @@ -19,7 +19,8 @@ from openai.types.responses.tool import Tool from openai.types.shared import Metadata, Reasoning from openai_harmony import ReasoningEffort -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import (BaseModel, ConfigDict, Field, field_validator, + model_validator) from typing_extensions import Annotated, Required, TypeAlias, TypedDict from tensorrt_llm.executor.request import LoRARequest @@ -592,6 +593,13 @@ class ChatCompletionRequest(OpenAIBaseModel): description=("Parameters for disaggregated serving"), ) + cache_salt: Optional[str] = Field( + default=None, + description= + ("If specified, KV cache will be salted with the provided string " + "to limit the kv cache reuse on with the requests having the same string." + )) + # doc: end-chat-completion-extra-params def to_sampling_params(self, vocab_size: int = 32000) -> SamplingParams: @@ -671,6 +679,16 @@ def check_suffix(cls, data): raise ValueError("suffix is not supported") return data + @field_validator("cache_salt") + @classmethod + def check_cache_salt_support(cls, v): + if v is not None: + if not isinstance(v, str) or not v.strip(): + raise ValueError( + "Parameter 'cache_salt' must be a non-empty string if provided." + ) + return v + ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam, ResponseReasoningItem, diff --git a/tensorrt_llm/serve/openai_server.py b/tensorrt_llm/serve/openai_server.py index de245046359..c622aed63f3 100644 --- a/tensorrt_llm/serve/openai_server.py +++ b/tensorrt_llm/serve/openai_server.py @@ -462,7 +462,8 @@ async def create_chat_response( _postproc_params=postproc_params if self.postproc_worker_enabled else None, streaming=request.stream, lora_request=request.lora_request, - disaggregated_params=disaggregated_params + disaggregated_params=disaggregated_params, + cache_salt=request.cache_salt, ) asyncio.create_task(self.await_disconnected(raw_request, promise)) if not self.postproc_worker_enabled: diff --git a/tests/integration/defs/test_e2e.py b/tests/integration/defs/test_e2e.py index 152676bb9ab..0def4787de6 100644 --- a/tests/integration/defs/test_e2e.py +++ b/tests/integration/defs/test_e2e.py @@ -1557,6 +1557,18 @@ def test_openai_misc_example(llm_root, llm_venv, backend: str): ]) +def test_openai_cache_salt(llm_root, llm_venv): + example_root = Path(os.path.join(llm_root, "examples", "serve")) + test_root = unittest_path() / "llmapi" / "apps" + llm_venv.run_cmd([ + "-m", "pip", "install", "-r", + os.path.join(example_root, "requirements.txt") + ]) + llm_venv.run_cmd( + ["-m", "pytest", + str(test_root / "_test_openai_cache_salt.py")]) + + @pytest.mark.parametrize("backend", ["pytorch", "trt"]) def test_openai_completions_example(llm_root, llm_venv, backend: str): test_root = unittest_path() / "llmapi" / "apps" diff --git a/tests/unittest/api_stability/references/llm.yaml b/tests/unittest/api_stability/references/llm.yaml index 7154238f6a8..56573d1fcec 100644 --- a/tests/unittest/api_stability/references/llm.yaml +++ b/tests/unittest/api_stability/references/llm.yaml @@ -196,6 +196,9 @@ methods: annotation: Optional[tensorrt_llm.scheduling_params.SchedulingParams] default: null status: prototype + cache_salt: + annotation: Optional[str] + default: null return_annotation: tensorrt_llm.llmapi.llm.RequestOutput get_kv_cache_events: parameters: diff --git a/tests/unittest/llmapi/apps/_test_openai_cache_salt.py b/tests/unittest/llmapi/apps/_test_openai_cache_salt.py new file mode 100644 index 00000000000..0799b6c2831 --- /dev/null +++ b/tests/unittest/llmapi/apps/_test_openai_cache_salt.py @@ -0,0 +1,231 @@ +"""Test cache_salt functionality in OpenAI API to ensure it prevents cache reuse""" + +import os +import tempfile + +import openai +import pytest +import yaml + +from ..test_llm import get_model_path +from .openai_server import RemoteOpenAIServer + +pytestmark = pytest.mark.threadleak(enabled=False) + + +@pytest.fixture(scope="module", ids=["TinyLlama-1.1B-Chat"]) +def model_name() -> str: + return "llama-models-v2/TinyLlama-1.1B-Chat-v1.0" + + +@pytest.fixture(scope="module") +def temp_extra_llm_api_options_file(): + """Create temporary config file with KV cache enabled for testing""" + temp_dir = tempfile.gettempdir() + temp_file_path = os.path.join(temp_dir, "cache_salt_test_options.yaml") + try: + extra_llm_api_options_dict = { + # Enable KV cache reuse + "kv_cache_config": { + "enable_block_reuse": True, + }, + # Enable performance metrics to get cache hit rate + "return_perf_metrics": True, + "enable_iter_perf_stats": True, + "enable_iter_req_stats": True, + # Disable CUDA graph for compatibility + "cuda_graph_config": None, + } + + with open(temp_file_path, 'w') as f: + yaml.dump(extra_llm_api_options_dict, f) + + yield temp_file_path + finally: + if os.path.exists(temp_file_path): + os.remove(temp_file_path) + + +@pytest.fixture(scope="module") +def server(model_name: str, + temp_extra_llm_api_options_file: str) -> RemoteOpenAIServer: + model_path = get_model_path(model_name) + args = [] + args.extend(["--backend", "pytorch"]) + args.extend(["--extra_llm_api_options", temp_extra_llm_api_options_file]) + with RemoteOpenAIServer(model_path, args) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server: RemoteOpenAIServer) -> openai.OpenAI: + return server.get_client() + + +def get_cache_hit_rate(client: openai.OpenAI) -> float: + """Get cache hit rate from the metrics endpoint""" + import httpx + + # Get the base URL from the OpenAI client (it includes /v1) + # We need to go up one level to access /metrics + base_url = str(client.base_url).rstrip('/') + if base_url.endswith('/v1'): + base_url = base_url[:-3] # Remove /v1 + + # Make a direct HTTP request to the metrics endpoint + with httpx.Client() as http_client: + response = http_client.get(f"{base_url}/metrics", timeout=5.0) + + # Check if metrics endpoint is available + if response.status_code != 200: + raise RuntimeError( + f"Metrics endpoint returned status {response.status_code}") + + metrics = response.json() + + # Validate that we have metrics data + if not isinstance(metrics, list) or len(metrics) == 0: + raise ValueError("No metrics data available") + + # Get the most recent stats + latest_stats = metrics[-1] + + # Extract KV cache statistics + kv_cache_stats = latest_stats.get("kvCacheStats", {}) + if not kv_cache_stats: + raise ValueError("No KV cache statistics available in metrics") + + try: + print(f"kv_cache_stats reused: {kv_cache_stats['reusedBlocks']}") + print(f"kv_cache_stats missed: {kv_cache_stats['missedBlocks']}") + print(f"kv_cache_stats hit rate: {kv_cache_stats['cacheHitRate']}") + return kv_cache_stats["cacheHitRate"] + except Exception as e: + print(f"Warning: Could not get cache metrics: {e}") + return 0.0 + + +def test_cache_salt_prevents_reuse_chat(client: openai.OpenAI, model_name: str): + """Test that different cache_salt values prevent KV cache reuse in chat completions""" + + # Common messages that will be used across all requests + messages = [{ + "role": "system", + "content": "You are a helpful assistant. Keep responses brief." + }, { + "role": + "user", + "content": + "What is the capital of France? Answer in one sentence." + }] + + # Test configuration + max_tokens = 30 + temperature = 0.0 # Deterministic for testing + + # Track responses for comparison + responses = [] + + # Test Case 1: First request without cache_salt (baseline) + print("\n=== Test Case 1: First request without cache_salt ===") + response1 = client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + ) + responses.append(response1.choices[0].message.content) + print(f"Response 1: {response1.choices[0].message.content[:100]}...") + + # Display initial cache metrics + initial_hit_rate = get_cache_hit_rate(client) + print(f"Initial cache hit rate: {initial_hit_rate:.2%}") + + # Test Case 2: Same messages without cache_salt (should reuse cache) + print( + "\n=== Test Case 2: Same messages without cache_salt (should reuse) ===" + ) + response2 = client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + ) + responses.append(response2.choices[0].message.content) + print(f"Response 2: {response2.choices[0].message.content[:100]}...") + + # Check if metrics are available + hit_rate_after_reuse = get_cache_hit_rate(client) + print(f"Cache hit rate after reuse: {hit_rate_after_reuse:.2%}") + assert hit_rate_after_reuse >= initial_hit_rate, \ + "Cache hit rate should increase when reusing cache without salt" + + # Test Case 3: Same messages with cache_salt="user_123" (should NOT reuse) + print( + "\n=== Test Case 3: Same messages with cache_salt='user_123' (no reuse) ===" + ) + response3 = client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + extra_body={"cache_salt": "user_123"}) + responses.append(response3.choices[0].message.content) + print(f"Response 3: {response3.choices[0].message.content[:100]}...") + + # Record metrics after request with different salt + hit_rate_after_salt1 = get_cache_hit_rate(client) + print(f"Cache hit rate after salt 'user_123': {hit_rate_after_salt1:.2%}") + assert hit_rate_after_salt1 < hit_rate_after_reuse, \ + "Cache hit rate should decrease when using a different salt" + + # Test Case 4: Same messages with same cache_salt="user_123" (should reuse) + print( + "\n=== Test Case 4: Same messages with same cache_salt='user_123' (should reuse) ===" + ) + response4 = client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + extra_body={"cache_salt": "user_123"} # Same salt should enable reuse + ) + responses.append(response4.choices[0].message.content) + print(f"Response 4: {response4.choices[0].message.content[:100]}...") + + # Cache hit rate should increase again when using same salt + hit_rate_after_salt1_reuse = get_cache_hit_rate(client) + print( + f"Cache hit rate after reusing salt 'user_123': {hit_rate_after_salt1_reuse:.2%}" + ) + assert hit_rate_after_salt1_reuse >= hit_rate_after_salt1, \ + "Cache hit rate should increase when reusing same salt" + + # Test Case 5: Same messages with different cache_salt="user_456" (should NOT reuse) + print( + "\n=== Test Case 5: Same messages with cache_salt='user_456' (no reuse) ===" + ) + response5 = client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=max_tokens, + temperature=temperature, + extra_body={"cache_salt": "user_456"}) + responses.append(response5.choices[0].message.content) + print(f"Response 5: {response5.choices[0].message.content[:100]}...") + + # Cache hit rate should decrease when using a different salt + hit_rate_after_salt2 = get_cache_hit_rate(client) + print(f"Cache hit rate after salt 'user_456': {hit_rate_after_salt2:.2%}") + assert hit_rate_after_salt2 < hit_rate_after_salt1_reuse, \ + "Cache hit rate should decrease when using a different salt" + + # Test empty string (should be rejected) + with pytest.raises(Exception) as exc_info: + client.chat.completions.create( + model=model_name, + messages=messages, + max_tokens=max_tokens, + extra_body={"cache_salt": ""} # Empty string should be rejected + ) + print(f"Empty string rejected as expected: {exc_info.value}")