Skip to content

Commit d184326

Browse files
committed
Fix for CuteDSL backend
Signed-off-by: Barry Kang <[email protected]>
1 parent 17a52d9 commit d184326

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tensorrt_llm/_torch/modules/linear.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def create_weights(self, module: Linear, in_features: int,
559559
dtype=torch.float8_e4m3fn),
560560
requires_grad=False)
561561

562-
if get_sm_version() == 100:
562+
if get_sm_version() == 100 and not module.use_cute_dsl_blockscaling_mm:
563563
scale_shape = (math.ceil(in_features / 512),
564564
math.ceil(out_features))
565565
module.weight_scale = Parameter(torch.empty(scale_shape,
@@ -595,6 +595,7 @@ def apply(self, module: Linear, input: torch.Tensor,
595595
# TODO (@lmin): replace with cute_dsl gemm
596596
act_input_fp8, act_input_sf = torch.ops.trtllm.fp8_quantize_1x128(
597597
input)
598+
print(module.weight_scale.dtype)
598599
output = torch.ops.trtllm.fp8_block_scaling_gemm(
599600
act_input_fp8, module.weight, act_input_sf,
600601
module.weight_scale)
@@ -649,7 +650,7 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
649650
weight_scale = load_weight_shard(weights[0][scale_name], module.tp_size,
650651
module.tp_rank,
651652
module.tp_mode).squeeze()
652-
if get_sm_version() == 100:
653+
if get_sm_version() == 100 and not module.use_cute_dsl_blockscaling_mm:
653654
weight_scale = fp8_utils.transform_sf_into_required_layout(
654655
weight_scale,
655656
mn=module.weight.shape[0],
@@ -692,7 +693,7 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
692693
module.tp_rank, module.tp_mode)
693694
fused_scale = torch.cat([left_scale, right_scale], dim=0).squeeze()
694695
copy_weight(module.weight, fused_weight)
695-
if get_sm_version() == 100:
696+
if get_sm_version() == 100 and not module.use_cute_dsl_blockscaling_mm:
696697
fused_scale = fp8_utils.transform_sf_into_required_layout(
697698
fused_scale,
698699
mn=fused_weight.shape[0],

0 commit comments

Comments
 (0)