From 0308e8ef908a17f2606a6ff3c4e2a8c04a36acfb Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Sat, 14 Jun 2025 19:31:22 +0800 Subject: [PATCH 1/3] fix MTP for EPLB Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- .../_torch/models/modeling_deepseekv3.py | 4 +- .../_torch/modules/fused_moe/__init__.py | 4 +- .../modules/fused_moe/fused_moe_wide_ep.py | 15 ++++- .../modules/fused_moe/moe_load_balancer.py | 44 ++++++++++++-- .../defs/accuracy/test_llm_api_pytorch.py | 60 +++++++++++++++++++ .../test_lists/test-db/l0_gb200.yml | 2 + 6 files changed, 120 insertions(+), 9 deletions(-) mode change 100644 => 100755 tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index f66663e9203..5e1e7a1b140 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -53,7 +53,8 @@ from ..modules.decoder_layer import DecoderLayer from ..modules.embedding import Embedding from ..modules.fused_moe import (CutlassFusedMoE, DeepSeekV3MoeRoutingMethod, - WideEPMoE, create_moe) + WideEPMoE, create_moe, + moe_load_balancer_set_repeated_for_next_layer) from ..modules.gated_mlp import GatedMLP from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig from ..modules.multi_stream_utils import maybe_execute_in_parallel @@ -1069,6 +1070,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]): self.num_hidden_layers = self.config.num_hidden_layers assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint." if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla: + moe_load_balancer_set_repeated_for_next_layer(model_nextn) mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers, self.model.aux_stream_dict) self.model.layers.append(mtp_layer) diff --git a/tensorrt_llm/_torch/modules/fused_moe/__init__.py b/tensorrt_llm/_torch/modules/fused_moe/__init__.py index bb8f047fecf..43f7cf05d0c 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/__init__.py +++ b/tensorrt_llm/_torch/modules/fused_moe/__init__.py @@ -4,7 +4,8 @@ from .fused_moe_vanilla import VanillaMoE from .fused_moe_wide_ep import WideEPMoE from .interface import MoE, MoEWeightLoadingMode -from .moe_load_balancer import MoeLoadBalancer +from .moe_load_balancer import (MoeLoadBalancer, + moe_load_balancer_set_repeated_for_next_layer) from .quantization import FusedMoEQuantScalesFP8 from .routing import (BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod, DefaultMoeRoutingMethod, @@ -23,6 +24,7 @@ "get_moe_cls", "Llama4RenormalizeMoeRoutingMethod", "LoadBalancedMoeRoutingMethod", + "moe_load_balancer_set_repeated_for_next_layer", "MoE", "MoeLoadBalancer", "MoEWeightLoadingMode", diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py old mode 100644 new mode 100755 index 3506845133b..384ebb44aa8 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -87,6 +87,8 @@ def __init__( moe_load_balancer = get_moe_load_balancer() self.layer_load_balancer = None + self.repeat_idx = 0 + self.repeat_count = 1 moe_load_balancer_config = model_config.moe_load_balancer init_expert_size_per_partition = moe_load_balancer_config.num_local_slots if moe_load_balancer_config else self.num_experts // self.ep_size @@ -102,6 +104,7 @@ def __init__( self.expert_size_per_partition = moe_load_balancer_config.num_local_slots self.layer_load_balancer = moe_load_balancer.add_layer( self.num_experts, top_k, self.expert_size_per_partition) + self.repeat_count = self.layer_load_balancer.get_repeat_count() loaded_initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments( self.layer_idx) self.num_slots = moe_load_balancer_config.num_slots @@ -434,6 +437,8 @@ def forward_chunk( ) x_sf = None + x_row = x.shape[0] + x_col = x.shape[1] sf_swizzle = True if self.has_any_quant: if self.has_fp8_qdq: @@ -668,13 +673,16 @@ def forward( else: all_rank_num_tokens_padded = all_rank_num_tokens if num_chunks == 1: + is_first_call = self.repeat_idx == 0 + is_last_call = self.repeat_idx == self.repeat_count - 1 outputs = self.forward_chunk( x, router_logits, cutlass_min_latency_mode, output_dtype, all_rank_num_tokens=all_rank_num_tokens_padded, - use_dp_padding=use_dp_padding) + use_dp_padding=use_dp_padding, + repeating_info=(is_first_call, is_last_call)) outputs = self.reducescatter_or_allreduce( outputs, all_rank_num_tokens=all_rank_num_tokens_padded, @@ -717,8 +725,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int): # Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap for idx_chunk, (x, router_logits) in enumerate( zip(x_list, router_logits_list)): - is_first_call = idx_chunk == 0 - is_last_call = idx_chunk == num_chunks - 1 + is_first_call = idx_chunk == 0 and self.repeat_idx == 0 + is_last_call = idx_chunk == num_chunks - 1 and self.repeat_idx == self.repeat_count - 1 if not self.enable_alltoall: if idx_chunk % 2 == 0: with torch.cuda.stream(self.aux_stream): @@ -777,6 +785,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int): if self.use_dp: rank = self.mapping.tp_rank outputs = outputs[:all_rank_num_tokens[rank]] + self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1 return outputs def alltoall_prepare_maybe_dispatch( diff --git a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py index b611f3f97a5..d6db1b30f6a 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py +++ b/tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py @@ -182,10 +182,11 @@ def finalize_layer_weights(self): offset = 0 for name in self.names: for expert_id in range(self.expert_start, self.expert_end): - t = self.shared_tensors[(expert_id, name)] + t = self.shared_tensors[(expert_id, name)].contiguous().cpu() data_size = t.numel() * t.element_size() aligned_size = self.align_size(data_size) - shm.buf[offset:offset + data_size] = t.numpy().tobytes() + shm.buf[offset:offset + data_size] = t.flatten().view( + torch.int8).numpy().tobytes() dtype = t.dtype tensor_shape = t.shape elt_count = t.numel() @@ -270,7 +271,8 @@ def __init__( single_layer_load_balancer_impl: _tbr.SingleLayerMoeLoadBalancer, shared_mpi_comm: MPI.Comm, expert_count: int, - updates_enabled: bool = True): + updates_enabled: bool = True, + repeated_count=1): """ Initialize a SingleLayerMoeLoadBalancer instance. @@ -279,6 +281,7 @@ def __init__( shared_mpi_comm: The MPI communicator for shared memory expert_count: total number of experts updates_enabled: whether to enable weight updates + repeated_count: the repeated count of current layer, used when forward is repeated more than once like MTP. """ self.single_layer_load_balancer_impl = single_layer_load_balancer_impl self.single_layer_load_balancer_ptr = single_layer_load_balancer_impl.get_pointer( @@ -306,6 +309,7 @@ def __init__( self.cudagraph_stream = None self.cudagraph_event = None + self.repeated_count = repeated_count self.statistic_stream = None self.statistic_event = None @@ -317,6 +321,9 @@ def get_load_expert_ids(self): assert self.updates_enabled, "should not call get_load_expert_ids when using statistic routing" return self.load_expert_ids + def get_repeat_count(self): + return self.repeated_count + def is_static_routing(self): return not self.updates_enabled @@ -675,6 +682,8 @@ def __init__(self, self.enable_statistic = False self.enable_update_weights = False + self.next_layer_repeated_count = None + def __del__(self): if not self.is_shutdown: self.shutdown() @@ -696,6 +705,16 @@ def _setup_mpi_comm(self): def set_use_gpu_memcpy(self, use_gpu_memcpy: bool): self.load_balancer_impl.set_use_gpu_memcpy(use_gpu_memcpy) + def set_repeated_for_next_layer(self, repeated_count: int): + """ + Set repeat count for next layer. + + Args: + repeated_count: The repeat count for next layer + """ + assert repeated_count > 0, "repeat count must be greater than 0" + self.next_layer_repeated_count = repeated_count + def add_layer(self, expert_count: int, top_k: int, slot_count_per_rank: int) -> SingleLayerMoeLoadBalancer: """ @@ -712,11 +731,16 @@ def add_layer(self, expert_count: int, top_k: int, single_layer_load_balancer_impl = self.load_balancer_impl.add_layer( expert_count, top_k, slot_count_per_rank) updates_enabled = not self.is_static_routing() + repeat_count = 1 + if self.next_layer_repeated_count is not None: + repeat_count = self.next_layer_repeated_count + self.next_layer_repeated_count = None single_layer_load_balancer = SingleLayerMoeLoadBalancer( single_layer_load_balancer_impl, self.shared_mpi_comm, expert_count, - updates_enabled=updates_enabled) + updates_enabled=updates_enabled, + repeated_count=repeat_count) single_layer_load_balancer.set_shared_memory_base_name( self.shared_memory_base_name) self.single_layer_load_balancers.append(single_layer_load_balancer) @@ -934,6 +958,18 @@ def get_moe_load_balancer() -> Optional[MoeLoadBalancer]: return None +def moe_load_balancer_set_repeated_for_next_layer(repeat_count: int): + """ + Set repeated count for next Single Layer created. + + Args: + repeat_count: repeated count + """ + load_balancer = get_moe_load_balancer() + if load_balancer is not None: + load_balancer.set_repeated_for_next_layer(repeat_count) + + def moe_load_balancer_add_single_layer( expert_count: int, top_k: int, slot_count_per_rank: int) -> Optional[SingleLayerMoeLoadBalancer]: diff --git a/tests/integration/defs/accuracy/test_llm_api_pytorch.py b/tests/integration/defs/accuracy/test_llm_api_pytorch.py index 391d9a7dd19..8f6e546c40a 100644 --- a/tests/integration/defs/accuracy/test_llm_api_pytorch.py +++ b/tests/integration/defs/accuracy/test_llm_api_pytorch.py @@ -787,6 +787,7 @@ def test_fp8_block_scales_4gpus_static_eplb(self): initial_global_assignments=initial_global_assignments, layer_updates_per_iter=0) pytorch_backend_options = dict(use_cuda_graph=True, + moe_backend="WIDEEP", moe_load_balancer=eplb_config) llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8", tensor_parallel_size=4, @@ -800,6 +801,65 @@ def test_fp8_block_scales_4gpus_static_eplb(self): task = GSM8K(self.MODEL_NAME) task.evaluate(llm) + @pytest.mark.skip_less_device(4) + @pytest.mark.skip_device_not_contain(["GB200"]) + @parametrize_with_ids("mtp_nextn", [0, 2]) + def test_bfloat16_4gpus_online_eplb(self, mtp_nextn): + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7) + num_slots = 80 + eplb_config = MoeLoadBalancerConfig(num_slots=num_slots, + layer_updates_per_iter=2) + pytorch_config = dict(use_cuda_graph=True, + moe_backend="WIDEEP", + moe_load_balancer=eplb_config) + mtp_config = None + if mtp_nextn > 0: + mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn) + llm = LLM(self.MODEL_PATH, + tensor_parallel_size=4, + moe_expert_parallel_size=4, + kv_cache_config=kv_cache_config, + enable_attention_dp=True, + **pytorch_config, + speculative_config=mtp_config) + with llm: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + + @pytest.mark.skip_less_device(4) + @pytest.mark.skip_device_not_contain(["GB200"]) + @parametrize_with_ids("fp8kv", [True, False]) + def test_nvfp4_4gpus_online_eplb(self, fp8kv): + kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7) + num_slots = 80 + eplb_config = MoeLoadBalancerConfig(num_slots=num_slots, + layer_updates_per_iter=2) + pytorch_backend_options = dict(use_cuda_graph=True, + moe_backend="WIDEEP", + moe_load_balancer=eplb_config) + quant_config = QuantConfig() + quant_config.quant_algo = QuantAlgo.NVFP4 + if fp8kv: + quant_config.kv_cache_quant_algo = QuantAlgo.FP8 + pytorch_backend_options["kv_cache_dtype"] = "fp8" + + llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only", + tensor_parallel_size=4, + moe_expert_parallel_size=4, + kv_cache_config=kv_cache_config, + **pytorch_backend_options, + enable_attention_dp=True, + quant_config=quant_config) + with llm: + # No need to run MMLU for fp8kv + if not fp8kv: + task = MMLU(self.MODEL_NAME) + task.evaluate(llm) + task = GSM8K(self.MODEL_NAME) + task.evaluate(llm) + @skip_pre_blackwell @parametrize_with_ids( "torch_compile", diff --git a/tests/integration/test_lists/test-db/l0_gb200.yml b/tests/integration/test_lists/test-db/l0_gb200.yml index 348268558f3..2061a04ed59 100644 --- a/tests/integration/test_lists/test-db/l0_gb200.yml +++ b/tests/integration/test_lists/test-db/l0_gb200.yml @@ -72,3 +72,5 @@ l0_gb200: - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True] + - accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2] From b88974fc89c670b98b40c761f96897f7da237e15 Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Fri, 20 Jun 2025 10:57:15 +0800 Subject: [PATCH 2/3] add MTP and chunk support for ExpertStatistic Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- tensorrt_llm/_torch/expert_statistic.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/expert_statistic.py b/tensorrt_llm/_torch/expert_statistic.py index 5f18c5b95ef..98dc127f4e0 100644 --- a/tensorrt_llm/_torch/expert_statistic.py +++ b/tensorrt_llm/_torch/expert_statistic.py @@ -92,4 +92,7 @@ def _maybe_add_info(self, expert_count: int, counts = torch.bincount(token_selected_experts.flatten(), minlength=expert_count) key = f"{self.current_iter_id}_{self.current_layer}" - self._records[key] = counts.cpu() + if key not in self._records: + self._records[key] = counts.cpu() + else: + self._records[key] += counts.cpu() From d093c63db2278dc84002cd64528ad668d0afa49f Mon Sep 17 00:00:00 2001 From: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> Date: Fri, 20 Jun 2025 21:25:58 +0800 Subject: [PATCH 3/3] fix allgather when statistic is not needed. Signed-off-by: Dongxu Yang <78518666+dongxuy04@users.noreply.github.com> --- tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py index 384ebb44aa8..0b03ea19f04 100755 --- a/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py +++ b/tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py @@ -484,7 +484,7 @@ def forward_chunk( dim=0, sizes=None if use_dp_padding else all_rank_num_tokens) # use separate allgather since doesn't have sizes, can be optimized but in allgather path it is OK - if is_last_call: + if is_last_call and loadbalancer_local_statistic_info is not None: gathered_loadbalancer_local_statistic_info = allgather( loadbalancer_local_statistic_info, self.mapping, dim=0) # Fp4 gemm has extra scaling factor