-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][fix] mxfp4 padding bug for TRT-LLM and CUTLASS MoE backends #7214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[None][fix] mxfp4 padding bug for TRT-LLM and CUTLASS MoE backends #7214
Conversation
📝 WalkthroughWalkthroughAdds a centralized alignment helper and applies pre-shard padding for weights, biases, and scales before sharding across MXFP4 and TRT fused MoE paths. Updates weight/scale loading to use the alignment and extends MXFP4WeightFusedMoEMethod.create_weights signature to accept Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Caller
participant Quant as quantization.py
participant Align as _get_weight_alignment
participant Pad as maybe_pad_for_mxfp4
participant Shard as load_weight_shard
participant Backend as MXFP4/TRT FusedMoE
Caller->>Quant: create_weights(..., weight_alignment, bias_dtype)
Quant->>Align: compute alignment(weight_alignment, scaling_vec_size, tp_size, shard_dim_size)
Align-->>Quant: alignment
Quant->>Pad: pad weights/scales/biases to alignment
Pad-->>Quant: padded tensors
Quant->>Shard: shard(padded tensors)
Shard-->>Quant: shards
Quant->>Backend: load shards
Backend-->>Caller: initialized weights
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Suggested labels
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 8
🧹 Nitpick comments (4)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (4)
1-1
: Missing NVIDIA Apache-2.0 copyright headerPer repository guidelines, prepend the NVIDIA Apache-2.0 header with the current year to all Python sources. Please add it at the top of this file.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.
2088-2109
: Add type hints and a docstring; guard against invalid inputs in _get_weight_alignmentThe LCM-based alignment looks sound. Please add type hints, a brief docstring, and input validation to make behavior explicit and catch misconfiguration early.
-def _get_weight_alignment(weight_alignment, scaling_vector_size, tp_size, - shard_dim_size): +def _get_weight_alignment(weight_alignment: int, + scaling_vector_size: int, + tp_size: int, + shard_dim_size: int) -> int: + """ + Compute a global alignment for pre-shard padding so that: + - The full pre-pad dimension is a multiple of alignment. + - After sharding across tp_size, the per-shard dimension is a multiple of weight_alignment. + - Alignment is compatible with scaling_vector_size to avoid fractional scale groups. + """ + if weight_alignment <= 0 or scaling_vector_size <= 0 or tp_size <= 0: + raise ValueError("weight_alignment, scaling_vector_size, and tp_size must be positive integers") + if shard_dim_size < 0: + raise ValueError("shard_dim_size must be non-negative")
2113-2169
: Minor: create_weights defaults—document weight_alignment and bias_dtypeThe extended signature is useful. Please document the semantics of weight_alignment with respect to element vs. byte alignment (MXFP4 packs 2 elems/byte) and when bias_dtype should be float32 (TRT-LLM) vs module.dtype (CUTLASS). Short Google-style docstring is sufficient.
2275-2346
: Add unit tests for pre-shard padding across tp sizes and shapesGiven the correctness sensitivity, add tests that cover:
- intermediate_size and hidden_size not divisible by alignment (e.g., 2880, 3001)
- tp_size in {1,2,4,8}
- scaling_vector_size {32}
- Both 2D weights and 1D biases
- Scales alignment for MXFP4 in both Cutlass and TRT-LLM paths
I can help generate parameterized tests with synthetic tensors to assert final per-shard shapes and that scaling blocks align with scaling_vector_size.
Also applies to: 2352-2421, 2608-2716, 2725-2832
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
tensorrt_llm/_torch/modules/fused_moe/quantization.py
(9 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Code must target Python 3.8+
Indent with 4 spaces; do not use tabs
Preserve module namespace when importing: from package.subpackage import foo; then use foo.SomeClass()
Python filenames use snake_case (e.g., some_file.py)
Class names use PascalCase
Function and method names use snake_case
Local variables use snake_case; prefix k for names starting with a number (e.g., k_99th_percentile)
Global variables are UPPER_SNAKE_CASE prefixed with G (e.g., G_MY_GLOBAL)
Constants are UPPER_SNAKE_CASE
Avoid shadowing variables from an outer scope
Initialize all externally visible members of a class in init
For interfaces used outside a file, prefer docstrings over comments; comments for internal code or local interfaces
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Attributes and variables can be documented inline with trailing docstrings under the class or module
Avoid using reflection when easily avoidable; prefer explicit parameters/constructs over dict(**locals())
In try/except, catch the narrowest exception types possible
For duck-typing try/except, keep try body minimal and place logic in else after attribute existence checks
Files:
tensorrt_llm/_torch/modules/fused_moe/quantization.py
**/*.{h,hpp,hxx,hh,c,cc,cpp,cxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA Apache-2.0 copyright header with current year to all source files
Files:
tensorrt_llm/_torch/modules/fused_moe/quantization.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (2)
tensorrt_llm/_torch/distributed/communicator.py (1)
tp_size
(46-47)tensorrt_llm/_torch/modules/linear.py (2)
load_weight_shard
(59-103)TensorParallelMode
(44-56)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
My local test results:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code LGTM. It might be possible to eliminate some duplicated code but let's don't worry about it for now.
Signed-off-by: Nikita Korobov <[email protected]>
27d7b4a
to
89b6f48
Compare
/bot run |
PR_Github #16588 [ run ] triggered by Bot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (1)
2113-2121
:⚠️ Action Required: Correct MXFP4WeightTRTLLMGenFusedMoEMethod Alignment and Bias ParametersThe TRTLLM-Gen subclass is not supplying the required alignment or bias dtype to the extended
create_weights
signature. Specifically:• In quantization.py at MXFP4WeightTRTLLMGenFusedMoEMethod (around line 2569), there is no class attribute overriding
weight_alignment
, soself.weight_alignment
defaults to 1 instead of the required 256.
• Thecreate_weights
override in MXFP4WeightTRTLLMGenFusedMoEMethod does not passbias_dtype
, so it falls back toNone
instead of the intendedtorch.float32
.Please update as follows:
class MXFP4WeightTRTLLMGenFusedMoEMethod(MXFP4WeightFusedMoEMethod): weight_dtype = torch.uint8 block_scales_dtype = torch.uint8 - # TRTLLM-Gen backend requires weight elements to be 256 aligned. + # TRTLLM-Gen backend requires: + # • weight elements aligned to 256 + # • bias in float32 + weight_alignment = 256 def create_weights(self, module: torch.nn.Module): weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4 block_scales_vec_size = torch.iinfo(self.block_scales_dtype).bits // 8 - super().create_weights(module, - self.weight_dtype, - weight_vec_size, - self.block_scales_dtype, - block_scales_vec_size, - self.weight_alignment) + super().create_weights( + module, + self.weight_dtype, + weight_vec_size, + self.block_scales_dtype, + block_scales_vec_size, + self.weight_alignment, + bias_dtype=torch.float32, + ) self.setup_quant_scales(module)
- Ensure no external callers still rely on the old 5-argument signature.
- Verify that all other MXFP4 subclasses (Cutlass, Base) continue to pass correct alignment and bias values.
♻️ Duplicate comments (8)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (8)
2275-2305
: Convert safetensors slices before F.pad; keep results contiguous.Calling F.pad on lazy safetensors slices will fail. Materialize to torch.Tensor first and ensure .contiguous() after padding.
alignment = _get_weight_alignment(self.weight_alignment, module.scaling_vector_size, module.tp_size, w1_weight.shape[0]) + # Ensure torch tensor (safetensors lazy slices won't work with F.pad) + if not isinstance(w1_weight, torch.Tensor): + w1_weight = w1_weight[:] + if not isinstance(w3_weight, torch.Tensor): + w3_weight = w3_weight[:] if len(w1_weight.shape) == 2: # Pad weights # We already satisfy alignment factor of 2 for we pack two MXFP4 into Uint8. assert w1_weight.dtype == torch.uint8 - w1_weight = maybe_pad_for_mxfp4(w1_weight, - self.weight_alignment // 2, - alignment) + w1_weight = maybe_pad_for_mxfp4(w1_weight, + self.weight_alignment // 2, + alignment).contiguous() assert w3_weight.dtype == torch.uint8 - w3_weight = maybe_pad_for_mxfp4(w3_weight, - self.weight_alignment // 2, - alignment) + w3_weight = maybe_pad_for_mxfp4(w3_weight, + self.weight_alignment // 2, + alignment).contiguous() else: # Pad bias. assert len(w1_weight.shape) == 1 assert len(w3_weight.shape) == 1 - w1_weight = maybe_pad_for_mxfp4(w1_weight, alignment) - w3_weight = maybe_pad_for_mxfp4(w3_weight, alignment) + w1_weight = maybe_pad_for_mxfp4(w1_weight, alignment).contiguous() + w3_weight = maybe_pad_for_mxfp4(w3_weight, alignment).contiguous()
2325-2339
: W2 pre-shard: convert lazy tensors; clarify packed-dim computation; keep contiguous.
- Materialize w2_weight if safetensors.
- After padding, call .contiguous().
- Optional: add a one-line comment showing why shard_w2_weight_dim doubles for packed bytes.
- shard_w2_weight_dim = 2 * w2_weight.shape[1] if len( - w2_weight.shape) == 2 else w2_weight.shape[0] + # For 2D weights, col dim stores 2 MXFP4 per uint8 => element count = 2 * bytes + shard_w2_weight_dim = (2 * w2_weight.shape[1]) if (len(w2_weight.shape) == 2) else w2_weight.shape[0] alignment = _get_weight_alignment(self.weight_alignment, module.scaling_vector_size, module.tp_size, shard_w2_weight_dim) - if len(w2_weight.shape) == 2: + # Ensure torch tensor before padding (safetensors slice won't work with F.pad) + if not isinstance(w2_weight, torch.Tensor): + w2_weight = w2_weight[:] + if len(w2_weight.shape) == 2: assert w2_weight.dtype == torch.uint8 - w2_weight = maybe_pad_for_mxfp4(w2_weight, alignment // 2, - self.weight_alignment) + w2_weight = maybe_pad_for_mxfp4(w2_weight, + alignment // 2, + self.weight_alignment).contiguous() else: # Pad bias. assert len(w2_weight.shape) == 1 - w2_weight = maybe_pad_for_mxfp4(w2_weight, self.weight_alignment) + w2_weight = maybe_pad_for_mxfp4(w2_weight, + self.weight_alignment).contiguous()
2353-2364
: W3/W1 weight scales: convert lazy tensors before pad; keep contiguous.Same safetensors hazard applies to scales; also ensure the padded result is contiguous.
alignment = _get_weight_alignment(self.weight_alignment, module.scaling_vector_size, module.tp_size, w3_weight_scale.shape[0]) - w1_weight_scale = maybe_pad_for_mxfp4( + if not isinstance(w1_weight_scale, torch.Tensor): + w1_weight_scale = w1_weight_scale[:] + if not isinstance(w3_weight_scale, torch.Tensor): + w3_weight_scale = w3_weight_scale[:] + + w1_weight_scale = maybe_pad_for_mxfp4( w1_weight_scale, - self.weight_alignment // module.scaling_vector_size, alignment) - w3_weight_scale = maybe_pad_for_mxfp4( + self.weight_alignment // module.scaling_vector_size, alignment).contiguous() + w3_weight_scale = maybe_pad_for_mxfp4( w3_weight_scale, - self.weight_alignment // module.scaling_vector_size, alignment) + self.weight_alignment // module.scaling_vector_size, alignment).contiguous()
2396-2404
: W2 weight scales: convert lazy tensors before pad; keep contiguous.Mirror the fix for W3/W1 scales.
alignment = _get_weight_alignment(self.weight_alignment, module.scaling_vector_size, module.tp_size, w2_weight_scale.shape[-1]) - w2_weight_scale = maybe_pad_for_mxfp4( - w2_weight_scale, alignment // module.scaling_vector_size, - self.weight_alignment) + if not isinstance(w2_weight_scale, torch.Tensor): + w2_weight_scale = w2_weight_scale[:] + w2_weight_scale = maybe_pad_for_mxfp4( + w2_weight_scale, + alignment // module.scaling_vector_size, + self.weight_alignment).contiguous()
2608-2638
: TRT-LLM W3/W1 pre-shard: convert lazy tensors; keep contiguous (bias stays float32).Materialize safetensors before F.pad and call .contiguous() to avoid non-strided tensors; keep the float32 cast for biases.
alignment = _get_weight_alignment(self.weight_alignment, module.scaling_vector_size, module.tp_size, w1_weight.shape[0]) + if not isinstance(w1_weight, torch.Tensor): + w1_weight = w1_weight[:] + if not isinstance(w3_weight, torch.Tensor): + w3_weight = w3_weight[:] if len(w1_weight.shape) == 2: # Pad weights # We already satisfy alignment factor of 2 for we pack two MXFP4 into Uint8. assert w1_weight.dtype == torch.uint8 - w1_weight = maybe_pad_for_mxfp4(w1_weight, - self.weight_alignment // 2, - alignment) + w1_weight = maybe_pad_for_mxfp4(w1_weight, + self.weight_alignment // 2, + alignment).contiguous() assert w3_weight.dtype == torch.uint8 - w3_weight = maybe_pad_for_mxfp4(w3_weight, - self.weight_alignment // 2, - alignment) + w3_weight = maybe_pad_for_mxfp4(w3_weight, + self.weight_alignment // 2, + alignment).contiguous() else: # Pad bias, TRTLLM backend expects float32 bias. assert len(w1_weight.shape) == 1 assert len(w3_weight.shape) == 1 - w1_weight = maybe_pad_for_mxfp4(w1_weight, alignment).float() - w3_weight = maybe_pad_for_mxfp4(w3_weight, alignment).float() + w1_weight = maybe_pad_for_mxfp4(w1_weight, alignment).contiguous().float() + w3_weight = maybe_pad_for_mxfp4(w3_weight, alignment).contiguous().float()
2676-2693
: Handle 1D W2 bias separately in TRT-LLM path to avoid 2D shuffle; convert and early-return.The shuffle path assumes 2D weights. For 1D bias, pad/convert/scale, copy to device, and return before computing permute indices.
- shard_w2_weight_dim = 2 * w2_weight.shape[1] if len( - w2_weight.shape) == 2 else w2_weight.shape[0] + shard_w2_weight_dim = (2 * w2_weight.shape[1]) if (len(w2_weight.shape) == 2) else w2_weight.shape[0] alignment = _get_weight_alignment(self.weight_alignment, module.scaling_vector_size, module.tp_size, shard_w2_weight_dim) - if len(w2_weight.shape) == 2: + # Ensure torch tensor before padding + if not isinstance(w2_weight, torch.Tensor): + w2_weight = w2_weight[:] + if len(w2_weight.shape) == 2: assert w2_weight.dtype == torch.uint8 - w2_weight = maybe_pad_for_mxfp4(w2_weight, alignment // 2, - self.weight_alignment) + w2_weight = maybe_pad_for_mxfp4(w2_weight, + alignment // 2, + self.weight_alignment).contiguous() else: # Pad bias, TRTLLM backend expects float32 bias. # Divide bias by tp_size as we shard along the hidden dimension. # The bias is applied at each TP rank before the final accumulation. - assert len(w2_weight.shape) == 1 - w2_weight = maybe_pad_for_mxfp4( - w2_weight, self.weight_alignment).float() / module.tp_size + assert len(w2_weight.shape) == 1 + w2_weight = (maybe_pad_for_mxfp4( + w2_weight, self.weight_alignment).contiguous().float() / module.tp_size) + # Copy to device buffer and exit early (no shuffle for 1D bias) + dst_w2_weight.copy_(w2_weight.to(device=dst_w2_weight.device, + dtype=dst_w2_weight.dtype), + non_blocking=True) + returnAlso applies to: 2694-2716
2725-2736
: TRT-LLM W3/W1 scales: convert lazy scales before pad; keep contiguous.Mirror Cutlass scales fix for the TRT-LLM path.
alignment = _get_weight_alignment(self.weight_alignment, module.scaling_vector_size, module.tp_size, w3_weight_scale.shape[0]) - w1_weight_scale = maybe_pad_for_mxfp4( + if not isinstance(w1_weight_scale, torch.Tensor): + w1_weight_scale = w1_weight_scale[:] + if not isinstance(w3_weight_scale, torch.Tensor): + w3_weight_scale = w3_weight_scale[:] + + w1_weight_scale = maybe_pad_for_mxfp4( w1_weight_scale, - self.weight_alignment // module.scaling_vector_size, alignment) - w3_weight_scale = maybe_pad_for_mxfp4( + self.weight_alignment // module.scaling_vector_size, alignment).contiguous() + w3_weight_scale = maybe_pad_for_mxfp4( w3_weight_scale, - self.weight_alignment // module.scaling_vector_size, alignment) + self.weight_alignment // module.scaling_vector_size, alignment).contiguous()
2788-2796
: TRT-LLM W2 scales: convert lazy scales before pad; keep contiguous.Same pattern as above.
alignment = _get_weight_alignment(self.weight_alignment, module.scaling_vector_size, module.tp_size, w2_weight_scale.shape[-1]) - w2_weight_scale = maybe_pad_for_mxfp4( - w2_weight_scale, alignment // module.scaling_vector_size, - self.weight_alignment) + if not isinstance(w2_weight_scale, torch.Tensor): + w2_weight_scale = w2_weight_scale[:] + w2_weight_scale = maybe_pad_for_mxfp4( + w2_weight_scale, + alignment // module.scaling_vector_size, + self.weight_alignment).contiguous()
🧹 Nitpick comments (4)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (4)
1-4
: Add NVIDIA copyright header (2025) at file top.Per repo guidelines, prepend the standard NVIDIA copyright header for 2025.
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + import math from abc import ABC, abstractmethod from typing import Dict, List, NamedTuple, Optional, Union
2088-2109
: Alignment helper is sound; document units and intent.Good use of LCM to couple weight_alignment, scaling_vector_size, and tp_size. Please add a short docstring clarifying:
- alignment is returned in “elements” of the sharded dimension (not packed bytes),
- why the second adjustment ensures per-shard length is a multiple of weight_alignment.
def _get_weight_alignment(weight_alignment, scaling_vector_size, tp_size, shard_dim_size): + """ + Compute a row-alignment (in elements of the sharded dimension) such that: + 1) total length is a multiple of lcm(weight_alignment, scaling_vector_size, tp_size), and + 2) per-shard length (after padding and TP split) is a multiple of weight_alignment. + This prevents fractional scale groups per shard for MXFP4/NVFP4 layouts. + """
141-151
: Optional: clarify maybe_pad_for_mxfp4 doc and units.Given the mix of packed-bytes and element counts across callers, a short docstring noting:
- col_alignment is in “storage units” of last dim (bytes for packed weights, elements for scales/bias),
- row_alignment is in elements of the second-to-last dim,
would prevent future misuse.
2275-2305
: Reduce duplication: add a tiny helper to materialize lazy tensors.You repeat isinstance(..., torch.Tensor) checks in several places. Consider a local helper to keep this change tight and consistent.
Add once near the top of this module (after imports):
def _as_torch_tensor(x): # Safetensors slice (lazy) exposes slicing; x[:] materializes as torch.Tensor return x if isinstance(x, torch.Tensor) else x[:]Then replace repeated checks with simple calls:
w1_weight = _as_torch_tensor(w1_weight) w3_weight = _as_torch_tensor(w3_weight) ...Also applies to: 2325-2339, 2353-2364, 2396-2404, 2608-2638, 2676-2716, 2725-2736, 2788-2796
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
tensorrt_llm/_torch/modules/fused_moe/quantization.py
(9 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures
Files:
tensorrt_llm/_torch/modules/fused_moe/quantization.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)
Files:
tensorrt_llm/_torch/modules/fused_moe/quantization.py
🧠 Learnings (3)
📓 Common learnings
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.
Applied to files:
tensorrt_llm/_torch/modules/fused_moe/quantization.py
📚 Learning: 2025-08-08T22:03:40.707Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:1198-1209
Timestamp: 2025-08-08T22:03:40.707Z
Learning: In the CUTLASS MoE kernels (cpp/tensorrt_llm/cutlass_extensions), when `layout_info.fusion` is set to `TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE`, the `router_scales` parameter must be non-null by design. The fused finalize kernel epilogue does not perform nullptr checks and requires valid router scales to function correctly. This is an implicit contract that callers must satisfy when enabling the FINALIZE fusion mode.
Applied to files:
tensorrt_llm/_torch/modules/fused_moe/quantization.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (1)
tensorrt_llm/_torch/modules/linear.py (2)
load_weight_shard
(58-102)TensorParallelMode
(43-55)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
PR_Github #16588 [ run ] completed with state |
/bot run |
PR_Github #16625 [ run ] triggered by Bot |
PR_Github #16625 [ run ] completed with state |
Summary by CodeRabbit
New Features
Bug Fixes
Description
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]
Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id
(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test
(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"
(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log
(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug
(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-list
parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
and the
scripts/test_to_stage_mapping.py
helper.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.