@@ -2085,6 +2085,29 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
2085
2085
non_blocking = True )
2086
2086
2087
2087
2088
+ def _get_weight_alignment (weight_alignment , scaling_vector_size , tp_size ,
2089
+ shard_dim_size ):
2090
+
2091
+ def lcm (a , b ):
2092
+ return abs (a * b ) // math .gcd (a , b )
2093
+
2094
+ # The alignment should be the least common multiple of weight_alignment, scaling_vector_size,
2095
+ # and tp_size. scaling_vector_size and tp_size must be considered
2096
+ # to avoid fractional scaling factors.
2097
+ alignment = lcm (weight_alignment , scaling_vector_size )
2098
+ alignment = lcm (alignment , tp_size )
2099
+
2100
+ # If after the alignment, the sharding dim per shard is not a multiple of weight_alignment,
2101
+ # we need to pad the weights to make it a multiple of weight_alignment.
2102
+ padded_weights_dim = math .ceil (shard_dim_size / alignment ) * alignment
2103
+ per_shard = padded_weights_dim // tp_size
2104
+ if per_shard % weight_alignment != 0 :
2105
+ alignment = weight_alignment * math .ceil (
2106
+ per_shard / weight_alignment ) * tp_size
2107
+
2108
+ return alignment
2109
+
2110
+
2088
2111
class MXFP4WeightFusedMoEMethod (FusedMoEMethodBase ):
2089
2112
2090
2113
def create_weights (self ,
@@ -2249,32 +2272,44 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
2249
2272
w1_weight : torch .Tensor ,
2250
2273
w3_weight : torch .Tensor ,
2251
2274
dst_w3_w1_weight : torch .Tensor ):
2275
+ # We pad before the sharding. This is done to avoid fractional scaling factors
2276
+ # per shard.
2277
+ #
2278
+ # E.g. if we pad after the sharding, with intermediate_size = 2880,
2279
+ # tp_size = 4, scaling_vector_size = 32, each shard gets 720 elements and
2280
+ # 22.5 scaling factors. After padding, each shard gets 768 in
2281
+ # intermediate_size, and 24 in scaling factors's intermediate_size.
2282
+ # The 2nd rank will start loading the 23rd scaling factor,
2283
+ # while it should've loaded 22nd for the first 16 elements only.
2284
+ # We pad the weights before the sharding to avoid this issue.
2285
+ alignment = _get_weight_alignment (self .weight_alignment ,
2286
+ module .scaling_vector_size ,
2287
+ module .tp_size , w1_weight .shape [0 ])
2288
+ if len (w1_weight .shape ) == 2 :
2289
+ # Pad weights
2290
+ # We already satisfy alignment factor of 2 for we pack two MXFP4 into Uint8.
2291
+ assert w1_weight .dtype == torch .uint8
2292
+ w1_weight = maybe_pad_for_mxfp4 (w1_weight ,
2293
+ self .weight_alignment // 2 ,
2294
+ alignment )
2295
+ assert w3_weight .dtype == torch .uint8
2296
+ w3_weight = maybe_pad_for_mxfp4 (w3_weight ,
2297
+ self .weight_alignment // 2 ,
2298
+ alignment )
2299
+ else :
2300
+ # Pad bias.
2301
+ assert len (w1_weight .shape ) == 1
2302
+ assert len (w3_weight .shape ) == 1
2303
+ w1_weight = maybe_pad_for_mxfp4 (w1_weight , alignment )
2304
+ w3_weight = maybe_pad_for_mxfp4 (w3_weight , alignment )
2305
+
2252
2306
w1_weight_shard = load_weight_shard (w1_weight , module .tp_size ,
2253
2307
module .tp_rank ,
2254
2308
TensorParallelMode .COLUMN )
2255
2309
w3_weight_shard = load_weight_shard (w3_weight , module .tp_size ,
2256
2310
module .tp_rank ,
2257
2311
TensorParallelMode .COLUMN )
2258
2312
2259
- if len (w1_weight_shard .shape ) == 2 :
2260
- # Pad weights
2261
- # We already satisfy alignment factor 2 for we pad 2 MXFP4 into Uint8.
2262
- assert w1_weight_shard .dtype == torch .uint8
2263
- w1_weight_shard = maybe_pad_for_mxfp4 (w1_weight_shard ,
2264
- self .weight_alignment // 2 ,
2265
- self .weight_alignment )
2266
- assert w3_weight_shard .dtype == torch .uint8
2267
- w3_weight_shard = maybe_pad_for_mxfp4 (w3_weight_shard ,
2268
- self .weight_alignment // 2 ,
2269
- self .weight_alignment )
2270
- else :
2271
- # Pad bias.
2272
- assert len (w1_weight_shard .shape ) == 1
2273
- w1_weight_shard = maybe_pad_for_mxfp4 (w1_weight_shard ,
2274
- self .weight_alignment )
2275
- w3_weight_shard = maybe_pad_for_mxfp4 (w3_weight_shard ,
2276
- self .weight_alignment )
2277
-
2278
2313
w31_weight_shard = torch .cat ([w3_weight_shard , w1_weight_shard ], dim = 0 )
2279
2314
dst_w3_w1_weight .copy_ (w31_weight_shard .view (dst_w3_w1_weight .dtype ),
2280
2315
non_blocking = True )
@@ -2287,23 +2322,25 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
2287
2322
Load w2 weight for each expert.
2288
2323
Override this method if you need to preprocess the weights differently.
2289
2324
"""
2325
+ shard_w2_weight_dim = 2 * w2_weight .shape [1 ] if len (
2326
+ w2_weight .shape ) == 2 else w2_weight .shape [0 ]
2327
+ alignment = _get_weight_alignment (self .weight_alignment ,
2328
+ module .scaling_vector_size ,
2329
+ module .tp_size , shard_w2_weight_dim )
2330
+
2331
+ if len (w2_weight .shape ) == 2 :
2332
+ assert w2_weight .dtype == torch .uint8
2333
+ w2_weight = maybe_pad_for_mxfp4 (w2_weight , alignment // 2 ,
2334
+ self .weight_alignment )
2335
+ else :
2336
+ # Pad bias.
2337
+ assert len (w2_weight .shape ) == 1
2338
+ w2_weight = maybe_pad_for_mxfp4 (w2_weight , self .weight_alignment )
2339
+
2290
2340
w2_weight_shard = load_weight_shard (w2_weight , module .tp_size ,
2291
2341
module .tp_rank ,
2292
2342
TensorParallelMode .ROW )
2293
2343
2294
- if len (w2_weight_shard .shape ) == 2 :
2295
- # Pad weights
2296
- # We already satisfy alignment factor 2 for we pad two MXFP4 into Uint8.
2297
- assert w2_weight_shard .dtype == torch .uint8
2298
- w2_weight_shard = maybe_pad_for_mxfp4 (w2_weight_shard ,
2299
- self .weight_alignment // 2 ,
2300
- self .weight_alignment )
2301
- else :
2302
- assert len (w2_weight_shard .shape ) == 1
2303
- # Pad bias.
2304
- w2_weight_shard = maybe_pad_for_mxfp4 (w2_weight_shard ,
2305
- self .weight_alignment )
2306
-
2307
2344
dst_w2_weight .copy_ (w2_weight_shard .view (dst_w2_weight .dtype ),
2308
2345
non_blocking = True )
2309
2346
@@ -2312,6 +2349,19 @@ def load_expert_w3_w1_weight_scale_mxfp4(
2312
2349
w3_weight_scale : torch .Tensor ,
2313
2350
dst_w3_w1_weight_scale : torch .Tensor ):
2314
2351
device = dst_w3_w1_weight_scale .device
2352
+
2353
+ alignment = _get_weight_alignment (self .weight_alignment ,
2354
+ module .scaling_vector_size ,
2355
+ module .tp_size ,
2356
+ w3_weight_scale .shape [0 ])
2357
+
2358
+ w1_weight_scale = maybe_pad_for_mxfp4 (
2359
+ w1_weight_scale ,
2360
+ self .weight_alignment // module .scaling_vector_size , alignment )
2361
+ w3_weight_scale = maybe_pad_for_mxfp4 (
2362
+ w3_weight_scale ,
2363
+ self .weight_alignment // module .scaling_vector_size , alignment )
2364
+
2315
2365
w1_weight_scale = load_weight_shard (w1_weight_scale ,
2316
2366
module .tp_size ,
2317
2367
module .tp_rank ,
@@ -2322,14 +2372,6 @@ def load_expert_w3_w1_weight_scale_mxfp4(
2322
2372
module .tp_rank ,
2323
2373
TensorParallelMode .COLUMN ,
2324
2374
device = device )
2325
- w1_weight_scale = maybe_pad_for_mxfp4 (
2326
- w1_weight_scale ,
2327
- self .weight_alignment // module .scaling_vector_size ,
2328
- self .weight_alignment )
2329
- w3_weight_scale = maybe_pad_for_mxfp4 (
2330
- w3_weight_scale ,
2331
- self .weight_alignment // module .scaling_vector_size ,
2332
- self .weight_alignment )
2333
2375
2334
2376
# Keep weights in device buffer
2335
2377
dst_w3_weight_scale , dst_w1_weight_scale = dst_w3_w1_weight_scale .chunk (
@@ -2350,15 +2392,21 @@ def load_expert_w2_weight_scale_mxfp4(self, module: torch.nn.Module,
2350
2392
w2_weight_scale : torch .Tensor ,
2351
2393
dst_w2_weight_scale : torch .Tensor ):
2352
2394
device = dst_w2_weight_scale .device
2395
+
2396
+ alignment = _get_weight_alignment (self .weight_alignment ,
2397
+ module .scaling_vector_size ,
2398
+ module .tp_size ,
2399
+ w2_weight_scale .shape [- 1 ])
2400
+
2401
+ w2_weight_scale = maybe_pad_for_mxfp4 (
2402
+ w2_weight_scale , alignment // module .scaling_vector_size ,
2403
+ self .weight_alignment )
2404
+
2353
2405
w2_weight_scale = load_weight_shard (w2_weight_scale ,
2354
2406
module .tp_size ,
2355
2407
module .tp_rank ,
2356
2408
TensorParallelMode .ROW ,
2357
2409
device = device )
2358
- w2_weight_scale = maybe_pad_for_mxfp4 (
2359
- w2_weight_scale ,
2360
- self .weight_alignment // module .scaling_vector_size ,
2361
- self .weight_alignment )
2362
2410
2363
2411
# Keep weights in device buffer
2364
2412
dst_w2_weight_scale .copy_ (
@@ -2556,6 +2604,38 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
2556
2604
dst_w3_w1_weight : torch .Tensor ):
2557
2605
device = dst_w3_w1_weight .device
2558
2606
assert device .type == "cuda"
2607
+
2608
+ # We pad before the sharding. This is done to avoid fractional scaling factors
2609
+ # per shard.
2610
+ #
2611
+ # E.g. if we pad after the sharding, with intermediate_size = 2880,
2612
+ # tp_size = 4, scaling_vector_size = 32, each shard gets 720 elements and
2613
+ # 22.5 scaling factors. After padding, each shard gets 768 in
2614
+ # intermediate_size, and 24 in scaling factors's intermediate_size.
2615
+ # The 2nd rank will start loading the 23rd scaling factor,
2616
+ # while it should've loaded 22nd for the first 16 elements only.
2617
+ # We pad the weights before the sharding to avoid this issue.
2618
+ alignment = _get_weight_alignment (self .weight_alignment ,
2619
+ module .scaling_vector_size ,
2620
+ module .tp_size , w1_weight .shape [0 ])
2621
+ if len (w1_weight .shape ) == 2 :
2622
+ # Pad weights
2623
+ # We already satisfy alignment factor of 2 for we pack two MXFP4 into Uint8.
2624
+ assert w1_weight .dtype == torch .uint8
2625
+ w1_weight = maybe_pad_for_mxfp4 (w1_weight ,
2626
+ self .weight_alignment // 2 ,
2627
+ alignment )
2628
+ assert w3_weight .dtype == torch .uint8
2629
+ w3_weight = maybe_pad_for_mxfp4 (w3_weight ,
2630
+ self .weight_alignment // 2 ,
2631
+ alignment )
2632
+ else :
2633
+ # Pad bias, TRTLLM backend expects float32 bias.
2634
+ assert len (w1_weight .shape ) == 1
2635
+ assert len (w3_weight .shape ) == 1
2636
+ w1_weight = maybe_pad_for_mxfp4 (w1_weight , alignment ).float ()
2637
+ w3_weight = maybe_pad_for_mxfp4 (w3_weight , alignment ).float ()
2638
+
2559
2639
w1_weight_shard = load_weight_shard (w1_weight ,
2560
2640
module .tp_size ,
2561
2641
module .tp_rank ,
@@ -2566,26 +2646,6 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
2566
2646
module .tp_rank ,
2567
2647
TensorParallelMode .COLUMN ,
2568
2648
device = device )
2569
-
2570
- if len (w1_weight_shard .shape ) == 2 :
2571
- # Pad weights
2572
- # We already satisfy alignment factor of 2 for we pad two MXFP4 into Uint8.
2573
- assert w1_weight_shard .dtype == torch .uint8
2574
- w1_weight_shard = maybe_pad_for_mxfp4 (w1_weight_shard ,
2575
- self .weight_alignment // 2 ,
2576
- self .weight_alignment )
2577
- assert w3_weight_shard .dtype == torch .uint8
2578
- w3_weight_shard = maybe_pad_for_mxfp4 (w3_weight_shard ,
2579
- self .weight_alignment // 2 ,
2580
- self .weight_alignment )
2581
- else :
2582
- # Pad bias, TRTLLM backend expects float32 bias.
2583
- assert len (w1_weight_shard .shape ) == 1
2584
- w1_weight_shard = maybe_pad_for_mxfp4 (
2585
- w1_weight_shard , self .weight_alignment ).float ()
2586
- w3_weight_shard = maybe_pad_for_mxfp4 (
2587
- w3_weight_shard , self .weight_alignment ).float ()
2588
-
2589
2649
# FIXME: this depends on the kernel internals
2590
2650
epilogue_tile_m = 128
2591
2651
@@ -2612,26 +2672,30 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
2612
2672
dst_w2_weight : torch .Tensor ):
2613
2673
device = dst_w2_weight .device
2614
2674
assert device .type == "cuda"
2615
- w2_weight_shard = load_weight_shard (w2_weight ,
2616
- module .tp_size ,
2617
- module .tp_rank ,
2618
- TensorParallelMode .ROW ,
2619
- device = device )
2620
2675
2621
- if len (w2_weight_shard .shape ) == 2 :
2622
- # Pad weights
2623
- # We already satisfy alignment factor of 2 for we pad two MXFP4 into Uint8.
2624
- assert w2_weight_shard .dtype == torch .uint8
2625
- w2_weight_shard = maybe_pad_for_mxfp4 (w2_weight_shard ,
2626
- self .weight_alignment // 2 ,
2627
- self .weight_alignment )
2676
+ shard_w2_weight_dim = 2 * w2_weight .shape [1 ] if len (
2677
+ w2_weight .shape ) == 2 else w2_weight .shape [0 ]
2678
+ alignment = _get_weight_alignment (self .weight_alignment ,
2679
+ module .scaling_vector_size ,
2680
+ module .tp_size , shard_w2_weight_dim )
2681
+
2682
+ if len (w2_weight .shape ) == 2 :
2683
+ assert w2_weight .dtype == torch .uint8
2684
+ w2_weight = maybe_pad_for_mxfp4 (w2_weight , alignment // 2 ,
2685
+ self .weight_alignment )
2628
2686
else :
2629
2687
# Pad bias, TRTLLM backend expects float32 bias.
2630
2688
# Divide bias by tp_size as we shard along the hidden dimension.
2631
2689
# The bias is applied at each TP rank before the final accumulation.
2632
- assert len (w2_weight_shard .shape ) == 1
2633
- w2_weight_shard = maybe_pad_for_mxfp4 (
2634
- w2_weight_shard , self .weight_alignment ).float () / module .tp_size
2690
+ assert len (w2_weight .shape ) == 1
2691
+ w2_weight = maybe_pad_for_mxfp4 (
2692
+ w2_weight , self .weight_alignment ).float () / module .tp_size
2693
+
2694
+ w2_weight_shard = load_weight_shard (w2_weight ,
2695
+ module .tp_size ,
2696
+ module .tp_rank ,
2697
+ TensorParallelMode .ROW ,
2698
+ device = device )
2635
2699
2636
2700
# FIXME: this depends on the kernel internals
2637
2701
epilogue_tile_m = 128
@@ -2657,6 +2721,19 @@ def load_expert_w3_w1_weight_scale_mxfp4(
2657
2721
dst_w3_w1_weight_scale : torch .Tensor ):
2658
2722
device = dst_w3_w1_weight_scale .device
2659
2723
assert device .type == "cuda"
2724
+
2725
+ alignment = _get_weight_alignment (self .weight_alignment ,
2726
+ module .scaling_vector_size ,
2727
+ module .tp_size ,
2728
+ w3_weight_scale .shape [0 ])
2729
+
2730
+ w1_weight_scale = maybe_pad_for_mxfp4 (
2731
+ w1_weight_scale ,
2732
+ self .weight_alignment // module .scaling_vector_size , alignment )
2733
+ w3_weight_scale = maybe_pad_for_mxfp4 (
2734
+ w3_weight_scale ,
2735
+ self .weight_alignment // module .scaling_vector_size , alignment )
2736
+
2660
2737
w1_weight_scale = load_weight_shard (w1_weight_scale ,
2661
2738
module .tp_size ,
2662
2739
module .tp_rank ,
@@ -2667,14 +2744,6 @@ def load_expert_w3_w1_weight_scale_mxfp4(
2667
2744
module .tp_rank ,
2668
2745
TensorParallelMode .COLUMN ,
2669
2746
device = device )
2670
- w1_weight_scale = maybe_pad_for_mxfp4 (
2671
- w1_weight_scale ,
2672
- self .weight_alignment // module .scaling_vector_size ,
2673
- self .weight_alignment )
2674
- w3_weight_scale = maybe_pad_for_mxfp4 (
2675
- w3_weight_scale ,
2676
- self .weight_alignment // module .scaling_vector_size ,
2677
- self .weight_alignment )
2678
2747
2679
2748
# Keep weights in device buffer
2680
2749
dst_w3_weight_scale , dst_w1_weight_scale = dst_w3_w1_weight_scale .chunk (
@@ -2715,20 +2784,21 @@ def load_expert_w2_weight_scale_mxfp4(self, module: torch.nn.Module,
2715
2784
dst_w2_weight_scale : torch .Tensor ):
2716
2785
device = dst_w2_weight_scale .device
2717
2786
assert device .type == "cuda"
2718
- # The last rank might get not full tensor, but its remainder.
2719
- # E.g. TP=8 and w2_weight_scale.shape[1] = 90, the last rank will get 6 elements.
2720
- # Take the original width, pad it to the self.weight_alignment // module.scaling_vector_size,
2721
- # Use this value as padding for the weight scales.
2722
- original_width = math .ceil (w2_weight_scale .shape [1 ] / module .tp_size )
2723
- sfs_alignment = self .weight_alignment // module .scaling_vector_size
2724
- padded_width = math .ceil (original_width / sfs_alignment ) * sfs_alignment
2787
+
2788
+ alignment = _get_weight_alignment (self .weight_alignment ,
2789
+ module .scaling_vector_size ,
2790
+ module .tp_size ,
2791
+ w2_weight_scale .shape [- 1 ])
2792
+
2793
+ w2_weight_scale = maybe_pad_for_mxfp4 (
2794
+ w2_weight_scale , alignment // module .scaling_vector_size ,
2795
+ self .weight_alignment )
2796
+
2725
2797
w2_weight_scale = load_weight_shard (w2_weight_scale ,
2726
2798
module .tp_size ,
2727
2799
module .tp_rank ,
2728
2800
TensorParallelMode .ROW ,
2729
2801
device = device )
2730
- w2_weight_scale = maybe_pad_for_mxfp4 (w2_weight_scale , padded_width ,
2731
- self .weight_alignment )
2732
2802
2733
2803
# Keep weights in device buffer
2734
2804
dst_w2_weight_scale .copy_ (
0 commit comments