Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion tensorrt_llm/_torch/expert_statistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,4 +92,7 @@ def _maybe_add_info(self, expert_count: int,
counts = torch.bincount(token_selected_experts.flatten(),
minlength=expert_count)
key = f"{self.current_iter_id}_{self.current_layer}"
self._records[key] = counts.cpu()
if key not in self._records:
self._records[key] = counts.cpu()
else:
self._records[key] += counts.cpu()
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@
from ..modules.decoder_layer import DecoderLayer
from ..modules.embedding import Embedding
from ..modules.fused_moe import (CutlassFusedMoE, DeepSeekV3MoeRoutingMethod,
WideEPMoE, create_moe)
WideEPMoE, create_moe,
moe_load_balancer_set_repeated_for_next_layer)
from ..modules.gated_mlp import GatedMLP
from ..modules.linear import Linear, TensorParallelMode, WeightsLoadingConfig
from ..modules.multi_stream_utils import maybe_execute_in_parallel
Expand Down Expand Up @@ -1069,6 +1070,7 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig]):
self.num_hidden_layers = self.config.num_hidden_layers
assert ckpt_nextn > 0, "There is not MTP modules in the checkpoint."
if ckpt_nextn == 1 and not model_config.spec_config.use_mtp_vanilla:
moe_load_balancer_set_repeated_for_next_layer(model_nextn)
mtp_layer = DeepseekV3MTP(model_config, self.num_hidden_layers,
self.model.aux_stream_dict)
self.model.layers.append(mtp_layer)
Expand Down
4 changes: 3 additions & 1 deletion tensorrt_llm/_torch/modules/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from .fused_moe_vanilla import VanillaMoE
from .fused_moe_wide_ep import WideEPMoE
from .interface import MoE, MoEWeightLoadingMode
from .moe_load_balancer import MoeLoadBalancer
from .moe_load_balancer import (MoeLoadBalancer,
moe_load_balancer_set_repeated_for_next_layer)
from .quantization import FusedMoEQuantScalesFP8
from .routing import (BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod,
DefaultMoeRoutingMethod,
Expand All @@ -23,6 +24,7 @@
"get_moe_cls",
"Llama4RenormalizeMoeRoutingMethod",
"LoadBalancedMoeRoutingMethod",
"moe_load_balancer_set_repeated_for_next_layer",
"MoE",
"MoeLoadBalancer",
"MoEWeightLoadingMode",
Expand Down
17 changes: 13 additions & 4 deletions tensorrt_llm/_torch/modules/fused_moe/fused_moe_wide_ep.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(

moe_load_balancer = get_moe_load_balancer()
self.layer_load_balancer = None
self.repeat_idx = 0
self.repeat_count = 1

moe_load_balancer_config = model_config.moe_load_balancer
init_expert_size_per_partition = moe_load_balancer_config.num_local_slots if moe_load_balancer_config else self.num_experts // self.ep_size
Expand All @@ -102,6 +104,7 @@ def __init__(
self.expert_size_per_partition = moe_load_balancer_config.num_local_slots
self.layer_load_balancer = moe_load_balancer.add_layer(
self.num_experts, top_k, self.expert_size_per_partition)
self.repeat_count = self.layer_load_balancer.get_repeat_count()
loaded_initial_global_assignments = moe_load_balancer_config.get_layer_initial_global_assignments(
self.layer_idx)
self.num_slots = moe_load_balancer_config.num_slots
Expand Down Expand Up @@ -434,6 +437,8 @@ def forward_chunk(
)

x_sf = None
x_row = x.shape[0]
x_col = x.shape[1]
sf_swizzle = True
if self.has_any_quant:
if self.has_fp8_qdq:
Expand Down Expand Up @@ -479,7 +484,7 @@ def forward_chunk(
dim=0,
sizes=None if use_dp_padding else all_rank_num_tokens)
# use separate allgather since doesn't have sizes, can be optimized but in allgather path it is OK
if is_last_call:
if is_last_call and loadbalancer_local_statistic_info is not None:
gathered_loadbalancer_local_statistic_info = allgather(
loadbalancer_local_statistic_info, self.mapping, dim=0)
# Fp4 gemm has extra scaling factor
Expand Down Expand Up @@ -668,13 +673,16 @@ def forward(
else:
all_rank_num_tokens_padded = all_rank_num_tokens
if num_chunks == 1:
is_first_call = self.repeat_idx == 0
is_last_call = self.repeat_idx == self.repeat_count - 1
outputs = self.forward_chunk(
x,
router_logits,
cutlass_min_latency_mode,
output_dtype,
all_rank_num_tokens=all_rank_num_tokens_padded,
use_dp_padding=use_dp_padding)
use_dp_padding=use_dp_padding,
repeating_info=(is_first_call, is_last_call))
outputs = self.reducescatter_or_allreduce(
outputs,
all_rank_num_tokens=all_rank_num_tokens_padded,
Expand Down Expand Up @@ -717,8 +725,8 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
# Postpone reduce-scatter/all-reduce to the next iteration to achieve better overlap
for idx_chunk, (x, router_logits) in enumerate(
zip(x_list, router_logits_list)):
is_first_call = idx_chunk == 0
is_last_call = idx_chunk == num_chunks - 1
is_first_call = idx_chunk == 0 and self.repeat_idx == 0
is_last_call = idx_chunk == num_chunks - 1 and self.repeat_idx == self.repeat_count - 1
if not self.enable_alltoall:
if idx_chunk % 2 == 0:
with torch.cuda.stream(self.aux_stream):
Expand Down Expand Up @@ -777,6 +785,7 @@ def split_chunk(split_token_num: int, split_num_chunks: int):
if self.use_dp:
rank = self.mapping.tp_rank
outputs = outputs[:all_rank_num_tokens[rank]]
self.repeat_idx = 0 if self.repeat_idx == self.repeat_count - 1 else self.repeat_idx + 1
return outputs

def alltoall_prepare_maybe_dispatch(
Expand Down
44 changes: 40 additions & 4 deletions tensorrt_llm/_torch/modules/fused_moe/moe_load_balancer.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,10 +182,11 @@ def finalize_layer_weights(self):
offset = 0
for name in self.names:
for expert_id in range(self.expert_start, self.expert_end):
t = self.shared_tensors[(expert_id, name)]
t = self.shared_tensors[(expert_id, name)].contiguous().cpu()
data_size = t.numel() * t.element_size()
aligned_size = self.align_size(data_size)
shm.buf[offset:offset + data_size] = t.numpy().tobytes()
shm.buf[offset:offset + data_size] = t.flatten().view(
torch.int8).numpy().tobytes()
dtype = t.dtype
tensor_shape = t.shape
elt_count = t.numel()
Expand Down Expand Up @@ -270,7 +271,8 @@ def __init__(
single_layer_load_balancer_impl: _tbr.SingleLayerMoeLoadBalancer,
shared_mpi_comm: MPI.Comm,
expert_count: int,
updates_enabled: bool = True):
updates_enabled: bool = True,
repeated_count=1):
"""
Initialize a SingleLayerMoeLoadBalancer instance.

Expand All @@ -279,6 +281,7 @@ def __init__(
shared_mpi_comm: The MPI communicator for shared memory
expert_count: total number of experts
updates_enabled: whether to enable weight updates
repeated_count: the repeated count of current layer, used when forward is repeated more than once like MTP.
"""
self.single_layer_load_balancer_impl = single_layer_load_balancer_impl
self.single_layer_load_balancer_ptr = single_layer_load_balancer_impl.get_pointer(
Expand Down Expand Up @@ -306,6 +309,7 @@ def __init__(

self.cudagraph_stream = None
self.cudagraph_event = None
self.repeated_count = repeated_count

self.statistic_stream = None
self.statistic_event = None
Expand All @@ -317,6 +321,9 @@ def get_load_expert_ids(self):
assert self.updates_enabled, "should not call get_load_expert_ids when using statistic routing"
return self.load_expert_ids

def get_repeat_count(self):
return self.repeated_count

def is_static_routing(self):
return not self.updates_enabled

Expand Down Expand Up @@ -675,6 +682,8 @@ def __init__(self,
self.enable_statistic = False
self.enable_update_weights = False

self.next_layer_repeated_count = None

def __del__(self):
if not self.is_shutdown:
self.shutdown()
Expand All @@ -696,6 +705,16 @@ def _setup_mpi_comm(self):
def set_use_gpu_memcpy(self, use_gpu_memcpy: bool):
self.load_balancer_impl.set_use_gpu_memcpy(use_gpu_memcpy)

def set_repeated_for_next_layer(self, repeated_count: int):
"""
Set repeat count for next layer.

Args:
repeated_count: The repeat count for next layer
"""
assert repeated_count > 0, "repeat count must be greater than 0"
self.next_layer_repeated_count = repeated_count

def add_layer(self, expert_count: int, top_k: int,
slot_count_per_rank: int) -> SingleLayerMoeLoadBalancer:
"""
Expand All @@ -712,11 +731,16 @@ def add_layer(self, expert_count: int, top_k: int,
single_layer_load_balancer_impl = self.load_balancer_impl.add_layer(
expert_count, top_k, slot_count_per_rank)
updates_enabled = not self.is_static_routing()
repeat_count = 1
if self.next_layer_repeated_count is not None:
repeat_count = self.next_layer_repeated_count
self.next_layer_repeated_count = None
single_layer_load_balancer = SingleLayerMoeLoadBalancer(
single_layer_load_balancer_impl,
self.shared_mpi_comm,
expert_count,
updates_enabled=updates_enabled)
updates_enabled=updates_enabled,
repeated_count=repeat_count)
single_layer_load_balancer.set_shared_memory_base_name(
self.shared_memory_base_name)
self.single_layer_load_balancers.append(single_layer_load_balancer)
Expand Down Expand Up @@ -934,6 +958,18 @@ def get_moe_load_balancer() -> Optional[MoeLoadBalancer]:
return None


def moe_load_balancer_set_repeated_for_next_layer(repeat_count: int):
"""
Set repeated count for next Single Layer created.

Args:
repeat_count: repeated count
"""
load_balancer = get_moe_load_balancer()
if load_balancer is not None:
load_balancer.set_repeated_for_next_layer(repeat_count)


def moe_load_balancer_add_single_layer(
expert_count: int, top_k: int,
slot_count_per_rank: int) -> Optional[SingleLayerMoeLoadBalancer]:
Expand Down
60 changes: 60 additions & 0 deletions tests/integration/defs/accuracy/test_llm_api_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,7 @@ def test_fp8_block_scales_4gpus_static_eplb(self):
initial_global_assignments=initial_global_assignments,
layer_updates_per_iter=0)
pytorch_backend_options = dict(use_cuda_graph=True,
moe_backend="WIDEEP",
moe_load_balancer=eplb_config)
llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/fp8",
tensor_parallel_size=4,
Expand All @@ -800,6 +801,65 @@ def test_fp8_block_scales_4gpus_static_eplb(self):
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
@pytest.mark.skip_device_not_contain(["GB200"])
@parametrize_with_ids("mtp_nextn", [0, 2])
def test_bfloat16_4gpus_online_eplb(self, mtp_nextn):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
num_slots = 80
eplb_config = MoeLoadBalancerConfig(num_slots=num_slots,
layer_updates_per_iter=2)
pytorch_config = dict(use_cuda_graph=True,
moe_backend="WIDEEP",
moe_load_balancer=eplb_config)
mtp_config = None
if mtp_nextn > 0:
mtp_config = MTPDecodingConfig(num_nextn_predict_layers=mtp_nextn)
llm = LLM(self.MODEL_PATH,
tensor_parallel_size=4,
moe_expert_parallel_size=4,
kv_cache_config=kv_cache_config,
enable_attention_dp=True,
**pytorch_config,
speculative_config=mtp_config)
with llm:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@pytest.mark.skip_less_device(4)
@pytest.mark.skip_device_not_contain(["GB200"])
@parametrize_with_ids("fp8kv", [True, False])
def test_nvfp4_4gpus_online_eplb(self, fp8kv):
kv_cache_config = KvCacheConfig(free_gpu_memory_fraction=0.7)
num_slots = 80
eplb_config = MoeLoadBalancerConfig(num_slots=num_slots,
layer_updates_per_iter=2)
pytorch_backend_options = dict(use_cuda_graph=True,
moe_backend="WIDEEP",
moe_load_balancer=eplb_config)
quant_config = QuantConfig()
quant_config.quant_algo = QuantAlgo.NVFP4
if fp8kv:
quant_config.kv_cache_quant_algo = QuantAlgo.FP8
pytorch_backend_options["kv_cache_dtype"] = "fp8"

llm = LLM(f"{llm_models_root()}/DeepSeek-V3-Lite/nvfp4_moe_only",
tensor_parallel_size=4,
moe_expert_parallel_size=4,
kv_cache_config=kv_cache_config,
**pytorch_backend_options,
enable_attention_dp=True,
quant_config=quant_config)
with llm:
# No need to run MMLU for fp8kv
if not fp8kv:
task = MMLU(self.MODEL_NAME)
task.evaluate(llm)
task = GSM8K(self.MODEL_NAME)
task.evaluate(llm)

@skip_pre_blackwell
@parametrize_with_ids(
"torch_compile",
Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_gb200.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,5 @@ l0_gb200:
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTLASS-mtp_nextn=2-pp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-tp4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=TRTLLM-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus_online_eplb[fp8kv=True]
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_bfloat16_4gpus_online_eplb[mtp_nextn=2]