Skip to content

Commit a419b77

Browse files
authored
[None][fix] mxfp4 padding bug for TRT-LLM and CUTLASS MoE backends (#7214)
Signed-off-by: Nikita Korobov <[email protected]>
1 parent 1e644fa commit a419b77

File tree

1 file changed

+166
-96
lines changed

1 file changed

+166
-96
lines changed

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 166 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -2085,6 +2085,29 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
20852085
non_blocking=True)
20862086

20872087

2088+
def _get_weight_alignment(weight_alignment, scaling_vector_size, tp_size,
2089+
shard_dim_size):
2090+
2091+
def lcm(a, b):
2092+
return abs(a * b) // math.gcd(a, b)
2093+
2094+
# The alignment should be the least common multiple of weight_alignment, scaling_vector_size,
2095+
# and tp_size. scaling_vector_size and tp_size must be considered
2096+
# to avoid fractional scaling factors.
2097+
alignment = lcm(weight_alignment, scaling_vector_size)
2098+
alignment = lcm(alignment, tp_size)
2099+
2100+
# If after the alignment, the sharding dim per shard is not a multiple of weight_alignment,
2101+
# we need to pad the weights to make it a multiple of weight_alignment.
2102+
padded_weights_dim = math.ceil(shard_dim_size / alignment) * alignment
2103+
per_shard = padded_weights_dim // tp_size
2104+
if per_shard % weight_alignment != 0:
2105+
alignment = weight_alignment * math.ceil(
2106+
per_shard / weight_alignment) * tp_size
2107+
2108+
return alignment
2109+
2110+
20882111
class MXFP4WeightFusedMoEMethod(FusedMoEMethodBase):
20892112

20902113
def create_weights(self,
@@ -2249,32 +2272,44 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
22492272
w1_weight: torch.Tensor,
22502273
w3_weight: torch.Tensor,
22512274
dst_w3_w1_weight: torch.Tensor):
2275+
# We pad before the sharding. This is done to avoid fractional scaling factors
2276+
# per shard.
2277+
#
2278+
# E.g. if we pad after the sharding, with intermediate_size = 2880,
2279+
# tp_size = 4, scaling_vector_size = 32, each shard gets 720 elements and
2280+
# 22.5 scaling factors. After padding, each shard gets 768 in
2281+
# intermediate_size, and 24 in scaling factors's intermediate_size.
2282+
# The 2nd rank will start loading the 23rd scaling factor,
2283+
# while it should've loaded 22nd for the first 16 elements only.
2284+
# We pad the weights before the sharding to avoid this issue.
2285+
alignment = _get_weight_alignment(self.weight_alignment,
2286+
module.scaling_vector_size,
2287+
module.tp_size, w1_weight.shape[0])
2288+
if len(w1_weight.shape) == 2:
2289+
# Pad weights
2290+
# We already satisfy alignment factor of 2 for we pack two MXFP4 into Uint8.
2291+
assert w1_weight.dtype == torch.uint8
2292+
w1_weight = maybe_pad_for_mxfp4(w1_weight,
2293+
self.weight_alignment // 2,
2294+
alignment)
2295+
assert w3_weight.dtype == torch.uint8
2296+
w3_weight = maybe_pad_for_mxfp4(w3_weight,
2297+
self.weight_alignment // 2,
2298+
alignment)
2299+
else:
2300+
# Pad bias.
2301+
assert len(w1_weight.shape) == 1
2302+
assert len(w3_weight.shape) == 1
2303+
w1_weight = maybe_pad_for_mxfp4(w1_weight, alignment)
2304+
w3_weight = maybe_pad_for_mxfp4(w3_weight, alignment)
2305+
22522306
w1_weight_shard = load_weight_shard(w1_weight, module.tp_size,
22532307
module.tp_rank,
22542308
TensorParallelMode.COLUMN)
22552309
w3_weight_shard = load_weight_shard(w3_weight, module.tp_size,
22562310
module.tp_rank,
22572311
TensorParallelMode.COLUMN)
22582312

2259-
if len(w1_weight_shard.shape) == 2:
2260-
# Pad weights
2261-
# We already satisfy alignment factor 2 for we pad 2 MXFP4 into Uint8.
2262-
assert w1_weight_shard.dtype == torch.uint8
2263-
w1_weight_shard = maybe_pad_for_mxfp4(w1_weight_shard,
2264-
self.weight_alignment // 2,
2265-
self.weight_alignment)
2266-
assert w3_weight_shard.dtype == torch.uint8
2267-
w3_weight_shard = maybe_pad_for_mxfp4(w3_weight_shard,
2268-
self.weight_alignment // 2,
2269-
self.weight_alignment)
2270-
else:
2271-
# Pad bias.
2272-
assert len(w1_weight_shard.shape) == 1
2273-
w1_weight_shard = maybe_pad_for_mxfp4(w1_weight_shard,
2274-
self.weight_alignment)
2275-
w3_weight_shard = maybe_pad_for_mxfp4(w3_weight_shard,
2276-
self.weight_alignment)
2277-
22782313
w31_weight_shard = torch.cat([w3_weight_shard, w1_weight_shard], dim=0)
22792314
dst_w3_w1_weight.copy_(w31_weight_shard.view(dst_w3_w1_weight.dtype),
22802315
non_blocking=True)
@@ -2287,23 +2322,25 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
22872322
Load w2 weight for each expert.
22882323
Override this method if you need to preprocess the weights differently.
22892324
"""
2325+
shard_w2_weight_dim = 2 * w2_weight.shape[1] if len(
2326+
w2_weight.shape) == 2 else w2_weight.shape[0]
2327+
alignment = _get_weight_alignment(self.weight_alignment,
2328+
module.scaling_vector_size,
2329+
module.tp_size, shard_w2_weight_dim)
2330+
2331+
if len(w2_weight.shape) == 2:
2332+
assert w2_weight.dtype == torch.uint8
2333+
w2_weight = maybe_pad_for_mxfp4(w2_weight, alignment // 2,
2334+
self.weight_alignment)
2335+
else:
2336+
# Pad bias.
2337+
assert len(w2_weight.shape) == 1
2338+
w2_weight = maybe_pad_for_mxfp4(w2_weight, self.weight_alignment)
2339+
22902340
w2_weight_shard = load_weight_shard(w2_weight, module.tp_size,
22912341
module.tp_rank,
22922342
TensorParallelMode.ROW)
22932343

2294-
if len(w2_weight_shard.shape) == 2:
2295-
# Pad weights
2296-
# We already satisfy alignment factor 2 for we pad two MXFP4 into Uint8.
2297-
assert w2_weight_shard.dtype == torch.uint8
2298-
w2_weight_shard = maybe_pad_for_mxfp4(w2_weight_shard,
2299-
self.weight_alignment // 2,
2300-
self.weight_alignment)
2301-
else:
2302-
assert len(w2_weight_shard.shape) == 1
2303-
# Pad bias.
2304-
w2_weight_shard = maybe_pad_for_mxfp4(w2_weight_shard,
2305-
self.weight_alignment)
2306-
23072344
dst_w2_weight.copy_(w2_weight_shard.view(dst_w2_weight.dtype),
23082345
non_blocking=True)
23092346

@@ -2312,6 +2349,19 @@ def load_expert_w3_w1_weight_scale_mxfp4(
23122349
w3_weight_scale: torch.Tensor,
23132350
dst_w3_w1_weight_scale: torch.Tensor):
23142351
device = dst_w3_w1_weight_scale.device
2352+
2353+
alignment = _get_weight_alignment(self.weight_alignment,
2354+
module.scaling_vector_size,
2355+
module.tp_size,
2356+
w3_weight_scale.shape[0])
2357+
2358+
w1_weight_scale = maybe_pad_for_mxfp4(
2359+
w1_weight_scale,
2360+
self.weight_alignment // module.scaling_vector_size, alignment)
2361+
w3_weight_scale = maybe_pad_for_mxfp4(
2362+
w3_weight_scale,
2363+
self.weight_alignment // module.scaling_vector_size, alignment)
2364+
23152365
w1_weight_scale = load_weight_shard(w1_weight_scale,
23162366
module.tp_size,
23172367
module.tp_rank,
@@ -2322,14 +2372,6 @@ def load_expert_w3_w1_weight_scale_mxfp4(
23222372
module.tp_rank,
23232373
TensorParallelMode.COLUMN,
23242374
device=device)
2325-
w1_weight_scale = maybe_pad_for_mxfp4(
2326-
w1_weight_scale,
2327-
self.weight_alignment // module.scaling_vector_size,
2328-
self.weight_alignment)
2329-
w3_weight_scale = maybe_pad_for_mxfp4(
2330-
w3_weight_scale,
2331-
self.weight_alignment // module.scaling_vector_size,
2332-
self.weight_alignment)
23332375

23342376
# Keep weights in device buffer
23352377
dst_w3_weight_scale, dst_w1_weight_scale = dst_w3_w1_weight_scale.chunk(
@@ -2350,15 +2392,21 @@ def load_expert_w2_weight_scale_mxfp4(self, module: torch.nn.Module,
23502392
w2_weight_scale: torch.Tensor,
23512393
dst_w2_weight_scale: torch.Tensor):
23522394
device = dst_w2_weight_scale.device
2395+
2396+
alignment = _get_weight_alignment(self.weight_alignment,
2397+
module.scaling_vector_size,
2398+
module.tp_size,
2399+
w2_weight_scale.shape[-1])
2400+
2401+
w2_weight_scale = maybe_pad_for_mxfp4(
2402+
w2_weight_scale, alignment // module.scaling_vector_size,
2403+
self.weight_alignment)
2404+
23532405
w2_weight_scale = load_weight_shard(w2_weight_scale,
23542406
module.tp_size,
23552407
module.tp_rank,
23562408
TensorParallelMode.ROW,
23572409
device=device)
2358-
w2_weight_scale = maybe_pad_for_mxfp4(
2359-
w2_weight_scale,
2360-
self.weight_alignment // module.scaling_vector_size,
2361-
self.weight_alignment)
23622410

23632411
# Keep weights in device buffer
23642412
dst_w2_weight_scale.copy_(
@@ -2556,6 +2604,38 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
25562604
dst_w3_w1_weight: torch.Tensor):
25572605
device = dst_w3_w1_weight.device
25582606
assert device.type == "cuda"
2607+
2608+
# We pad before the sharding. This is done to avoid fractional scaling factors
2609+
# per shard.
2610+
#
2611+
# E.g. if we pad after the sharding, with intermediate_size = 2880,
2612+
# tp_size = 4, scaling_vector_size = 32, each shard gets 720 elements and
2613+
# 22.5 scaling factors. After padding, each shard gets 768 in
2614+
# intermediate_size, and 24 in scaling factors's intermediate_size.
2615+
# The 2nd rank will start loading the 23rd scaling factor,
2616+
# while it should've loaded 22nd for the first 16 elements only.
2617+
# We pad the weights before the sharding to avoid this issue.
2618+
alignment = _get_weight_alignment(self.weight_alignment,
2619+
module.scaling_vector_size,
2620+
module.tp_size, w1_weight.shape[0])
2621+
if len(w1_weight.shape) == 2:
2622+
# Pad weights
2623+
# We already satisfy alignment factor of 2 for we pack two MXFP4 into Uint8.
2624+
assert w1_weight.dtype == torch.uint8
2625+
w1_weight = maybe_pad_for_mxfp4(w1_weight,
2626+
self.weight_alignment // 2,
2627+
alignment)
2628+
assert w3_weight.dtype == torch.uint8
2629+
w3_weight = maybe_pad_for_mxfp4(w3_weight,
2630+
self.weight_alignment // 2,
2631+
alignment)
2632+
else:
2633+
# Pad bias, TRTLLM backend expects float32 bias.
2634+
assert len(w1_weight.shape) == 1
2635+
assert len(w3_weight.shape) == 1
2636+
w1_weight = maybe_pad_for_mxfp4(w1_weight, alignment).float()
2637+
w3_weight = maybe_pad_for_mxfp4(w3_weight, alignment).float()
2638+
25592639
w1_weight_shard = load_weight_shard(w1_weight,
25602640
module.tp_size,
25612641
module.tp_rank,
@@ -2566,26 +2646,6 @@ def load_expert_w3_w1_weight(self, module: torch.nn.Module,
25662646
module.tp_rank,
25672647
TensorParallelMode.COLUMN,
25682648
device=device)
2569-
2570-
if len(w1_weight_shard.shape) == 2:
2571-
# Pad weights
2572-
# We already satisfy alignment factor of 2 for we pad two MXFP4 into Uint8.
2573-
assert w1_weight_shard.dtype == torch.uint8
2574-
w1_weight_shard = maybe_pad_for_mxfp4(w1_weight_shard,
2575-
self.weight_alignment // 2,
2576-
self.weight_alignment)
2577-
assert w3_weight_shard.dtype == torch.uint8
2578-
w3_weight_shard = maybe_pad_for_mxfp4(w3_weight_shard,
2579-
self.weight_alignment // 2,
2580-
self.weight_alignment)
2581-
else:
2582-
# Pad bias, TRTLLM backend expects float32 bias.
2583-
assert len(w1_weight_shard.shape) == 1
2584-
w1_weight_shard = maybe_pad_for_mxfp4(
2585-
w1_weight_shard, self.weight_alignment).float()
2586-
w3_weight_shard = maybe_pad_for_mxfp4(
2587-
w3_weight_shard, self.weight_alignment).float()
2588-
25892649
# FIXME: this depends on the kernel internals
25902650
epilogue_tile_m = 128
25912651

@@ -2612,26 +2672,30 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
26122672
dst_w2_weight: torch.Tensor):
26132673
device = dst_w2_weight.device
26142674
assert device.type == "cuda"
2615-
w2_weight_shard = load_weight_shard(w2_weight,
2616-
module.tp_size,
2617-
module.tp_rank,
2618-
TensorParallelMode.ROW,
2619-
device=device)
26202675

2621-
if len(w2_weight_shard.shape) == 2:
2622-
# Pad weights
2623-
# We already satisfy alignment factor of 2 for we pad two MXFP4 into Uint8.
2624-
assert w2_weight_shard.dtype == torch.uint8
2625-
w2_weight_shard = maybe_pad_for_mxfp4(w2_weight_shard,
2626-
self.weight_alignment // 2,
2627-
self.weight_alignment)
2676+
shard_w2_weight_dim = 2 * w2_weight.shape[1] if len(
2677+
w2_weight.shape) == 2 else w2_weight.shape[0]
2678+
alignment = _get_weight_alignment(self.weight_alignment,
2679+
module.scaling_vector_size,
2680+
module.tp_size, shard_w2_weight_dim)
2681+
2682+
if len(w2_weight.shape) == 2:
2683+
assert w2_weight.dtype == torch.uint8
2684+
w2_weight = maybe_pad_for_mxfp4(w2_weight, alignment // 2,
2685+
self.weight_alignment)
26282686
else:
26292687
# Pad bias, TRTLLM backend expects float32 bias.
26302688
# Divide bias by tp_size as we shard along the hidden dimension.
26312689
# The bias is applied at each TP rank before the final accumulation.
2632-
assert len(w2_weight_shard.shape) == 1
2633-
w2_weight_shard = maybe_pad_for_mxfp4(
2634-
w2_weight_shard, self.weight_alignment).float() / module.tp_size
2690+
assert len(w2_weight.shape) == 1
2691+
w2_weight = maybe_pad_for_mxfp4(
2692+
w2_weight, self.weight_alignment).float() / module.tp_size
2693+
2694+
w2_weight_shard = load_weight_shard(w2_weight,
2695+
module.tp_size,
2696+
module.tp_rank,
2697+
TensorParallelMode.ROW,
2698+
device=device)
26352699

26362700
# FIXME: this depends on the kernel internals
26372701
epilogue_tile_m = 128
@@ -2657,6 +2721,19 @@ def load_expert_w3_w1_weight_scale_mxfp4(
26572721
dst_w3_w1_weight_scale: torch.Tensor):
26582722
device = dst_w3_w1_weight_scale.device
26592723
assert device.type == "cuda"
2724+
2725+
alignment = _get_weight_alignment(self.weight_alignment,
2726+
module.scaling_vector_size,
2727+
module.tp_size,
2728+
w3_weight_scale.shape[0])
2729+
2730+
w1_weight_scale = maybe_pad_for_mxfp4(
2731+
w1_weight_scale,
2732+
self.weight_alignment // module.scaling_vector_size, alignment)
2733+
w3_weight_scale = maybe_pad_for_mxfp4(
2734+
w3_weight_scale,
2735+
self.weight_alignment // module.scaling_vector_size, alignment)
2736+
26602737
w1_weight_scale = load_weight_shard(w1_weight_scale,
26612738
module.tp_size,
26622739
module.tp_rank,
@@ -2667,14 +2744,6 @@ def load_expert_w3_w1_weight_scale_mxfp4(
26672744
module.tp_rank,
26682745
TensorParallelMode.COLUMN,
26692746
device=device)
2670-
w1_weight_scale = maybe_pad_for_mxfp4(
2671-
w1_weight_scale,
2672-
self.weight_alignment // module.scaling_vector_size,
2673-
self.weight_alignment)
2674-
w3_weight_scale = maybe_pad_for_mxfp4(
2675-
w3_weight_scale,
2676-
self.weight_alignment // module.scaling_vector_size,
2677-
self.weight_alignment)
26782747

26792748
# Keep weights in device buffer
26802749
dst_w3_weight_scale, dst_w1_weight_scale = dst_w3_w1_weight_scale.chunk(
@@ -2715,20 +2784,21 @@ def load_expert_w2_weight_scale_mxfp4(self, module: torch.nn.Module,
27152784
dst_w2_weight_scale: torch.Tensor):
27162785
device = dst_w2_weight_scale.device
27172786
assert device.type == "cuda"
2718-
# The last rank might get not full tensor, but its remainder.
2719-
# E.g. TP=8 and w2_weight_scale.shape[1] = 90, the last rank will get 6 elements.
2720-
# Take the original width, pad it to the self.weight_alignment // module.scaling_vector_size,
2721-
# Use this value as padding for the weight scales.
2722-
original_width = math.ceil(w2_weight_scale.shape[1] / module.tp_size)
2723-
sfs_alignment = self.weight_alignment // module.scaling_vector_size
2724-
padded_width = math.ceil(original_width / sfs_alignment) * sfs_alignment
2787+
2788+
alignment = _get_weight_alignment(self.weight_alignment,
2789+
module.scaling_vector_size,
2790+
module.tp_size,
2791+
w2_weight_scale.shape[-1])
2792+
2793+
w2_weight_scale = maybe_pad_for_mxfp4(
2794+
w2_weight_scale, alignment // module.scaling_vector_size,
2795+
self.weight_alignment)
2796+
27252797
w2_weight_scale = load_weight_shard(w2_weight_scale,
27262798
module.tp_size,
27272799
module.tp_rank,
27282800
TensorParallelMode.ROW,
27292801
device=device)
2730-
w2_weight_scale = maybe_pad_for_mxfp4(w2_weight_scale, padded_width,
2731-
self.weight_alignment)
27322802

27332803
# Keep weights in device buffer
27342804
dst_w2_weight_scale.copy_(

0 commit comments

Comments
 (0)