Skip to content

Conversation

tiffany940107
Copy link

@tiffany940107 tiffany940107 commented Aug 25, 2025

Summary by CodeRabbit

  • New Features
    • Optional FP8 Block Scale acceleration for vision and multimodal processing, disabled by default.
    • Enable via environment variable or configuration; gracefully falls back when not enabled or unsupported.
    • Targeted linear layers are automatically replaced with pre-quantized FP8 variants when eligible, reducing GPU memory use and improving throughput.
    • Public check available to determine if FP8 Block Scale mode is active.
    • Added debug logs to trace enablement source and replacement progress for easier troubleshooting.

Description

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@tiffany940107 tiffany940107 requested review from a team as code owners August 25, 2025 07:59
Copy link
Contributor

coderabbitai bot commented Aug 25, 2025

📝 Walkthrough

Walkthrough

Adds an opt-in FP8 Block Scale pathway to Qwen2VL vision and input processor classes, including env/config toggles, layer pattern configuration, runtime replacement of matching Linear layers with pre-quantized FP8 wrappers, FP8 quantize/GEMM forward path, public enablement checkers, and extensive debug prints. Changes are confined to one file.

Changes

Cohort / File(s) Summary
Feature toggle and configuration
tensorrt_llm/_torch/models/modeling_qwen2vl.py
Adds ENABLE_FP8_BLOCK_SCALE constant; reads TLLM_ENABLE_FP8_BLOCK_SCALE and config flag enable_fp8_block_scale; logs source; supports pretrained_config.fp8_block_scale_patterns.
Layer discovery and replacement pipeline
tensorrt_llm/_torch/models/modeling_qwen2vl.py
Scans model for Linear layers matching patterns (attn.qkv, attn.proj, mlp.gate_proj, mlp.down_proj, mlp.up_proj); replaces with pre-quantized FP8-Block-Scale wrappers via _replace_linear_layers_with_pre_quantization and _create_pre_quantized_fp8_block_linear.
Pre-quantized FP8 Linear implementation
tensorrt_llm/_torch/models/modeling_qwen2vl.py
Introduces PreQuantizedTrtllmFp8BlockLinear that pre-quantizes weights (block-wise), stores CPU tensors, performs fp8 input quantization and fp8_block_scaling_gemm, handles bias/reshape, validates dimensions, and requires tensorrt_llm ops.
Public API additions
tensorrt_llm/_torch/models/modeling_qwen2vl.py
Adds is_fp8_blockscale_enabled() to Qwen2VLInputProcessorBase and Qwen2VisionModelBase.
Diagnostics
tensorrt_llm/_torch/models/modeling_qwen2vl.py
Prints model structure (Linear layers), enablement source, and replacement progress; retains original path when disabled.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant C as Config/Env
  participant IP as Qwen2VLInputProcessorBase
  participant VM as Qwen2VisionModelBase
  participant RP as Replacement Routine
  participant L as Linear Modules

  C->>IP: Provide TLLM_ENABLE_FP8_BLOCK_SCALE / enable_fp8_block_scale
  C->>VM: Provide patterns / flags
  IP->>IP: Determine enable_fp8_block_scale
  VM->>VM: Determine enable_fp8_block_scale
  alt FP8 Block Scale enabled
    VM->>RP: Scan model structure
    RP->>L: Match by patterns
    RP->>VM: Replace matched Linear with FP8 wrappers
  else Disabled
    VM-->>VM: No replacement
  end
Loading
sequenceDiagram
  autonumber
  participant FW as Forward()
  participant PQL as PreQuantizedTrtllmFp8BlockLinear
  participant Q as FP8 Quantize
  participant G as FP8 Block GEMM
  participant DQ as Dequant/Output

  FW->>PQL: x (input)
  alt in_features % 16 != 0 or ops missing
    PQL-->>FW: Fallback matmul (original dtype)
  else FP8 path
    PQL->>Q: fp8_quantize_1x128(x)
    Q-->>PQL: x_fp8, scale_x
    PQL->>G: fp8_block_scaling_gemm(x_fp8, W_fp8, scales)
    G-->>PQL: y_fp8
    PQL->>DQ: Convert y_fp8 to original dtype, add bias, reshape
    DQ-->>FW: y
  end
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • juney-nvidia
  • symphonylyh

Tip

🔌 Remote MCP (Model Context Protocol) integration is now available!

Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

CodeRabbit Commands (Invoked using PR/Issue comments)

Type @coderabbitai help to get the list of available commands.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Status, Documentation and Community

  • Visit our Status Page to check the current availability of CodeRabbit.
  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 6

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (1)

1-4: Missing NVIDIA copyright header (2025) at file top.

Per coding guidelines, prepend the NVIDIA header to all source files.

Apply this diff at the very top of the file:

+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+ # Licensed under the Apache License, Version 2.0 (the "License");
+ # you may not use this file except in compliance with the License.
+ # You may obtain a copy of the License in the LICENSE file at the root of this project.
🧹 Nitpick comments (5)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (5)

30-31: Prefer G_-prefixed global and avoid import-time behavior coupling; rename and use consistently.

  • Use G_-prefixed UPPER_SNAKE_CASE for globals.
  • Minor: reading env at import time makes behavior harder to override in tests; keeping the env read is fine, but at least rename for consistency and clarity.

Apply this diff:

-ENABLE_FP8_BLOCK_SCALE = os.getenv('TLLM_ENABLE_FP8_BLOCK_SCALE', '0') == '1'
+G_ENABLE_FP8_BLOCK_SCALE = os.getenv('TLLM_ENABLE_FP8_BLOCK_SCALE', '0') == '1'

And update usage:

- self.enable_fp8_block_scale = ENABLE_FP8_BLOCK_SCALE or config_enable
+ self.enable_fp8_block_scale = G_ENABLE_FP8_BLOCK_SCALE or config_enable

- if ENABLE_FP8_BLOCK_SCALE:
+ if G_ENABLE_FP8_BLOCK_SCALE:

Also applies to: 372-379


369-406: Replace prints with logger and reduce noise; demote structure dumps to debug.

Use the existing project logger instead of print; keep structure dumps behind debug level to avoid flooding stdout.

Apply this diff:

-        print(f"FP8 Block Scale mode: {'ENABLED' if self.enable_fp8_block_scale else 'DISABLED'}")
-        if ENABLE_FP8_BLOCK_SCALE:
-            print("  - Enabled via environment variable TLLM_ENABLE_FP8_BLOCK_SCALE=1")
-        elif config_enable:
-            print("  - Enabled via config file")
-        else:
-            print("  - Disabled (use TLLM_ENABLE_FP8_BLOCK_SCALE=1 or set enable_fp8_block_scale=True in config)")
+        logger.info("FP8 Block Scale mode: %s",
+                    "ENABLED" if self.enable_fp8_block_scale else "DISABLED")
+        if G_ENABLE_FP8_BLOCK_SCALE:
+            logger.info("  - Enabled via env TLLM_ENABLE_FP8_BLOCK_SCALE=1")
+        elif config_enable:
+            logger.info("  - Enabled via config flag enable_fp8_block_scale=True")
+        else:
+            logger.info("  - Disabled (use TLLM_ENABLE_FP8_BLOCK_SCALE=1 or set enable_fp8_block_scale=True in config)")
@@
-            print("Visual model structure:")
-            for name, module in self.visual.named_modules():
-                if isinstance(module, torch.nn.Linear):
-                    print(f"  Linear layer: {name}")
+            logger.debug("Visual model structure (Linear layers):")
+            for name, module in self.visual.named_modules():
+                if isinstance(module, torch.nn.Linear):
+                    logger.debug("  Linear layer: %s", name)
@@
-            print("Skipping FP8 Block Scale layer replacement, using original implementation")
+            logger.info("Skipping FP8 Block Scale layer replacement (disabled); using original implementation.")

Also addresses Ruff E501 in a few long lines.


528-560: Avoid importing unused package and check for custom op availability explicitly.

Ruff flags the unused import; also validating ops is more accurate than importing the package root.

Apply this diff inside the wrapper’s init:

-                try:
-                    import tensorrt_llm
-                    pass
-                except ImportError:
-                    raise ImportError("TensorRT-LLM is not installed.")
+                # Verify that required custom ops are available
+                trtllm_ops = getattr(torch.ops, "trtllm", None)
+                if trtllm_ops is None or not hasattr(trtllm_ops, "fp8_quantize_1x128") \
+                        or not hasattr(trtllm_ops, "fp8_block_scaling_gemm"):
+                    raise ImportError(
+                        "TensorRT-LLM custom ops (fp8_quantize_1x128/fp8_block_scaling_gemm) are not available."
+                    )

729-737: Expose the same FP8-enable helper on the input processor for parity (optional).

If external callers query both sides, consider adding the same method on Qwen2VLInputProcessorBase for consistency.

Apply this addition near the end of Qwen2VLInputProcessorBase:

+    def is_fp8_blockscale_enabled(self) -> bool:
+        """Whether FP8 Block Scale is enabled for this pipeline."""
+        return os.getenv('TLLM_ENABLE_FP8_BLOCK_SCALE', '0') == '1' or bool(
+            getattr(self.model_config, 'enable_fp8_block_scale', False)
+        )

498-498: Fix Ruff findings: long lines and stray whitespace.

  • E501: break long f-strings/log lines.
  • W293: remove trailing whitespace-only lines.

Example fixes (some already addressed above):

-                    print(f"DEBUG: Checking {name}, weight.shape={weight.shape}, in_features={in_features}, out_features={out_features}, in_features%16={in_features % 16}")
+                    logger.debug(
+                        "DEBUG: %s weight=%s in_features=%d out_features=%d in%%16=%d",
+                        name, tuple(weight.shape), in_features, out_features, in_features % 16
+                    )

And remove whitespace-only lines at the indicated locations.

Also applies to: 553-553, 569-569, 531-531, 534-534, 732-732

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between a1e03af and c91aee1.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/models/modeling_qwen2vl.py (3 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/models/modeling_qwen2vl.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/models/modeling_qwen2vl.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (3)
tensorrt_llm/module.py (2)
  • named_modules (102-114)
  • Module (33-226)
tensorrt_llm/_torch/modules/linear.py (1)
  • Linear (1495-1704)
tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py (1)
  • fp8_linear (57-96)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/models/modeling_qwen2vl.py

498-498: Line too long (172 > 120)

(E501)


531-531: Blank line contains whitespace

Remove whitespace from blank line

(W293)


534-534: Blank line contains whitespace

Remove whitespace from blank line

(W293)


553-553: Line too long (135 > 120)

(E501)


556-556: tensorrt_llm imported but unused; consider using importlib.util.find_spec to test for availability

(F401)


569-569: Line too long (124 > 120)

(E501)


732-732: Blank line contains whitespace

Remove whitespace from blank line

(W293)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (1)

711-726: Properties proxying original Linear are fine.

Returning weight/bias/in_features/out_features from the wrapped layer preserves expected attributes for downstream code.

Comment on lines +381 to +396
if self.enable_fp8_block_scale:
# Define layer name patterns to be replaced with FP8 Block Scale
# Now supports MLP layers, handling dimension mismatch through padding
self.fp8_block_scale_patterns = [
"blocks.*.attn.qkv", # All block attention qkv
"blocks.*.attn.proj", # Re-enable attention projection, fix reshape logic
"blocks.*.mlp.gate_proj", # All block mlp gate_proj
"blocks.*.mlp.down_proj", # All block mlp down_proj
"blocks.*.mlp.up_proj", # All block mlp up_proj
]

# Allow custom replacement patterns through configuration
if hasattr(pretrained_config, 'fp8_block_scale_patterns'):
self.fp8_block_scale_patterns = pretrained_config.fp8_block_scale_patterns

# Print model structure for debugging
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Pattern matching will likely never hit; switch to fnmatch/glob or re.search with prefix allowance.

Using re.match later (with patterns like "blocks.*.attn.qkv") requires the string to start with "blocks". Most models name modules like "encoder.blocks.0.attn.qkv" so nothing will match and no layers will be replaced.

No diff here (see the function refactor below), but please adopt one of:

  • Use fnmatch.fnmatchcase(name, pattern) and also try *.{pattern} to allow prefixes, or
  • Use re.search with r'(?:^|.*\.)blocks\.\d+\.attn\.(qkv|proj)$'-style compiled regexes.
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 381 to 396, the
configured patterns like "blocks.*.attn.qkv" will not match typical module names
when later checked with re.match (which anchors at the start); update the
matching strategy so pattern checks succeed: either (preferred) switch to
fnmatch.fnmatchcase and update stored patterns to allow arbitrary prefixes (e.g.
prepending "*." when loading user config or automatically trying both raw and
"*."+pattern), or compile and use re.search with anchored-friendly regexes (e.g.
prepend "(?:^|.*\.)" and escape the pattern parts) so modules like
"encoder.blocks.0.attn.qkv" match; ensure you apply the same matching approach
for both default patterns and any pretrained_config.fp8_block_scale_patterns
provided by users and validate with a small unit test or assertion that example
module names match the intended patterns.

Comment on lines +472 to +527
def _replace_linear_layers_with_pre_quantization(self):
"""
Replace linear layers and pre-quantize weights to avoid repeated quantization during forward pass
"""
import re
import torch.nn as nn

# Directly iterate through all submodules of the visual module
for name, module in self.visual.named_modules():
# Check if it's a linear layer
if isinstance(module, nn.Linear):
# Check if it matches any pattern
should_replace = False
for pattern in self.fp8_block_scale_patterns:
# Convert pattern to regex
regex_pattern = pattern.replace("*", r"\d+")
if re.match(regex_pattern, name):
should_replace = True
break

if should_replace:
# Check if weight dimensions meet TensorRT-LLM requirements
# For matrix multiplication input @ weight.T, N dimension is in_features
weight = module.weight
in_features = weight.shape[0] # Input feature dimension
out_features = weight.shape[1] # Output feature dimension
print(f"DEBUG: Checking {name}, weight.shape={weight.shape}, in_features={in_features}, out_features={out_features}, in_features%16={in_features % 16}")

if in_features % 16 != 0:
print(f"Skipping {name}: in_features ({in_features}) not divisible by 16")
continue

try:
# Create pre-quantized FP8 Block Scale replacement
fp8_linear = self._create_pre_quantized_fp8_block_linear(module)

# Find parent module and child module names
parent_name = '.'.join(name.split('.')[:-1])
child_name = name.split('.')[-1]

if parent_name:
# Get parent module
parent_module = self.visual
for part in parent_name.split('.'):
parent_module = getattr(parent_module, part)

# Replace child module
setattr(parent_module, child_name, fp8_linear)
else:
# Direct replacement
setattr(self.visual, child_name, fp8_linear)

print(f"Replaced Linear layer with Pre-quantized FP8 Block Scale: {name}")
except Exception as e:
print(f"Failed to replace {name}: {e}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Robust replacement: use named_modules_with_parent; fix ModuleList indexing, wrong in_features, and fragile regex.

  • Current traversal uses named_modules() then manual getattr walking; this breaks on ModuleList numeric indices (no attribute '0') and can throw.
  • You compute in_features = weight.shape[0], but nn.Linear.weight is [out_features, in_features]. Use module.in_features.
  • Replace during iteration safely using named_modules_with_parent provided by this repo to avoid stale references.

Apply this diff to rewrite the function:

-    def _replace_linear_layers_with_pre_quantization(self):
-        """
-        Replace linear layers and pre-quantize weights to avoid repeated quantization during forward pass
-        """
-        import re
-        import torch.nn as nn
-        
-        # Directly iterate through all submodules of the visual module
-        for name, module in self.visual.named_modules():
-            # Check if it's a linear layer
-            if isinstance(module, nn.Linear):
-                # Check if it matches any pattern
-                should_replace = False
-                for pattern in self.fp8_block_scale_patterns:
-                    # Convert pattern to regex
-                    regex_pattern = pattern.replace("*", r"\d+")
-                    if re.match(regex_pattern, name):
-                        should_replace = True
-                        break
-                
-                if should_replace:
-                    # Check if weight dimensions meet TensorRT-LLM requirements
-                    # For matrix multiplication input @ weight.T, N dimension is in_features
-                    weight = module.weight
-                    in_features = weight.shape[0]  # Input feature dimension
-                    out_features = weight.shape[1]  # Output feature dimension
-                    print(f"DEBUG: Checking {name}, weight.shape={weight.shape}, in_features={in_features}, out_features={out_features}, in_features%16={in_features % 16}")
-                    
-                    if in_features % 16 != 0:
-                        print(f"Skipping {name}: in_features ({in_features}) not divisible by 16")
-                        continue
-                    
-                    try:
-                        # Create pre-quantized FP8 Block Scale replacement
-                        fp8_linear = self._create_pre_quantized_fp8_block_linear(module)
-                        
-                        # Find parent module and child module names
-                        parent_name = '.'.join(name.split('.')[:-1])
-                        child_name = name.split('.')[-1]
-                        
-                        if parent_name:
-                            # Get parent module
-                            parent_module = self.visual
-                            for part in parent_name.split('.'):
-                                parent_module = getattr(parent_module, part)
-                            
-                            # Replace child module
-                            setattr(parent_module, child_name, fp8_linear)
-                        else:
-                            # Direct replacement
-                            setattr(self.visual, child_name, fp8_linear)
-                            
-                        print(f"Replaced Linear layer with Pre-quantized FP8 Block Scale: {name}")
-                    except Exception as e:
-                        print(f"Failed to replace {name}: {e}")
+    def _replace_linear_layers_with_pre_quantization(self):
+        """
+        Replace target nn.Linear layers with pre-quantized FP8 Block-Scale wrappers.
+        Uses named_modules_with_parent to safely mutate the module tree while iterating.
+        """
+        import fnmatch
+        num_replaced = 0
+
+        for name, module, parent in self.visual.named_modules_with_parent(remove_duplicate=True):
+            if parent is None or not isinstance(module, nn.Linear):
+                continue
+
+            # Match both exact patterns and those preceded by any prefix (e.g., "encoder.")
+            matched = any(
+                fnmatch.fnmatchcase(name, pat) or fnmatch.fnmatchcase(name, f"*.{pat}")
+                for pat in self.fp8_block_scale_patterns
+            )
+            if not matched:
+                continue
+
+            in_features = module.in_features
+            if in_features % 16 != 0:
+                logger.debug("FP8 skip %s: in_features (%d) %% 16 != 0", name, in_features)
+                continue
+
+            try:
+                fp8_linear = self._create_pre_quantized_fp8_block_linear(module)
+                if fp8_linear is None:
+                    logger.debug("FP8 skip %s: pre-quantization unavailable/failed.", name)
+                    continue
+
+                child_name = name.rsplit(".", 1)[-1]
+                setattr(parent, child_name, fp8_linear)
+                num_replaced += 1
+                logger.info("FP8 Block Scale: replaced layer %s", name)
+            except Exception as e:
+                logger.exception("FP8 Block Scale: failed to replace %s: %s", name, e)
+
+        logger.info("FP8 Block Scale: total Linear layers replaced: %d", num_replaced)

Committable suggestion skipped: line range outside the PR's diff.

🧰 Tools
🪛 Ruff (0.12.2)

498-498: Line too long (172 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 472-527, replace
the current traversal and replacement logic with a safe approach: iterate using
the repository-provided named_modules_with_parent to get (parent, name, module)
so you never need fragile getattr walks (this also preserves ModuleList
parents), compute dimensions from the layer attributes (use module.in_features
and module.out_features instead of weight.shape indices), match
fp8_block_scale_patterns using fnmatch.fnmatchcase(name, pattern) (or equivalent
wildcard matching) instead of fragile regex construction, and perform the
replacement on the parent correctly (if parent is a Module and has the attribute
name use setattr, if parent is a ModuleList or list and name.isdigit() assign
parent[int(name)] = fp8_linear) while wrapping creation in try/except and
logging failures.

Comment on lines +496 to +502
in_features = weight.shape[0] # Input feature dimension
out_features = weight.shape[1] # Output feature dimension
print(f"DEBUG: Checking {name}, weight.shape={weight.shape}, in_features={in_features}, out_features={out_features}, in_features%16={in_features % 16}")

if in_features % 16 != 0:
print(f"Skipping {name}: in_features ({in_features}) not divisible by 16")
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Wrong feature dims: nn.Linear.weight is [out_features, in_features].

You set:

  • in_features = weight.shape[0] (this is out_features)
  • out_features = weight.shape[1] (this is in_features)

This leads to wrong divisibility checks and misleading logs.

The function refactor above fixes this by using module.in_features and avoiding manual shape indexing.

🧰 Tools
🪛 Ruff (0.12.2)

498-498: Line too long (172 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 496 to 502, the
code treats nn.Linear.weight shape as [in_features, out_features] but PyTorch
stores weights as [out_features, in_features]; swap usage accordingly and stop
relying on manual shape indexing: use module.in_features and module.out_features
(or module.weight.shape[1]/[0] if needed) for accurate values, update the debug
print to show the correct in/out features, and perform the divisibility check on
module.in_features (not weight.shape[0]); this ensures correct logging and
correct skip logic.

Comment on lines +545 to +554
# Pre-quantize weights and scaling factors
print(f"Pre-quantizing weights for layer with shape {original_linear.weight.shape}")
self.weight_fp8, self.weight_scale = self._pre_quantize_weight(original_linear.weight)

# Move quantized weights and scaling factors to CPU to save GPU memory
self.weight_fp8 = self.weight_fp8.cpu()
self.weight_scale = self.weight_scale.cpu()

print(f"Pre-quantization completed. Weight FP8 shape: {self.weight_fp8.shape}, Scale shape: {self.weight_scale.shape}")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

⚠️ Potential issue

Do not move quantized weights to CPU; this induces per-forward HtoD copies. Keep them as (non-persistent) buffers on the target device.

Current code moves FP8 weights/scales to CPU, then copies to GPU every forward. This will dominate latency and negate any GEMM speedup.

Apply this diff:

-                self.weight_fp8, self.weight_scale = self._pre_quantize_weight(original_linear.weight)
-                
-                # Move quantized weights and scaling factors to CPU to save GPU memory
-                self.weight_fp8 = self.weight_fp8.cpu()
-                self.weight_scale = self.weight_scale.cpu()
+                q_weight, q_scale = self._pre_quantize_weight(original_linear.weight)
+                self._fp8_enabled = (q_weight.dtype == torch.float8_e4m3fn)
+                # Keep on the same device as the original weight to avoid runtime transfers
+                self.register_buffer("weight_fp8", q_weight.to(original_linear.weight.device), persistent=False)
+                self.register_buffer("weight_scale", q_scale.to(original_linear.weight.device), persistent=False)

If memory is a concern, consider moving the original FP16/BF16 weight to CPU when FP8 is enabled and using the original only as a fallback path.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Pre-quantize weights and scaling factors
print(f"Pre-quantizing weights for layer with shape {original_linear.weight.shape}")
self.weight_fp8, self.weight_scale = self._pre_quantize_weight(original_linear.weight)
# Move quantized weights and scaling factors to CPU to save GPU memory
self.weight_fp8 = self.weight_fp8.cpu()
self.weight_scale = self.weight_scale.cpu()
print(f"Pre-quantization completed. Weight FP8 shape: {self.weight_fp8.shape}, Scale shape: {self.weight_scale.shape}")
# Pre-quantize weights and scaling factors
print(f"Pre-quantizing weights for layer with shape {original_linear.weight.shape}")
- self.weight_fp8, self.weight_scale = self._pre_quantize_weight(original_linear.weight)
-
- # Move quantized weights and scaling factors to CPU to save GPU memory
- self.weight_fp8 = self.weight_fp8.cpu()
q_weight, q_scale = self._pre_quantize_weight(original_linear.weight)
self._fp8_enabled = (q_weight.dtype == torch.float8_e4m3fn)
# Keep on the same device as the original weight to avoid runtime transfers
self.register_buffer("weight_fp8", q_weight.to(original_linear.weight.device), persistent=False)
self.register_buffer("weight_scale", q_scale.to(original_linear.weight.device), persistent=False)
print(f"Pre-quantization completed. Weight FP8 shape: {self.weight_fp8.shape}, Scale shape: {self.weight_scale.shape}")
🧰 Tools
🪛 Ruff (0.12.2)

553-553: Line too long (135 > 120)

(E501)

🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 545 to 554, do
not move the quantized FP8 weight and scale tensors to CPU (the .cpu() calls)
because that forces HtoD copies each forward; instead register them as
non-persistent buffers on the target device (e.g., device =
original_linear.weight.device) so they remain on GPU: remove the .cpu() calls,
call self.register_buffer("weight_fp8", self.weight_fp8.to(device),
persistent=False) and self.register_buffer("weight_scale",
self.weight_scale.to(device), persistent=False) (or set the tensors directly on
device if already assigned), and if memory is a concern optionally move the
original FP16/BF16 weight to CPU as fallback while keeping the FP8 tensors
resident on the target device.

Comment on lines +582 to +626
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward method using pre-quantized weights"""
# Get parameters from original linear layer
bias = getattr(self.original_linear, 'bias', None)

# Check if input dimensions meet requirements
input_features = input.shape[-1]
if input_features % 16 != 0:
print(f"Using original linear layer: input_features ({input_features}) not divisible by 16")
return self.original_linear(input)

# Save original shape and data type
origin_shape = input.shape
origin_dtype = input.dtype
input = input.to(torch.bfloat16)

if input.dim() > 2:
input = input.reshape(-1, input.shape[-1])

# Execute input FP8 quantization
act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(input)

# Move pre-quantized weights and scaling factors to current device
weight_fp8 = self.weight_fp8.to(input.device)
weight_scale = self.weight_scale.to(input.device)

# Execute FP8 GEMM
output = torch.ops.trtllm.fp8_block_scaling_gemm(act_input_fp8, weight_fp8, input_scale, weight_scale)
output = output.to(origin_dtype)

if bias is not None:
output = output + bias

# Handle output shape
if output.dim() == 2:
if len(origin_shape) == 3:
batch_size, seq_len, hidden_size = origin_shape
output = output.reshape(batch_size, seq_len, hidden_size)
elif len(origin_shape) == 2:
pass # No reshape needed
else:
return self.original_linear(input)

return output

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Forward path issues: missing fallback when weight isn’t FP8, CPU→GPU copies, and incorrect reshape dimension.

  • Fallback to the original linear if pre-quantization failed (weight dtype != FP8) or input features aren’t divisible by 16.
  • Avoid copying weights from CPU every call (fixed above).
  • Reshape must use out_features, not input hidden size.

Apply this diff:

-            def forward(self, input: torch.Tensor) -> torch.Tensor:
+            def forward(self, input: torch.Tensor) -> torch.Tensor:
                 """Forward method using pre-quantized weights"""
-                # Get parameters from original linear layer
-                bias = getattr(self.original_linear, 'bias', None)
-                
-                # Check if input dimensions meet requirements
-                input_features = input.shape[-1]
-                if input_features % 16 != 0:
-                    print(f"Using original linear layer: input_features ({input_features}) not divisible by 16")
-                    return self.original_linear(input)
+                # Fallback if FP8 not enabled or input dim incompatible
+                input_features = input.shape[-1]
+                if not getattr(self, "_fp8_enabled", False) or (input_features % 16 != 0):
+                    return self.original_linear(input)
+                bias = getattr(self.original_linear, "bias", None)
                 
                 # Save original shape and data type
-                origin_shape = input.shape
-                origin_dtype = input.dtype
-                input = input.to(torch.bfloat16)
+                origin_shape = input.shape
+                origin_dtype = input.dtype
+                x = input.to(torch.bfloat16)
 
-                if input.dim() > 2:
-                    input = input.reshape(-1, input.shape[-1])
+                if x.dim() > 2:
+                    x = x.reshape(-1, x.shape[-1])
 
                 # Execute input FP8 quantization
-                act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(input)
-                
-                # Move pre-quantized weights and scaling factors to current device
-                weight_fp8 = self.weight_fp8.to(input.device)
-                weight_scale = self.weight_scale.to(input.device)
+                act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(x)
+                weight_fp8 = self.weight_fp8
+                weight_scale = self.weight_scale
+                if weight_fp8.device != x.device:
+                    weight_fp8 = weight_fp8.to(x.device)
+                    weight_scale = weight_scale.to(x.device)
                 
                 # Execute FP8 GEMM
                 output = torch.ops.trtllm.fp8_block_scaling_gemm(act_input_fp8, weight_fp8, input_scale, weight_scale)
                 output = output.to(origin_dtype)
 
                 if bias is not None:
                     output = output + bias
                     
                 # Handle output shape
                 if output.dim() == 2:
                     if len(origin_shape) == 3:
-                        batch_size, seq_len, hidden_size = origin_shape
-                        output = output.reshape(batch_size, seq_len, hidden_size)
+                        batch_size, seq_len, _ = origin_shape
+                        out_features = self.original_linear.out_features
+                        output = output.reshape(batch_size, seq_len, out_features)
                     elif len(origin_shape) == 2:
                         pass  # No reshape needed
                     else:
-                        return self.original_linear(input)
+                        return self.original_linear(input)
                     
                 return output
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward method using pre-quantized weights"""
# Get parameters from original linear layer
bias = getattr(self.original_linear, 'bias', None)
# Check if input dimensions meet requirements
input_features = input.shape[-1]
if input_features % 16 != 0:
print(f"Using original linear layer: input_features ({input_features}) not divisible by 16")
return self.original_linear(input)
# Save original shape and data type
origin_shape = input.shape
origin_dtype = input.dtype
input = input.to(torch.bfloat16)
if input.dim() > 2:
input = input.reshape(-1, input.shape[-1])
# Execute input FP8 quantization
act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(input)
# Move pre-quantized weights and scaling factors to current device
weight_fp8 = self.weight_fp8.to(input.device)
weight_scale = self.weight_scale.to(input.device)
# Execute FP8 GEMM
output = torch.ops.trtllm.fp8_block_scaling_gemm(act_input_fp8, weight_fp8, input_scale, weight_scale)
output = output.to(origin_dtype)
if bias is not None:
output = output + bias
# Handle output shape
if output.dim() == 2:
if len(origin_shape) == 3:
batch_size, seq_len, hidden_size = origin_shape
output = output.reshape(batch_size, seq_len, hidden_size)
elif len(origin_shape) == 2:
pass # No reshape needed
else:
return self.original_linear(input)
return output
def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward method using pre-quantized weights"""
# Fallback if FP8 not enabled or input dim incompatible
input_features = input.shape[-1]
if not getattr(self, "_fp8_enabled", False) or (input_features % 16 != 0):
return self.original_linear(input)
bias = getattr(self.original_linear, "bias", None)
# Save original shape and data type
origin_shape = input.shape
origin_dtype = input.dtype
x = input.to(torch.bfloat16)
if x.dim() > 2:
x = x.reshape(-1, x.shape[-1])
# Execute input FP8 quantization
act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(x)
weight_fp8 = self.weight_fp8
weight_scale = self.weight_scale
if weight_fp8.device != x.device:
weight_fp8 = weight_fp8.to(x.device)
weight_scale = weight_scale.to(x.device)
# Execute FP8 GEMM
output = torch.ops.trtllm.fp8_block_scaling_gemm(
act_input_fp8, weight_fp8, input_scale, weight_scale
)
output = output.to(origin_dtype)
if bias is not None:
output = output + bias
# Handle output shape
if output.dim() == 2:
if len(origin_shape) == 3:
batch_size, seq_len, _ = origin_shape
out_features = self.original_linear.out_features
output = output.reshape(batch_size, seq_len, out_features)
elif len(origin_shape) == 2:
pass # No reshape needed
else:
return self.original_linear(input)
return output
🤖 Prompt for AI Agents
tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 582-626: the forward
needs a robust fallback when pre-quantized weights are not available, must avoid
unnecessary CPU→GPU copies each call, and must reshape using the layer's
out_features (not input hidden size). Add a fast check that the module actually
has valid FP8 weights (e.g., self.weight_fp8 and self.weight_scale are present
and have the expected FP8 dtype/device/state) and if not, call and return
self.original_linear on the original input (restore original dtype/device/shape
before calling); only move weight_fp8/weight_scale to input.device when they
exist and are not already on that device (minimize copies) and preferably
register them as buffers when created so they live on the module device; after
GEMM, when restoring output shape, use self.original_linear.out_features (or the
linear's out_features attribute) for the feature dimension instead of input
hidden size; ensure any early returns use the original input/shape/dtype rather
than the transformed tensors.

Comment on lines +627 to +710
def _create_blockwise_quantized_weight(
self,
param_value: torch.Tensor,
block_size: int = 128,
):
"""
Create block-wise quantized weights
Reference: transformers fp8 128*128 block quantization
Supports padding non-128-multiple matrices to 128 multiples
"""
param_value = param_value.to(torch.float32)

# Get FP8 min/max values
fp8_min = torch.finfo(torch.float8_e4m3fn).min
fp8_max = torch.finfo(torch.float8_e4m3fn).max

rows, cols = param_value.shape[-2:]
original_shape = param_value.shape

# Check if N dimension is divisible by 16 (TensorRT-LLM FP8 GEMM requirement)
# For matrix multiplication input @ weight.T, N dimension is cols (in_features)
if cols % 16 != 0:
print(f"Warning: Matrix N dimension ({cols}) not divisible by 16, skipping FP8 quantization")
return param_value, torch.ones(1, device=param_value.device, dtype=torch.float32)

# Calculate padding needed for rows and columns
# Round up to block_size multiples
target_rows = ((rows + block_size - 1) // block_size) * block_size
target_cols = ((cols + block_size - 1) // block_size) * block_size
pad_rows = target_rows - rows
pad_cols = target_cols - cols

# Perform padding if needed
if pad_rows > 0 or pad_cols > 0:
print(f"Padding matrix from ({rows}, {cols}) to ({rows + pad_rows}, {cols + pad_cols})")

# Create padded weight matrix
padded_weight = torch.zeros(
rows + pad_rows, cols + pad_cols,
device=param_value.device, dtype=param_value.dtype
)

# Copy original weights to top-left corner of padded matrix
padded_weight[:rows, :cols] = param_value

# Use padded weights for quantization
param_value = padded_weight
rows, cols = rows + pad_rows, cols + pad_cols

# Now matrix dimensions are multiples of 128, can perform block-wise quantization
block_size_m, block_size_n = block_size, block_size
param_value_orig_shape = param_value.shape
param_value = param_value.reshape(
-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n
).permute(0, 1, 3, 2, 4)

# Calculate scaling factor for each block
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
scale = fp8_max / max_abs
scale_orig_shape = scale.shape
scale = scale.unsqueeze(-1).unsqueeze(-1)

@torch.compiler.disable()
def _quantize(param_value, scale, fp8_min, fp8_max):
# Quantize the weights
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)

quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
# Reshape back to matrix shape
quantized_param = quantized_param.reshape(param_value_orig_shape)

# Reshape scale to match the number of blocks
scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()

return quantized_param, scale

quantized_param, scale = _quantize(param_value, scale, fp8_min, fp8_max)

# If original matrix was padded, crop back to original size
if pad_rows > 0 or pad_cols > 0:
quantized_param = quantized_param[:original_shape[0], :original_shape[1]]

return quantized_param, scale

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Verification agent

🧩 Analysis chain

Blockwise quantization: handle zeros, crop scales consistently with weight cropping, and avoid shape-destroying squeeze.

  • When a row/col block has amax=0, scale = fp8_max / 0 → inf. Add epsilon to avoid inf/nan.
  • You crop quantized_param back to the original size but not scale, which can mismatch kernel expectations. Crop the block-scale grid accordingly.
  • squeeze() may drop needed dims when a dimension equals 1; avoid it.

Apply this diff:

+                import math
@@
-                max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
-                scale = fp8_max / max_abs
+                max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
+                # Avoid division by zero for all-zero blocks
+                eps = torch.tensor(1e-6, device=max_abs.device, dtype=max_abs.dtype)
+                scale = fp8_max / torch.maximum(max_abs, eps)
                 scale_orig_shape = scale.shape
                 scale = scale.unsqueeze(-1).unsqueeze(-1)
@@
-                    # Reshape scale to match the number of blocks
-                    scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
+                    # Reshape scale to match the number of blocks and convert to dequant scale (amax/fp8_max)
+                    scale = scale.reshape(scale_orig_shape).reciprocal()
@@
-                # If original matrix was padded, crop back to original size
+                # If original matrix was padded, crop both quantized weight and block-scales
                 if pad_rows > 0 or pad_cols > 0:
                     quantized_param = quantized_param[:original_shape[0], :original_shape[1]]
+                    # Crop block scales to ceil(original/block_size)
+                    orig_block_rows = (original_shape[0] + block_size_m - 1) // block_size_m
+                    orig_block_cols = (original_shape[1] + block_size_n - 1) // block_size_n
+                    if scale.dim() == 3:
+                        scale = scale[:, :orig_block_rows, :orig_block_cols]
+                    elif scale.dim() == 2:
+                        scale = scale[:orig_block_rows, :orig_block_cols]

Please verify the shape that torch.ops.trtllm.fp8_block_scaling_gemm expects for weight_scale (per-block grid vs. flattened), and adjust accordingly.


🏁 Script executed:

#!/bin/bash
# Search for fp8_block_scaling_gemm usage or definitions
rg -n "fp8_block_scaling_gemm" -C3 . || true

# Search for torch.ops.trtllm references (to find where the operator is defined or wrapped)
rg -n "torch\.ops\.trtllm" -C3 . || true

# Search for any documentation or tests calling fp8_block_scaling_gemm to infer expected weight_scale shape
rg -n "fp8_block_scaling" -C3 . || true

Length of output: 262474


Mandatory Refactor: Strengthen Blockwise FP8 Quantization Robustness

To ensure correct, crash-free blockwise FP8 weight quantization, please update _create_blockwise_quantized_weight as follows:

  • Prevent division by zero when a block is all zeros
    Add a small epsilon before computing scale = fp8_max / max_abs to avoid inf/nan.
  • Crop the block‐scale grid in sync with the cropped weight matrix
    When you slice quantized_param back to original_shape, also trim scale to the corresponding number of blocks.
  • Preserve explicit block‐scale tensor dimensions
    Replace the broad .squeeze() call with a direct reshape on scale_orig_shape so you never accidentally drop a needed dimension.

Suggested diff (within _create_blockwise_quantized_weight):

+                import math
@@
-                max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
-                scale = fp8_max / max_abs
+                max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
+                # Avoid division by zero for all-zero blocks
+                eps = torch.tensor(1e-6, device=max_abs.device, dtype=max_abs.dtype)
+                scale = fp8_max / torch.maximum(max_abs, eps)
                 scale_orig_shape = scale.shape
                 scale = scale.unsqueeze(-1).unsqueeze(-1)
@@
-                    # Reshape scale to match the number of blocks
-                    scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
+                    # Reshape scale to match blocks and compute dequant factors
+                    scale = scale.reshape(scale_orig_shape).reciprocal()
@@
-                # If original matrix was padded, crop back to original size
+                # If we padded the matrix, crop both weight and block scales
                 if pad_rows > 0 or pad_cols > 0:
                     quantized_param = quantized_param[:original_shape[0], :original_shape[1]]
+                    # Determine original block counts
+                    orig_block_rows = math.ceil(original_shape[0] / block_size_m)
+                    orig_block_cols = math.ceil(original_shape[1] / block_size_n)
+                    # Trim scale to [orig_block_rows, orig_block_cols] (or include batch dim)
+                    if scale.dim() == 3:
+                        scale = scale[:, :orig_block_rows, :orig_block_cols]
+                    else:
+                        scale = scale[:orig_block_rows, :orig_block_cols]

Please verify whether torch.ops.trtllm.fp8_block_scaling_gemm expects the weight_scale tensor as a 2-D block grid ([row_blocks, col_blocks]) or with an extra leading batch dimension. Adjust the final squeeze/reshape accordingly.

📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _create_blockwise_quantized_weight(
self,
param_value: torch.Tensor,
block_size: int = 128,
):
"""
Create block-wise quantized weights
Reference: transformers fp8 128*128 block quantization
Supports padding non-128-multiple matrices to 128 multiples
"""
param_value = param_value.to(torch.float32)
# Get FP8 min/max values
fp8_min = torch.finfo(torch.float8_e4m3fn).min
fp8_max = torch.finfo(torch.float8_e4m3fn).max
rows, cols = param_value.shape[-2:]
original_shape = param_value.shape
# Check if N dimension is divisible by 16 (TensorRT-LLM FP8 GEMM requirement)
# For matrix multiplication input @ weight.T, N dimension is cols (in_features)
if cols % 16 != 0:
print(f"Warning: Matrix N dimension ({cols}) not divisible by 16, skipping FP8 quantization")
return param_value, torch.ones(1, device=param_value.device, dtype=torch.float32)
# Calculate padding needed for rows and columns
# Round up to block_size multiples
target_rows = ((rows + block_size - 1) // block_size) * block_size
target_cols = ((cols + block_size - 1) // block_size) * block_size
pad_rows = target_rows - rows
pad_cols = target_cols - cols
# Perform padding if needed
if pad_rows > 0 or pad_cols > 0:
print(f"Padding matrix from ({rows}, {cols}) to ({rows + pad_rows}, {cols + pad_cols})")
# Create padded weight matrix
padded_weight = torch.zeros(
rows + pad_rows, cols + pad_cols,
device=param_value.device, dtype=param_value.dtype
)
# Copy original weights to top-left corner of padded matrix
padded_weight[:rows, :cols] = param_value
# Use padded weights for quantization
param_value = padded_weight
rows, cols = rows + pad_rows, cols + pad_cols
# Now matrix dimensions are multiples of 128, can perform block-wise quantization
block_size_m, block_size_n = block_size, block_size
param_value_orig_shape = param_value.shape
param_value = param_value.reshape(
-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n
).permute(0, 1, 3, 2, 4)
# Calculate scaling factor for each block
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
scale = fp8_max / max_abs
scale_orig_shape = scale.shape
scale = scale.unsqueeze(-1).unsqueeze(-1)
@torch.compiler.disable()
def _quantize(param_value, scale, fp8_min, fp8_max):
# Quantize the weights
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
# Reshape back to matrix shape
quantized_param = quantized_param.reshape(param_value_orig_shape)
# Reshape scale to match the number of blocks
scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
return quantized_param, scale
quantized_param, scale = _quantize(param_value, scale, fp8_min, fp8_max)
# If original matrix was padded, crop back to original size
if pad_rows > 0 or pad_cols > 0:
quantized_param = quantized_param[:original_shape[0], :original_shape[1]]
return quantized_param, scale
import math
# Calculate scaling factor for each block
- max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
# Avoid division by zero for all-zero blocks
eps = torch.tensor(1e-6, device=max_abs.device, dtype=max_abs.dtype)
scale = fp8_max / torch.maximum(max_abs, eps)
scale_orig_shape = scale.shape
scale = scale.unsqueeze(-1).unsqueeze(-1)
@torch.compiler.disable()
def _quantize(param_value, scale, fp8_min, fp8_max):
# Quantize the weights
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
# Reshape back to matrix shape
quantized_param = quantized_param.reshape(param_value_orig_shape)
- # Reshape scale to match the number of blocks
# Reshape scale to match blocks and compute dequant factors
scale = scale.reshape(scale_orig_shape).reciprocal()
return quantized_param, scale
quantized_param, scale = _quantize(param_value, scale, fp8_min, fp8_max)
# If we padded the matrix, crop both weight and block scales
if pad_rows > 0 or pad_cols > 0:
quantized_param = quantized_param[:original_shape[0], :original_shape[1]]
# Determine original block counts
orig_block_rows = math.ceil(original_shape[0] / block_size_m)
orig_block_cols = math.ceil(original_shape[1] / block_size_n)
# Trim scale to [orig_block_rows, orig_block_cols] (or include batch dim)
if scale.dim() == 3:
scale = scale[:, :orig_block_rows, :orig_block_cols]
else:
scale = scale[:orig_block_rows, :orig_block_cols]
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 627-710, the
block-wise FP8 quantization can produce inf/nan when a block is all zeros and
incorrectly trims/reshapes the block-scale tensor; fix by adding a small epsilon
to max_abs before computing scale (e.g. max_abs = torch.clamp(max_abs,
min=eps)), ensure you crop the scale grid in sync with quantized_param when you
slice back to original_shape (compute the number of row/col blocks corresponding
to original_shape and slice scale accordingly), and replace the broad .squeeze()
with an explicit reshape using scale_orig_shape (or scale_orig_shape without the
last two singleton dims) before taking reciprocal so you preserve the exact
block-grid dimensions expected by torch.ops.trtllm.fp8_block_scaling_gemm
(verify whether it needs [row_blocks, col_blocks] or [batch, row_blocks,
col_blocks] and shape the final scale tensor accordingly).

@venkywonka
Copy link
Collaborator

@tiffany940107 is this PR requiring reviews?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants