-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Add blockscale fp8 gemm to qwen2.5vl vlm part #7204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
📝 WalkthroughWalkthroughAdds an opt-in FP8 Block Scale pathway to Qwen2VL vision and input processor classes, including env/config toggles, layer pattern configuration, runtime replacement of matching Linear layers with pre-quantized FP8 wrappers, FP8 quantize/GEMM forward path, public enablement checkers, and extensive debug prints. Changes are confined to one file. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant C as Config/Env
participant IP as Qwen2VLInputProcessorBase
participant VM as Qwen2VisionModelBase
participant RP as Replacement Routine
participant L as Linear Modules
C->>IP: Provide TLLM_ENABLE_FP8_BLOCK_SCALE / enable_fp8_block_scale
C->>VM: Provide patterns / flags
IP->>IP: Determine enable_fp8_block_scale
VM->>VM: Determine enable_fp8_block_scale
alt FP8 Block Scale enabled
VM->>RP: Scan model structure
RP->>L: Match by patterns
RP->>VM: Replace matched Linear with FP8 wrappers
else Disabled
VM-->>VM: No replacement
end
sequenceDiagram
autonumber
participant FW as Forward()
participant PQL as PreQuantizedTrtllmFp8BlockLinear
participant Q as FP8 Quantize
participant G as FP8 Block GEMM
participant DQ as Dequant/Output
FW->>PQL: x (input)
alt in_features % 16 != 0 or ops missing
PQL-->>FW: Fallback matmul (original dtype)
else FP8 path
PQL->>Q: fp8_quantize_1x128(x)
Q-->>PQL: x_fp8, scale_x
PQL->>G: fp8_block_scaling_gemm(x_fp8, W_fp8, scales)
G-->>PQL: y_fp8
PQL->>DQ: Convert y_fp8 to original dtype, add bias, reshape
DQ-->>FW: y
end
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Tip 🔌 Remote MCP (Model Context Protocol) integration is now available!Pro plan users can now connect to remote MCP servers from the Integrations page. Connect with popular remote MCPs such as Notion and Linear to add more context to your reviews and chats. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. 🪧 TipsChatThere are 3 ways to chat with CodeRabbit:
SupportNeed help? Create a ticket on our support page for assistance with any issues or questions. CodeRabbit Commands (Invoked using PR/Issue comments)Type Other keywords and placeholders
Status, Documentation and Community
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (1)
1-4
: Missing NVIDIA copyright header (2025) at file top.Per coding guidelines, prepend the NVIDIA header to all source files.
Apply this diff at the very top of the file:
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. + # Licensed under the Apache License, Version 2.0 (the "License"); + # you may not use this file except in compliance with the License. + # You may obtain a copy of the License in the LICENSE file at the root of this project.
🧹 Nitpick comments (5)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (5)
30-31
: Prefer G_-prefixed global and avoid import-time behavior coupling; rename and use consistently.
- Use G_-prefixed UPPER_SNAKE_CASE for globals.
- Minor: reading env at import time makes behavior harder to override in tests; keeping the env read is fine, but at least rename for consistency and clarity.
Apply this diff:
-ENABLE_FP8_BLOCK_SCALE = os.getenv('TLLM_ENABLE_FP8_BLOCK_SCALE', '0') == '1' +G_ENABLE_FP8_BLOCK_SCALE = os.getenv('TLLM_ENABLE_FP8_BLOCK_SCALE', '0') == '1'And update usage:
- self.enable_fp8_block_scale = ENABLE_FP8_BLOCK_SCALE or config_enable + self.enable_fp8_block_scale = G_ENABLE_FP8_BLOCK_SCALE or config_enable - if ENABLE_FP8_BLOCK_SCALE: + if G_ENABLE_FP8_BLOCK_SCALE:Also applies to: 372-379
369-406
: Replace prints with logger and reduce noise; demote structure dumps to debug.Use the existing project logger instead of print; keep structure dumps behind debug level to avoid flooding stdout.
Apply this diff:
- print(f"FP8 Block Scale mode: {'ENABLED' if self.enable_fp8_block_scale else 'DISABLED'}") - if ENABLE_FP8_BLOCK_SCALE: - print(" - Enabled via environment variable TLLM_ENABLE_FP8_BLOCK_SCALE=1") - elif config_enable: - print(" - Enabled via config file") - else: - print(" - Disabled (use TLLM_ENABLE_FP8_BLOCK_SCALE=1 or set enable_fp8_block_scale=True in config)") + logger.info("FP8 Block Scale mode: %s", + "ENABLED" if self.enable_fp8_block_scale else "DISABLED") + if G_ENABLE_FP8_BLOCK_SCALE: + logger.info(" - Enabled via env TLLM_ENABLE_FP8_BLOCK_SCALE=1") + elif config_enable: + logger.info(" - Enabled via config flag enable_fp8_block_scale=True") + else: + logger.info(" - Disabled (use TLLM_ENABLE_FP8_BLOCK_SCALE=1 or set enable_fp8_block_scale=True in config)") @@ - print("Visual model structure:") - for name, module in self.visual.named_modules(): - if isinstance(module, torch.nn.Linear): - print(f" Linear layer: {name}") + logger.debug("Visual model structure (Linear layers):") + for name, module in self.visual.named_modules(): + if isinstance(module, torch.nn.Linear): + logger.debug(" Linear layer: %s", name) @@ - print("Skipping FP8 Block Scale layer replacement, using original implementation") + logger.info("Skipping FP8 Block Scale layer replacement (disabled); using original implementation.")Also addresses Ruff E501 in a few long lines.
528-560
: Avoid importing unused package and check for custom op availability explicitly.Ruff flags the unused import; also validating ops is more accurate than importing the package root.
Apply this diff inside the wrapper’s init:
- try: - import tensorrt_llm - pass - except ImportError: - raise ImportError("TensorRT-LLM is not installed.") + # Verify that required custom ops are available + trtllm_ops = getattr(torch.ops, "trtllm", None) + if trtllm_ops is None or not hasattr(trtllm_ops, "fp8_quantize_1x128") \ + or not hasattr(trtllm_ops, "fp8_block_scaling_gemm"): + raise ImportError( + "TensorRT-LLM custom ops (fp8_quantize_1x128/fp8_block_scaling_gemm) are not available." + )
729-737
: Expose the same FP8-enable helper on the input processor for parity (optional).If external callers query both sides, consider adding the same method on Qwen2VLInputProcessorBase for consistency.
Apply this addition near the end of Qwen2VLInputProcessorBase:
+ def is_fp8_blockscale_enabled(self) -> bool: + """Whether FP8 Block Scale is enabled for this pipeline.""" + return os.getenv('TLLM_ENABLE_FP8_BLOCK_SCALE', '0') == '1' or bool( + getattr(self.model_config, 'enable_fp8_block_scale', False) + )
498-498
: Fix Ruff findings: long lines and stray whitespace.
- E501: break long f-strings/log lines.
- W293: remove trailing whitespace-only lines.
Example fixes (some already addressed above):
- print(f"DEBUG: Checking {name}, weight.shape={weight.shape}, in_features={in_features}, out_features={out_features}, in_features%16={in_features % 16}") + logger.debug( + "DEBUG: %s weight=%s in_features=%d out_features=%d in%%16=%d", + name, tuple(weight.shape), in_features, out_features, in_features % 16 + )And remove whitespace-only lines at the indicated locations.
Also applies to: 553-553, 569-569, 531-531, 534-534, 732-732
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (1)
tensorrt_llm/_torch/models/modeling_qwen2vl.py
(3 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/models/modeling_qwen2vl.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/models/modeling_qwen2vl.py
🧬 Code graph analysis (1)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (3)
tensorrt_llm/module.py (2)
named_modules
(102-114)Module
(33-226)tensorrt_llm/_torch/modules/linear.py (1)
Linear
(1495-1704)tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py (1)
fp8_linear
(57-96)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/models/modeling_qwen2vl.py
498-498: Line too long (172 > 120)
(E501)
531-531: Blank line contains whitespace
Remove whitespace from blank line
(W293)
534-534: Blank line contains whitespace
Remove whitespace from blank line
(W293)
553-553: Line too long (135 > 120)
(E501)
556-556: tensorrt_llm
imported but unused; consider using importlib.util.find_spec
to test for availability
(F401)
569-569: Line too long (124 > 120)
(E501)
732-732: Blank line contains whitespace
Remove whitespace from blank line
(W293)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (1)
tensorrt_llm/_torch/models/modeling_qwen2vl.py (1)
711-726
: Properties proxying original Linear are fine.Returning weight/bias/in_features/out_features from the wrapped layer preserves expected attributes for downstream code.
if self.enable_fp8_block_scale: | ||
# Define layer name patterns to be replaced with FP8 Block Scale | ||
# Now supports MLP layers, handling dimension mismatch through padding | ||
self.fp8_block_scale_patterns = [ | ||
"blocks.*.attn.qkv", # All block attention qkv | ||
"blocks.*.attn.proj", # Re-enable attention projection, fix reshape logic | ||
"blocks.*.mlp.gate_proj", # All block mlp gate_proj | ||
"blocks.*.mlp.down_proj", # All block mlp down_proj | ||
"blocks.*.mlp.up_proj", # All block mlp up_proj | ||
] | ||
|
||
# Allow custom replacement patterns through configuration | ||
if hasattr(pretrained_config, 'fp8_block_scale_patterns'): | ||
self.fp8_block_scale_patterns = pretrained_config.fp8_block_scale_patterns | ||
|
||
# Print model structure for debugging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Pattern matching will likely never hit; switch to fnmatch/glob or re.search with prefix allowance.
Using re.match
later (with patterns like "blocks.*.attn.qkv"
) requires the string to start with "blocks"
. Most models name modules like "encoder.blocks.0.attn.qkv"
so nothing will match and no layers will be replaced.
No diff here (see the function refactor below), but please adopt one of:
- Use
fnmatch.fnmatchcase(name, pattern)
and also try*.{pattern}
to allow prefixes, or - Use
re.search
withr'(?:^|.*\.)blocks\.\d+\.attn\.(qkv|proj)$'
-style compiled regexes.
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 381 to 396, the
configured patterns like "blocks.*.attn.qkv" will not match typical module names
when later checked with re.match (which anchors at the start); update the
matching strategy so pattern checks succeed: either (preferred) switch to
fnmatch.fnmatchcase and update stored patterns to allow arbitrary prefixes (e.g.
prepending "*." when loading user config or automatically trying both raw and
"*."+pattern), or compile and use re.search with anchored-friendly regexes (e.g.
prepend "(?:^|.*\.)" and escape the pattern parts) so modules like
"encoder.blocks.0.attn.qkv" match; ensure you apply the same matching approach
for both default patterns and any pretrained_config.fp8_block_scale_patterns
provided by users and validate with a small unit test or assertion that example
module names match the intended patterns.
def _replace_linear_layers_with_pre_quantization(self): | ||
""" | ||
Replace linear layers and pre-quantize weights to avoid repeated quantization during forward pass | ||
""" | ||
import re | ||
import torch.nn as nn | ||
|
||
# Directly iterate through all submodules of the visual module | ||
for name, module in self.visual.named_modules(): | ||
# Check if it's a linear layer | ||
if isinstance(module, nn.Linear): | ||
# Check if it matches any pattern | ||
should_replace = False | ||
for pattern in self.fp8_block_scale_patterns: | ||
# Convert pattern to regex | ||
regex_pattern = pattern.replace("*", r"\d+") | ||
if re.match(regex_pattern, name): | ||
should_replace = True | ||
break | ||
|
||
if should_replace: | ||
# Check if weight dimensions meet TensorRT-LLM requirements | ||
# For matrix multiplication input @ weight.T, N dimension is in_features | ||
weight = module.weight | ||
in_features = weight.shape[0] # Input feature dimension | ||
out_features = weight.shape[1] # Output feature dimension | ||
print(f"DEBUG: Checking {name}, weight.shape={weight.shape}, in_features={in_features}, out_features={out_features}, in_features%16={in_features % 16}") | ||
|
||
if in_features % 16 != 0: | ||
print(f"Skipping {name}: in_features ({in_features}) not divisible by 16") | ||
continue | ||
|
||
try: | ||
# Create pre-quantized FP8 Block Scale replacement | ||
fp8_linear = self._create_pre_quantized_fp8_block_linear(module) | ||
|
||
# Find parent module and child module names | ||
parent_name = '.'.join(name.split('.')[:-1]) | ||
child_name = name.split('.')[-1] | ||
|
||
if parent_name: | ||
# Get parent module | ||
parent_module = self.visual | ||
for part in parent_name.split('.'): | ||
parent_module = getattr(parent_module, part) | ||
|
||
# Replace child module | ||
setattr(parent_module, child_name, fp8_linear) | ||
else: | ||
# Direct replacement | ||
setattr(self.visual, child_name, fp8_linear) | ||
|
||
print(f"Replaced Linear layer with Pre-quantized FP8 Block Scale: {name}") | ||
except Exception as e: | ||
print(f"Failed to replace {name}: {e}") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Robust replacement: use named_modules_with_parent; fix ModuleList indexing, wrong in_features, and fragile regex.
- Current traversal uses
named_modules()
then manualgetattr
walking; this breaks onModuleList
numeric indices (no attribute'0'
) and can throw. - You compute
in_features = weight.shape[0]
, butnn.Linear.weight
is[out_features, in_features]
. Usemodule.in_features
. - Replace during iteration safely using
named_modules_with_parent
provided by this repo to avoid stale references.
Apply this diff to rewrite the function:
- def _replace_linear_layers_with_pre_quantization(self):
- """
- Replace linear layers and pre-quantize weights to avoid repeated quantization during forward pass
- """
- import re
- import torch.nn as nn
-
- # Directly iterate through all submodules of the visual module
- for name, module in self.visual.named_modules():
- # Check if it's a linear layer
- if isinstance(module, nn.Linear):
- # Check if it matches any pattern
- should_replace = False
- for pattern in self.fp8_block_scale_patterns:
- # Convert pattern to regex
- regex_pattern = pattern.replace("*", r"\d+")
- if re.match(regex_pattern, name):
- should_replace = True
- break
-
- if should_replace:
- # Check if weight dimensions meet TensorRT-LLM requirements
- # For matrix multiplication input @ weight.T, N dimension is in_features
- weight = module.weight
- in_features = weight.shape[0] # Input feature dimension
- out_features = weight.shape[1] # Output feature dimension
- print(f"DEBUG: Checking {name}, weight.shape={weight.shape}, in_features={in_features}, out_features={out_features}, in_features%16={in_features % 16}")
-
- if in_features % 16 != 0:
- print(f"Skipping {name}: in_features ({in_features}) not divisible by 16")
- continue
-
- try:
- # Create pre-quantized FP8 Block Scale replacement
- fp8_linear = self._create_pre_quantized_fp8_block_linear(module)
-
- # Find parent module and child module names
- parent_name = '.'.join(name.split('.')[:-1])
- child_name = name.split('.')[-1]
-
- if parent_name:
- # Get parent module
- parent_module = self.visual
- for part in parent_name.split('.'):
- parent_module = getattr(parent_module, part)
-
- # Replace child module
- setattr(parent_module, child_name, fp8_linear)
- else:
- # Direct replacement
- setattr(self.visual, child_name, fp8_linear)
-
- print(f"Replaced Linear layer with Pre-quantized FP8 Block Scale: {name}")
- except Exception as e:
- print(f"Failed to replace {name}: {e}")
+ def _replace_linear_layers_with_pre_quantization(self):
+ """
+ Replace target nn.Linear layers with pre-quantized FP8 Block-Scale wrappers.
+ Uses named_modules_with_parent to safely mutate the module tree while iterating.
+ """
+ import fnmatch
+ num_replaced = 0
+
+ for name, module, parent in self.visual.named_modules_with_parent(remove_duplicate=True):
+ if parent is None or not isinstance(module, nn.Linear):
+ continue
+
+ # Match both exact patterns and those preceded by any prefix (e.g., "encoder.")
+ matched = any(
+ fnmatch.fnmatchcase(name, pat) or fnmatch.fnmatchcase(name, f"*.{pat}")
+ for pat in self.fp8_block_scale_patterns
+ )
+ if not matched:
+ continue
+
+ in_features = module.in_features
+ if in_features % 16 != 0:
+ logger.debug("FP8 skip %s: in_features (%d) %% 16 != 0", name, in_features)
+ continue
+
+ try:
+ fp8_linear = self._create_pre_quantized_fp8_block_linear(module)
+ if fp8_linear is None:
+ logger.debug("FP8 skip %s: pre-quantization unavailable/failed.", name)
+ continue
+
+ child_name = name.rsplit(".", 1)[-1]
+ setattr(parent, child_name, fp8_linear)
+ num_replaced += 1
+ logger.info("FP8 Block Scale: replaced layer %s", name)
+ except Exception as e:
+ logger.exception("FP8 Block Scale: failed to replace %s: %s", name, e)
+
+ logger.info("FP8 Block Scale: total Linear layers replaced: %d", num_replaced)
Committable suggestion skipped: line range outside the PR's diff.
🧰 Tools
🪛 Ruff (0.12.2)
498-498: Line too long (172 > 120)
(E501)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 472-527, replace
the current traversal and replacement logic with a safe approach: iterate using
the repository-provided named_modules_with_parent to get (parent, name, module)
so you never need fragile getattr walks (this also preserves ModuleList
parents), compute dimensions from the layer attributes (use module.in_features
and module.out_features instead of weight.shape indices), match
fp8_block_scale_patterns using fnmatch.fnmatchcase(name, pattern) (or equivalent
wildcard matching) instead of fragile regex construction, and perform the
replacement on the parent correctly (if parent is a Module and has the attribute
name use setattr, if parent is a ModuleList or list and name.isdigit() assign
parent[int(name)] = fp8_linear) while wrapping creation in try/except and
logging failures.
in_features = weight.shape[0] # Input feature dimension | ||
out_features = weight.shape[1] # Output feature dimension | ||
print(f"DEBUG: Checking {name}, weight.shape={weight.shape}, in_features={in_features}, out_features={out_features}, in_features%16={in_features % 16}") | ||
|
||
if in_features % 16 != 0: | ||
print(f"Skipping {name}: in_features ({in_features}) not divisible by 16") | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wrong feature dims: nn.Linear.weight is [out_features, in_features].
You set:
- in_features = weight.shape[0] (this is out_features)
- out_features = weight.shape[1] (this is in_features)
This leads to wrong divisibility checks and misleading logs.
The function refactor above fixes this by using module.in_features
and avoiding manual shape indexing.
🧰 Tools
🪛 Ruff (0.12.2)
498-498: Line too long (172 > 120)
(E501)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 496 to 502, the
code treats nn.Linear.weight shape as [in_features, out_features] but PyTorch
stores weights as [out_features, in_features]; swap usage accordingly and stop
relying on manual shape indexing: use module.in_features and module.out_features
(or module.weight.shape[1]/[0] if needed) for accurate values, update the debug
print to show the correct in/out features, and perform the divisibility check on
module.in_features (not weight.shape[0]); this ensures correct logging and
correct skip logic.
# Pre-quantize weights and scaling factors | ||
print(f"Pre-quantizing weights for layer with shape {original_linear.weight.shape}") | ||
self.weight_fp8, self.weight_scale = self._pre_quantize_weight(original_linear.weight) | ||
|
||
# Move quantized weights and scaling factors to CPU to save GPU memory | ||
self.weight_fp8 = self.weight_fp8.cpu() | ||
self.weight_scale = self.weight_scale.cpu() | ||
|
||
print(f"Pre-quantization completed. Weight FP8 shape: {self.weight_fp8.shape}, Scale shape: {self.weight_scale.shape}") | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Do not move quantized weights to CPU; this induces per-forward HtoD copies. Keep them as (non-persistent) buffers on the target device.
Current code moves FP8 weights/scales to CPU, then copies to GPU every forward. This will dominate latency and negate any GEMM speedup.
Apply this diff:
- self.weight_fp8, self.weight_scale = self._pre_quantize_weight(original_linear.weight)
-
- # Move quantized weights and scaling factors to CPU to save GPU memory
- self.weight_fp8 = self.weight_fp8.cpu()
- self.weight_scale = self.weight_scale.cpu()
+ q_weight, q_scale = self._pre_quantize_weight(original_linear.weight)
+ self._fp8_enabled = (q_weight.dtype == torch.float8_e4m3fn)
+ # Keep on the same device as the original weight to avoid runtime transfers
+ self.register_buffer("weight_fp8", q_weight.to(original_linear.weight.device), persistent=False)
+ self.register_buffer("weight_scale", q_scale.to(original_linear.weight.device), persistent=False)
If memory is a concern, consider moving the original FP16/BF16 weight to CPU when FP8 is enabled and using the original only as a fallback path.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
# Pre-quantize weights and scaling factors | |
print(f"Pre-quantizing weights for layer with shape {original_linear.weight.shape}") | |
self.weight_fp8, self.weight_scale = self._pre_quantize_weight(original_linear.weight) | |
# Move quantized weights and scaling factors to CPU to save GPU memory | |
self.weight_fp8 = self.weight_fp8.cpu() | |
self.weight_scale = self.weight_scale.cpu() | |
print(f"Pre-quantization completed. Weight FP8 shape: {self.weight_fp8.shape}, Scale shape: {self.weight_scale.shape}") | |
# Pre-quantize weights and scaling factors | |
print(f"Pre-quantizing weights for layer with shape {original_linear.weight.shape}") | |
- self.weight_fp8, self.weight_scale = self._pre_quantize_weight(original_linear.weight) | |
- | |
- # Move quantized weights and scaling factors to CPU to save GPU memory | |
- self.weight_fp8 = self.weight_fp8.cpu() | |
q_weight, q_scale = self._pre_quantize_weight(original_linear.weight) | |
self._fp8_enabled = (q_weight.dtype == torch.float8_e4m3fn) | |
# Keep on the same device as the original weight to avoid runtime transfers | |
self.register_buffer("weight_fp8", q_weight.to(original_linear.weight.device), persistent=False) | |
self.register_buffer("weight_scale", q_scale.to(original_linear.weight.device), persistent=False) | |
print(f"Pre-quantization completed. Weight FP8 shape: {self.weight_fp8.shape}, Scale shape: {self.weight_scale.shape}") |
🧰 Tools
🪛 Ruff (0.12.2)
553-553: Line too long (135 > 120)
(E501)
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 545 to 554, do
not move the quantized FP8 weight and scale tensors to CPU (the .cpu() calls)
because that forces HtoD copies each forward; instead register them as
non-persistent buffers on the target device (e.g., device =
original_linear.weight.device) so they remain on GPU: remove the .cpu() calls,
call self.register_buffer("weight_fp8", self.weight_fp8.to(device),
persistent=False) and self.register_buffer("weight_scale",
self.weight_scale.to(device), persistent=False) (or set the tensors directly on
device if already assigned), and if memory is a concern optionally move the
original FP16/BF16 weight to CPU as fallback while keeping the FP8 tensors
resident on the target device.
def forward(self, input: torch.Tensor) -> torch.Tensor: | ||
"""Forward method using pre-quantized weights""" | ||
# Get parameters from original linear layer | ||
bias = getattr(self.original_linear, 'bias', None) | ||
|
||
# Check if input dimensions meet requirements | ||
input_features = input.shape[-1] | ||
if input_features % 16 != 0: | ||
print(f"Using original linear layer: input_features ({input_features}) not divisible by 16") | ||
return self.original_linear(input) | ||
|
||
# Save original shape and data type | ||
origin_shape = input.shape | ||
origin_dtype = input.dtype | ||
input = input.to(torch.bfloat16) | ||
|
||
if input.dim() > 2: | ||
input = input.reshape(-1, input.shape[-1]) | ||
|
||
# Execute input FP8 quantization | ||
act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(input) | ||
|
||
# Move pre-quantized weights and scaling factors to current device | ||
weight_fp8 = self.weight_fp8.to(input.device) | ||
weight_scale = self.weight_scale.to(input.device) | ||
|
||
# Execute FP8 GEMM | ||
output = torch.ops.trtllm.fp8_block_scaling_gemm(act_input_fp8, weight_fp8, input_scale, weight_scale) | ||
output = output.to(origin_dtype) | ||
|
||
if bias is not None: | ||
output = output + bias | ||
|
||
# Handle output shape | ||
if output.dim() == 2: | ||
if len(origin_shape) == 3: | ||
batch_size, seq_len, hidden_size = origin_shape | ||
output = output.reshape(batch_size, seq_len, hidden_size) | ||
elif len(origin_shape) == 2: | ||
pass # No reshape needed | ||
else: | ||
return self.original_linear(input) | ||
|
||
return output | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forward path issues: missing fallback when weight isn’t FP8, CPU→GPU copies, and incorrect reshape dimension.
- Fallback to the original linear if pre-quantization failed (weight dtype != FP8) or input features aren’t divisible by 16.
- Avoid copying weights from CPU every call (fixed above).
- Reshape must use out_features, not input hidden size.
Apply this diff:
- def forward(self, input: torch.Tensor) -> torch.Tensor:
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
"""Forward method using pre-quantized weights"""
- # Get parameters from original linear layer
- bias = getattr(self.original_linear, 'bias', None)
-
- # Check if input dimensions meet requirements
- input_features = input.shape[-1]
- if input_features % 16 != 0:
- print(f"Using original linear layer: input_features ({input_features}) not divisible by 16")
- return self.original_linear(input)
+ # Fallback if FP8 not enabled or input dim incompatible
+ input_features = input.shape[-1]
+ if not getattr(self, "_fp8_enabled", False) or (input_features % 16 != 0):
+ return self.original_linear(input)
+ bias = getattr(self.original_linear, "bias", None)
# Save original shape and data type
- origin_shape = input.shape
- origin_dtype = input.dtype
- input = input.to(torch.bfloat16)
+ origin_shape = input.shape
+ origin_dtype = input.dtype
+ x = input.to(torch.bfloat16)
- if input.dim() > 2:
- input = input.reshape(-1, input.shape[-1])
+ if x.dim() > 2:
+ x = x.reshape(-1, x.shape[-1])
# Execute input FP8 quantization
- act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(input)
-
- # Move pre-quantized weights and scaling factors to current device
- weight_fp8 = self.weight_fp8.to(input.device)
- weight_scale = self.weight_scale.to(input.device)
+ act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(x)
+ weight_fp8 = self.weight_fp8
+ weight_scale = self.weight_scale
+ if weight_fp8.device != x.device:
+ weight_fp8 = weight_fp8.to(x.device)
+ weight_scale = weight_scale.to(x.device)
# Execute FP8 GEMM
output = torch.ops.trtllm.fp8_block_scaling_gemm(act_input_fp8, weight_fp8, input_scale, weight_scale)
output = output.to(origin_dtype)
if bias is not None:
output = output + bias
# Handle output shape
if output.dim() == 2:
if len(origin_shape) == 3:
- batch_size, seq_len, hidden_size = origin_shape
- output = output.reshape(batch_size, seq_len, hidden_size)
+ batch_size, seq_len, _ = origin_shape
+ out_features = self.original_linear.out_features
+ output = output.reshape(batch_size, seq_len, out_features)
elif len(origin_shape) == 2:
pass # No reshape needed
else:
- return self.original_linear(input)
+ return self.original_linear(input)
return output
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
"""Forward method using pre-quantized weights""" | |
# Get parameters from original linear layer | |
bias = getattr(self.original_linear, 'bias', None) | |
# Check if input dimensions meet requirements | |
input_features = input.shape[-1] | |
if input_features % 16 != 0: | |
print(f"Using original linear layer: input_features ({input_features}) not divisible by 16") | |
return self.original_linear(input) | |
# Save original shape and data type | |
origin_shape = input.shape | |
origin_dtype = input.dtype | |
input = input.to(torch.bfloat16) | |
if input.dim() > 2: | |
input = input.reshape(-1, input.shape[-1]) | |
# Execute input FP8 quantization | |
act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(input) | |
# Move pre-quantized weights and scaling factors to current device | |
weight_fp8 = self.weight_fp8.to(input.device) | |
weight_scale = self.weight_scale.to(input.device) | |
# Execute FP8 GEMM | |
output = torch.ops.trtllm.fp8_block_scaling_gemm(act_input_fp8, weight_fp8, input_scale, weight_scale) | |
output = output.to(origin_dtype) | |
if bias is not None: | |
output = output + bias | |
# Handle output shape | |
if output.dim() == 2: | |
if len(origin_shape) == 3: | |
batch_size, seq_len, hidden_size = origin_shape | |
output = output.reshape(batch_size, seq_len, hidden_size) | |
elif len(origin_shape) == 2: | |
pass # No reshape needed | |
else: | |
return self.original_linear(input) | |
return output | |
def forward(self, input: torch.Tensor) -> torch.Tensor: | |
"""Forward method using pre-quantized weights""" | |
# Fallback if FP8 not enabled or input dim incompatible | |
input_features = input.shape[-1] | |
if not getattr(self, "_fp8_enabled", False) or (input_features % 16 != 0): | |
return self.original_linear(input) | |
bias = getattr(self.original_linear, "bias", None) | |
# Save original shape and data type | |
origin_shape = input.shape | |
origin_dtype = input.dtype | |
x = input.to(torch.bfloat16) | |
if x.dim() > 2: | |
x = x.reshape(-1, x.shape[-1]) | |
# Execute input FP8 quantization | |
act_input_fp8, input_scale = torch.ops.trtllm.fp8_quantize_1x128(x) | |
weight_fp8 = self.weight_fp8 | |
weight_scale = self.weight_scale | |
if weight_fp8.device != x.device: | |
weight_fp8 = weight_fp8.to(x.device) | |
weight_scale = weight_scale.to(x.device) | |
# Execute FP8 GEMM | |
output = torch.ops.trtllm.fp8_block_scaling_gemm( | |
act_input_fp8, weight_fp8, input_scale, weight_scale | |
) | |
output = output.to(origin_dtype) | |
if bias is not None: | |
output = output + bias | |
# Handle output shape | |
if output.dim() == 2: | |
if len(origin_shape) == 3: | |
batch_size, seq_len, _ = origin_shape | |
out_features = self.original_linear.out_features | |
output = output.reshape(batch_size, seq_len, out_features) | |
elif len(origin_shape) == 2: | |
pass # No reshape needed | |
else: | |
return self.original_linear(input) | |
return output |
🤖 Prompt for AI Agents
tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 582-626: the forward
needs a robust fallback when pre-quantized weights are not available, must avoid
unnecessary CPU→GPU copies each call, and must reshape using the layer's
out_features (not input hidden size). Add a fast check that the module actually
has valid FP8 weights (e.g., self.weight_fp8 and self.weight_scale are present
and have the expected FP8 dtype/device/state) and if not, call and return
self.original_linear on the original input (restore original dtype/device/shape
before calling); only move weight_fp8/weight_scale to input.device when they
exist and are not already on that device (minimize copies) and preferably
register them as buffers when created so they live on the module device; after
GEMM, when restoring output shape, use self.original_linear.out_features (or the
linear's out_features attribute) for the feature dimension instead of input
hidden size; ensure any early returns use the original input/shape/dtype rather
than the transformed tensors.
def _create_blockwise_quantized_weight( | ||
self, | ||
param_value: torch.Tensor, | ||
block_size: int = 128, | ||
): | ||
""" | ||
Create block-wise quantized weights | ||
Reference: transformers fp8 128*128 block quantization | ||
Supports padding non-128-multiple matrices to 128 multiples | ||
""" | ||
param_value = param_value.to(torch.float32) | ||
|
||
# Get FP8 min/max values | ||
fp8_min = torch.finfo(torch.float8_e4m3fn).min | ||
fp8_max = torch.finfo(torch.float8_e4m3fn).max | ||
|
||
rows, cols = param_value.shape[-2:] | ||
original_shape = param_value.shape | ||
|
||
# Check if N dimension is divisible by 16 (TensorRT-LLM FP8 GEMM requirement) | ||
# For matrix multiplication input @ weight.T, N dimension is cols (in_features) | ||
if cols % 16 != 0: | ||
print(f"Warning: Matrix N dimension ({cols}) not divisible by 16, skipping FP8 quantization") | ||
return param_value, torch.ones(1, device=param_value.device, dtype=torch.float32) | ||
|
||
# Calculate padding needed for rows and columns | ||
# Round up to block_size multiples | ||
target_rows = ((rows + block_size - 1) // block_size) * block_size | ||
target_cols = ((cols + block_size - 1) // block_size) * block_size | ||
pad_rows = target_rows - rows | ||
pad_cols = target_cols - cols | ||
|
||
# Perform padding if needed | ||
if pad_rows > 0 or pad_cols > 0: | ||
print(f"Padding matrix from ({rows}, {cols}) to ({rows + pad_rows}, {cols + pad_cols})") | ||
|
||
# Create padded weight matrix | ||
padded_weight = torch.zeros( | ||
rows + pad_rows, cols + pad_cols, | ||
device=param_value.device, dtype=param_value.dtype | ||
) | ||
|
||
# Copy original weights to top-left corner of padded matrix | ||
padded_weight[:rows, :cols] = param_value | ||
|
||
# Use padded weights for quantization | ||
param_value = padded_weight | ||
rows, cols = rows + pad_rows, cols + pad_cols | ||
|
||
# Now matrix dimensions are multiples of 128, can perform block-wise quantization | ||
block_size_m, block_size_n = block_size, block_size | ||
param_value_orig_shape = param_value.shape | ||
param_value = param_value.reshape( | ||
-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n | ||
).permute(0, 1, 3, 2, 4) | ||
|
||
# Calculate scaling factor for each block | ||
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2)) | ||
scale = fp8_max / max_abs | ||
scale_orig_shape = scale.shape | ||
scale = scale.unsqueeze(-1).unsqueeze(-1) | ||
|
||
@torch.compiler.disable() | ||
def _quantize(param_value, scale, fp8_min, fp8_max): | ||
# Quantize the weights | ||
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) | ||
|
||
quantized_param = quantized_param.permute(0, 1, 3, 2, 4) | ||
# Reshape back to matrix shape | ||
quantized_param = quantized_param.reshape(param_value_orig_shape) | ||
|
||
# Reshape scale to match the number of blocks | ||
scale = scale.reshape(scale_orig_shape).squeeze().reciprocal() | ||
|
||
return quantized_param, scale | ||
|
||
quantized_param, scale = _quantize(param_value, scale, fp8_min, fp8_max) | ||
|
||
# If original matrix was padded, crop back to original size | ||
if pad_rows > 0 or pad_cols > 0: | ||
quantized_param = quantized_param[:original_shape[0], :original_shape[1]] | ||
|
||
return quantized_param, scale | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
🧩 Analysis chain
Blockwise quantization: handle zeros, crop scales consistently with weight cropping, and avoid shape-destroying squeeze.
- When a row/col block has amax=0,
scale = fp8_max / 0
→ inf. Add epsilon to avoid inf/nan. - You crop
quantized_param
back to the original size but notscale
, which can mismatch kernel expectations. Crop the block-scale grid accordingly. squeeze()
may drop needed dims when a dimension equals 1; avoid it.
Apply this diff:
+ import math
@@
- max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
- scale = fp8_max / max_abs
+ max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
+ # Avoid division by zero for all-zero blocks
+ eps = torch.tensor(1e-6, device=max_abs.device, dtype=max_abs.dtype)
+ scale = fp8_max / torch.maximum(max_abs, eps)
scale_orig_shape = scale.shape
scale = scale.unsqueeze(-1).unsqueeze(-1)
@@
- # Reshape scale to match the number of blocks
- scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
+ # Reshape scale to match the number of blocks and convert to dequant scale (amax/fp8_max)
+ scale = scale.reshape(scale_orig_shape).reciprocal()
@@
- # If original matrix was padded, crop back to original size
+ # If original matrix was padded, crop both quantized weight and block-scales
if pad_rows > 0 or pad_cols > 0:
quantized_param = quantized_param[:original_shape[0], :original_shape[1]]
+ # Crop block scales to ceil(original/block_size)
+ orig_block_rows = (original_shape[0] + block_size_m - 1) // block_size_m
+ orig_block_cols = (original_shape[1] + block_size_n - 1) // block_size_n
+ if scale.dim() == 3:
+ scale = scale[:, :orig_block_rows, :orig_block_cols]
+ elif scale.dim() == 2:
+ scale = scale[:orig_block_rows, :orig_block_cols]
Please verify the shape that torch.ops.trtllm.fp8_block_scaling_gemm
expects for weight_scale
(per-block grid vs. flattened), and adjust accordingly.
🏁 Script executed:
#!/bin/bash
# Search for fp8_block_scaling_gemm usage or definitions
rg -n "fp8_block_scaling_gemm" -C3 . || true
# Search for torch.ops.trtllm references (to find where the operator is defined or wrapped)
rg -n "torch\.ops\.trtllm" -C3 . || true
# Search for any documentation or tests calling fp8_block_scaling_gemm to infer expected weight_scale shape
rg -n "fp8_block_scaling" -C3 . || true
Length of output: 262474
Mandatory Refactor: Strengthen Blockwise FP8 Quantization Robustness
To ensure correct, crash-free blockwise FP8 weight quantization, please update _create_blockwise_quantized_weight
as follows:
- Prevent division by zero when a block is all zeros
Add a small epsilon before computingscale = fp8_max / max_abs
to avoidinf
/nan
. - Crop the block‐scale grid in sync with the cropped weight matrix
When you slicequantized_param
back tooriginal_shape
, also trimscale
to the corresponding number of blocks. - Preserve explicit block‐scale tensor dimensions
Replace the broad.squeeze()
call with a direct reshape onscale_orig_shape
so you never accidentally drop a needed dimension.
Suggested diff (within _create_blockwise_quantized_weight
):
+ import math
@@
- max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
- scale = fp8_max / max_abs
+ max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
+ # Avoid division by zero for all-zero blocks
+ eps = torch.tensor(1e-6, device=max_abs.device, dtype=max_abs.dtype)
+ scale = fp8_max / torch.maximum(max_abs, eps)
scale_orig_shape = scale.shape
scale = scale.unsqueeze(-1).unsqueeze(-1)
@@
- # Reshape scale to match the number of blocks
- scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
+ # Reshape scale to match blocks and compute dequant factors
+ scale = scale.reshape(scale_orig_shape).reciprocal()
@@
- # If original matrix was padded, crop back to original size
+ # If we padded the matrix, crop both weight and block scales
if pad_rows > 0 or pad_cols > 0:
quantized_param = quantized_param[:original_shape[0], :original_shape[1]]
+ # Determine original block counts
+ orig_block_rows = math.ceil(original_shape[0] / block_size_m)
+ orig_block_cols = math.ceil(original_shape[1] / block_size_n)
+ # Trim scale to [orig_block_rows, orig_block_cols] (or include batch dim)
+ if scale.dim() == 3:
+ scale = scale[:, :orig_block_rows, :orig_block_cols]
+ else:
+ scale = scale[:orig_block_rows, :orig_block_cols]
Please verify whether torch.ops.trtllm.fp8_block_scaling_gemm
expects the weight_scale
tensor as a 2-D block grid ([row_blocks, col_blocks]
) or with an extra leading batch dimension. Adjust the final squeeze/reshape accordingly.
📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
def _create_blockwise_quantized_weight( | |
self, | |
param_value: torch.Tensor, | |
block_size: int = 128, | |
): | |
""" | |
Create block-wise quantized weights | |
Reference: transformers fp8 128*128 block quantization | |
Supports padding non-128-multiple matrices to 128 multiples | |
""" | |
param_value = param_value.to(torch.float32) | |
# Get FP8 min/max values | |
fp8_min = torch.finfo(torch.float8_e4m3fn).min | |
fp8_max = torch.finfo(torch.float8_e4m3fn).max | |
rows, cols = param_value.shape[-2:] | |
original_shape = param_value.shape | |
# Check if N dimension is divisible by 16 (TensorRT-LLM FP8 GEMM requirement) | |
# For matrix multiplication input @ weight.T, N dimension is cols (in_features) | |
if cols % 16 != 0: | |
print(f"Warning: Matrix N dimension ({cols}) not divisible by 16, skipping FP8 quantization") | |
return param_value, torch.ones(1, device=param_value.device, dtype=torch.float32) | |
# Calculate padding needed for rows and columns | |
# Round up to block_size multiples | |
target_rows = ((rows + block_size - 1) // block_size) * block_size | |
target_cols = ((cols + block_size - 1) // block_size) * block_size | |
pad_rows = target_rows - rows | |
pad_cols = target_cols - cols | |
# Perform padding if needed | |
if pad_rows > 0 or pad_cols > 0: | |
print(f"Padding matrix from ({rows}, {cols}) to ({rows + pad_rows}, {cols + pad_cols})") | |
# Create padded weight matrix | |
padded_weight = torch.zeros( | |
rows + pad_rows, cols + pad_cols, | |
device=param_value.device, dtype=param_value.dtype | |
) | |
# Copy original weights to top-left corner of padded matrix | |
padded_weight[:rows, :cols] = param_value | |
# Use padded weights for quantization | |
param_value = padded_weight | |
rows, cols = rows + pad_rows, cols + pad_cols | |
# Now matrix dimensions are multiples of 128, can perform block-wise quantization | |
block_size_m, block_size_n = block_size, block_size | |
param_value_orig_shape = param_value.shape | |
param_value = param_value.reshape( | |
-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n | |
).permute(0, 1, 3, 2, 4) | |
# Calculate scaling factor for each block | |
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2)) | |
scale = fp8_max / max_abs | |
scale_orig_shape = scale.shape | |
scale = scale.unsqueeze(-1).unsqueeze(-1) | |
@torch.compiler.disable() | |
def _quantize(param_value, scale, fp8_min, fp8_max): | |
# Quantize the weights | |
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) | |
quantized_param = quantized_param.permute(0, 1, 3, 2, 4) | |
# Reshape back to matrix shape | |
quantized_param = quantized_param.reshape(param_value_orig_shape) | |
# Reshape scale to match the number of blocks | |
scale = scale.reshape(scale_orig_shape).squeeze().reciprocal() | |
return quantized_param, scale | |
quantized_param, scale = _quantize(param_value, scale, fp8_min, fp8_max) | |
# If original matrix was padded, crop back to original size | |
if pad_rows > 0 or pad_cols > 0: | |
quantized_param = quantized_param[:original_shape[0], :original_shape[1]] | |
return quantized_param, scale | |
import math | |
# Calculate scaling factor for each block | |
- max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2)) | |
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2)) | |
# Avoid division by zero for all-zero blocks | |
eps = torch.tensor(1e-6, device=max_abs.device, dtype=max_abs.dtype) | |
scale = fp8_max / torch.maximum(max_abs, eps) | |
scale_orig_shape = scale.shape | |
scale = scale.unsqueeze(-1).unsqueeze(-1) | |
@torch.compiler.disable() | |
def _quantize(param_value, scale, fp8_min, fp8_max): | |
# Quantize the weights | |
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) | |
quantized_param = quantized_param.permute(0, 1, 3, 2, 4) | |
# Reshape back to matrix shape | |
quantized_param = quantized_param.reshape(param_value_orig_shape) | |
- # Reshape scale to match the number of blocks | |
# Reshape scale to match blocks and compute dequant factors | |
scale = scale.reshape(scale_orig_shape).reciprocal() | |
return quantized_param, scale | |
quantized_param, scale = _quantize(param_value, scale, fp8_min, fp8_max) | |
# If we padded the matrix, crop both weight and block scales | |
if pad_rows > 0 or pad_cols > 0: | |
quantized_param = quantized_param[:original_shape[0], :original_shape[1]] | |
# Determine original block counts | |
orig_block_rows = math.ceil(original_shape[0] / block_size_m) | |
orig_block_cols = math.ceil(original_shape[1] / block_size_n) | |
# Trim scale to [orig_block_rows, orig_block_cols] (or include batch dim) | |
if scale.dim() == 3: | |
scale = scale[:, :orig_block_rows, :orig_block_cols] | |
else: | |
scale = scale[:orig_block_rows, :orig_block_cols] |
🤖 Prompt for AI Agents
In tensorrt_llm/_torch/models/modeling_qwen2vl.py around lines 627-710, the
block-wise FP8 quantization can produce inf/nan when a block is all zeros and
incorrectly trims/reshapes the block-scale tensor; fix by adding a small epsilon
to max_abs before computing scale (e.g. max_abs = torch.clamp(max_abs,
min=eps)), ensure you crop the scale grid in sync with quantized_param when you
slice back to original_shape (compute the number of row/col blocks corresponding
to original_shape and slice scale accordingly), and replace the broad .squeeze()
with an explicit reshape using scale_orig_shape (or scale_orig_shape without the
last two singleton dims) before taking reciprocal so you preserve the exact
block-grid dimensions expected by torch.ops.trtllm.fp8_block_scaling_gemm
(verify whether it needs [row_blocks, col_blocks] or [batch, row_blocks,
col_blocks] and shape the final scale tensor accordingly).
@tiffany940107 is this PR requiring reviews? |
Summary by CodeRabbit
Description
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]
Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id
(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test
(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"
(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log
(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug
(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-list
parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
and the
scripts/test_to_stage_mapping.py
helper.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.