-
Notifications
You must be signed in to change notification settings - Fork 1.7k
imp(torchsampler):support sample params temperature/topp/topk #6451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughThe sampling logic was refactored to unify multiple sampling strategies into a new Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~40 minutes Poem
Note ⚡️ Unit Test Generation is now available in beta!Learn more here, or try it out under "Finishing Touches" below. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR comments)
Other keywords and placeholders
Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
154-160
: Add input validation and handle edge cases.The temperature sampling implementation is correct, but consider these improvements for robustness:
- Handle 1D logits for consistency with other sampling functions
- Consider special handling for temperature approaching 0 (should behave like greedy)
- Add upper bound validation for temperature to prevent numerical instability
def temperature_sampling_batch(logits: torch.Tensor, temperature: float): assert temperature > 0, "Temperature must be positive" + + logits_dim = logits.dim() + if logits_dim == 1: + logits = logits.unsqueeze(0) + + # Handle near-zero temperature as greedy sampling + if temperature < 1e-8: + next_tokens = torch.argmax(logits, dim=-1) + softmax_probs = torch.softmax(logits, dim=-1) + return next_tokens, softmax_probs + scaled_logits = logits / temperature softmax_probs = torch.softmax(scaled_logits, dim=-1) next_tokens = torch.multinomial(softmax_probs, num_samples=1).squeeze(-1) + + # Restore original shape if input was 1D + if logits_dim == 1: + next_tokens = next_tokens.squeeze(0) + softmax_probs = softmax_probs.squeeze(0) + return next_tokens, softmax_probs
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/pyexecutor/sampler.py
(5 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a class in the constructor in Python.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
tensorrt_llm/_torch/pyexecutor/sampler.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
tensorrt_llm/_torch/pyexecutor/sampler.py
🔇 Additional comments (5)
tensorrt_llm/_torch/pyexecutor/sampler.py (5)
170-173
: LGTM!The new
Temperature
type definition follows the established pattern and is properly integrated into theStrategy
union type. The implementation maintains consistency with existing sampling strategy types.
176-194
: LGTM!The strategy selection logic is well-implemented:
- Clear priority order (top_p → top_k → temperature → greedy)
- Appropriate condition for temperature sampling (temperature != 1.0)
- Helpful docstring documenting the default values
- Maintains backward compatibility
207-208
: LGTM!The temperature strategy case is correctly implemented, following the established pattern of the existing match statement.
447-447
: LGTM! Consistent token adjustment implementation.The addition of
append_eagle3
calls ensures consistent token adjustment for both batched and per-request sampling flows. This addresses the issue mentioned in the PR objectives where token adjustments were not being applied uniformly.Also applies to: 455-455
436-436
: No changes required for eagle3 sampling assertionA search for any eagle3 support of non-greedy (temperature) sampling in
tensorrt_llm/_torch/pyexecutor/sampler.py
returned only the existing assertion. Since there’s no implementation or docs indicating eagle3 can handle non-greedy strategies yet, this assertion remains accurate and can stay as is.
2d4654d
to
0095f6a
Compare
0095f6a
to
d5a9f64
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (3)
tensorrt_llm/_torch/pyexecutor/sampler.py (3)
101-101
: Remove debug print statement before production.Debug print statements should not be committed to production code. Consider using proper logging instead.
- print(f"Debug INFO Apply top_k_sampling_batch with top_k:{top_k}") + # Optionally use logger.debug() if logging is needed
125-125
: Remove debug print statement before production.Debug print statements should not be committed to production code. Consider using proper logging instead.
- print(f"Debug INFO Apply top_p_sampling_batch with top_p:{top_p}") + # Optionally use logger.debug() if logging is needed
180-194
: Refactor parameter extraction for better readability.The parameter extraction logic is verbose and repetitive. Consider refactoring for better maintainability.
def request_strategy(request: LlmRequest) -> Strategy: - top_p = request.sampling_config.top_p[0] if request.sampling_config.top_p is not None and len( - request.sampling_config.top_p) > 0 else None - top_k = request.sampling_config.top_k[0] if request.sampling_config.top_k is not None and len( - request.sampling_config.top_k) > 0 else None - temperature = request.sampling_config.temperature[0] if request.sampling_config.temperature is not None and len( - request.sampling_config.temperature) > 0 and request.sampling_config.temperature[0] > 0 else None + def get_first_param(param_list): + return param_list[0] if param_list is not None and len(param_list) > 0 else None + + config = request.sampling_config + top_p = get_first_param(config.top_p) + top_k = get_first_param(config.top_k) + temperature = get_first_param(config.temperature) + + # Only use temperature if it's positive + if temperature is not None and temperature <= 0: + temperature = None
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/pyexecutor/sampler.py
(7 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without reflection.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
tensorrt_llm/_torch/pyexecutor/sampler.py
**/*.{cpp,h,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
tensorrt_llm/_torch/pyexecutor/sampler.py
🔇 Additional comments (6)
tensorrt_llm/_torch/pyexecutor/sampler.py (6)
173-173
: LGTM! Temperature type definitions are correct.The Temperature type definition and Strategy union extension properly follow the existing pattern.
Also applies to: 176-176
207-208
: LGTM! Temperature strategy handling is correct.The temperature case in the sample function correctly calls the temperature_sampling_batch function.
437-442
: LGTM! Improved readability of batched strategy logic.The reformatted code maintains the same logic while improving readability. The condition clearly shows when batched sampling can be used.
447-447
: LGTM! Consistent token adjustment with append_eagle3.Good addition to ensure tokens are properly adjusted for eagle3 in batched sampling path.
455-455
: LGTM! Consistent token adjustment with append_eagle3.Good addition to ensure tokens are properly adjusted for eagle3 in per-request sampling path.
191-192
: Temperature edge cases verified
- In
tensorrt_llm/_torch/pyexecutor/sampler.py
,temperature_sampling_batch
assertstemperature > 0
.- In
tensorrt_llm/runtime/model_runner.py
, a zero temperature is converted totemperature = 1.0
(andtop_k = 1
) to prevent overflow.- A temperature of 1.0 correctly falls back to greedy sampling.
Very small or very large temperature values simply scale logits as expected (no explicit clamping), and negative or zero inputs are safely handled. No code changes required.
aa06be8
to
6528042
Compare
@@ -98,8 +98,8 @@ def update_requests(self, state: SampleState) -> None: | |||
|
|||
|
|||
def top_k_sampling_batch(logits, top_k=50): | |||
logits_dim = logits.dim() | |||
if logits_dim == 1: | |||
# logits_dim = logits.dim() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Greedy = tuple[Literal["greedy"], None] | ||
GREEDY: Greedy = ("greedy", None) | ||
Strategy = TopK | TopP | Greedy | ||
Strategy = TopK | TopP | Greedy | Temperature |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm assuming people might want to use Temperature
and TopK/TopP, not mutually exclusive, no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have refactored the sampling strategy code to support multiple sampling strategies.
@@ -404,29 +422,20 @@ def _process_requests(self, | |||
num_steps = [1 + len(req.py_draft_tokens) for req in requests] | |||
sum_steps = sum(num_steps) | |||
no_draft_tokens = len(requests) == sum_steps | |||
fast_path = not self.enable_mixed_sampler and no_draft_tokens and gen_logits_host is None and log_probs_host is None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a breaking change - you cannot do this without breaking current code. Rather, within the fast path, you can call the sampling strategies if not greedy
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed it.
6528042
to
8994348
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
463-470
: Fix line length issue and improve readability.Line 467 exceeds the maximum line length limit. Split it for better readability.
if fast_path: logits = raw_logits[:len(requests)] if batched_strategy is not None: next_tokens,_= sample(batched_strategy,logits) else: - next_tokens = torch.empty(len(requests), dtype=torch.long, device=logits.device,pin_memory=logits.device.type == 'cuda') + next_tokens = torch.empty( + len(requests), + dtype=torch.long, + device=logits.device, + pin_memory=logits.device.type == 'cuda' + ) for i, strategy in enumerate(strategies): next_tokens[i],_ = sample(strategy, logits[i])
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/pyexecutor/sampler.py
(3 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without reflection.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
tensorrt_llm/_torch/pyexecutor/sampler.py
**/*.{cpp,h,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
tensorrt_llm/_torch/pyexecutor/sampler.py
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (3)
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-30T06:11:42.362Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : Use C++ style for single-line comments (// This is a single line comment).
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-30T06:11:42.362Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : C++ comments are required. C comments are not allowed except for special cases (inline).
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-30T06:11:42.362Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : Use #if / #endif to disable code, preferably with a mnemonic condition. Do NOT use comments to disable code.
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/sampler.py
467-467: Line too long (136 > 120)
(E501)
🔇 Additional comments (3)
tensorrt_llm/_torch/pyexecutor/sampler.py (3)
101-101
: Consider the type alias change impact.The
Strategy
type alias has been changed from tuple-based to string literals. While this is cleaner, ensure this doesn't break any external code that might depend on this type.
104-116
: Well-structured dataclass implementation.The
TorchSamplingConfig
dataclass provides a clean replacement for the tuple-based strategy representation with proper type hints and equality comparison.
232-236
: Clean implementation of the new sampling interface.The refactored
sample
function properly uses the pipeline pattern to apply sampling strategies.
def apply_top_p_sampling(self,logits: torch.Tensor, top_p: float = 0.9): | ||
logits_dim = logits.dim() | ||
if logits_dim == 1: | ||
logits = logits.unsqueeze(0) | ||
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Replace assertion with proper validation.
The assertion at line 182 should be replaced with proper error handling for production code.
- def apply_top_p_sampling(self,logits: torch.Tensor, top_p: float = 0.9):
- logits_dim = logits.dim()
- if logits_dim == 1:
- logits = logits.unsqueeze(0)
- assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]"
+ def apply_top_p_sampling(self,logits: torch.Tensor, top_p: float = 0.9):
+ if logits.dim() == 1:
+ logits = logits.unsqueeze(0)
+ if logits.dim() != 2:
+ raise ValueError(f"logits should be 2D: [batch_size, vocab_size], got {logits.dim()}D")
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def apply_top_p_sampling(self,logits: torch.Tensor, top_p: float = 0.9): | |
logits_dim = logits.dim() | |
if logits_dim == 1: | |
logits = logits.unsqueeze(0) | |
assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" | |
def apply_top_p_sampling(self,logits: torch.Tensor, top_p: float = 0.9): | |
if logits.dim() == 1: | |
logits = logits.unsqueeze(0) | |
if logits.dim() != 2: | |
raise ValueError(f"logits should be 2D: [batch_size, vocab_size], got {logits.dim()}D") |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/sampler.py around lines 178 to 183, replace
the assertion that checks if logits is 2D with proper error handling by raising
a descriptive exception if the condition is not met. This ensures the code
handles invalid input gracefully in production instead of abruptly stopping with
an assertion error.
8994348
to
6dc21dd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
178-183
: Replace assertion with proper validation.The assertion at line 182 should be replaced with proper error handling for production code, as noted in previous reviews.
- assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" + if logits.dim() != 2: + raise ValueError(f"logits should be 2D: [batch_size, vocab_size], got {logits.dim()}D")
🧹 Nitpick comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
458-470
: Fix line length and approve sampling logic updates.The sampling logic correctly uses the new pipeline-based approach. However, line 466 exceeds the 120-character limit.
- next_tokens = torch.empty(len(requests), dtype=torch.long, device=logits.device,pin_memory=logits.device.type == 'cuda') + next_tokens = torch.empty(len(requests), dtype=torch.long, + device=logits.device, pin_memory=logits.device.type == 'cuda')
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/pyexecutor/sampler.py
(3 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without reflection.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
tensorrt_llm/_torch/pyexecutor/sampler.py
**/*.{cpp,h,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
tensorrt_llm/_torch/pyexecutor/sampler.py
🧠 Learnings (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (4)
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-30T06:11:42.362Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : Use C++ style for single-line comments (// This is a single line comment).
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-30T06:11:42.362Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : C++ comments are required. C comments are not allowed except for special cases (inline).
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-07-30T06:11:42.362Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : Use #if / #endif to disable code, preferably with a mnemonic condition. Do NOT use comments to disable code.
Learnt from: amitz-nv
PR: #5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.402Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks is_adapter_in_cpu_cache()
and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/sampler.py
466-466: Line too long (136 > 120)
(E501)
🔇 Additional comments (8)
tensorrt_llm/_torch/pyexecutor/sampler.py (8)
4-5
: LGTM! Good type simplification.The change from complex tuple-based
Strategy
type to simple string literals makes the code more readable and maintainable. The new imports support the refactored sampling system effectively.Also applies to: 101-101
104-115
: LGTM! Well-designed dataclass.The
TorchSamplingConfig
dataclass provides a clean, maintainable replacement for the previous tuple-based strategy representation. The custom__eq__
method ensures proper comparison semantics.
133-146
: LGTM! Clean pipeline application logic.The
apply
method correctly implements the sequential application of sampling steps, with proper fallback to greedy sampling when no steps are configured. The separation of logit transformation and token sampling is well-designed.
149-157
: LGTM! Clean strategy dispatch using pattern matching.The use of
match-case
for strategy dispatch is clean and extensible. The fallback return ensures graceful handling of unrecognized strategies.
160-176
: LGTM! Correct top-k sampling implementation.The top-k sampling logic correctly handles dimension expansion and efficiently masks logits using
torch.topk
. The implementation properly preserves the top-k tokens while setting others to-inf
.
204-209
: LGTM! Clean temperature scaling implementation.The temperature sampling correctly scales logits by dividing by temperature. The positive temperature assertion is appropriate as a logical constraint, and the method properly handles dimension expansion.
217-235
: LGTM! Functions correctly updated for new sampling system.The sampling functions have been properly updated to use the new
TorchSamplingConfig
dataclass. The bug previously flagged wheretop_p
was reading fromtop_k
field has been correctly fixed.
479-479
: LGTM! Simplified batched strategy detection.The logic correctly identifies when batched sampling is possible by checking for a single strategy, which is more straightforward than the previous tuple-based comparison approach.
6dc21dd
to
3f647d6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
178-201
: Top-p sampling implementation needs assertion fix.The nucleus sampling logic is mathematically correct, but the assertion should be replaced with proper error handling as noted in previous reviews.
Apply the suggested fix from the previous review:
- assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" + if logits.dim() != 2: + raise ValueError(f"logits should be 2D: [batch_size, vocab_size], got {logits.dim()}D")
🧹 Nitpick comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
458-470
: Address line length violation and approve integration logic.The batching logic correctly handles both uniform and mixed sampling strategies.
Fix the line length violation on line 466:
- next_tokens = torch.empty(len(requests), dtype=torch.long, device=logits.device,pin_memory=logits.device.type == 'cuda') + next_tokens = torch.empty(len(requests), dtype=torch.long, + device=logits.device, + pin_memory=logits.device.type == 'cuda')
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/pyexecutor/sampler.py
(3 hunks)
🧰 Additional context used
🧠 Learnings (1)
📚 Learning: in tensorrt_llm/executor/worker.py, the lora adapter cache optimization logic that checks `is_adapte...
Learnt from: amitz-nv
PR: NVIDIA/TensorRT-LLM#5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.402Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks `is_adapter_in_cpu_cache()` and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
Applied to files:
tensorrt_llm/_torch/pyexecutor/sampler.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/sampler.py
466-466: Line too long (136 > 120)
(E501)
🔇 Additional comments (11)
tensorrt_llm/_torch/pyexecutor/sampler.py (11)
4-5
: LGTM! Clean import additions for the new typing requirements.The addition of
Literal
,Optional
, andAny
from typing aligns well with the new dataclass-based approach. However, verify that themath
import is actually used in the implementation.
101-101
: Excellent refactoring of Strategy type.The change from tuple-based strategy definitions to string literals significantly improves code readability and maintainability.
104-115
: Well-designed configuration dataclass.The
TorchSamplingConfig
dataclass provides a clean interface for sampling parameters with appropriate optional typing and a custom equality method for comparison logic.
117-130
: Excellent pipeline design for composable sampling strategies.The sequential application of sampling steps (temperature → top_k → top_p) follows a logical order, and the parameter validation appropriately filters out ineffective values (temperature ≤ 0, top_k ≤ 1).
133-146
: Solid implementation of the sampling pipeline application.The method correctly handles both greedy (no steps) and probabilistic sampling cases, applying transformations sequentially and using appropriate sampling methods for each case.
149-157
: Clean strategy dispatch using pattern matching.The match statement provides clear and maintainable routing to different sampling methods.
160-176
: Mathematically correct top-k sampling implementation.The implementation properly identifies the top-k values and masks out lower probability tokens by setting them to negative infinity.
204-214
: Temperature and greedy sampling implementations look good.Both methods are implemented correctly - temperature scaling is now efficient, and greedy sampling uses the standard argmax approach.
217-225
: Proper extraction of sampling parameters from requests.The function correctly handles optional parameters and validates array lengths before accessing elements.
228-235
: Clean integration of new sampling pipeline.The updated functions properly utilize the new
TorchSamplingPipeline
while maintaining the same interface contracts.
479-487
: Consistent application of new sampling approach.The mixed sampler path correctly applies the same batching logic as the fast path, ensuring consistent behavior across different execution modes.
Signed-off-by: xq25478 [email protected]
3f647d6
to
208456c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (1)
tensorrt_llm/_torch/pyexecutor/sampler.py (1)
182-182
: Replace assertion with proper error handling.As noted in previous reviews, this assertion should be replaced with proper exception handling for production code.
- assert logits_dim == 2, "logits should be 2D: [batch_size, vocab_size]" + if logits.dim() != 2: + raise ValueError(f"logits should be 2D: [batch_size, vocab_size], got {logits.dim()}D")
🧹 Nitpick comments (2)
tensorrt_llm/_torch/pyexecutor/sampler.py (2)
5-5
: Remove unused import.The
math
import is not used anywhere in the code and should be removed.-import math
466-466
: Fix line length violation.This line exceeds the 120 character limit specified in static analysis.
- next_tokens = torch.empty(len(requests), dtype=torch.long, device=logits.device,pin_memory=logits.device.type == 'cuda') + next_tokens = torch.empty(len(requests), dtype=torch.long, device=logits.device, + pin_memory=logits.device.type == 'cuda')
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tensorrt_llm/_torch/pyexecutor/sampler.py
(3 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.
Files:
tensorrt_llm/_torch/pyexecutor/sampler.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Files:
tensorrt_llm/_torch/pyexecutor/sampler.py
🧠 Learnings (5)
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : c++ comments are required. c comments are not allowed except fo...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-04T02:12:17.582Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : C++ comments are required. C comments are not allowed except for special cases (inline). Use C++ style for single-line comments (//).
Applied to files:
tensorrt_llm/_torch/pyexecutor/sampler.py
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx} : use #if / #endif to disable code, preferably with a mnemonic co...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-04T02:12:17.582Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx} : Use #if / #endif to disable code, preferably with a mnemonic condition. Do NOT use comments to disable code.
Applied to files:
tensorrt_llm/_torch/pyexecutor/sampler.py
📚 Learning: in tensorrt_llm/executor/worker.py, the lora adapter cache optimization logic that checks `is_adapte...
Learnt from: amitz-nv
PR: NVIDIA/TensorRT-LLM#5616
File: tensorrt_llm/executor/worker.py:375-384
Timestamp: 2025-07-17T09:01:27.402Z
Learning: In tensorrt_llm/executor/worker.py, the LoRA adapter cache optimization logic that checks `is_adapter_in_cpu_cache()` and conditionally passes None for weights/config has a known race condition issue that cannot be solved with simple error handling or verification checks. This is a known limitation that requires a more comprehensive solution.
Applied to files:
tensorrt_llm/_torch/pyexecutor/sampler.py
📚 Learning: applies to **/*.{h,hpp} : use a preprocessor guard in header files. the guard name must have prefix ...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-04T02:12:17.582Z
Learning: Applies to **/*.{h,hpp} : Use a preprocessor guard in header files. The guard name must have prefix TRTLLM_ followed by the filename, all in caps, and no trailing underscore.
Applied to files:
tensorrt_llm/_torch/pyexecutor/sampler.py
📚 Learning: applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : all tensorrt-llm open source software code should contain...
Learnt from: CR
PR: NVIDIA/TensorRT-LLM#0
File: CODING_GUIDELINES.md:0-0
Timestamp: 2025-08-04T02:12:17.582Z
Learning: Applies to **/*.{cpp,h,hpp,cc,cxx,cu,py} : All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.
Applied to files:
tensorrt_llm/_torch/pyexecutor/sampler.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/pyexecutor/sampler.py
466-466: Line too long (136 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (9)
tensorrt_llm/_torch/pyexecutor/sampler.py (9)
101-101
: LGTM! Improved type definition.The new
Strategy
type alias using string literals is much cleaner and more maintainable than the previous tuple-based approach.
104-115
: LGTM! Well-designed configuration dataclass.The dataclass design with optional fields and custom equality comparison is appropriate for representing sampling configurations. The implementation follows Python best practices.
133-146
: LGTM! Well-structured pipeline application.The sequential application of sampling steps followed by appropriate token generation (greedy vs multinomial) is correctly implemented. The design allows for composable sampling strategies.
160-175
: LGTM! Correct top-k sampling implementation.The top-k sampling correctly identifies the k highest logits and masks the rest to negative infinity, preserving the relative probabilities among the top-k tokens.
204-208
: LGTM! Clean temperature sampling implementation.The temperature sampling correctly scales logits by dividing by temperature, with appropriate input validation and dimension handling.
217-225
: LGTM! Robust parameter extraction.The function correctly extracts sampling parameters from requests with proper null checks, length validation, and additional temperature positivity check.
232-235
: LGTM! Clean sampling interface.The sample function provides a clean interface that encapsulates pipeline creation and application, maintaining the expected return signature.
458-474
: LGTM! Well-integrated sampling strategy handling.The integration correctly extracts strategies, determines batching feasibility, and applies appropriate sampling methods while preserving the fast path optimization.
479-479
: LGTM! Correct logic for mixed sampler enabling.The change from checking all strategies are identical to checking there's exactly one unique strategy is more accurate for determining when batched sampling can be used.
@@ -1,7 +1,8 @@ | |||
from abc import ABC, abstractmethod |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add required NVIDIA copyright header.
According to coding guidelines, all TensorRT-LLM source files must contain an NVIDIA copyright header with the current year.
Add the copyright header at the beginning of the file:
+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
from abc import ABC, abstractmethod
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
from abc import ABC, abstractmethod | |
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from abc import ABC, abstractmethod | |
# ... rest of file unchanged ... |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/sampler.py at line 1, add the required NVIDIA
copyright header at the very top of the file before any imports. Include the
current year in the header as per the coding guidelines to ensure compliance
with the project's licensing requirements.
steps.append(("temperature", config.temperature)) | ||
if config.top_k is not None and config.top_k > 1: | ||
steps.append(("top_k", config.top_k)) | ||
if config.top_p is not None and int(config.top_p) != 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix top_p validation logic.
The condition int(config.top_p) != 1
is incorrect as it casts the float to int, losing precision. A top_p value like 0.9 would be cast to 0, making the condition always true.
- if config.top_p is not None and int(config.top_p) != 1:
+ if config.top_p is not None and config.top_p != 1.0:
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
if config.top_p is not None and int(config.top_p) != 1: | |
if config.top_p is not None and config.top_p != 1.0: |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/pyexecutor/sampler.py at line 128, the condition uses
int(config.top_p) which loses precision for float values like 0.9. Replace the
int cast with a direct float comparison to 1, such as checking if config.top_p
is not None and config.top_p != 1, to correctly validate the top_p value without
losing precision.
imp(torchsampler):support sample params temperature/topp/topk
Currently, TorchSampler in TensorRT-LLM cannot implement effective TOP_P and TOP_K sampling. Regardless of how the user passes samplingparams, it always uses greedy search. This PR adds an implementation of temperature sampling and modifies the relevant code to support TOP_P and TOP_K sampling.
Summary by CodeRabbit