Skip to content
14 changes: 12 additions & 2 deletions tensorrt_llm/_torch/distributed/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ def __init__(self, mapping: Mapping, dtype: torch.dtype):
def get_supported_dtypes():
return (torch.float16, torch.bfloat16, torch.float32)

# Check if MNNVL is supported
@staticmethod
def is_mnnvl(mapping: Mapping, dtype: torch.dtype) -> bool:
from tensorrt_llm._mnnvl_utils import MnnvlMemory
Expand Down Expand Up @@ -455,8 +456,14 @@ def __init__(self,
self.workspace = get_allreduce_workspace(self.mapping)

# Initialize MNNVL AllReduce if needed
if self.strategy == AllReduceStrategy.MNNVL:
if MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
if self.strategy in (AllReduceStrategy.AUTO,
AllReduceStrategy.MNNVL):
if self.mapping.tp_size != self.mapping.world_size:
logger.debug(
f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} "
f"!= world_size:{self.mapping.world_size}")
self.mnnvl_allreduce = None
elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype):
try:
self.mnnvl_allreduce = MNNVLAllReduce(
self.mapping, dtype) if dtype else None
Expand All @@ -474,6 +481,9 @@ def __init__(self,
)
self.mnnvl_allreduce = None

def is_mnnvl(self) -> bool:
return self.mnnvl_allreduce is not None

def forward(
self,
input: torch.Tensor,
Expand Down
24 changes: 15 additions & 9 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
torch.cuda.Stream]):
super().__init__()
self.model_config = model_config
config = model_config.pretrained_config
self.config = model_config.pretrained_config
config = self.config

self.hidden_size = config.hidden_size
self.moe_intermediate_size = config.moe_intermediate_size
Expand Down Expand Up @@ -642,6 +643,10 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4()

has_tp = mapping.has_tp()
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy,
dtype=config.torch_dtype)
self.moe_allreduce = MoEAllReduce(self.mapping)

if (config.n_routed_experts is not None
and layer_idx >= config.first_k_dense_replace
Expand Down Expand Up @@ -694,10 +699,6 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.layer_idx = layer_idx
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy,
dtype=config.torch_dtype)
self.moe_allreduce = MoEAllReduce(self.mapping)
self.next_layer_layernorm: RMSNorm = None

def _get_decoder_layer_quant_config(
Expand Down Expand Up @@ -743,10 +744,15 @@ def _compute_mlp_tp_size(self, intermediate_size: int,
intermediate_size // block_size,
self.mapping.tp_size,
)
mlp_tp_size = math.gcd(
tp,
self.mapping.gpus_per_node,
) if tp > self.mapping.gpus_per_node else tp # Avoid costly inter-node TP

if tp > self.mapping.gpus_per_node and not self.allreduce.is_mnnvl(
):
mlp_tp_size = math.gcd(
tp,
self.mapping.gpus_per_node,
) # Avoid costly inter-node TP when MNNVL is not supported
else:
mlp_tp_size = tp
return mlp_tp_size

def forward(
Expand Down