Skip to content

Conversation

nekorobov
Copy link
Collaborator

@nekorobov nekorobov commented Aug 25, 2025

Summary by CodeRabbit

  • New Features

    • Added configurable weight alignment and bias data-type when creating quantized weights.
    • Centralized alignment computation and standardized pre-shard padding for weights and scales across fused MoE backends.
  • Bug Fixes

    • Eliminated fractional scaling and post-shard padding errors that caused inconsistent shard distributions.
    • Fixed 1D bias and mixed-precision loading behaviors for MXFP4/TRT paths, improving reliability.

Description

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@nekorobov nekorobov requested a review from a team as a code owner August 25, 2025 13:01
@nekorobov nekorobov requested a review from HuiGao-NV August 25, 2025 13:01
Copy link
Contributor

coderabbitai bot commented Aug 25, 2025

📝 Walkthrough

Walkthrough

Adds a centralized alignment helper and applies pre-shard padding for weights, biases, and scales before sharding across MXFP4 and TRT fused MoE paths. Updates weight/scale loading to use the alignment and extends MXFP4WeightFusedMoEMethod.create_weights signature to accept weight_alignment and bias_dtype. All edits are in tensorrt_llm/_torch/modules/fused_moe/quantization.py.

Changes

Cohort / File(s) Summary
Alignment & pre-shard padding + backend updates
tensorrt_llm/_torch/modules/fused_moe/quantization.py
Added _get_weight_alignment(weight_alignment, scaling_vector_size, tp_size, shard_dim_size) to compute LCM-based alignment and padding; introduced centralized pre-shard padding via maybe_pad_for_mxfp4 and applied it to weights, biases (1D), and weight scales; refactored load paths (load_expert_w3_w1_weight, load_expert_w2_weight, their scale variants, and quant scale loaders) to pad before load_weight_shard instead of post-shard; propagated alignment logic into MXFP4 and TRT fused MoE methods; updated MXFP4WeightFusedMoEMethod.create_weights signature to accept weight_alignment and bias_dtype and added internal padding/shape rounding for created weights and scales.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Caller
  participant Quant as quantization.py
  participant Align as _get_weight_alignment
  participant Pad as maybe_pad_for_mxfp4
  participant Shard as load_weight_shard
  participant Backend as MXFP4/TRT FusedMoE

  Caller->>Quant: create_weights(..., weight_alignment, bias_dtype)
  Quant->>Align: compute alignment(weight_alignment, scaling_vec_size, tp_size, shard_dim_size)
  Align-->>Quant: alignment
  Quant->>Pad: pad weights/scales/biases to alignment
  Pad-->>Quant: padded tensors
  Quant->>Shard: shard(padded tensors)
  Shard-->>Quant: shards
  Quant->>Backend: load shards
  Backend-->>Caller: initialized weights
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Suggested labels

weekly_release_blocker

Suggested reviewers

  • HuiGao-NV
  • litaotju
  • chzblych
  • yuxianq

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@nekorobov nekorobov requested a review from dongfengy August 25, 2025 13:03
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 8

🧹 Nitpick comments (4)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (4)

1-1: Missing NVIDIA Apache-2.0 copyright header

Per repository guidelines, prepend the NVIDIA Apache-2.0 header with the current year to all Python sources. Please add it at the top of this file.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

2088-2109: Add type hints and a docstring; guard against invalid inputs in _get_weight_alignment

The LCM-based alignment looks sound. Please add type hints, a brief docstring, and input validation to make behavior explicit and catch misconfiguration early.

-def _get_weight_alignment(weight_alignment, scaling_vector_size, tp_size,
-                          shard_dim_size):
+def _get_weight_alignment(weight_alignment: int,
+                          scaling_vector_size: int,
+                          tp_size: int,
+                          shard_dim_size: int) -> int:
+    """
+    Compute a global alignment for pre-shard padding so that:
+      - The full pre-pad dimension is a multiple of alignment.
+      - After sharding across tp_size, the per-shard dimension is a multiple of weight_alignment.
+      - Alignment is compatible with scaling_vector_size to avoid fractional scale groups.
+    """
+    if weight_alignment <= 0 or scaling_vector_size <= 0 or tp_size <= 0:
+        raise ValueError("weight_alignment, scaling_vector_size, and tp_size must be positive integers")
+    if shard_dim_size < 0:
+        raise ValueError("shard_dim_size must be non-negative")

2113-2169: Minor: create_weights defaults—document weight_alignment and bias_dtype

The extended signature is useful. Please document the semantics of weight_alignment with respect to element vs. byte alignment (MXFP4 packs 2 elems/byte) and when bias_dtype should be float32 (TRT-LLM) vs module.dtype (CUTLASS). Short Google-style docstring is sufficient.


2275-2346: Add unit tests for pre-shard padding across tp sizes and shapes

Given the correctness sensitivity, add tests that cover:

  • intermediate_size and hidden_size not divisible by alignment (e.g., 2880, 3001)
  • tp_size in {1,2,4,8}
  • scaling_vector_size {32}
  • Both 2D weights and 1D biases
  • Scales alignment for MXFP4 in both Cutlass and TRT-LLM paths

I can help generate parameterized tests with synthetic tensors to assert final per-shard shapes and that scaling blocks align with scaling_vector_size.

Also applies to: 2352-2421, 2608-2716, 2725-2832

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between bea5e07 and 27d7b4a.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py (9 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent with 4 spaces; do not use tabs
Preserve module namespace when importing: from package.subpackage import foo; then use foo.SomeClass()
Python filenames use snake_case (e.g., some_file.py)
Class names use PascalCase
Function and method names use snake_case
Local variables use snake_case; prefix k for names starting with a number (e.g., k_99th_percentile)
Global variables are UPPER_SNAKE_CASE prefixed with G (e.g., G_MY_GLOBAL)
Constants are UPPER_SNAKE_CASE
Avoid shadowing variables from an outer scope
Initialize all externally visible members of a class in init
For interfaces used outside a file, prefer docstrings over comments; comments for internal code or local interfaces
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Attributes and variables can be documented inline with trailing docstrings under the class or module
Avoid using reflection when easily avoidable; prefer explicit parameters/constructs over dict(**locals())
In try/except, catch the narrowest exception types possible
For duck-typing try/except, keep try body minimal and place logic in else after attribute existence checks

Files:

  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
**/*.{h,hpp,hxx,hh,c,cc,cpp,cxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA Apache-2.0 copyright header with current year to all source files

Files:

  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (2)
tensorrt_llm/_torch/distributed/communicator.py (1)
  • tp_size (46-47)
tensorrt_llm/_torch/modules/linear.py (2)
  • load_weight_shard (59-103)
  • TensorParallelMode (44-56)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

@dongfengy
Copy link
Collaborator

My local test results:

dp4_CUTLASS.log:[08/25/2025-21:37:32] [TRT-LLM] [I] lm-eval gsm8k exact_match,flexible-extract accuracy: 89.39
dp4_CUTLASS.log:0.10s setup    accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-CUTLASS]
dp4_CUTLASS.log:=================== 1 passed, 1 warning in 145.35s (0:02:25) ===================
dp4_TRTLLM.log:[08/25/2025-21:40:05] [TRT-LLM] [I] lm-eval gsm8k exact_match,flexible-extract accuracy: 89.69
dp4_TRTLLM.log:0.10s setup    accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[dp4-TRTLLM]
dp4_TRTLLM.log:=================== 1 passed, 1 warning in 129.97s (0:02:09) ===================
ep4_CUTLASS.log:[08/25/2025-21:32:11] [TRT-LLM] [I] lm-eval gsm8k exact_match,flexible-extract accuracy: 89.92
ep4_CUTLASS.log:0.09s setup    accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-CUTLASS]
ep4_CUTLASS.log:=================== 1 passed, 1 warning in 142.01s (0:02:22) ===================
ep4_TRTLLM.log:[08/25/2025-21:34:43] [TRT-LLM] [I] lm-eval gsm8k exact_match,flexible-extract accuracy: 90.75
ep4_TRTLLM.log:0.10s setup    accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[ep4-TRTLLM]
ep4_TRTLLM.log:=================== 1 passed, 1 warning in 129.56s (0:02:09) ===================
tp1_CUTLASS.log:[08/25/2025-21:05:21] [TRT-LLM] [I] lm-eval gsm8k exact_match,flexible-extract accuracy: 90.22
tp1_CUTLASS.log:0.10s setup    accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp1-CUTLASS]
tp1_CUTLASS.log:=================== 1 passed, 1 warning in 200.61s (0:03:20) ===================
tp1_TRTLLM.log:[08/25/2025-21:08:49] [TRT-LLM] [I] lm-eval gsm8k exact_match,flexible-extract accuracy: 90.37
tp1_TRTLLM.log:0.10s setup    accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp1-TRTLLM]
tp1_TRTLLM.log:=================== 1 passed, 1 warning in 184.94s (0:03:04) ===================
tp2_CUTLASS.log:[08/25/2025-21:11:43] [TRT-LLM] [I] lm-eval gsm8k exact_match,flexible-extract accuracy: 90.14
tp2_CUTLASS.log:0.10s setup    accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp2-CUTLASS]
tp2_CUTLASS.log:=================== 1 passed, 1 warning in 151.13s (0:02:31) ===================
tp2_TRTLLM.log:[08/25/2025-21:15:19] [TRT-LLM] [I] lm-eval gsm8k exact_match,flexible-extract accuracy: 90.83
tp2_TRTLLM.log:0.11s setup    accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp2-TRTLLM]
tp2_TRTLLM.log:=================== 1 passed, 1 warning in 165.58s (0:02:45) ===================
tp4_CUTLASS.log:[08/25/2025-21:18:56] [TRT-LLM] [I] lm-eval gsm8k exact_match,flexible-extract accuracy: 90.90
tp4_CUTLASS.log:0.10s setup    accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-CUTLASS]
tp4_CUTLASS.log:=================== 1 passed, 1 warning in 169.14s (0:02:49) ===================
tp4_TRTLLM.log:[08/25/2025-21:22:45] [TRT-LLM] [I] lm-eval gsm8k exact_match,flexible-extract accuracy: 90.14
tp4_TRTLLM.log:0.12s setup    accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp4-TRTLLM]
tp4_TRTLLM.log:=================== 1 passed, 1 warning in 179.67s (0:02:59) ===================
tp8_CUTLASS.log:[08/25/2025-21:26:09] [TRT-LLM] [I] lm-eval gsm8k exact_match,flexible-extract accuracy: 90.07
tp8_CUTLASS.log:0.10s setup    accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp8-CUTLASS]
tp8_CUTLASS.log:=================== 1 passed, 1 warning in 164.34s (0:02:44) ===================
tp8_TRTLLM.log:[08/25/2025-21:29:21] [TRT-LLM] [I] lm-eval gsm8k exact_match,flexible-extract accuracy: 89.84
tp8_TRTLLM.log:0.10s setup    accuracy/test_llm_api_pytorch.py::TestGPTOSS::test_w4_4gpus[tp8-TRTLLM]
tp8_TRTLLM.log:=================== 1 passed, 1 warning in 164.16s (0:02:44) ===================

Copy link
Collaborator

@dongfengy dongfengy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code LGTM. It might be possible to eliminate some duplicated code but let's don't worry about it for now.

@nekorobov nekorobov force-pushed the user/nkorobov/fix-mxfp4-padding-bug branch from 27d7b4a to 89b6f48 Compare August 26, 2025 18:42
@nekorobov nekorobov changed the title fix: mxfp4 padding bug for TRT-LLM and CUTLASS MoE backends [None][fix] mxfp4 padding bug for TRT-LLM and CUTLASS MoE backends Aug 26, 2025
@nekorobov
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16588 [ run ] triggered by Bot

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (1)

2113-2121: ⚠️ Action Required: Correct MXFP4WeightTRTLLMGenFusedMoEMethod Alignment and Bias Parameters

The TRTLLM-Gen subclass is not supplying the required alignment or bias dtype to the extended create_weights signature. Specifically:

• In quantization.py at MXFP4WeightTRTLLMGenFusedMoEMethod (around line 2569), there is no class attribute overriding weight_alignment, so self.weight_alignment defaults to 1 instead of the required 256.
• The create_weights override in MXFP4WeightTRTLLMGenFusedMoEMethod does not pass bias_dtype, so it falls back to None instead of the intended torch.float32.

Please update as follows:

 class MXFP4WeightTRTLLMGenFusedMoEMethod(MXFP4WeightFusedMoEMethod):
     weight_dtype = torch.uint8
     block_scales_dtype = torch.uint8
-    # TRTLLM-Gen backend requires weight elements to be 256 aligned.
+    # TRTLLM-Gen backend requires:
+    #   • weight elements aligned to 256
+    #   • bias in float32
+    weight_alignment = 256

     def create_weights(self, module: torch.nn.Module):
         weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
         block_scales_vec_size = torch.iinfo(self.block_scales_dtype).bits // 8
-        super().create_weights(module,
-                               self.weight_dtype,
-                               weight_vec_size,
-                               self.block_scales_dtype,
-                               block_scales_vec_size,
-                               self.weight_alignment)
+        super().create_weights(
+            module,
+            self.weight_dtype,
+            weight_vec_size,
+            self.block_scales_dtype,
+            block_scales_vec_size,
+            self.weight_alignment,
+            bias_dtype=torch.float32,
+        )
         self.setup_quant_scales(module)
  • Ensure no external callers still rely on the old 5-argument signature.
  • Verify that all other MXFP4 subclasses (Cutlass, Base) continue to pass correct alignment and bias values.
♻️ Duplicate comments (8)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (8)

2275-2305: Convert safetensors slices before F.pad; keep results contiguous.

Calling F.pad on lazy safetensors slices will fail. Materialize to torch.Tensor first and ensure .contiguous() after padding.

         alignment = _get_weight_alignment(self.weight_alignment,
                                           module.scaling_vector_size,
                                           module.tp_size, w1_weight.shape[0])
+        # Ensure torch tensor (safetensors lazy slices won't work with F.pad)
+        if not isinstance(w1_weight, torch.Tensor):
+            w1_weight = w1_weight[:]
+        if not isinstance(w3_weight, torch.Tensor):
+            w3_weight = w3_weight[:]
         if len(w1_weight.shape) == 2:
             # Pad weights
             # We already satisfy alignment factor of 2 for we pack two MXFP4 into Uint8.
             assert w1_weight.dtype == torch.uint8
-            w1_weight = maybe_pad_for_mxfp4(w1_weight,
-                                            self.weight_alignment // 2,
-                                            alignment)
+            w1_weight = maybe_pad_for_mxfp4(w1_weight,
+                                            self.weight_alignment // 2,
+                                            alignment).contiguous()
             assert w3_weight.dtype == torch.uint8
-            w3_weight = maybe_pad_for_mxfp4(w3_weight,
-                                            self.weight_alignment // 2,
-                                            alignment)
+            w3_weight = maybe_pad_for_mxfp4(w3_weight,
+                                            self.weight_alignment // 2,
+                                            alignment).contiguous()
         else:
             # Pad bias.
             assert len(w1_weight.shape) == 1
             assert len(w3_weight.shape) == 1
-            w1_weight = maybe_pad_for_mxfp4(w1_weight, alignment)
-            w3_weight = maybe_pad_for_mxfp4(w3_weight, alignment)
+            w1_weight = maybe_pad_for_mxfp4(w1_weight, alignment).contiguous()
+            w3_weight = maybe_pad_for_mxfp4(w3_weight, alignment).contiguous()

2325-2339: W2 pre-shard: convert lazy tensors; clarify packed-dim computation; keep contiguous.

  • Materialize w2_weight if safetensors.
  • After padding, call .contiguous().
  • Optional: add a one-line comment showing why shard_w2_weight_dim doubles for packed bytes.
-        shard_w2_weight_dim = 2 * w2_weight.shape[1] if len(
-            w2_weight.shape) == 2 else w2_weight.shape[0]
+        # For 2D weights, col dim stores 2 MXFP4 per uint8 => element count = 2 * bytes
+        shard_w2_weight_dim = (2 * w2_weight.shape[1]) if (len(w2_weight.shape) == 2) else w2_weight.shape[0]
         alignment = _get_weight_alignment(self.weight_alignment,
                                           module.scaling_vector_size,
                                           module.tp_size, shard_w2_weight_dim)
 
-        if len(w2_weight.shape) == 2:
+        # Ensure torch tensor before padding (safetensors slice won't work with F.pad)
+        if not isinstance(w2_weight, torch.Tensor):
+            w2_weight = w2_weight[:]
+        if len(w2_weight.shape) == 2:
             assert w2_weight.dtype == torch.uint8
-            w2_weight = maybe_pad_for_mxfp4(w2_weight, alignment // 2,
-                                            self.weight_alignment)
+            w2_weight = maybe_pad_for_mxfp4(w2_weight,
+                                            alignment // 2,
+                                            self.weight_alignment).contiguous()
         else:
             # Pad bias.
             assert len(w2_weight.shape) == 1
-            w2_weight = maybe_pad_for_mxfp4(w2_weight, self.weight_alignment)
+            w2_weight = maybe_pad_for_mxfp4(w2_weight,
+                                            self.weight_alignment).contiguous()

2353-2364: W3/W1 weight scales: convert lazy tensors before pad; keep contiguous.

Same safetensors hazard applies to scales; also ensure the padded result is contiguous.

         alignment = _get_weight_alignment(self.weight_alignment,
                                           module.scaling_vector_size,
                                           module.tp_size,
                                           w3_weight_scale.shape[0])
 
-        w1_weight_scale = maybe_pad_for_mxfp4(
+        if not isinstance(w1_weight_scale, torch.Tensor):
+            w1_weight_scale = w1_weight_scale[:]
+        if not isinstance(w3_weight_scale, torch.Tensor):
+            w3_weight_scale = w3_weight_scale[:]
+
+        w1_weight_scale = maybe_pad_for_mxfp4(
             w1_weight_scale,
-            self.weight_alignment // module.scaling_vector_size, alignment)
-        w3_weight_scale = maybe_pad_for_mxfp4(
+            self.weight_alignment // module.scaling_vector_size, alignment).contiguous()
+        w3_weight_scale = maybe_pad_for_mxfp4(
             w3_weight_scale,
-            self.weight_alignment // module.scaling_vector_size, alignment)
+            self.weight_alignment // module.scaling_vector_size, alignment).contiguous()

2396-2404: W2 weight scales: convert lazy tensors before pad; keep contiguous.

Mirror the fix for W3/W1 scales.

         alignment = _get_weight_alignment(self.weight_alignment,
                                           module.scaling_vector_size,
                                           module.tp_size,
                                           w2_weight_scale.shape[-1])
 
-        w2_weight_scale = maybe_pad_for_mxfp4(
-            w2_weight_scale, alignment // module.scaling_vector_size,
-            self.weight_alignment)
+        if not isinstance(w2_weight_scale, torch.Tensor):
+            w2_weight_scale = w2_weight_scale[:]
+        w2_weight_scale = maybe_pad_for_mxfp4(
+            w2_weight_scale,
+            alignment // module.scaling_vector_size,
+            self.weight_alignment).contiguous()

2608-2638: TRT-LLM W3/W1 pre-shard: convert lazy tensors; keep contiguous (bias stays float32).

Materialize safetensors before F.pad and call .contiguous() to avoid non-strided tensors; keep the float32 cast for biases.

         alignment = _get_weight_alignment(self.weight_alignment,
                                           module.scaling_vector_size,
                                           module.tp_size, w1_weight.shape[0])
+        if not isinstance(w1_weight, torch.Tensor):
+            w1_weight = w1_weight[:]
+        if not isinstance(w3_weight, torch.Tensor):
+            w3_weight = w3_weight[:]
         if len(w1_weight.shape) == 2:
             # Pad weights
             # We already satisfy alignment factor of 2 for we pack two MXFP4 into Uint8.
             assert w1_weight.dtype == torch.uint8
-            w1_weight = maybe_pad_for_mxfp4(w1_weight,
-                                            self.weight_alignment // 2,
-                                            alignment)
+            w1_weight = maybe_pad_for_mxfp4(w1_weight,
+                                            self.weight_alignment // 2,
+                                            alignment).contiguous()
             assert w3_weight.dtype == torch.uint8
-            w3_weight = maybe_pad_for_mxfp4(w3_weight,
-                                            self.weight_alignment // 2,
-                                            alignment)
+            w3_weight = maybe_pad_for_mxfp4(w3_weight,
+                                            self.weight_alignment // 2,
+                                            alignment).contiguous()
         else:
             # Pad bias, TRTLLM backend expects float32 bias.
             assert len(w1_weight.shape) == 1
             assert len(w3_weight.shape) == 1
-            w1_weight = maybe_pad_for_mxfp4(w1_weight, alignment).float()
-            w3_weight = maybe_pad_for_mxfp4(w3_weight, alignment).float()
+            w1_weight = maybe_pad_for_mxfp4(w1_weight, alignment).contiguous().float()
+            w3_weight = maybe_pad_for_mxfp4(w3_weight, alignment).contiguous().float()

2676-2693: Handle 1D W2 bias separately in TRT-LLM path to avoid 2D shuffle; convert and early-return.

The shuffle path assumes 2D weights. For 1D bias, pad/convert/scale, copy to device, and return before computing permute indices.

-        shard_w2_weight_dim = 2 * w2_weight.shape[1] if len(
-            w2_weight.shape) == 2 else w2_weight.shape[0]
+        shard_w2_weight_dim = (2 * w2_weight.shape[1]) if (len(w2_weight.shape) == 2) else w2_weight.shape[0]
         alignment = _get_weight_alignment(self.weight_alignment,
                                           module.scaling_vector_size,
                                           module.tp_size, shard_w2_weight_dim)
 
-        if len(w2_weight.shape) == 2:
+        # Ensure torch tensor before padding
+        if not isinstance(w2_weight, torch.Tensor):
+            w2_weight = w2_weight[:]
+        if len(w2_weight.shape) == 2:
             assert w2_weight.dtype == torch.uint8
-            w2_weight = maybe_pad_for_mxfp4(w2_weight, alignment // 2,
-                                            self.weight_alignment)
+            w2_weight = maybe_pad_for_mxfp4(w2_weight,
+                                            alignment // 2,
+                                            self.weight_alignment).contiguous()
         else:
             # Pad bias, TRTLLM backend expects float32 bias.
             # Divide bias by tp_size as we shard along the hidden dimension.
             # The bias is applied at each TP rank before the final accumulation.
-            assert len(w2_weight.shape) == 1
-            w2_weight = maybe_pad_for_mxfp4(
-                w2_weight, self.weight_alignment).float() / module.tp_size
+            assert len(w2_weight.shape) == 1
+            w2_weight = (maybe_pad_for_mxfp4(
+                w2_weight, self.weight_alignment).contiguous().float() / module.tp_size)
+            # Copy to device buffer and exit early (no shuffle for 1D bias)
+            dst_w2_weight.copy_(w2_weight.to(device=dst_w2_weight.device,
+                                             dtype=dst_w2_weight.dtype),
+                                non_blocking=True)
+            return

Also applies to: 2694-2716


2725-2736: TRT-LLM W3/W1 scales: convert lazy scales before pad; keep contiguous.

Mirror Cutlass scales fix for the TRT-LLM path.

         alignment = _get_weight_alignment(self.weight_alignment,
                                           module.scaling_vector_size,
                                           module.tp_size,
                                           w3_weight_scale.shape[0])
 
-        w1_weight_scale = maybe_pad_for_mxfp4(
+        if not isinstance(w1_weight_scale, torch.Tensor):
+            w1_weight_scale = w1_weight_scale[:]
+        if not isinstance(w3_weight_scale, torch.Tensor):
+            w3_weight_scale = w3_weight_scale[:]
+
+        w1_weight_scale = maybe_pad_for_mxfp4(
             w1_weight_scale,
-            self.weight_alignment // module.scaling_vector_size, alignment)
-        w3_weight_scale = maybe_pad_for_mxfp4(
+            self.weight_alignment // module.scaling_vector_size, alignment).contiguous()
+        w3_weight_scale = maybe_pad_for_mxfp4(
             w3_weight_scale,
-            self.weight_alignment // module.scaling_vector_size, alignment)
+            self.weight_alignment // module.scaling_vector_size, alignment).contiguous()

2788-2796: TRT-LLM W2 scales: convert lazy scales before pad; keep contiguous.

Same pattern as above.

         alignment = _get_weight_alignment(self.weight_alignment,
                                           module.scaling_vector_size,
                                           module.tp_size,
                                           w2_weight_scale.shape[-1])
 
-        w2_weight_scale = maybe_pad_for_mxfp4(
-            w2_weight_scale, alignment // module.scaling_vector_size,
-            self.weight_alignment)
+        if not isinstance(w2_weight_scale, torch.Tensor):
+            w2_weight_scale = w2_weight_scale[:]
+        w2_weight_scale = maybe_pad_for_mxfp4(
+            w2_weight_scale,
+            alignment // module.scaling_vector_size,
+            self.weight_alignment).contiguous()
🧹 Nitpick comments (4)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (4)

1-4: Add NVIDIA copyright header (2025) at file top.

Per repo guidelines, prepend the standard NVIDIA copyright header for 2025.

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+
 import math
 from abc import ABC, abstractmethod
 from typing import Dict, List, NamedTuple, Optional, Union

2088-2109: Alignment helper is sound; document units and intent.

Good use of LCM to couple weight_alignment, scaling_vector_size, and tp_size. Please add a short docstring clarifying:

  • alignment is returned in “elements” of the sharded dimension (not packed bytes),
  • why the second adjustment ensures per-shard length is a multiple of weight_alignment.
 def _get_weight_alignment(weight_alignment, scaling_vector_size, tp_size,
                           shard_dim_size):
+    """
+    Compute a row-alignment (in elements of the sharded dimension) such that:
+      1) total length is a multiple of lcm(weight_alignment, scaling_vector_size, tp_size), and
+      2) per-shard length (after padding and TP split) is a multiple of weight_alignment.
+    This prevents fractional scale groups per shard for MXFP4/NVFP4 layouts.
+    """

141-151: Optional: clarify maybe_pad_for_mxfp4 doc and units.

Given the mix of packed-bytes and element counts across callers, a short docstring noting:

  • col_alignment is in “storage units” of last dim (bytes for packed weights, elements for scales/bias),
  • row_alignment is in elements of the second-to-last dim,
    would prevent future misuse.

2275-2305: Reduce duplication: add a tiny helper to materialize lazy tensors.

You repeat isinstance(..., torch.Tensor) checks in several places. Consider a local helper to keep this change tight and consistent.

Add once near the top of this module (after imports):

def _as_torch_tensor(x):
    # Safetensors slice (lazy) exposes slicing; x[:] materializes as torch.Tensor
    return x if isinstance(x, torch.Tensor) else x[:]

Then replace repeated checks with simple calls:

w1_weight = _as_torch_tensor(w1_weight)
w3_weight = _as_torch_tensor(w3_weight)
...

Also applies to: 2325-2339, 2353-2364, 2396-2404, 2608-2638, 2676-2716, 2725-2736, 2788-2796

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 27d7b4a and 2a80cf9.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/modules/fused_moe/quantization.py (9 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
🧠 Learnings (3)
📓 Common learnings
Learnt from: djns99
PR: NVIDIA/TensorRT-LLM#6915
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:4010-4012
Timestamp: 2025-08-14T23:23:27.449Z
Learning: For MOE (Mixture of Experts) code reviews in TensorRT-LLM, avoid repeatedly suggesting finalize fusion validation checks and safety assertions. The user djns99 has indicated these suggestions are repetitive and unwanted across multiple MOE-related changes.
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.

Applied to files:

  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
📚 Learning: 2025-08-08T22:03:40.707Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu:1198-1209
Timestamp: 2025-08-08T22:03:40.707Z
Learning: In the CUTLASS MoE kernels (cpp/tensorrt_llm/cutlass_extensions), when `layout_info.fusion` is set to `TmaWarpSpecializedGroupedGemmInput::EpilogueFusion::FINALIZE`, the `router_scales` parameter must be non-null by design. The fused finalize kernel epilogue does not perform nullptr checks and requires valid router scales to function correctly. This is an implicit contract that callers must satisfy when enabling the FINALIZE fusion mode.

Applied to files:

  • tensorrt_llm/_torch/modules/fused_moe/quantization.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/modules/fused_moe/quantization.py (1)
tensorrt_llm/_torch/modules/linear.py (2)
  • load_weight_shard (58-102)
  • TensorParallelMode (43-55)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16588 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12453 completed with status: 'FAILURE'

@dongfengy
Copy link
Collaborator

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16625 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #16625 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #12483 completed with status: 'SUCCESS'

@hlu1 hlu1 merged commit a419b77 into NVIDIA:main Aug 28, 2025
5 of 6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants