@@ -559,7 +559,7 @@ def create_weights(self, module: Linear, in_features: int,
559
559
dtype = torch .float8_e4m3fn ),
560
560
requires_grad = False )
561
561
562
- if get_sm_version () == 100 :
562
+ if get_sm_version () == 100 and not module . use_cute_dsl_blockscaling_mm :
563
563
scale_shape = (math .ceil (in_features / 512 ),
564
564
math .ceil (out_features ))
565
565
module .weight_scale = Parameter (torch .empty (scale_shape ,
@@ -595,6 +595,7 @@ def apply(self, module: Linear, input: torch.Tensor,
595
595
# TODO (@lmin): replace with cute_dsl gemm
596
596
act_input_fp8 , act_input_sf = torch .ops .trtllm .fp8_quantize_1x128 (
597
597
input )
598
+ print (module .weight_scale .dtype )
598
599
output = torch .ops .trtllm .fp8_block_scaling_gemm (
599
600
act_input_fp8 , module .weight , act_input_sf ,
600
601
module .weight_scale )
@@ -649,7 +650,7 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
649
650
weight_scale = load_weight_shard (weights [0 ][scale_name ], module .tp_size ,
650
651
module .tp_rank ,
651
652
module .tp_mode ).squeeze ()
652
- if get_sm_version () == 100 :
653
+ if get_sm_version () == 100 and not module . use_cute_dsl_blockscaling_mm :
653
654
weight_scale = fp8_utils .transform_sf_into_required_layout (
654
655
weight_scale ,
655
656
mn = module .weight .shape [0 ],
@@ -692,7 +693,7 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
692
693
module .tp_rank , module .tp_mode )
693
694
fused_scale = torch .cat ([left_scale , right_scale ], dim = 0 ).squeeze ()
694
695
copy_weight (module .weight , fused_weight )
695
- if get_sm_version () == 100 :
696
+ if get_sm_version () == 100 and not module . use_cute_dsl_blockscaling_mm :
696
697
fused_scale = fp8_utils .transform_sf_into_required_layout (
697
698
fused_scale ,
698
699
mn = fused_weight .shape [0 ],
0 commit comments