Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,9 @@ def from_pretrained(cls,
json_quant_configs = quant_config_dict['quantization']

quant_config.quant_algo = json_quant_configs.get('quant_algo', None)
# fp8_pb_wo from modelopt is the same as FP8_BLOCK_SCALES
if quant_config.quant_algo == "fp8_pb_wo":
quant_config.quant_algo = 'FP8_BLOCK_SCALES'
quant_config.kv_cache_quant_algo = json_quant_configs.get(
'kv_cache_quant_algo', None)
quant_config.group_size = json_quant_configs.get('group_size', None)
Expand Down
8 changes: 5 additions & 3 deletions tensorrt_llm/_torch/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,8 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:

scale_name = self._get_scale_name(weights)
weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
module.tp_rank,
module.tp_mode).squeeze()
copy_weight(module.weight_scale, weight_scale)
if "input_scale" in weights[0]:
copy_weight(module.input_scale, weights[0]["input_scale"])
Expand All @@ -582,7 +583,8 @@ def load_weights_fused_qkv_linear(self, module: Linear,
module.tp_rank, module.tp_mode)
v_scale = load_weight_shard(weights[2][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale))
fused_fp8_block_scale = torch.cat((q_scale, k_scale, v_scale)).squeeze()

copy_weight(module.weight_scale, fused_fp8_block_scale)

def load_weights_fused_gate_up_linear(self, module: Linear,
Expand All @@ -597,7 +599,7 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
module.tp_rank, module.tp_mode)
right_scale = load_weight_shard(weights[1][scale_name], module.tp_size,
module.tp_rank, module.tp_mode)
fused_scale = torch.cat([left_scale, right_scale], dim=0)
fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze()
copy_weight(module.weight_scale, fused_scale)


Expand Down
6 changes: 5 additions & 1 deletion tensorrt_llm/llmapi/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,11 @@ def _update_from_hf_quant_config(self) -> bool:

hf_quant_algo = hf_quant_config.pop("quant_algo", None)
if hf_quant_algo is not None:
hf_quant_algo = QuantAlgo(hf_quant_algo)
# fp8_pb_wo from modelopt is the same as fp8_block_scales
if hf_quant_algo == "fp8_pb_wo":
hf_quant_algo = QuantAlgo.FP8_BLOCK_SCALES
else:
hf_quant_algo = QuantAlgo(hf_quant_algo)
if quant_config.quant_algo is None:
logger.info(
f"Setting quant_algo={hf_quant_algo} form HF quant config."
Expand Down