From 327d42de7e8bfe899476c99bff5b1a6fd0b604f3 Mon Sep 17 00:00:00 2001 From: Aurelien Chartier <2567591+achartier@users.noreply.github.com> Date: Wed, 16 Jul 2025 22:29:15 +0000 Subject: [PATCH 1/2] feat: add support for Modelopt fp8_pb_wo quantization scheme Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com> --- tensorrt_llm/_torch/model_config.py | 2 ++ tensorrt_llm/_torch/modules/linear.py | 8 +++++--- tensorrt_llm/llmapi/llm_utils.py | 6 +++++- 3 files changed, 12 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 671564baadc..89e0c46e2fe 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -202,6 +202,8 @@ def from_pretrained(cls, json_quant_configs = quant_config_dict['quantization'] quant_config.quant_algo = json_quant_configs.get('quant_algo', None) + if quant_config.quant_algo == "fp8_pb_wo": + quant_config.quant_algo = 'FP8_BLOCK_SCALES' quant_config.kv_cache_quant_algo = json_quant_configs.get( 'kv_cache_quant_algo', None) quant_config.group_size = json_quant_configs.get('group_size', None) diff --git a/tensorrt_llm/_torch/modules/linear.py b/tensorrt_llm/_torch/modules/linear.py index ca9cb6501d0..134f1c8ebf8 100644 --- a/tensorrt_llm/_torch/modules/linear.py +++ b/tensorrt_llm/_torch/modules/linear.py @@ -562,7 +562,8 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None: scale_name = self._get_scale_name(weights) weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size, - module.tp_rank, module.tp_mode) + module.tp_rank, + module.tp_mode).squeeze() copy_weight(module.weight_scale, weight_scale) if "input_scale" in weights[0]: copy_weight(module.input_scale, weights[0]["input_scale"]) @@ -582,7 +583,8 @@ def load_weights_fused_qkv_linear(self, module: Linear, module.tp_rank, module.tp_mode) v_scale = load_weight_shard(weights[2][scale_name], module.tp_size, module.tp_rank, module.tp_mode) - fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)) + fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze() + copy_weight(module.weight_scale, fused_fp8_block_scale) def load_weights_fused_gate_up_linear(self, module: Linear, @@ -597,7 +599,7 @@ def load_weights_fused_gate_up_linear(self, module: Linear, module.tp_rank, module.tp_mode) right_scale = load_weight_shard(weights[1][scale_name], module.tp_size, module.tp_rank, module.tp_mode) - fused_scale = torch.cat([left_scale, right_scale], dim=0) + fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze() copy_weight(module.weight_scale, fused_scale) diff --git a/tensorrt_llm/llmapi/llm_utils.py b/tensorrt_llm/llmapi/llm_utils.py index 31f853f3705..a62568a54e8 100644 --- a/tensorrt_llm/llmapi/llm_utils.py +++ b/tensorrt_llm/llmapi/llm_utils.py @@ -362,7 +362,11 @@ def _update_from_hf_quant_config(self) -> bool: hf_quant_algo = hf_quant_config.pop("quant_algo", None) if hf_quant_algo is not None: - hf_quant_algo = QuantAlgo(hf_quant_algo) + # fp8_pb_wo from modelopt is the same as fp8_block_scales + if hf_quant_algo == "fp8_pb_wo": + hf_quant_algo = QuantAlgo.FP8_BLOCK_SCALES + else: + hf_quant_algo = QuantAlgo(hf_quant_algo) if quant_config.quant_algo is None: logger.info( f"Setting quant_algo={hf_quant_algo} form HF quant config." From 58c5594f851ec7cbcd2ae1b019934447e3891231 Mon Sep 17 00:00:00 2001 From: Aurelien Chartier <2567591+achartier@users.noreply.github.com> Date: Fri, 18 Jul 2025 02:10:48 +0000 Subject: [PATCH 2/2] address review Signed-off-by: Aurelien Chartier <2567591+achartier@users.noreply.github.com> --- tensorrt_llm/_torch/model_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorrt_llm/_torch/model_config.py b/tensorrt_llm/_torch/model_config.py index 89e0c46e2fe..3de3edd3a9b 100644 --- a/tensorrt_llm/_torch/model_config.py +++ b/tensorrt_llm/_torch/model_config.py @@ -202,6 +202,7 @@ def from_pretrained(cls, json_quant_configs = quant_config_dict['quantization'] quant_config.quant_algo = json_quant_configs.get('quant_algo', None) + # fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES if quant_config.quant_algo == "fp8_pb_wo": quant_config.quant_algo = 'FP8_BLOCK_SCALES' quant_config.kv_cache_quant_algo = json_quant_configs.get(