@@ -562,7 +562,8 @@ def load_weights_vanilla(self, module: Linear, weights: List[Dict]) -> None:
562
562
563
563
scale_name = self ._get_scale_name (weights )
564
564
weight_scale = load_weight_shard (weights [0 ][scale_name ], module .tp_size ,
565
- module .tp_rank , module .tp_mode )
565
+ module .tp_rank ,
566
+ module .tp_mode ).squeeze ()
566
567
copy_weight (module .weight_scale , weight_scale )
567
568
if "input_scale" in weights [0 ]:
568
569
copy_weight (module .input_scale , weights [0 ]["input_scale" ])
@@ -582,7 +583,8 @@ def load_weights_fused_qkv_linear(self, module: Linear,
582
583
module .tp_rank , module .tp_mode )
583
584
v_scale = load_weight_shard (weights [2 ][scale_name ], module .tp_size ,
584
585
module .tp_rank , module .tp_mode )
585
- fused_fp8_block_scale = torch .cat ((q_scale , k_scale , v_scale ))
586
+ fused_fp8_block_scale = torch .cat ((q_scale , k_scale , v_scale )).squeeze ()
587
+
586
588
copy_weight (module .weight_scale , fused_fp8_block_scale )
587
589
588
590
def load_weights_fused_gate_up_linear (self , module : Linear ,
@@ -597,7 +599,7 @@ def load_weights_fused_gate_up_linear(self, module: Linear,
597
599
module .tp_rank , module .tp_mode )
598
600
right_scale = load_weight_shard (weights [1 ][scale_name ], module .tp_size ,
599
601
module .tp_rank , module .tp_mode )
600
- fused_scale = torch .cat ([left_scale , right_scale ], dim = 0 )
602
+ fused_scale = torch .cat ([left_scale , right_scale ], dim = 0 ). squeeze ()
601
603
copy_weight (module .weight_scale , fused_scale )
602
604
603
605
0 commit comments