Skip to content

Commit eb27ed5

Browse files
dongxuy04dominicshanshan
authored andcommitted
Add MTP support for Online EPLB (NVIDIA#5213)
Signed-off-by: Dongxu Yang <[email protected]>
1 parent 7b7d3e4 commit eb27ed5

File tree

7 files changed

+125
-11
lines changed

7 files changed

+125
-11
lines changed

tensorrt_llm/_torch/expert_statistic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,4 +92,7 @@ def _maybe_add_info(self, expert_count: int,
9292
counts = torch.bincount(token_selected_experts.flatten(),
9393
minlength=expert_count)
9494
key = f"{self.current_iter_id}_{self.current_layer}"
95-
self._records[key] = counts.cpu()
95+
if key not in self._records:
96+
self._records[key] = counts.cpu()
97+
else:
98+
self._records[key] += counts.cpu()

tensorrt_llm/_torch/models/modeling_deepseekv3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@
5454
from ..modules.decoder_layer import DecoderLayer
5555
from ..modules.embedding import Embedding
5656
from ..modules.fused_moe import (CutlassFusedMoE, DeepSeekV3MoeRoutingMethod,
57-
WideEPMoE, create_moe)
57+
WideEPMoE, create_moe,
58+
moe_load_balancer_set_repeated_for_next_layer)
5859
from ..modules.gated_mlp import GatedMLP
5960
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
6061
from ..modules.multi_stream_utils import maybe_execute_in_parallel
@@ -1076,6 +1077,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
10761077
self.num_hidden_layers = self.config.num_hidden_layers
10771078
assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
10781079
if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla:
1080+
moe_load_balancer_set_repeated_for_next_layer(model_nextn)
10791081
mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers,
10801082
self.model.aux_stream_dict)
10811083
self.model.layers.append(mtp_layer)

tensorrt_llm/_torch/modules/fused_moe/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
from .fused_moe_vanilla import VanillaMoE
55
from .fused_moe_wide_ep import WideEPMoE
66
from .interface import MoE, MoEWeightLoadingMode
7-
from .moe_load_balancer import MoeLoadBalancer
7+
from .moe_load_balancer import (MoeLoadBalancer,
8+
moe_load_balancer_set_repeated_for_next_layer)
89
from .quantization import FusedMoEQuantScalesFP8
910
from .routing import (BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod,
1011
DefaultMoeRoutingMethod,
@@ -23,6 +24,7 @@
2324
"get_moe_cls",
2425
"Llama4RenormalizeMoeRoutingMethod",
2526
"LoadBalancedMoeRoutingMethod",
27+
"moe_load_balancer_set_repeated_for_next_layer",
2628
"MoE",
2729
"MoeLoadBalancer",
2830
"MoEWeightLoadingMode",

tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py

100644100755
Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ def __init__(
8787

8888
moe_load_balancer = get_moe_load_balancer()
8989
self.layer_load_balancer = None
90+
self.repeat_idx = 0
91+
self.repeat_count = 1
9092

9193
moe_load_balancer_config = model_config.moe_load_balancer
9294
init_expert_size_per_partition = moe_load_balancer_config.num_local_slots if moe_load_balancer_config else self.num_experts // self.ep_size
@@ -102,6 +104,7 @@ def __init__(
102104
self.expert_size_per_partition = moe_load_balancer_config.num_local_slots
103105
self.layer_load_balancer = moe_load_balancer.add_layer(
104106
self.num_experts, top_k, self.expert_size_per_partition)
107+
self.repeat_count = self.layer_load_balancer.get_repeat_count()
105108
loaded_initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments(
106109
self.layer_idx)
107110
self.num_slots = moe_load_balancer_config.num_slots
@@ -434,6 +437,8 @@ def forward_chunk(
434437
)
435438

436439
x_sf = None
440+
x_row = x.shape[0]
441+
x_col = x.shape[1]
437442
sf_swizzle = True
438443
if self.has_any_quant:
439444
if self.has_fp8_qdq:
@@ -479,7 +484,7 @@ def forward_chunk(
479484
dim=0,
480485
sizes=None if use_dp_padding else all_rank_num_tokens)
481486
# use separate allgather since doesn't have sizes, can be optimized but in allgather path it is OK
482-
if is_last_call:
487+
if is_last_call and loadbalancer_local_statistic_info is not None:
483488
gathered_loadbalancer_local_statistic_info = allgather(
484489
loadbalancer_local_statistic_info, self.mapping, dim=0)
485490
# Fp4 gemm has extra scaling factor
@@ -668,13 +673,16 @@ def forward(
668673
else:
669674
all_rank_num_tokens_padded = all_rank_num_tokens
670675
if num_chunks == 1:
676+
is_first_call = self.repeat_idx == 0
677+
is_last_call = self.repeat_idx == self.repeat_count - 1
671678
outputs = self.forward_chunk(
672679
x,
673680
router_logits,
674681
cutlass_min_latency_mode,
675682
output_dtype,
676683
all_rank_num_tokens=all_rank_num_tokens_padded,
677-
use_dp_padding=use_dp_padding)
684+
use_dp_padding=use_dp_padding,
685+
repeating_info=(is_first_call, is_last_call))
678686
outputs = self.reducescatter_or_allreduce(
679687
outputs,
680688
all_rank_num_tokens=all_rank_num_tokens_padded,
@@ -717,8 +725,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
717725
# Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap
718726
for idx_chunk, (x, router_logits) in enumerate(
719727
zip(x_list, router_logits_list)):
720-
is_first_call = idx_chunk == 0
721-
is_last_call = idx_chunk == num_chunks - 1
728+
is_first_call = idx_chunk == 0 and self.repeat_idx == 0
729+
is_last_call = idx_chunk == num_chunks - 1 and self.repeat_idx == self.repeat_count - 1
722730
if not self.enable_alltoall:
723731
if idx_chunk % 2 == 0:
724732
with torch.cuda.stream(self.aux_stream):
@@ -777,6 +785,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
777785
if self.use_dp:
778786
rank = self.mapping.tp_rank
779787
outputs = outputs[:all_rank_num_tokens[rank]]
788+
self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1
780789
return outputs
781790

782791
def alltoall_prepare_maybe_dispatch(

tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,10 +182,11 @@ def finalize_layer_weights(self):
182182
offset = 0
183183
for name in self.names:
184184
for expert_id in range(self.expert_start, self.expert_end):
185-
t = self.shared_tensors[(expert_id, name)]
185+
t = self.shared_tensors[(expert_id, name)].contiguous().cpu()
186186
data_size = t.numel() * t.element_size()
187187
aligned_size = self.align_size(data_size)
188-
shm.buf[offset:offset + data_size] = t.numpy().tobytes()
188+
shm.buf[offset:offset + data_size] = t.flatten().view(
189+
torch.int8).numpy().tobytes()
189190
dtype = t.dtype
190191
tensor_shape = t.shape
191192
elt_count = t.numel()
@@ -270,7 +271,8 @@ def __init__(
270271
single_layer_load_balancer_impl: _tbr.SingleLayerMoeLoadBalancer,
271272
shared_mpi_comm: MPI.Comm,
272273
expert_count: int,
273-
updates_enabled: bool = True):
274+
updates_enabled: bool = True,
275+
repeated_count=1):
274276
"""
275277
Initialize a SingleLayerMoeLoadBalancer instance.
276278
@@ -279,6 +281,7 @@ def __init__(
279281
shared_mpi_comm: The MPI communicator for shared memory
280282
expert_count: total number of experts
281283
updates_enabled: whether to enable weight updates
284+
repeated_count: the repeated count of current layer, used when forward is repeated more than once like MTP.
282285
"""
283286
self.single_layer_load_balancer_impl = single_layer_load_balancer_impl
284287
self.single_layer_load_balancer_ptr = single_layer_load_balancer_impl.get_pointer(
@@ -306,6 +309,7 @@ def __init__(
306309

307310
self.cudagraph_stream = None
308311
self.cudagraph_event = None
312+
self.repeated_count = repeated_count
309313

310314
self.statistic_stream = None
311315
self.statistic_event = None
@@ -317,6 +321,9 @@ def get_load_expert_ids(self):
317321
assert self.updates_enabled, "should not call get_load_expert_ids when using statistic routing"
318322
return self.load_expert_ids
319323

324+
def get_repeat_count(self):
325+
return self.repeated_count
326+
320327
def is_static_routing(self):
321328
return not self.updates_enabled
322329

@@ -675,6 +682,8 @@ def __init__(self,
675682
self.enable_statistic = False
676683
self.enable_update_weights = False
677684

685+
self.next_layer_repeated_count = None
686+
678687
def __del__(self):
679688
if not self.is_shutdown:
680689
self.shutdown()
@@ -696,6 +705,16 @@ def _setup_mpi_comm(self):
696705
def set_use_gpu_memcpy(self, use_gpu_memcpy: bool):
697706
self.load_balancer_impl.set_use_gpu_memcpy(use_gpu_memcpy)
698707

708+
def set_repeated_for_next_layer(self, repeated_count: int):
709+
"""
710+
Set repeat count for next layer.
711+
712+
Args:
713+
repeated_count: The repeat count for next layer
714+
"""
715+
assert repeated_count > 0, "repeat count must be greater than 0"
716+
self.next_layer_repeated_count = repeated_count
717+
699718
def add_layer(self, expert_count: int, top_k: int,
700719
slot_count_per_rank: int) -> SingleLayerMoeLoadBalancer:
701720
"""
@@ -712,11 +731,16 @@ def add_layer(self, expert_count: int, top_k: int,
712731
single_layer_load_balancer_impl = self.load_balancer_impl.add_layer(
713732
expert_count, top_k, slot_count_per_rank)
714733
updates_enabled = not self.is_static_routing()
734+
repeat_count = 1
735+
if self.next_layer_repeated_count is not None:
736+
repeat_count = self.next_layer_repeated_count
737+
self.next_layer_repeated_count = None
715738
single_layer_load_balancer = SingleLayerMoeLoadBalancer(
716739
single_layer_load_balancer_impl,
717740
self.shared_mpi_comm,
718741
expert_count,
719-
updates_enabled=updates_enabled)
742+
updates_enabled=updates_enabled,
743+
repeated_count=repeat_count)
720744
single_layer_load_balancer.set_shared_memory_base_name(
721745
self.shared_memory_base_name)
722746
self.single_layer_load_balancers.append(single_layer_load_balancer)
@@ -934,6 +958,18 @@ def get_moe_load_balancer() -> Optional[MoeLoadBalancer]:
934958
return None
935959

936960

961+
def moe_load_balancer_set_repeated_for_next_layer(repeat_count: int):
962+
"""
963+
Set repeated count for next Single Layer created.
964+
965+
Args:
966+
repeat_count: repeated count
967+
"""
968+
load_balancer = get_moe_load_balancer()
969+
if load_balancer is not None:
970+
load_balancer.set_repeated_for_next_layer(repeat_count)
971+
972+
937973
def moe_load_balancer_add_single_layer(
938974
expert_count: int, top_k: int,
939975
slot_count_per_rank: int) -> Optional[SingleLayerMoeLoadBalancer]:

tests/integration/defs/accuracy/test_llm_api_pytorch.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -794,6 +794,7 @@ def test_fp8_block_scales_4gpus_static_eplb(self):
794794
initial_global_assignments=initial_global_assignments,
795795
layer_updates_per_iter=0)
796796
pytorch_backend_options = dict(use_cuda_graph=True,
797+
moe_backend="WIDEEP",
797798
moe_load_balancer=eplb_config)
798799
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
799800
tensor_parallel_size=4,
@@ -807,6 +808,65 @@ def test_fp8_block_scales_4gpus_static_eplb(self):
807808
task = GSM8K(self.MODEL_NAME)
808809
task.evaluate(llm)
809810

811+
@pytest.mark.skip_less_device(4)
812+
@pytest.mark.skip_device_not_contain(["GB200"])
813+
@parametrize_with_ids("mtp_nextn", [0, 2])
814+
def test_bfloat16_4gpus_online_eplb(self, mtp_nextn):
815+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
816+
num_slots = 80
817+
eplb_config = MoeLoadBalancerConfig(num_slots=num_slots,
818+
layer_updates_per_iter=2)
819+
pytorch_config = dict(use_cuda_graph=True,
820+
moe_backend="WIDEEP",
821+
moe_load_balancer=eplb_config)
822+
mtp_config = None
823+
if mtp_nextn > 0:
824+
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
825+
llm = LLM(self.MODEL_PATH,
826+
tensor_parallel_size=4,
827+
moe_expert_parallel_size=4,
828+
kv_cache_config=kv_cache_config,
829+
enable_attention_dp=True,
830+
**pytorch_config,
831+
speculative_config=mtp_config)
832+
with llm:
833+
task = MMLU(self.MODEL_NAME)
834+
task.evaluate(llm)
835+
task = GSM8K(self.MODEL_NAME)
836+
task.evaluate(llm)
837+
838+
@pytest.mark.skip_less_device(4)
839+
@pytest.mark.skip_device_not_contain(["GB200"])
840+
@parametrize_with_ids("fp8kv", [True, False])
841+
def test_nvfp4_4gpus_online_eplb(self, fp8kv):
842+
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
843+
num_slots = 80
844+
eplb_config = MoeLoadBalancerConfig(num_slots=num_slots,
845+
layer_updates_per_iter=2)
846+
pytorch_backend_options = dict(use_cuda_graph=True,
847+
moe_backend="WIDEEP",
848+
moe_load_balancer=eplb_config)
849+
quant_config = QuantConfig()
850+
quant_config.quant_algo = QuantAlgo.NVFP4
851+
if fp8kv:
852+
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
853+
pytorch_backend_options["kv_cache_dtype"] = "fp8"
854+
855+
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only",
856+
tensor_parallel_size=4,
857+
moe_expert_parallel_size=4,
858+
kv_cache_config=kv_cache_config,
859+
**pytorch_backend_options,
860+
enable_attention_dp=True,
861+
quant_config=quant_config)
862+
with llm:
863+
# No need to run MMLU for fp8kv
864+
if not fp8kv:
865+
task = MMLU(self.MODEL_NAME)
866+
task.evaluate(llm)
867+
task = GSM8K(self.MODEL_NAME)
868+
task.evaluate(llm)
869+
810870
@skip_pre_blackwell
811871
@parametrize_with_ids(
812872
"torch_compile",

tests/integration/test_lists/test-db/l0_gb200.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,5 @@ l0_gb200:
6464
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
6565
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
6666
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
67+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True]
68+
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]

0 commit comments

Comments
 (0)