Skip to content

Commit 812243b

Browse files
feat: add support for Modelopt fp8_pb_wo quantization scheme (#6106)
Signed-off-by: Aurelien Chartier <[email protected]> Co-authored-by: Haohang Huang <[email protected]>
1 parent 992b273 commit 812243b

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

tensorrt_llm/_torch/model_config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,9 @@ def from_pretrained(cls,
202202
json_quant_configs = quant_config_dict['quantization']
203203

204204
quant_config.quant_algo = json_quant_configs.get('quant_algo', None)
205+
# fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES
206+
if quant_config.quant_algo == "fp8_pb_wo":
207+
quant_config.quant_algo = 'FP8_BLOCK_SCALES'
205208
quant_config.kv_cache_quant_algo = json_quant_configs.get(
206209
'kv_cache_quant_algo', None)
207210
quant_config.group_size = json_quant_configs.get('group_size', None)

tensorrt_llm/_torch/modules/linear.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,8 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
562562

563563
scale_name = self._get_scale_name(weights)
564564
weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
565-
module.tp_rank, module.tp_mode)
565+
module.tp_rank,
566+
module.tp_mode).squeeze()
566567
copy_weight(module.weight_scale, weight_scale)
567568
if "input_scale" in weights[0]:
568569
copy_weight(module.input_scale, weights[0]["input_scale"])
@@ -582,7 +583,8 @@ def load_weights_fused_qkv_linear(self, module: Linear,
582583
module.tp_rank, module.tp_mode)
583584
v_scale = load_weight_shard(weights[2][scale_name], module.tp_size,
584585
module.tp_rank, module.tp_mode)
585-
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale))
586+
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze()
587+
586588
copy_weight(module.weight_scale, fused_fp8_block_scale)
587589

588590
def load_weights_fused_gate_up_linear(self, module: Linear,
@@ -597,7 +599,7 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
597599
module.tp_rank, module.tp_mode)
598600
right_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
599601
module.tp_rank, module.tp_mode)
600-
fused_scale = torch.cat([left_scale, right_scale], dim=0)
602+
fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze()
601603
copy_weight(module.weight_scale, fused_scale)
602604

603605

tensorrt_llm/llmapi/llm_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,11 @@ def _update_from_hf_quant_config(self) -> bool:
362362

363363
hf_quant_algo = hf_quant_config.pop("quant_algo", None)
364364
if hf_quant_algo is not None:
365-
hf_quant_algo = QuantAlgo(hf_quant_algo)
365+
# fp8_pb_wo from modelopt is the same as fp8_block_scales
366+
if hf_quant_algo == "fp8_pb_wo":
367+
hf_quant_algo = QuantAlgo.FP8_BLOCK_SCALES
368+
else:
369+
hf_quant_algo = QuantAlgo(hf_quant_algo)
366370
if quant_config.quant_algo is None:
367371
logger.info(
368372
f"Setting quant_algo={hf_quant_algo} form HF quant config."

0 commit comments

Comments
 (0)