Skip to content

Conversation

qixiang-99
Copy link
Collaborator

@qixiang-99 qixiang-99 commented Jul 10, 2025

fix/improve kvcache allocation in PyTorch runtime

Background

We found significant memory under-allocation when using VSWA KvCacheManager. One reason is that PyTorch runtime uses estimated 'max_tokens' to constrain KV cache allocation, while 'max_tokens' is a strange(and potentially ambiguous) concept in Variable Sliding Window Attention KV Cache.

TL;DR

Previously, we used max_tokens to limit the KV cache memory usage, but that method didn't work for VSWA's KV cache. This PR refines the KV cache estimation logic (mainly in KVCacheCreator) by directly enforcing a memory size limit.
Besides, this PR also contains these minor fixes:

  1. Improve estimation phase to avoid potential OOM issues for long sequence with CUDA Graph.
  2. Add optional torch.cuda.empty_cache pytest fixture to avoid potential OOM in CI
    Big shout out to @jaedeok-nvidia, @ixlmar and @Funatiq for their suggestions and reviews!

Description

Currently, the PyTorch path includes a step that estimates available free memory for the KV cache, calculates the corresponding 'max_tokens', and updates the kv_cache_config accordingly.

However, the existing logic significantly limits memory allocation, as it assumes a homogeneous attention model (nvbug).

Additionally, the 'max_tokens' concept doesn't align well with VSWA scenarios, as its definition is "The maximum number of tokens that should be stored in the KV cache". In VSWA case, "tokens stored in KV cache" is a bit "coarse-grained", as the in-window tokens are "fully in KV cache", while the out-of-window tokens are "partially in KV cache" because of non-homogeneous attention.

Proposed Solution

An alternative approach could be to directly compute the allocated memory size during the estimation step.

Advantages:

  1. Aligns naturally with the block distribution function KVCacheManagerCpp.calculate_max_num_blocks (link), which already expects a memory size parameter.
  2. Provides greater flexibility, particularly beneficial for future scenarios such as dynamic memory allocation in VSWA.

Implementation Details

  1. This PR introduces a new field, mMaxGpuTotalBytes (exposed as max_gpu_total_bytes in Python), in KvCacheConfig.
  2. Currently, KvCacheCreator is responsible for building and estimating KV caches. With this PR, it also calculates the estimated in-memory size of the KV cache and uses the max_gpu_total_bytes setting from KvCacheConfig to limit memory allocation.
  3. Within KVCacheManager, the VSWA path uses a C++ binding function (see calculate_max_num_blocks_from_cpp) that takes a memory size as input to determine KV blocks. For non-VSWA cases, it still uses its own Python function, but future efforts aim to unify both approaches, as discussed here.

"Hack" to achieve this

One hack is to interpret freeGpuMemoryFraction as the fraction of total GPU memory reserved for the KV cache. This allows us to avoid adding an extra field, but we should clearly document this usage to prevent confusion. (Perhaps add a clarifying comment in the code?)

Test Coverage

tests/unittest/_torch/test_resource_manager.py

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--disable-fail-fast --skip-test --stage-list "A10-1, xxx" --gpu-type "A30, H100_PCIe" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-[Post-Merge]-1, xxx"]

Launch build/test pipelines. All previously running jobs will be killed.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests. Will also run L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-[Post-Merge]-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-[Post-Merge]-1, xxx".

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Summary by CodeRabbit

  • New Features

    • Add option to cap KV-cache GPU memory via max_gpu_total_bytes (used with existing fraction limit).
    • KV-cache stats now report allocated GPU bytes; exposed in Python.
    • Automatic KV-cache sizing moved to a memory-based approach, accounting for main/draft models and VSWA.
  • Refactor

    • Public API renamed: estimate_max_tokens → configure_kv_cache_capacity; executor creation uses the new flow.
  • Tests

    • New fixture clears CUDA cache before tests; applied to an existing disaggregated single-GPU test.

@qixiang-99
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11576 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #11576 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #8571 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@qixiang-99 qixiang-99 changed the title Feat/improve vswa kvcache fix/improve kvcache allocation in PyTorch runtime Jul 11, 2025
@qixiang-99 qixiang-99 force-pushed the feat/improve-vswa-kvcache branch from 8628139 to e658165 Compare July 11, 2025 23:29
@qixiang-99 qixiang-99 requested a review from symphonylyh July 14, 2025 21:01
@qixiang-99 qixiang-99 force-pushed the feat/improve-vswa-kvcache branch from e658165 to 5036ed3 Compare July 20, 2025 20:56
Copy link
Contributor

coderabbitai bot commented Jul 20, 2025

📝 Walkthrough

Walkthrough

Tracks and exposes GPU KV-cache allocated bytes; adds max_gpu_total_bytes configuration (C++ and Python bindings); switches Python KV-capacity estimation from token-based to memory-based (VSWA-aware); updates resource manager to honor explicit byte limits; updates tests and fixtures for GPU memory handling.

Changes

Cohort / File(s) Summary
KV cache stats & manager (C++)
cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h, cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Add KvCacheStats::allocatedBytes and KVCacheManager::mAllocatedBytes; compute and store total allocated KV-cache bytes during pool allocation; populate via getKvCacheStats() and log usage.
Executor KV cache config (C++ core)
cpp/include/tensorrt_llm/executor/executor.h, cpp/tensorrt_llm/executor/kvCacheConfig.cpp
Add maxGpuTotalBytes ctor arg + getter/setter; change setMaxTokens to accept std::optional<SizeType32>; apply provided max GPU bytes during construction.
Pybind bindings (KV cache & config)
cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp, cpp/tensorrt_llm/pybind/executor/executorConfig.cpp
Expose KvCacheStats.allocated_bytes (read-only); extend KvCacheConfig binding with max_gpu_total_bytes arg/property; update constructor binding and pickle (__getstate__/__setstate__).
Nanobind bindings (KV cache & config)
cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp, cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp
Mirror pybind updates: add allocated_bytes property; extend KvCacheConfig ctor and add max_gpu_total_bytes property; update __getstate__/__setstate__.
Python KV capacity logic
tensorrt_llm/_torch/pyexecutor/_util.py
Replace token-based estimation with memory-based flow: add _get_kv_size_per_token, _cal_max_memory; rename estimate_max_tokensconfigure_kv_cache_capacity; compute/clamp KV cache capacity in bytes using allocated bytes, fraction, and optional max_gpu_total_bytes; VSWA-aware behavior.
Executor creation hook (Python)
tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
Call configure_kv_cache_capacity(py_executor) during MODEL_EXTRA stage instead of estimate_max_tokens.
Resource manager (Python)
tensorrt_llm/_torch/pyexecutor/resource_manager.py
Refine VSWA detection; adjust_window_sizes_for_vswa now returns (dict, max_attention_window_vec); propagate updated max_attention_window_vec; introduce _primary/_secondary_pool_memory_bytes honoring explicit max_gpu_total_bytes vs fraction; pass pool bytes into C++ allocation calls.
Public API args & validation (Python)
tensorrt_llm/llmapi/llm_args.py
Add max_gpu_total_bytes field to KvCacheConfig with validation; add max_attention_window validator; include new arg when converting to pybind config.
Tests & fixtures
tests/integration/defs/conftest.py, tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py, tests/unittest/_torch/executor/test_resource_manager.py
Add torch_empty_cache fixture (clears CUDA cache) and apply to an integration test; expand unit tests for VSWA, pool sizing, and _primary/_secondary_pool_memory_bytes behavior; add new test helpers and a test_calculate_max_num_blocks_from_cpp.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant App
  participant Creator as KvCacheCreator (Py)
  participant Exec as PyExecutor
  participant CppMgr as KVCacheManager (C++)
  participant Cfg as KvCacheConfig

  App->>Creator: configure_kv_cache_capacity(py_executor)
  Creator->>Exec: get kv managers (main, draft)
  loop per KV cache
    Creator->>CppMgr: getKvCacheStats()
    CppMgr-->>Creator: {allocatedBytes, ...}
  end
  Creator->>Creator: per_token_size = _get_kv_size_per_token()
  Creator->>Creator: max_mem = _cal_max_memory(free, total, fraction, allocated_bytes)
  alt max_gpu_total_bytes provided
    Creator->>Creator: max_mem = min(max_mem, max_gpu_total_bytes)
  end
  alt VSWA detected
    Creator->>Cfg: setMaxTokens(std::nullopt)
  else
    Creator->>Cfg: setMaxTokens(floor(max_mem / per_token_size))
  end
  Creator->>Cfg: setMaxGpuTotalBytes(if provided)
  Creator-->>App: configuration applied
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • qiaoxj07
  • pcastonguay
  • ixlmar
  • jaedeok-nvidia
  • Superjomn

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@qixiang-99
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12391 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12391 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9210 completed with status: 'FAILURE'

@qixiang-99 qixiang-99 force-pushed the feat/improve-vswa-kvcache branch from 5036ed3 to 7432901 Compare July 21, 2025 02:01
@qixiang-99
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12399 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12399 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9218 completed with status: 'FAILURE'

Copy link
Collaborator

@jaedeok-nvidia jaedeok-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank @qixiang-99 for your contribution to VSWA fixes. Hope this PR fix our underutilization issue in VSWA models. And, let's consider how to improve the efficient further.

@qixiang-99 qixiang-99 force-pushed the feat/improve-vswa-kvcache branch from 7432901 to 5a0d5de Compare July 21, 2025 07:32
@qixiang-99
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12427 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12427 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9240 completed with status: 'FAILURE'

@qixiang-99
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12473 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #12473 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #9278 completed with status: 'FAILURE'

@qixiang-99
Copy link
Collaborator Author

/bot run --disable-fail-fast

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
cpp/tensorrt_llm/executor/kvCacheConfig.cpp (1)

43-46: Avoid implicit construction of std::optional; pass the optional directly to setMaxTokens

setMaxTokens now accepts std::optional<SizeType32>. Passing maxTokens.value() relies on implicit construction back to optional. Pass the optional directly to avoid ambiguity and future refactor hazards.

-        setMaxTokens(maxTokens.value());
+        setMaxTokens(maxTokens);
🧹 Nitpick comments (3)
cpp/tensorrt_llm/executor/kvCacheConfig.cpp (3)

31-31: Prefer pass-by-value for uint64_t; consider using std::optional<uint64_t> for consistency with other config fields

Passing a trivially copyable 64-bit value by const& is unnecessary. Also, other optional fields use std::optional; consider aligning maxGpuTotalBytes with that pattern to avoid sentinel semantics (0 meaning “unlimited”) leaking into the API surface.

Apply this minimal change to pass by value (keeps current semantics):

-    std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults, uint64_t const& maxGpuTotalBytes)
+    std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults, uint64_t maxGpuTotalBytes)

67-70: Redundant setter call for mMaxGpuTotalBytes in constructor

mMaxGpuTotalBytes is already initialized in the initializer list. The extra call to setMaxGpuTotalBytes(maxGpuTotalBytes) adds no validation and is redundant. Remove for clarity.

-    if (maxGpuTotalBytes)
-    {
-        setMaxGpuTotalBytes(maxGpuTotalBytes);
-    }

235-238: Clarify semantics for max GPU bytes; optionally validate

If 0 means “no limit” and non-zero applies a hard cap, please document it near the setter (and in the header). Optionally, add validation if there are known lower bounds.

 void KvCacheConfig::setMaxGpuTotalBytes(uint64_t maxGpuTotalBytes)
 {
+    // 0 disables the limit; non-zero applies a hard cap in bytes.
     mMaxGpuTotalBytes = maxGpuTotalBytes;
 }
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 69922ef and d1c4516.

📒 Files selected for processing (4)
  • cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h (3 hunks)
  • cpp/include/tensorrt_llm/executor/executor.h (4 hunks)
  • cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp (1 hunks)
  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
  • cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
  • cpp/include/tensorrt_llm/executor/executor.h
🧰 Additional context used
📓 Path-based instructions (4)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...

Files:

  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp
**/*.{cpp,cxx,cc,cu}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu}: Avoid literal values except for 0, nullptr, true, false; use named constexpr for other literals
Place semicolon of empty for/while loop on a new line
Always use brace-delimited bodies for switch/while/do-for/if/else
Use inline C comments in argument lists when parameter meaning is unclear (e.g., /* checkForErrors = */ false)
Do not use assignment in subexpressions (e.g., if (x = y) ... is forbidden)
Switch on enums should enumerate all values and omit default to catch new values at compile time
Structure switch statements; prohibit fallthrough except between empty cases; each case ends with break or throw; return at end of case not allowed; put break inside braces for compound case
Prefer anonymous namespaces over static for internal linkage of functions
Every defined function must be called at least once (no unused methods)

Files:

  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Parameter names must be consistent between declarations and definitions

Files:

  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp
🧠 Learnings (1)
📚 Learning: 2025-08-14T21:04:50.208Z
Learnt from: thorjohnsen
PR: NVIDIA/TensorRT-LLM#6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.208Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.

Applied to files:

  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
cpp/tensorrt_llm/executor/kvCacheConfig.cpp (1)

145-148: Getter LGTM

The getter is straightforward and consistent with the new member.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15464 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #15464 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #11649 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

Copy link
Collaborator

@eopXD eopXD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the test coverages!

@qixiang-99 qixiang-99 force-pushed the feat/improve-vswa-kvcache branch from d1c4516 to 6e60db3 Compare August 21, 2025 20:27
@qixiang-99
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16082 [ run ] triggered by Bot

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
cpp/include/tensorrt_llm/executor/executor.h (1)

1031-1031: API change ripple effects: ensure call sites now pass optional

setMaxTokens(std::optional<SizeType32>) changes the call contract. Please ensure all call sites are updated (explicit std::nullopt or std::optional<SizeType32>{value}) to avoid surprising implicit conversions.

Run:

#!/bin/bash
# Find all setMaxTokens call sites and show context to confirm optional is passed intentionally
rg -n -C2 --type=cpp '\bsetMaxTokens\s*\('

echo
echo "Heuristic: calls that might pass raw integers/constants (review manually):"
rg -n --type=cpp -P '\bsetMaxTokens\s*\(\s*(?!std::optional|std::nullopt|\{)' | sed -n '1,200p'
🧹 Nitpick comments (9)
cpp/tensorrt_llm/executor/kvCacheConfig.cpp (4)

31-31: Pass small scalar by value, not const&

uint64_t is a small trivially copyable scalar; prefer pass-by-value for consistency and to avoid binding to a temporary. Align with the header once updated.

Apply in this file:

-    std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults, uint64_t const& maxGpuTotalBytes)
+    std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults, uint64_t maxGpuTotalBytes)

41-42: Avoid double write of mMaxGpuTotalBytes

mMaxGpuTotalBytes is initialized in the ctor initializer list and immediately set again via the setter. Since the setter has no extra validation, the second write is redundant.

Remove the redundant setter call:

-    if (maxGpuTotalBytes)
-    {
-        setMaxGpuTotalBytes(maxGpuTotalBytes);
-    }

Also applies to: 67-70


43-46: Pass optional directly to setter; drop implicit conversion and redundant guard

setMaxTokens now accepts std::optional<SizeType32>. Passing maxTokens.value() relies on an implicit conversion and the surrounding if is no longer needed.

Simplify:

-    if (maxTokens)
-    {
-        setMaxTokens(maxTokens.value());
-    }
+    setMaxTokens(maxTokens);

Also applies to: 165-172


235-238: Clarify semantics; consider lightweight validation

setMaxGpuTotalBytes silently accepts any value. At minimum, document that 0 means “no cap/unset” (consistent with the ctor default) and, if feasible, guard against pathological inputs via a check (e.g., extremely small non-zero values that cannot fit a single block will be handled elsewhere, but a comment here helps).

If you prefer not to add checks here, please add a one-line comment in the header on mMaxGpuTotalBytes documenting that 0 disables the cap.

cpp/include/tensorrt_llm/executor/executor.h (5)

1-15: Update copyright year to 2025 (guideline compliance)

Header still shows 2022–2024, while new code landed in 2025 and guidelines require current year.

- * Copyright (c) 2022-2024, NVIDIA CORPORATION.  All rights reserved.
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION.  All rights reserved.

1009-1011: Prefer passing uint64_t by value; also document 0 as “no cap”

Passing scalars by const& is inconsistent with nearby parameters and unnecessary. Also, the public API should explicitly document the sentinel semantics for 0.

-        std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt,
-        uint64_t const& maxGpuTotalBytes = 0);
+        std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt,
+        uint64_t maxGpuTotalBytes = 0);

And in the Doxygen of the new field (see lines 1101–1105), add “0 disables the cap” (see separate diff below).


1042-1042: Setter name and type are fine; add brief docs to mirror getter

Consider adding a short Doxygen note to setMaxGpuTotalBytes that 0 disables the cap and that the effective budget is clamped by other limits (see field docs below).


1101-1105: Document precedence and sentinel semantics; keep behavior unambiguous

The new field docs mention the min with mFreeGpuMemoryFraction, but do not state:

  • what 0 means, and
  • how this interacts with mMaxTokens (converted to bytes).

Clarify that the effective KV-cache budget is the minimum of all configured limits, and that 0 means “unset/no cap” for mMaxGpuTotalBytes.

Suggested doc tweak:

-    /// @brief The maximum size in bytes of GPU memory that can be allocated for the KV cache.
-    /// If both mMaxGpuTotalBytes and mFreeGpuMemoryFraction are specified, memory corresponding to the minimum will
-    /// be allocated.
+    /// @brief The maximum size in bytes of GPU memory that can be allocated for the KV cache.
+    /// Semantics:
+    ///  - A value of 0 disables this cap (unset).
+    ///  - The effective KV cache budget is the minimum of applicable limits:
+    ///      min( cap_from_mMaxTokens_bytes, cap_from_mFreeGpuMemoryFraction, mMaxGpuTotalBytes (if non-zero) ).

17-17: Guideline note: prefer include guards over #pragma once

Project guidelines state header files must use include guards named TRTLLM_<FILENAME>_H. Transitioning this header is optional here but recommended for consistency.

If you choose to migrate later, ensure guards are unique and added across headers in a single sweep to avoid churn.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between d1c4516 and 6e60db3.

📒 Files selected for processing (4)
  • cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h (3 hunks)
  • cpp/include/tensorrt_llm/executor/executor.h (4 hunks)
  • cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp (1 hunks)
  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp (6 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
  • cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
🧰 Additional context used
📓 Path-based instructions (5)
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...

Files:

  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp
  • cpp/include/tensorrt_llm/executor/executor.h
**/*.{cpp,cxx,cc,cu}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu}: Avoid literal values except for 0, nullptr, true, false; use named constexpr for other literals
Place semicolon of empty for/while loop on a new line
Always use brace-delimited bodies for switch/while/do-for/if/else
Use inline C comments in argument lists when parameter meaning is unclear (e.g., /* checkForErrors = */ false)
Do not use assignment in subexpressions (e.g., if (x = y) ... is forbidden)
Switch on enums should enumerate all values and omit default to catch new values at compile time
Structure switch statements; prohibit fallthrough except between empty cases; each case ends with break or throw; return at end of case not allowed; put break inside braces for compound case
Prefer anonymous namespaces over static for internal linkage of functions
Every defined function must be called at least once (no unused methods)

Files:

  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Parameter names must be consistent between declarations and definitions

Files:

  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp
  • cpp/include/tensorrt_llm/executor/executor.h
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp
  • cpp/include/tensorrt_llm/executor/executor.h
**/*.{h,hpp,hxx,hh,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Header files must use include guards named TRTLLM__H without underscores prefix/suffix (e.g., TRTLLM_FOO_BAR_HELLO_H)

Files:

  • cpp/include/tensorrt_llm/executor/executor.h
🧠 Learnings (2)
📚 Learning: 2025-08-20T06:56:02.889Z
Learnt from: eopXD
PR: NVIDIA/TensorRT-LLM#6768
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:577-579
Timestamp: 2025-08-20T06:56:02.889Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, maxSequenceLength is now enforced as a non-optional argument in the BlockManager constructor, so concerns about std::nullopt defaulting to 0 are not applicable. When windowSize > maxSequenceLength, a warning should be added instead of handling optional parameter cases.

Applied to files:

  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp
  • cpp/include/tensorrt_llm/executor/executor.h
📚 Learning: 2025-08-14T21:04:50.248Z
Learnt from: thorjohnsen
PR: NVIDIA/TensorRT-LLM#6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.

Applied to files:

  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp
  • cpp/include/tensorrt_llm/executor/executor.h
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
cpp/tensorrt_llm/executor/kvCacheConfig.cpp (1)

145-148: Getter looks correct

getMaxGpuTotalBytes cleanly returns the configured cap. No issues.

cpp/include/tensorrt_llm/executor/executor.h (1)

1026-1026: LGTM: [[nodiscard]] getter for max GPU bytes cap

The accessor is consistent with other getters and marked [[nodiscard]]. No further action.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16082 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12093 completed with status: 'FAILURE'

Signed-off-by: qixiang-99 <[email protected]>

feat: Add max_free_gpu_memory_size support for KV cache configuration

- Introduced max_free_gpu_memory_size to manage GPU memory allocation for KV cache.
- Updated KvCacheConfig and related methods to handle the new parameter.
- Modified estimation logic in KvCacheCreator to utilize max_free_gpu_memory_size for VSWA cases.
- Adjusted resource management to ensure compatibility with the new memory allocation strategy.

Signed-off-by: qixiang-99 <[email protected]>

address comments:
- vswa path should work now
- modify `KvCacheCreator` so it will always provide KvCacheManager memory size instead of max_tokens
- Handle user provided `max_tokens` inside `KvCacheCreator` -- it will raise warning for VSWA case; it will translate `maxt_tokens` to memory size for non-VSWA case

Signed-off-by: qixiang-99 <[email protected]>

Respect max_gpu_total_bytes in KVCacheManager for non-VSWA case

Signed-off-by: qixiang-99 <[email protected]>

add memory estimation for attention metadata to solve OOM issue when cuda graph is enabled and max window size is large

Signed-off-by: qixiang-99 <[email protected]>

Enhance KVCacheManager to maintain adjusted max attention window sizes. Introduced an adjusted dictionary to track window size mappings and updated the logic to reflect these changes in the max attention window vector. Updated unit tests to validate the new behavior and ensure expected outputs for various memory configurations.

Signed-off-by: qixiang-99 <[email protected]>

fix: Draft model calculation in KVCacheManager and KvCacheCreator
- Introduced a method to calculate KV size per token and updated memory estimation to ensure proper handling of max_gpu_total_bytes.
- Changed KVCacheManager to still use `max_tokens` in non-VSWA case as its`calculate_max_num_blocks` doesn't consider the draft model tokens calculation.

Signed-off-by: qixiang-99 <[email protected]>

minor fix for description

Signed-off-by: qixiang-99 <[email protected]>

Enhance KVCacheManager and related components to track allocated GPU memory for KV-cache. Added `allocatedBytes` to `KvCacheStats` and updated memory allocation logic in `KVCacheManager` and `KvCacheCreator` to ensure accurate memory usage reporting. Adjusted Python bindings to expose the new `allocated_bytes` attribute.

Signed-off-by: qixiang-99 <[email protected]>

update logging

Signed-off-by: qixiang-99 <[email protected]>

update Nanobind accordingly

Signed-off-by: qixiang-99 <[email protected]>

Update secondary pool memory allocation in KVCacheManager to use host_cache_size from configuration if provided.

Signed-off-by: qixiang-99 <[email protected]>

Refactor the fix for CUDA Graph OOM issue.

Signed-off-by: qixiang-99 <[email protected]>

Add torch_empty_cache fixture to clear CUDA cache before tests

Signed-off-by: qixiang-99 <[email protected]>

minor fix for rebase

Signed-off-by: qixiang-99 <[email protected]>

Add validation for max_gpu_total_bytes in KvCacheConfig

- Introduced a field validator to ensure that max_gpu_total_bytes is non-negative, enhancing input validation for the configuration.

Signed-off-by: qixiang-99 <[email protected]>
- Added a field validator for max_attention_window to ensure it is a non-empty list of positive integers.
- Implemented checks to prevent redundant patterns in the list, ensuring only minimal repeating patterns are accepted.
- Updated KvCacheCreator to use a set for determining VSWA status based on max_attention_window.

Signed-off-by: qixiang-99 <[email protected]>
- Changed primary and secondary pool memory allocation to use instance variables instead of local variables for better clarity and maintainability.
- Updated logging to reflect the new instance variable usage.
- Added unit tests to validate memory allocation behavior in KVCacheManager.

Signed-off-by: qixiang-99 <[email protected]>
- we shouldn't use max_seq_len as the kv config max_tokens as it doesn't need that much
- and it makes preparation to have OOM more easily especially long seq.

Signed-off-by: qixiang-99 <[email protected]>
@qixiang-99 qixiang-99 force-pushed the feat/improve-vswa-kvcache branch from 6e60db3 to 6e93c1b Compare August 22, 2025 05:26
@qixiang-99
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16130 [ run ] triggered by Bot

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (3)
tensorrt_llm/llmapi/llm_args.py (1)

994-999: Docs/runtime mismatch: resource_manager currently prefers max_gpu_total_bytes over fraction

The docstring states the minimum of max_gpu_total_bytes and free_gpu_memory_fraction should be used when both are set. In resource_manager.calculate_max_num_blocks_from_cpp, the code currently picks max_gpu_total_bytes if set, ignoring the fraction even when the fraction implies a smaller budget. This can over-allocate relative to the documented behavior.

Recommend aligning the runtime to use min(max_gpu_total_bytes, free_mem * fraction).

Also applies to: 1012-1030

tensorrt_llm/_torch/pyexecutor/_util.py (2)

295-311: Do not apply max_tokens in VSWA; and fix pre-clamp max_tokens inconsistency.

  • For VSWA, token-based limits are ambiguous; current code warns but still clamps by max_tokens. Ignore tokens entirely in VSWA to avoid under-allocation or misleading configs.
  • Also, assigning max_tokens here is premature; it can become inconsistent after the later max_gpu_total_bytes clamp.
         # ---------------------------handle max_tokens---------------------------------
         # if user provided max_tokens, calculate max memory from max_tokens
-        if self._max_kv_tokens_in is not None:
-            # raise error if it is VSWA case
-            if is_vswa:
-                logger.warning(
-                    "max_tokens should not be set for VSWA case as it is ambiguous concept for VSWA."
-                )
-            # calculate max memory from max_tokens
-            kv_cache_max_memory_from_max_tokens = self._max_kv_tokens_in * self._get_kv_size_per_token(
-            )
-            kv_cache_max_memory = min(kv_cache_max_memory,
-                                      kv_cache_max_memory_from_max_tokens)
-            logger.info(
-                f"max_tokens={self._max_kv_tokens_in} is provided, max_memory is set to {kv_cache_max_memory / (GB):.2f} GiB"
-            )
-        if is_vswa:
-            # For VSWA KvCacheManager now it can only use max_gpu_total_bytes
-            executor_config.kv_cache_config.max_tokens = None
-        else:
-            # For non-VSWA KvCacheManager, its logic still relies on max_tokens, need to improve in the future.
-            executor_config.kv_cache_config.max_tokens = int(
-                kv_cache_max_memory // self._get_kv_size_per_token())
+        if self._max_kv_tokens_in is not None:
+            if is_vswa:
+                logger.warning(
+                    "max_tokens is ignored for VSWA because token-based limits "
+                    "are ambiguous under variable sliding windows."
+                )
+            else:
+                # calculate max memory from max_tokens and clamp
+                kv_cache_max_memory_from_max_tokens = (
+                    self._max_kv_tokens_in * self._get_kv_size_per_token()
+                )
+                kv_cache_max_memory = min(
+                    kv_cache_max_memory, kv_cache_max_memory_from_max_tokens
+                )
+                logger.info(
+                    f"max_tokens={self._max_kv_tokens_in} is provided, "
+                    f"max_memory is set to {kv_cache_max_memory / (GB):.2f} GiB"
+                )
         # ---------------------------handle max_tokens---------------------------------

320-335: Finalize max_tokens after the final memory clamp; fix long log line (E501).

  • Compute/assign max_tokens only after applying the user-provided max_gpu_total_bytes clamp to keep them consistent.
  • Break the long log line on 327 and keep the “Estimated max memory” message concise.
         # ---------------------------handle max_gpu_total_bytes---------------------------------
         # if user provided max_gpu_total_bytes, set max memory from max_gpu_total_bytes
         if executor_config.kv_cache_config.max_gpu_total_bytes > 0:
             kv_cache_max_memory = min(
                 kv_cache_max_memory,
                 executor_config.kv_cache_config.max_gpu_total_bytes)
-            logger.info(
-                f"max_gpu_total_bytes={executor_config.kv_cache_config.max_gpu_total_bytes / (GB):.2f} GiB is provided, max_memory is set to {kv_cache_max_memory / (GB):.2f} GiB"
-            )
+            logger.info(
+                f"max_gpu_total_bytes="
+                f"{executor_config.kv_cache_config.max_gpu_total_bytes / (GB):.2f} GiB "
+                f"is provided, max_memory is set to {kv_cache_max_memory / (GB):.2f} GiB"
+            )
 
-        logger.info(
-            f"Estimated max memory in KV cache : {kv_cache_max_memory / (GB):.2f} GiB"
-        )
-        # set max_gpu_total_bytes
-        executor_config.kv_cache_config.max_gpu_total_bytes = kv_cache_max_memory
+        # Finalize max_tokens after all memory clamps
+        if is_vswa:
+            # VSWA path uses byte-based capacity only
+            executor_config.kv_cache_config.max_tokens = None
+        else:
+            # Keep Python path using max_tokens consistent with the final memory budget
+            executor_config.kv_cache_config.max_tokens = int(
+                kv_cache_max_memory // self._get_kv_size_per_token()
+            )
+
+        logger.info(
+            f"Estimated max memory in KV cache: {kv_cache_max_memory / (GB):.2f} GiB"
+        )
+        # set max_gpu_total_bytes
+        executor_config.kv_cache_config.max_gpu_total_bytes = kv_cache_max_memory
         # ---------------------------handle max_gpu_total_bytes---------------------------------
🧹 Nitpick comments (12)
cpp/include/tensorrt_llm/executor/executor.h (3)

1009-1011: Prefer pass-by-value for trivial types; avoid const-ref on uint64_t

maxGpuTotalBytes should be passed by value, not uint64_t const&. Passing a fundamental type by const-ref adds indirection without benefit and is inconsistent with surrounding params.

Apply this diff to the constructor signature:

-        std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt,
-        uint64_t const& maxGpuTotalBytes = 0);
+        std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt,
+        uint64_t maxGpuTotalBytes = 0);

1031-1031: Unify setter parameter style for optionals

setMaxTokens(std::optional<SizeType32> maxTokens) is passed by value, while other setters for optionals (e.g., setFreeGpuMemoryFraction, setCrossKvCacheFraction) take const&. Choose one convention for consistency; here, taking const& matches the rest.

If you keep by-value for move semantics internally, please mirror that across other setters or document the deviation.

-    void setMaxTokens(std::optional<SizeType32> maxTokens);
+    void setMaxTokens(std::optional<SizeType32> const& maxTokens);

1052-1055: Docs: clarify min-of-three semantics including maxGpuTotalBytes

Several comments still mention only mMaxTokens and mFreeGpuMemoryFraction. With mMaxGpuTotalBytes added, the effective allocated memory should be the minimum across all specified limits. Please update the comments to prevent confusion.

@@
-    /// If both mMaxTokens and mFreeGpuMemoryFraction are specified, memory corresponding to the minimum will be
-    /// allocated.
+    /// If mMaxTokens, mFreeGpuMemoryFraction, and/or mMaxGpuTotalBytes are specified, the memory corresponding to the
+    /// minimum of the applicable limits will be allocated.
@@
-    /// If both mMaxTokens and mFreeGpuMemoryFraction are specified, memory corresponding to the minimum will be
-    /// allocated.
+    /// If mMaxTokens, mFreeGpuMemoryFraction, and/or mMaxGpuTotalBytes are specified, the memory corresponding to the
+    /// minimum of the applicable limits will be allocated.
@@
-    /// If both mMaxGpuTotalBytes and mFreeGpuMemoryFraction are specified, memory corresponding to the minimum will
-    /// be allocated.
+    /// If mMaxTokens, mFreeGpuMemoryFraction, and/or mMaxGpuTotalBytes are specified, the memory corresponding to the
+    /// minimum of the applicable limits will be allocated.

Also applies to: 1066-1069, 1101-1104

tensorrt_llm/llmapi/llm_args.py (2)

994-999: New field is well-integrated; minor doc reflow to meet 120-col limit

The addition of max_gpu_total_bytes and its docs look good and mirror the C++ binding. Please wrap the description to satisfy the 120-character guideline.

-    max_gpu_total_bytes: int = Field(
-        default=0,
-        description=
-        "The maximum size in bytes of GPU memory that can be allocated for the KV cache. If both `max_gpu_total_bytes` and `free_gpu_memory_fraction` are specified, memory corresponding to the minimum will be allocated."
-    )
+    max_gpu_total_bytes: int = Field(
+        default=0,
+        description=(
+            "The maximum size in bytes of GPU memory that can be allocated for the KV cache. "
+            "If both `max_gpu_total_bytes` and `free_gpu_memory_fraction` are specified, the "
+            "minimum of the applicable limits will be used."
+        ),
+    )

1031-1038: Validation is correct; consider also rejecting booleans (bool is a subclass of int)

True/False can slip through as integers in Python. If you want to be strict, guard against isinstance(v, bool).

     def validate_max_gpu_total_bytes(cls, v: int):
-        if v < 0:
+        if isinstance(v, bool) or v < 0:
             raise ValueError(
                 "kv_cache_config.max_gpu_total_bytes must be non-negative")
         return v
tensorrt_llm/_torch/pyexecutor/resource_manager.py (3)

834-836: Wrap debug log to 120 cols

Long f-string with newline exceeds the style limit. Split into two calls or format the values beforehand.

-        logger.debug(
-            f"primary_pool_memory_bytes is set to {self._primary_pool_memory_bytes/1024**3}GB, \n"
-            f"secondary_pool_memory_bytes is set to {self._secondary_pool_memory_bytes/1024**3}GB"
-        )
+        gb = 1024 ** 3
+        logger.debug(
+            f"primary_pool_memory_bytes: {self._primary_pool_memory_bytes / gb:.3f} GB"
+        )
+        logger.debug(
+            f"secondary_pool_memory_bytes: {self._secondary_pool_memory_bytes / gb:.3f} GB"
+        )

686-696: Build adjusted vector once after computing the full mapping

adjust_window_sizes_for_vswa rebuilds adjusted_max_attention_window_vec on each loop iteration, which is O(L*N). Construct it once after adjusted_dict is complete to reduce overhead and make intent clearer.

@@
-        adjusted_dict = {}
-        adjusted_max_attention_window_vec = max_attention_window_vec.copy()
+        adjusted_dict = {}
@@
-            adjusted_dict[window_size] = accum_max_tokens
-            # also update adjusted_max_attention_window_vec
-            adjusted_max_attention_window_vec = [
-                adjusted_dict.get(v, v)
-                for v in adjusted_max_attention_window_vec
-            ]
+            adjusted_dict[window_size] = accum_max_tokens
@@
-        return (adjusted_window_size_to_layers,
-                adjusted_max_attention_window_vec)
+        adjusted_max_attention_window_vec = [
+            adjusted_dict.get(v, v) for v in max_attention_window_vec
+        ]
+        return adjusted_window_size_to_layers, adjusted_max_attention_window_vec

Also applies to: 780-786, 790-791


253-257: VSWA flow looks correct; consider parity for non-VSWA path in future

  • The VSWA path now uses memory budgets from bytes; good.
  • Non-VSWA path still uses the legacy token-based estimator, which ignores max_gpu_total_bytes. This divergence can surprise users.

If you want, I can propose a small patch to have calculate_max_num_blocks also take max_gpu_total_bytes into account (min of bytes- and token-derived budgets) to keep behavior consistent across both paths.

Also applies to: 274-286, 840-851

tests/unittest/_torch/executor/test_resource_manager.py (2)

523-536: Update test to new return signature and rewrap long lines

Adapting to the tuple return is good. Please also wrap the assertion messages to respect the 120-col limit.

-        for memory_bytes, expected_window_sizes, expected_max_attention_window_vec, max_tokens, description in test_cases:
+        for (memory_bytes,
+             expected_window_sizes,
+             expected_max_attention_window_vec,
+             max_tokens,
+             description) in test_cases:
@@
-                self.assertEqual(
-                    adjusted, expected_window_sizes,
-                    f"Test case '{description}' failed.\n"
-                    f"Memory bytes: {memory_bytes}\n"
-                    f"Actual: {adjusted}\n"
-                    f"Expected: {expected_window_sizes}")
+                self.assertEqual(
+                    adjusted,
+                    expected_window_sizes,
+                    "\n".join([
+                        f"Test case '{description}' failed.",
+                        f"Memory bytes: {memory_bytes}",
+                        f"Actual: {adjusted}",
+                        f"Expected: {expected_window_sizes}",
+                    ]),
+                )
@@
-                self.assertEqual(
-                    adjusted_max_attention_window_vec,
-                    expected_max_attention_window_vec,
-                    f"Test case '{description}' failed.\n"
-                    f"Memory bytes: {memory_bytes}\n"
-                    f"Actual: {adjusted_max_attention_window_vec}\n"
-                    f"Expected: {expected_max_attention_window_vec}")
+                self.assertEqual(
+                    adjusted_max_attention_window_vec,
+                    expected_max_attention_window_vec,
+                    "\n".join([
+                        f"Test case '{description}' failed.",
+                        f"Memory bytes: {memory_bytes}",
+                        f"Actual: {adjusted_max_attention_window_vec}",
+                        f"Expected: {expected_max_attention_window_vec}",
+                    ]),
+                )

648-687: Assertions on private attributes work but are brittle; prefer a public hook if available

The test introspects _primary_pool_memory_bytes and _secondary_pool_memory_bytes. If a public stats accessor becomes available (e.g., via bindings), prefer that to avoid tight coupling to implementation details.

If there's an accessor in KVCacheManagerCpp that exposes pool budgets, consider switching to that in a follow-up. I can help patch the test once available.

tensorrt_llm/_torch/pyexecutor/_util.py (2)

101-110: Add type hints and a short docstring for _get_kv_size_per_token (clarify units are bytes/token).

This helper is newly introduced and used in capacity math; make the return units explicit and self-document to avoid misinterpretation across call sites.

-    def _get_kv_size_per_token(self):
+    def _get_kv_size_per_token(self) -> int:
+        """Return total KV-cache bytes per token for main model (+ draft model if present)."""
         model_config = self._model_engine.model.model_config
         mapping = self._mapping
         kv_size_per_token = self._get_cache_size_per_token(
             model_config, mapping)
         if self._draft_model_engine is not None:
             draft_model_config = self._draft_model_engine.model.model_config
             kv_size_per_token += self._get_cache_size_per_token(
                 draft_model_config, mapping)
         return kv_size_per_token

112-129: Avoid discounting preallocated KV bytes by the free-memory fraction; clamp, fix logging labels, and resolve E501.

  • The docstring says “add this amount back in” for allocated_bytes; current formula multiplies it by fraction, which paradoxically reduces already-required KV bytes.
  • Clamp negative intermediate results.
  • Rename “tmp kv_mem” to “allocated KV cache memory” and break long lines (Ruff E501: lines 117, 126).
-    def _cal_max_memory(self, peak_memory, total_gpu_memory, fraction,
-                        allocated_bytes: int) -> int:
+    def _cal_max_memory(self,
+                        peak_memory: int,
+                        total_gpu_memory: int,
+                        fraction: float,
+                        allocated_bytes: int) -> int:
         """
         Calculate the max KV cache capacity.
 
-        NOTE: `allocated_bytes` is the total KV-cache memory that must be pre-allocated during the estimation phase (for both the main and draft models) so the estimation run can complete successfully. When computing `available_kv_mem`, add this amount back in.
+        NOTE: `allocated_bytes` is the total KV-cache memory that must be
+        pre-allocated during the estimation phase (for both the main and draft
+        models) so the estimation run can complete successfully. When computing
+        `available_kv_mem`, add this amount back in.
         """
         kv_size_per_token = self._get_kv_size_per_token()
-
-        available_kv_mem = (total_gpu_memory - peak_memory +
-                            allocated_bytes) * fraction
-        logger.info(
-            f"Peak memory during memory usage profiling (torch + non-torch): {peak_memory / (GB):.2f} GiB, "
-            f"available KV cache memory when calculating max tokens: {available_kv_mem / (GB):.2f} GiB, "
-            f"fraction is set {fraction}, kv size is {kv_size_per_token}. device total memory {total_gpu_memory / (GB):.2f} GiB, "
-            f", tmp kv_mem { (allocated_bytes) / (GB):.2f} GiB")
-        return int(available_kv_mem)
+        free_after_peak = max(total_gpu_memory - peak_memory, 0)
+        # Add back pre-allocated KV bytes (not subject to the safety fraction),
+        # then apply the safety fraction to the remaining free memory.
+        available_kv_mem = allocated_bytes + free_after_peak * fraction
+        logger.info(
+            "Peak memory during memory usage profiling (torch + non-torch): "
+            f"{peak_memory / (GB):.2f} GiB, "
+            "available KV cache memory for capacity planning: "
+            f"{available_kv_mem / (GB):.2f} GiB, "
+            f"fraction={fraction}, kv bytes/token={kv_size_per_token}, "
+            f"device total memory={total_gpu_memory / (GB):.2f} GiB, "
+            f"allocated KV cache memory={allocated_bytes / (GB):.2f} GiB"
+        )
+        return int(available_kv_mem)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 6e60db3 and 6e93c1b.

📒 Files selected for processing (15)
  • cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h (3 hunks)
  • cpp/include/tensorrt_llm/executor/executor.h (4 hunks)
  • cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp (1 hunks)
  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp (6 hunks)
  • cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp (1 hunks)
  • cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp (4 hunks)
  • cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp (1 hunks)
  • cpp/tensorrt_llm/pybind/executor/executorConfig.cpp (3 hunks)
  • tensorrt_llm/_torch/pyexecutor/_util.py (4 hunks)
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py (1 hunks)
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py (7 hunks)
  • tensorrt_llm/llmapi/llm_args.py (2 hunks)
  • tests/integration/defs/conftest.py (1 hunks)
  • tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py (1 hunks)
  • tests/unittest/_torch/executor/test_resource_manager.py (9 hunks)
🚧 Files skipped from review as they are similar to previous changes (10)
  • cpp/tensorrt_llm/pybind/batch_manager/kvCacheManager.cpp
  • cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
  • tests/integration/defs/disaggregated/test_disaggregated_single_gpu.py
  • cpp/tensorrt_llm/nanobind/batch_manager/kvCacheManager.cpp
  • tensorrt_llm/_torch/pyexecutor/py_executor_creator.py
  • cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
  • tests/integration/defs/conftest.py
  • cpp/tensorrt_llm/executor/kvCacheConfig.cpp
  • cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp
  • cpp/tensorrt_llm/pybind/executor/executorConfig.cpp
🧰 Additional context used
📓 Path-based instructions (5)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tests/unittest/_torch/executor/test_resource_manager.py
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
  • tensorrt_llm/llmapi/llm_args.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tests/unittest/_torch/executor/test_resource_manager.py
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
  • cpp/include/tensorrt_llm/executor/executor.h
  • tensorrt_llm/llmapi/llm_args.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh}: In C++, close namespaces with a comment naming the namespace (e.g., } // namespace foo)
Prefer const/constexpr variables over #define for constants
Declare variables const if not modified after initialization
Use Allman brace style in C++
C++ filenames use lowerCamelCase and must be case-insensitively unique within a build target
C++ type names use UpperCamelCase
Local variables, methods, and namespaces use lowerCamelCase
Global non-static variables not in anonymous namespace use gPrefix lowerCamelCase (e.g., gExample)
Static globals or globals in anonymous namespaces use sPrefix lowerCamelCase
Locally visible static variables start with 's' (e.g., static std::once_flag sFlag;)
Member variables use mPrefix lowerCamelCase; public members may omit but are encouraged to use 'm'
Constants (enums, global/static/function-scope magic numbers) use kPREFIXED_UPPER_SNAKE (e.g., kDIGIT_NUM)
If macros are unavoidable, use UPPER_SNAKE_CASE (prefer constants over #define)
Constructor parameter that conflicts with a public member name gets trailing underscore (foo_)
Literal suffixes should be uppercase (e.g., 1234L not 1234l)
C++: use spaces only; indent 4 spaces
Run clang-format (LLVM style) before submitting; wrap lines at 120 characters
If formatting must be bypassed, use // clang-format off/on around the section
Prefer smart pointers; use unique_ptr for sole ownership, shared_ptr for shared; weak_ptr only in exceptional cases
Do not use deprecated pre-C++11 smart pointers
Use C++ style comments; avoid C comments except special inline cases; prefer // single-line
Capitalize and punctuate full-sentence comments
Follow Doxygen rules: use //! for comments and //!< for members in C++
Disable code with #if/#endif and mnemonic conditions; avoid commented-out code; avoid dead code
Do not throw exceptions across library boundaries
Use least-forceful casts; avoid removing const/volatile; avoid C-style and functional casts (except constructors); p...

Files:

  • cpp/include/tensorrt_llm/executor/executor.h
**/*.{h,hpp,hxx,hh,cuh,cpp,cxx,cc,cu}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Parameter names must be consistent between declarations and definitions

Files:

  • cpp/include/tensorrt_llm/executor/executor.h
**/*.{h,hpp,hxx,hh,cuh}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Header files must use include guards named TRTLLM__H without underscores prefix/suffix (e.g., TRTLLM_FOO_BAR_HELLO_H)

Files:

  • cpp/include/tensorrt_llm/executor/executor.h
🧠 Learnings (5)
📚 Learning: 2025-08-14T21:04:50.248Z
Learnt from: thorjohnsen
PR: NVIDIA/TensorRT-LLM#6910
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-14T21:04:50.248Z
Learning: In KV cache onboarding logic during prefill in cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, when calculating which blocks fall within the attention window, use getTokensPerBlock() to advance token indices rather than block->getUniqueTokens().size(), because the calculation needs to consider the post-prefill state where blocks will be filled to capacity, not their current token count.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
  • cpp/include/tensorrt_llm/executor/executor.h
📚 Learning: 2025-08-15T06:46:54.897Z
Learnt from: eopXD
PR: NVIDIA/TensorRT-LLM#6767
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-15T06:46:54.897Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp addToken function, newly allocated blocks are unshared by design. The beam search path in addToken (when sequence.getNumTokens() > windowSize) is currently broken/non-functional with SWA, so the block allocation doesn't follow a shared-then-unshared pattern.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
📚 Learning: 2025-08-20T06:56:02.889Z
Learnt from: eopXD
PR: NVIDIA/TensorRT-LLM#6768
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:577-579
Timestamp: 2025-08-20T06:56:02.889Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, maxSequenceLength is now enforced as a non-optional argument in the BlockManager constructor, so concerns about std::nullopt defaulting to 0 are not applicable. When windowSize > maxSequenceLength, a warning should be added instead of handling optional parameter cases.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
  • cpp/include/tensorrt_llm/executor/executor.h
📚 Learning: 2025-08-21T09:41:49.327Z
Learnt from: eopXD
PR: NVIDIA/TensorRT-LLM#6768
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:2010-2045
Timestamp: 2025-08-21T09:41:49.327Z
Learning: In cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp, updateSequenceCacheBlockOffsets is specifically for updating bookkeeping when blocks are added during the context phase, not for refreshing offsets after detach operations. During detach operations, GenerationRequest::removeFrontBlock handles the necessary cache block bookkeeping internally.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/_util.py
📚 Learning: 2025-08-15T06:46:53.813Z
Learnt from: eopXD
PR: NVIDIA/TensorRT-LLM#6767
File: cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp:0-0
Timestamp: 2025-08-15T06:46:53.813Z
Learning: In the TensorRT-LLM KV cache manager, SWA (Sliding Window Attention) combined with beam search is currently in a broken/non-functional state and is planned for future rework. During preparatory refactoring phases, code related to SWA+beam search may intentionally remain in a non-working state until the broader rework is completed.

Applied to files:

  • tensorrt_llm/_torch/pyexecutor/_util.py
  • tensorrt_llm/_torch/pyexecutor/resource_manager.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/_util.py

117-117: Line too long (261 > 120)

(E501)


126-126: Line too long (130 > 120)

(E501)


309-309: Line too long (125 > 120)

(E501)


327-327: Line too long (178 > 120)

(E501)

tests/unittest/_torch/executor/test_resource_manager.py

523-523: Line too long (122 > 120)

(E501)


636-636: Line too long (129 > 120)

(E501)

tensorrt_llm/_torch/pyexecutor/resource_manager.py

829-829: Line too long (128 > 120)

(E501)


830-830: Line too long (130 > 120)

(E501)

tensorrt_llm/llmapi/llm_args.py

994-994: Line too long (131 > 120)

(E501)


1050-1051: One-line docstring should fit on one line

Reformat to one line

(D200)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (8)
cpp/include/tensorrt_llm/executor/executor.h (1)

1000-1011: Validation complete – KvCacheConfig maxGpuTotalBytes fully wired through

All checks have passed:

  • The KvCacheConfig constructor in cpp/tensorrt_llm/executor/kvCacheConfig.cpp uses the same parameters (including uint64_t const& maxGpuTotalBytes) as declared in the header.
  • The getter (getMaxGpuTotalBytes()) and setter (setMaxGpuTotalBytes(uint64_t)) are implemented in kvCacheConfig.cpp and declared at lines 1026 and 1042 of the header.
  • In the pybind bindings (cpp/tensorrt_llm/pybind/executor/executorConfig.cpp), the constructor’s py::init includes py::arg("max_gpu_total_bytes"), there is a .def_property("max_gpu_total_bytes", &getMaxGpuTotalBytes, &setMaxGpuTotalBytes), and the pickling closures capture and restore getMaxGpuTotalBytes.
  • In the nanobind bindings (cpp/tensorrt_llm/nanobind/executor/executorConfig.cpp), the constructor’s nb::init and nb::arg("max_gpu_total_bytes") are present, there is a .def_prop_rw("max_gpu_total_bytes", &getMaxGpuTotalBytes, &setMaxGpuTotalBytes), and the __getstate__/__setstate__ lambdas include the maxGpuTotalBytes field.

No further updates are needed.

tensorrt_llm/llmapi/llm_args.py (2)

1039-1061: Validation for max_attention_window is reasonable

Checks for non-empty, positive integers are appropriate and match the relaxed stance discussed in prior reviews. No action needed.


1012-1030: The above script will check:

  • That the pybind binding for tle::KvCacheConfig exists.
  • That max_gpu_total_bytes is exposed as a property in pybind.
  • That a py::pickle(...) binding is present in pybind for KvCacheConfig.
  • That nanobind defines kvCacheConfigGetstate and kvCacheConfigSetstate for pickling.

Once we have the results, we can confirm whether both the pybind and nanobind pickling implementations include the new max_gpu_total_bytes field or if further updates are needed.

tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)

840-851: Side-effect assignment is fine; ensure callers read back the adjusted vector

You correctly reassign self.max_attention_window_vec from the function's return. Just a note for future maintainers: this vector now reflects any memory-driven clamping and should be the source of truth downstream.

tensorrt_llm/_torch/pyexecutor/_util.py (4)

200-205: Good rename and scope: configure_kv_cache_capacity accurately reflects behavior.

Name and VSWA note improve clarity. No functional concerns.


269-279: KV stats aggregation for main + draft managers looks correct.

Summing allocated_bytes across both managers matches the estimation objective. Good guard on draft presence.


280-284: Capacity derived via _cal_max_memory is appropriate.

No issues; this is the right place to centralize the math.


285-287: Use of unique-value check for VSWA detection is correct.

len(set(max_attention_window)) > 1 avoids false positives like [128, 128].

@qixiang-99
Copy link
Collaborator Author

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16201 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16130 [ run ] completed with state ABORTED

@eopXD
Copy link
Collaborator

eopXD commented Aug 25, 2025

/bot run --disable-fail-fast

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16414 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16414 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12336 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@HuiGao-NV HuiGao-NV merged commit b165f8b into NVIDIA:main Aug 26, 2025
5 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants