From 45c7d52a10a75f67df5a5b702fa119f7bbf5aad5 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 13 Aug 2025 22:53:55 -0700 Subject: [PATCH 01/12] Avoid split mlp tp for mnnvl Signed-off-by: Shiyu Li --- tensorrt_llm/_torch/distributed/ops.py | 192 +++++++----------- .../_torch/models/modeling_deepseekv3.py | 21 +- 2 files changed, 90 insertions(+), 123 deletions(-) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 74ac9590a38..b23ea8a837c 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -10,8 +10,7 @@ from tensorrt_llm._utils import mpi_barrier from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer -from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams, - AllReduceStrategy, MoEAllReduceParams) +from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams, AllReduceStrategy, MoEAllReduceParams from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.plugin.plugin import CustomAllReduceHelper @@ -21,56 +20,48 @@ def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor: - if not hasattr(_thread_local, f'allreduce_workspaces_{mapping.pp_rank}'): - setattr(_thread_local, f'allreduce_workspaces_{mapping.pp_rank}', {}) + if not hasattr(_thread_local, f"allreduce_workspaces_{mapping.pp_rank}"): + setattr(_thread_local, f"allreduce_workspaces_{mapping.pp_rank}", {}) - allreduce_workspaces = getattr(_thread_local, - f'allreduce_workspaces_{mapping.pp_rank}') + allreduce_workspaces = getattr(_thread_local, f"allreduce_workspaces_{mapping.pp_rank}") if mapping not in allreduce_workspaces: ipc_buffers, workspace = CustomAllReduceHelper.allocate_allreduce_fusion_workspace( mapping, - CustomAllReduceHelper.max_workspace_size_auto( - mapping.tp_size, support_deterministic=False), + CustomAllReduceHelper.max_workspace_size_auto(mapping.tp_size, support_deterministic=False), ) allreduce_workspaces[mapping] = (ipc_buffers, workspace) return allreduce_workspaces[mapping][1] def allocate_low_presicion_allreduce_workspace(mapping: Mapping) -> None: - if not hasattr(_thread_local, 'lowprecision_allreduce_workspaces'): + if not hasattr(_thread_local, "lowprecision_allreduce_workspaces"): _thread_local.lowprecision_allreduce_workspaces = {} lowprecision_allreduce_workspaces = _thread_local.lowprecision_allreduce_workspaces if mapping not in lowprecision_allreduce_workspaces: ipc_buffers, workspace = CustomAllReduceHelper.allocate_lowprecision_workspace( mapping, - CustomAllReduceHelper.max_workspace_size_lowprecision( - mapping.tp_size), + CustomAllReduceHelper.max_workspace_size_lowprecision(mapping.tp_size), ) lowprecision_allreduce_workspaces[mapping] = (ipc_buffers, workspace) - CustomAllReduceHelper.initialize_lowprecision_buffers( - workspace, mapping.tp_size) + CustomAllReduceHelper.initialize_lowprecision_buffers(workspace, mapping.tp_size) return def get_allreduce_mnnvl_workspace( mapping: Mapping, dtype: torch.dtype ) -> Tuple[McastGPUBuffer, torch.Tensor, torch.Tensor, int]: - if not hasattr(_thread_local, - f'allreduce_mnnvl_workspaces_{mapping.pp_rank}'): - setattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}', - {}) + if not hasattr(_thread_local, f"allreduce_mnnvl_workspaces_{mapping.pp_rank}"): + setattr(_thread_local, f"allreduce_mnnvl_workspaces_{mapping.pp_rank}", {}) force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1" - allreduce_mnnvl_workspaces = getattr( - _thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}') + allreduce_mnnvl_workspaces = getattr(_thread_local, f"allreduce_mnnvl_workspaces_{mapping.pp_rank}") if mapping not in allreduce_mnnvl_workspaces: # buffer shape: [3, 2, buffer_tokens, hidden_dim] stride = 3 * 2 * dtype.itemsize # Max hidden_size_to_support max_hidden_dim = 16384 - buffer_size_in_bytes = math.ceil( - 12_000_000 / (max_hidden_dim * stride)) * (max_hidden_dim * stride) + buffer_size_in_bytes = math.ceil(12_000_000 / (max_hidden_dim * stride)) * (max_hidden_dim * stride) max_num_elements = buffer_size_in_bytes // stride mcast_buffer = McastGPUBuffer( @@ -81,8 +72,7 @@ def get_allreduce_mnnvl_workspace( True, # mnNvlink ) - buffer = mcast_buffer.get_uc_buffer(mapping.tp_rank, - (3, 2, max_num_elements), dtype, 0) + buffer = mcast_buffer.get_uc_buffer(mapping.tp_rank, (3, 2, max_num_elements), dtype, 0) # Only initialize the buffer when we need to resize it buffer.fill_(-0.0) # CPU barrier since we assume this should not be called in cuda graph @@ -92,44 +82,32 @@ def get_allreduce_mnnvl_workspace( # This is a buffer to maintain the state of this allreduce Op # Should have the same lifetime with self._buffer # [Buffer_ptr, Clear_ptr, num_tokens_to_clear,atomic access counter] - buffer_flags = torch.tensor([0, 2, 0, 0], - dtype=torch.uint32, - device=torch.device("cuda", - mapping.local_rank)) + buffer_flags = torch.tensor([0, 2, 0, 0], dtype=torch.uint32, device=torch.device("cuda", mapping.local_rank)) - allreduce_mnnvl_workspaces[mapping] = (mcast_buffer, buffer, - buffer_flags, max_num_elements) + allreduce_mnnvl_workspaces[mapping] = (mcast_buffer, buffer, buffer_flags, max_num_elements) return allreduce_mnnvl_workspaces[mapping] -def userbuffers_allreduce_finalize( - input: torch.Tensor, - force_applying_finalize: bool = False) -> torch.Tensor: - output = torch.ops.trtllm.userbuffers_allreduce_finalize( - input, force_applying_finalize) +def userbuffers_allreduce_finalize(input: torch.Tensor, force_applying_finalize: bool = False) -> torch.Tensor: + output = torch.ops.trtllm.userbuffers_allreduce_finalize(input, force_applying_finalize) return output def get_output_info(input: torch.Tensor, dim: int) -> List[int]: dim = dim % input.ndim - output_shape = [ - val if idx != dim else -1 for idx, val in enumerate(input.shape) - ] + output_shape = [val if idx != dim else -1 for idx, val in enumerate(input.shape)] numel_base = -math.prod(output_shape) - return {'output_shape': output_shape, 'numel_base': numel_base} + return {"output_shape": output_shape, "numel_base": numel_base} -def filter_valid_input( - input_list: List[torch.Tensor] -) -> Tuple[List[torch.Tensor], List[bool]]: +def filter_valid_input(input_list: List[torch.Tensor]) -> Tuple[List[torch.Tensor], List[bool]]: func_valid = lambda x: x is not None valid_list = list(map(func_valid, input_list)) input_list = list(filter(func_valid, input_list)) return input_list, valid_list -def restore_full_output(valid_outputs: List[torch.Tensor], - valid_list: List[bool]) -> List[torch.Tensor]: +def restore_full_output(valid_outputs: List[torch.Tensor], valid_list: List[bool]) -> List[torch.Tensor]: idx = 0 full_outputs = [] for v in valid_list: @@ -144,7 +122,7 @@ def allgather( dim: int = -1, sizes: Optional[List[int]] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: - ''' + """ Add an operation that performs a collective all-gather. If 'sizes' is 'None', the input tensors in the different ranks must have the same shape. @@ -168,7 +146,7 @@ def allgather( sizes(Optional[List[int]]): An optional list indicating 'input.shape[dim]' in all ranks. By default None. Returns: The gathered tensor or tensor list. - ''' + """ if mapping.tp_size == 1: return input @@ -177,24 +155,18 @@ def allgather( if isinstance(input, torch.Tensor): assert input.shape[dim] == sizes[mapping.tp_rank] else: - assert all([ - val.shape[dim] == sizes[mapping.tp_rank] for val in input - if val is not None - ]) + assert all([val.shape[dim] == sizes[mapping.tp_rank] for val in input if val is not None]) # Inputs are reshaped in this way to pass necessary shape information to the allgather op if isinstance(input, torch.Tensor): torch_op = torch.ops.trtllm.allgather output_info = get_output_info(input, dim) - input = input.contiguous().view(-1, output_info['numel_base']) + input = input.contiguous().view(-1, output_info["numel_base"]) else: input, valid = filter_valid_input(input) torch_op = torch.ops.trtllm.allgather_list output_info = [get_output_info(val, dim) for val in input] - input = [ - val.contiguous().view(-1, val_info['numel_base']) - for val, val_info in zip(input, output_info) - ] + input = [val.contiguous().view(-1, val_info["numel_base"]) for val, val_info in zip(input, output_info)] output = torch_op( input, @@ -204,23 +176,19 @@ def allgather( def convert_output(x, x_info): if dim == 0: - x = x.view(x_info['output_shape']) + x = x.view(x_info["output_shape"]) else: if sizes is None: x_list = x.chunk(mapping.tp_size) else: x_list = x.split(sizes) - x = torch.cat([x.reshape(x_info['output_shape']) for x in x_list], - dim=dim) + x = torch.cat([x.reshape(x_info["output_shape"]) for x in x_list], dim=dim) return x if isinstance(input, torch.Tensor): output = convert_output(output, output_info) else: - output = [ - convert_output(val, val_info) - for val, val_info in zip(output, output_info) - ] + output = [convert_output(val, val_info) for val, val_info in zip(output, output_info)] output = restore_full_output(output, valid) return output @@ -240,21 +208,18 @@ def reducescatter( if isinstance(input, torch.Tensor): assert input.shape[dim] == sum_split_size else: - assert all([ - val.shape[dim] == sum_split_size for val in input - if val is not None - ]) + assert all([val.shape[dim] == sum_split_size for val in input if val is not None]) def convert_input(x, x_info): # Inputs are reshaped in this way to pass necessary shape information to the reducescatter op if dim == 0: - x = x.contiguous().view(-1, x_info['numel_base']) + x = x.contiguous().view(-1, x_info["numel_base"]) else: if sizes is None: x_list = x.chunk(mapping.tp_size, dim=dim) else: x_list = x.split(sizes, dim=dim) - x = torch.cat([x.reshape(-1, x_info['numel_base']) for x in x_list]) + x = torch.cat([x.reshape(-1, x_info["numel_base"]) for x in x_list]) return x if isinstance(input, torch.Tensor): @@ -265,10 +230,7 @@ def convert_input(x, x_info): input, valid = filter_valid_input(input) torch_op = torch.ops.trtllm.reducescatter_list output_info = [get_output_info(val, dim) for val in input] - input = [ - convert_input(val, val_info) - for val, val_info in zip(input, output_info) - ] + input = [convert_input(val, val_info) for val, val_info in zip(input, output_info)] output = torch_op( input, @@ -277,12 +239,9 @@ def convert_input(x, x_info): ) if isinstance(input, torch.Tensor): - output = output.view(output_info['output_shape']) + output = output.view(output_info["output_shape"]) else: - output = [ - val.view(val_info['output_shape']) - for val, val_info in zip(output, output_info) - ] + output = [val.view(val_info["output_shape"]) for val, val_info in zip(output, output_info)] output = restore_full_output(output, valid) return output @@ -300,13 +259,13 @@ def __init__(self, mapping: Mapping, dtype: torch.dtype): super().__init__() self.mapping = mapping self.dtype = dtype - assert ( - dtype in MNNVLAllReduce.get_supported_dtypes() - and (not mapping.has_cp()) + assert dtype in MNNVLAllReduce.get_supported_dtypes() and ( + not mapping.has_cp() ), "MNNVL all reduce only supports dtype {MNNVLAllReduce.get_supported_dtypes()} and without cp." - self.mcast_buffer_mnnvl, self.buffer_mnnvl, self.buffer_flags_mnnvl, self.max_num_elements_mnnvl = get_allreduce_mnnvl_workspace( - self.mapping, dtype) + self.mcast_buffer_mnnvl, self.buffer_mnnvl, self.buffer_flags_mnnvl, self.max_num_elements_mnnvl = ( + get_allreduce_mnnvl_workspace(self.mapping, dtype) + ) @staticmethod def get_supported_dtypes(): @@ -318,9 +277,13 @@ def is_mnnvl(mapping: Mapping, dtype: torch.dtype) -> bool: arch = platform.machine().lower() is_on_aarch64 = "aarch64" in arch - return (dtype in MNNVLAllReduce.get_supported_dtypes() - and not mapping.has_cp() and mapping.is_multi_node() - and MnnvlMemory.supports_mnnvl() and is_on_aarch64) + return ( + dtype in MNNVLAllReduce.get_supported_dtypes() + and not mapping.has_cp() + and mapping.is_multi_node() + and MnnvlMemory.supports_mnnvl() + and is_on_aarch64 + ) def forward( self, @@ -346,9 +309,7 @@ def forward( max_num_tokens = self.max_num_elements_mnnvl // hidden_dim num_elements_in_use = max_num_tokens * hidden_dim if num_tokens > max_num_tokens: - logger.debug( - f"MNNVL AllReduce can't be enabled due to {num_tokens=} larger than {max_num_tokens=}." - ) + logger.debug(f"MNNVL AllReduce can't be enabled due to {num_tokens=} larger than {max_num_tokens=}.") return None # This should not happen but leave this check for future code changes @@ -359,9 +320,7 @@ def forward( return None output = torch.empty_like(input) - buffer_mnnvl = self.buffer_mnnvl.view(-1)[:(3 * 2 * - num_elements_in_use)].view( - 3, 2, -1, hidden_dim) + buffer_mnnvl = self.buffer_mnnvl.view(-1)[: (3 * 2 * num_elements_in_use)].view(3, 2, -1, hidden_dim) if fusion_op == AllReduceFusionOp.NONE: output = torch.ops.trtllm.mnnvl_twoshot_allreduce( @@ -373,8 +332,10 @@ def forward( ) return output.view(shape) # Fallback to use other allreduce if hidden_size is not supported - elif (fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM - and hidden_dim in MNNVLAllReduce.SUPPORTED_FUSION_HIDDEN_DIMS): + elif ( + fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM + and hidden_dim in MNNVLAllReduce.SUPPORTED_FUSION_HIDDEN_DIMS + ): torch.ops.trtllm.mnnvl_twoshot_allreduce( input, buffer_mnnvl, @@ -398,10 +359,12 @@ def forward( class AllReduce(nn.Module): - def __init__(self, - mapping: Mapping, - strategy: AllReduceStrategy = AllReduceStrategy.AUTO, - dtype: Optional[torch.dtype] = None): + def __init__( + self, + mapping: Mapping, + strategy: AllReduceStrategy = AllReduceStrategy.AUTO, + dtype: Optional[torch.dtype] = None, + ): super().__init__() """ AllReduce is a module that performs an all-reduce operation on a tensor. @@ -455,23 +418,22 @@ def __init__(self, self.workspace = get_allreduce_workspace(self.mapping) # Initialize MNNVL AllReduce if needed - if self.strategy == AllReduceStrategy.MNNVL: + if self.strategy in (AllReduceStrategy.AUTO, AllReduceStrategy.MNNVL): + if self.mapping.tp_size != self.mapping.world_size: + logger.debug(f"MNNVLAllReduce is disabled due to tp_size != world_size") + self.mnnvl_allreduce = None if MNNVLAllReduce.is_mnnvl(self.mapping, dtype): try: - self.mnnvl_allreduce = MNNVLAllReduce( - self.mapping, dtype) if dtype else None + self.mnnvl_allreduce = MNNVLAllReduce(self.mapping, dtype) if dtype else None if self.mnnvl_allreduce: logger.debug(f"MNNVLAllReduce is enabled") else: logger.debug(f"MNNVLAllReduce is disabled") except Exception as e: - logger.debug( - f"MNNVL AllReduce can't be enabled due to {e}.") + logger.debug(f"MNNVL AllReduce can't be enabled due to {e}.") self.mnnvl_allreduce = None else: - logger.debug( - f"MNNVLAllReduce can't be enabled due to failing the is_mnnvl check." - ) + logger.debug(f"MNNVLAllReduce can't be enabled due to failing the is_mnnvl check.") self.mnnvl_allreduce = None def forward( @@ -480,7 +442,7 @@ def forward( *, all_reduce_params: Optional[AllReduceParams] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: - ''' + """ The input tensors in the different ranks must have the same shape. The output tensor will have that same shape with the input tensor. The output tensor will be replicated among the TP group. @@ -502,10 +464,8 @@ def forward( RESIDUAL_RMS_NORM_OUT_QUANT_FP8: [norm, norm_quant, residual] RESIDUAL_RMS_NORM_QUANT_NVFP4: [norm_quant_fp4, scale_factor, residual] RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: [norm, norm_quant_fp4, scale_factor, residual] - ''' - if self.mapping.tp_size == 1 or (all_reduce_params is not None - and all_reduce_params.enable_allreduce - == False): + """ + if self.mapping.tp_size == 1 or (all_reduce_params is not None and all_reduce_params.enable_allreduce == False): return input input = input.contiguous() # Underlying op requires contiguous input @@ -516,8 +476,7 @@ def forward( # Try MNNVL AllReduce first if available if self.mnnvl_allreduce: - mnnvl_output = self.mnnvl_allreduce( - input, all_reduce_params=all_reduce_params) + mnnvl_output = self.mnnvl_allreduce(input, all_reduce_params=all_reduce_params) if mnnvl_output is not None: return mnnvl_output @@ -536,8 +495,7 @@ def forward( strategy=allreduce_strategy, op=all_reduce_params.fusion_op, eps=all_reduce_params.eps, - trigger_completion_at_end=all_reduce_params. - trigger_completion_at_end, + trigger_completion_at_end=all_reduce_params.trigger_completion_at_end, ) return output if len(output) > 1 else output[0] @@ -616,15 +574,15 @@ def forward( eps=all_reduce_params.eps, ) else: - assert all_reduce_params.residual.shape[ - 0] <= self.max_token, "Num tokens must be less than or equal to max_token" + assert ( + all_reduce_params.residual.shape[0] <= self.max_token + ), "Num tokens must be less than or equal to max_token" return torch.ops.trtllm.moe_finalize_allreduce( input=input, residual=all_reduce_params.residual, norm_weight=all_reduce_params.norm_weight, - expanded_idx_to_permuted_idx=all_reduce_params. - expanded_idx_to_permuted_idx, + expanded_idx_to_permuted_idx=all_reduce_params.expanded_idx_to_permuted_idx, shared_expert_output=all_reduce_params.shared_expert_output, expert_scale_factor=all_reduce_params.expert_scale_factor, workspace=self.workspace, diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 5fdcb43be3a..43ea94a1155 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -41,7 +41,7 @@ from tensorrt_llm._ipc_utils import can_access_peer from tensorrt_llm._utils import get_sm_version -from tensorrt_llm.functional import PositionEmbeddingType +from tensorrt_llm.functional import AllReduceStrategy, PositionEmbeddingType from tensorrt_llm.llmapi.utils import enable_llm_debug from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig @@ -51,7 +51,8 @@ from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, - MoEAllReduce, MoEAllReduceParams, allgather) + MNNVLAllReduce, MoEAllReduce, MoEAllReduceParams, + allgather) from ..model_config import ModelConfig from ..modules.attention import MLA from ..modules.decoder_layer import DecoderLayer @@ -743,10 +744,18 @@ def _compute_mlp_tp_size(self, intermediate_size: int, intermediate_size // block_size, self.mapping.tp_size, ) - mlp_tp_size = math.gcd( - tp, - self.mapping.gpus_per_node, - ) if tp > self.mapping.gpus_per_node else tp # Avoid costly inter-node TP + + if tp > self.mapping.gpus_per_node and ( + self.allreduce.strategy not in ( + AllReduceStrategy.AUTO, + AllReduceStrategy.MNNVL, + ) or not MNNVLAllReduce.is_mnnvl()): + mlp_tp_size = math.gcd( + tp, + self.mapping.gpus_per_node, + ) # Avoid costly inter-node TP when MNNVL is not supported + else: + mlp_tp_size = tp return mlp_tp_size def forward( From dbbb44227926333e6dfedce05954b620b3777cd7 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 13 Aug 2025 23:11:17 -0700 Subject: [PATCH 02/12] Wrong branching Signed-off-by: Shiyu Li --- tensorrt_llm/_torch/distributed/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index b23ea8a837c..7bd9e0f3cd1 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -422,7 +422,7 @@ def __init__( if self.mapping.tp_size != self.mapping.world_size: logger.debug(f"MNNVLAllReduce is disabled due to tp_size != world_size") self.mnnvl_allreduce = None - if MNNVLAllReduce.is_mnnvl(self.mapping, dtype): + elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype): try: self.mnnvl_allreduce = MNNVLAllReduce(self.mapping, dtype) if dtype else None if self.mnnvl_allreduce: From 91f2d570f1c48f0bc590b60463bc68f2eda06ae7 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 13 Aug 2025 23:26:58 -0700 Subject: [PATCH 03/12] Address review comments Signed-off-by: Shiyu Li --- tensorrt_llm/_torch/distributed/ops.py | 4 +++- tensorrt_llm/_torch/models/modeling_deepseekv3.py | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 7bd9e0f3cd1..6aba1e0bb15 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -420,7 +420,9 @@ def __init__( # Initialize MNNVL AllReduce if needed if self.strategy in (AllReduceStrategy.AUTO, AllReduceStrategy.MNNVL): if self.mapping.tp_size != self.mapping.world_size: - logger.debug(f"MNNVLAllReduce is disabled due to tp_size != world_size") + logger.debug( + f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} != world_size:{self.mapping.world_size}" + ) self.mnnvl_allreduce = None elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype): try: diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 43ea94a1155..0b926e6a5d4 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -746,10 +746,12 @@ def _compute_mlp_tp_size(self, intermediate_size: int, ) if tp > self.mapping.gpus_per_node and ( - self.allreduce.strategy not in ( + self.model_config.allreduce_strategy not in ( AllReduceStrategy.AUTO, AllReduceStrategy.MNNVL, - ) or not MNNVLAllReduce.is_mnnvl()): + ) or not MNNVLAllReduce.is_mnnvl( + self.mapping, + self.model_config.pretrained_config.torch_dtype)): mlp_tp_size = math.gcd( tp, self.mapping.gpus_per_node, From a016e59887dae2ea1ac7b81c3bcbd20ae5e89032 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 13 Aug 2025 23:35:15 -0700 Subject: [PATCH 04/12] Fix long log information. Signed-off-by: Shiyu Li --- tensorrt_llm/_torch/distributed/ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 6aba1e0bb15..03d3e54ab8a 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -421,8 +421,8 @@ def __init__( if self.strategy in (AllReduceStrategy.AUTO, AllReduceStrategy.MNNVL): if self.mapping.tp_size != self.mapping.world_size: logger.debug( - f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} != world_size:{self.mapping.world_size}" - ) + f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} " + f"!= world_size:{self.mapping.world_size}") self.mnnvl_allreduce = None elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype): try: From a8d978c2077f56f9e259bf8a3a7bcb3cb3ec8c12 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 14 Aug 2025 09:32:15 -0700 Subject: [PATCH 05/12] Fix import error Signed-off-by: Shiyu Li --- tensorrt_llm/_torch/models/modeling_deepseekv3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 0b926e6a5d4..53304fa299a 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -51,8 +51,8 @@ from ..attention_backend import AttentionMetadata from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, - MNNVLAllReduce, MoEAllReduce, MoEAllReduceParams, - allgather) + MoEAllReduce, MoEAllReduceParams, allgather) +from ..distributed.ops import MNNVLAllReduce from ..model_config import ModelConfig from ..modules.attention import MLA from ..modules.decoder_layer import DecoderLayer From e9465dc034c1e1b686dc36f6f26b06f63a9cfd3c Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 14 Aug 2025 09:47:14 -0700 Subject: [PATCH 06/12] Fix format error during rebase Signed-off-by: Shiyu Li --- tensorrt_llm/_torch/distributed/ops.py | 147 ++++++++++++++++--------- 1 file changed, 97 insertions(+), 50 deletions(-) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 03d3e54ab8a..e5b65137b76 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -10,7 +10,8 @@ from tensorrt_llm._utils import mpi_barrier from tensorrt_llm.bindings.internal.runtime import McastGPUBuffer -from tensorrt_llm.functional import AllReduceFusionOp, AllReduceParams, AllReduceStrategy, MoEAllReduceParams +from tensorrt_llm.functional import (AllReduceFusionOp, AllReduceParams, + AllReduceStrategy, MoEAllReduceParams) from tensorrt_llm.logger import logger from tensorrt_llm.mapping import Mapping from tensorrt_llm.plugin.plugin import CustomAllReduceHelper @@ -23,11 +24,13 @@ def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor: if not hasattr(_thread_local, f"allreduce_workspaces_{mapping.pp_rank}"): setattr(_thread_local, f"allreduce_workspaces_{mapping.pp_rank}", {}) - allreduce_workspaces = getattr(_thread_local, f"allreduce_workspaces_{mapping.pp_rank}") + allreduce_workspaces = getattr(_thread_local, + f"allreduce_workspaces_{mapping.pp_rank}") if mapping not in allreduce_workspaces: ipc_buffers, workspace = CustomAllReduceHelper.allocate_allreduce_fusion_workspace( mapping, - CustomAllReduceHelper.max_workspace_size_auto(mapping.tp_size, support_deterministic=False), + CustomAllReduceHelper.max_workspace_size_auto( + mapping.tp_size, support_deterministic=False), ) allreduce_workspaces[mapping] = (ipc_buffers, workspace) return allreduce_workspaces[mapping][1] @@ -40,28 +43,34 @@ def allocate_low_presicion_allreduce_workspace(mapping: Mapping) -> None: if mapping not in lowprecision_allreduce_workspaces: ipc_buffers, workspace = CustomAllReduceHelper.allocate_lowprecision_workspace( mapping, - CustomAllReduceHelper.max_workspace_size_lowprecision(mapping.tp_size), + CustomAllReduceHelper.max_workspace_size_lowprecision( + mapping.tp_size), ) lowprecision_allreduce_workspaces[mapping] = (ipc_buffers, workspace) - CustomAllReduceHelper.initialize_lowprecision_buffers(workspace, mapping.tp_size) + CustomAllReduceHelper.initialize_lowprecision_buffers( + workspace, mapping.tp_size) return def get_allreduce_mnnvl_workspace( mapping: Mapping, dtype: torch.dtype ) -> Tuple[McastGPUBuffer, torch.Tensor, torch.Tensor, int]: - if not hasattr(_thread_local, f"allreduce_mnnvl_workspaces_{mapping.pp_rank}"): - setattr(_thread_local, f"allreduce_mnnvl_workspaces_{mapping.pp_rank}", {}) + if not hasattr(_thread_local, + f"allreduce_mnnvl_workspaces_{mapping.pp_rank}"): + setattr(_thread_local, f"allreduce_mnnvl_workspaces_{mapping.pp_rank}", + {}) force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1" - allreduce_mnnvl_workspaces = getattr(_thread_local, f"allreduce_mnnvl_workspaces_{mapping.pp_rank}") + allreduce_mnnvl_workspaces = getattr( + _thread_local, f"allreduce_mnnvl_workspaces_{mapping.pp_rank}") if mapping not in allreduce_mnnvl_workspaces: # buffer shape: [3, 2, buffer_tokens, hidden_dim] stride = 3 * 2 * dtype.itemsize # Max hidden_size_to_support max_hidden_dim = 16384 - buffer_size_in_bytes = math.ceil(12_000_000 / (max_hidden_dim * stride)) * (max_hidden_dim * stride) + buffer_size_in_bytes = math.ceil( + 12_000_000 / (max_hidden_dim * stride)) * (max_hidden_dim * stride) max_num_elements = buffer_size_in_bytes // stride mcast_buffer = McastGPUBuffer( @@ -72,7 +81,8 @@ def get_allreduce_mnnvl_workspace( True, # mnNvlink ) - buffer = mcast_buffer.get_uc_buffer(mapping.tp_rank, (3, 2, max_num_elements), dtype, 0) + buffer = mcast_buffer.get_uc_buffer(mapping.tp_rank, + (3, 2, max_num_elements), dtype, 0) # Only initialize the buffer when we need to resize it buffer.fill_(-0.0) # CPU barrier since we assume this should not be called in cuda graph @@ -82,32 +92,44 @@ def get_allreduce_mnnvl_workspace( # This is a buffer to maintain the state of this allreduce Op # Should have the same lifetime with self._buffer # [Buffer_ptr, Clear_ptr, num_tokens_to_clear,atomic access counter] - buffer_flags = torch.tensor([0, 2, 0, 0], dtype=torch.uint32, device=torch.device("cuda", mapping.local_rank)) + buffer_flags = torch.tensor([0, 2, 0, 0], + dtype=torch.uint32, + device=torch.device("cuda", + mapping.local_rank)) - allreduce_mnnvl_workspaces[mapping] = (mcast_buffer, buffer, buffer_flags, max_num_elements) + allreduce_mnnvl_workspaces[mapping] = (mcast_buffer, buffer, + buffer_flags, max_num_elements) return allreduce_mnnvl_workspaces[mapping] -def userbuffers_allreduce_finalize(input: torch.Tensor, force_applying_finalize: bool = False) -> torch.Tensor: - output = torch.ops.trtllm.userbuffers_allreduce_finalize(input, force_applying_finalize) +def userbuffers_allreduce_finalize( + input: torch.Tensor, + force_applying_finalize: bool = False) -> torch.Tensor: + output = torch.ops.trtllm.userbuffers_allreduce_finalize( + input, force_applying_finalize) return output def get_output_info(input: torch.Tensor, dim: int) -> List[int]: dim = dim % input.ndim - output_shape = [val if idx != dim else -1 for idx, val in enumerate(input.shape)] + output_shape = [ + val if idx != dim else -1 for idx, val in enumerate(input.shape) + ] numel_base = -math.prod(output_shape) return {"output_shape": output_shape, "numel_base": numel_base} -def filter_valid_input(input_list: List[torch.Tensor]) -> Tuple[List[torch.Tensor], List[bool]]: +def filter_valid_input( + input_list: List[torch.Tensor] +) -> Tuple[List[torch.Tensor], List[bool]]: func_valid = lambda x: x is not None valid_list = list(map(func_valid, input_list)) input_list = list(filter(func_valid, input_list)) return input_list, valid_list -def restore_full_output(valid_outputs: List[torch.Tensor], valid_list: List[bool]) -> List[torch.Tensor]: +def restore_full_output(valid_outputs: List[torch.Tensor], + valid_list: List[bool]) -> List[torch.Tensor]: idx = 0 full_outputs = [] for v in valid_list: @@ -155,7 +177,10 @@ def allgather( if isinstance(input, torch.Tensor): assert input.shape[dim] == sizes[mapping.tp_rank] else: - assert all([val.shape[dim] == sizes[mapping.tp_rank] for val in input if val is not None]) + assert all([ + val.shape[dim] == sizes[mapping.tp_rank] for val in input + if val is not None + ]) # Inputs are reshaped in this way to pass necessary shape information to the allgather op if isinstance(input, torch.Tensor): @@ -166,7 +191,10 @@ def allgather( input, valid = filter_valid_input(input) torch_op = torch.ops.trtllm.allgather_list output_info = [get_output_info(val, dim) for val in input] - input = [val.contiguous().view(-1, val_info["numel_base"]) for val, val_info in zip(input, output_info)] + input = [ + val.contiguous().view(-1, val_info["numel_base"]) + for val, val_info in zip(input, output_info) + ] output = torch_op( input, @@ -182,13 +210,17 @@ def convert_output(x, x_info): x_list = x.chunk(mapping.tp_size) else: x_list = x.split(sizes) - x = torch.cat([x.reshape(x_info["output_shape"]) for x in x_list], dim=dim) + x = torch.cat([x.reshape(x_info["output_shape"]) for x in x_list], + dim=dim) return x if isinstance(input, torch.Tensor): output = convert_output(output, output_info) else: - output = [convert_output(val, val_info) for val, val_info in zip(output, output_info)] + output = [ + convert_output(val, val_info) + for val, val_info in zip(output, output_info) + ] output = restore_full_output(output, valid) return output @@ -208,7 +240,10 @@ def reducescatter( if isinstance(input, torch.Tensor): assert input.shape[dim] == sum_split_size else: - assert all([val.shape[dim] == sum_split_size for val in input if val is not None]) + assert all([ + val.shape[dim] == sum_split_size for val in input + if val is not None + ]) def convert_input(x, x_info): # Inputs are reshaped in this way to pass necessary shape information to the reducescatter op @@ -230,7 +265,10 @@ def convert_input(x, x_info): input, valid = filter_valid_input(input) torch_op = torch.ops.trtllm.reducescatter_list output_info = [get_output_info(val, dim) for val in input] - input = [convert_input(val, val_info) for val, val_info in zip(input, output_info)] + input = [ + convert_input(val, val_info) + for val, val_info in zip(input, output_info) + ] output = torch_op( input, @@ -241,7 +279,10 @@ def convert_input(x, x_info): if isinstance(input, torch.Tensor): output = output.view(output_info["output_shape"]) else: - output = [val.view(val_info["output_shape"]) for val, val_info in zip(output, output_info)] + output = [ + val.view(val_info["output_shape"]) + for val, val_info in zip(output, output_info) + ] output = restore_full_output(output, valid) return output @@ -264,8 +305,7 @@ def __init__(self, mapping: Mapping, dtype: torch.dtype): ), "MNNVL all reduce only supports dtype {MNNVLAllReduce.get_supported_dtypes()} and without cp." self.mcast_buffer_mnnvl, self.buffer_mnnvl, self.buffer_flags_mnnvl, self.max_num_elements_mnnvl = ( - get_allreduce_mnnvl_workspace(self.mapping, dtype) - ) + get_allreduce_mnnvl_workspace(self.mapping, dtype)) @staticmethod def get_supported_dtypes(): @@ -277,13 +317,9 @@ def is_mnnvl(mapping: Mapping, dtype: torch.dtype) -> bool: arch = platform.machine().lower() is_on_aarch64 = "aarch64" in arch - return ( - dtype in MNNVLAllReduce.get_supported_dtypes() - and not mapping.has_cp() - and mapping.is_multi_node() - and MnnvlMemory.supports_mnnvl() - and is_on_aarch64 - ) + return (dtype in MNNVLAllReduce.get_supported_dtypes() + and not mapping.has_cp() and mapping.is_multi_node() + and MnnvlMemory.supports_mnnvl() and is_on_aarch64) def forward( self, @@ -309,7 +345,9 @@ def forward( max_num_tokens = self.max_num_elements_mnnvl // hidden_dim num_elements_in_use = max_num_tokens * hidden_dim if num_tokens > max_num_tokens: - logger.debug(f"MNNVL AllReduce can't be enabled due to {num_tokens=} larger than {max_num_tokens=}.") + logger.debug( + f"MNNVL AllReduce can't be enabled due to {num_tokens=} larger than {max_num_tokens=}." + ) return None # This should not happen but leave this check for future code changes @@ -320,7 +358,9 @@ def forward( return None output = torch.empty_like(input) - buffer_mnnvl = self.buffer_mnnvl.view(-1)[: (3 * 2 * num_elements_in_use)].view(3, 2, -1, hidden_dim) + buffer_mnnvl = self.buffer_mnnvl.view(-1)[:(3 * 2 * + num_elements_in_use)].view( + 3, 2, -1, hidden_dim) if fusion_op == AllReduceFusionOp.NONE: output = torch.ops.trtllm.mnnvl_twoshot_allreduce( @@ -332,10 +372,8 @@ def forward( ) return output.view(shape) # Fallback to use other allreduce if hidden_size is not supported - elif ( - fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM - and hidden_dim in MNNVLAllReduce.SUPPORTED_FUSION_HIDDEN_DIMS - ): + elif (fusion_op == AllReduceFusionOp.RESIDUAL_RMS_NORM + and hidden_dim in MNNVLAllReduce.SUPPORTED_FUSION_HIDDEN_DIMS): torch.ops.trtllm.mnnvl_twoshot_allreduce( input, buffer_mnnvl, @@ -418,7 +456,8 @@ def __init__( self.workspace = get_allreduce_workspace(self.mapping) # Initialize MNNVL AllReduce if needed - if self.strategy in (AllReduceStrategy.AUTO, AllReduceStrategy.MNNVL): + if self.strategy in (AllReduceStrategy.AUTO, + AllReduceStrategy.MNNVL): if self.mapping.tp_size != self.mapping.world_size: logger.debug( f"MNNVLAllReduce is disabled due to tp_size:{self.mapping.tp_size} " @@ -426,16 +465,20 @@ def __init__( self.mnnvl_allreduce = None elif MNNVLAllReduce.is_mnnvl(self.mapping, dtype): try: - self.mnnvl_allreduce = MNNVLAllReduce(self.mapping, dtype) if dtype else None + self.mnnvl_allreduce = MNNVLAllReduce( + self.mapping, dtype) if dtype else None if self.mnnvl_allreduce: logger.debug(f"MNNVLAllReduce is enabled") else: logger.debug(f"MNNVLAllReduce is disabled") except Exception as e: - logger.debug(f"MNNVL AllReduce can't be enabled due to {e}.") + logger.debug( + f"MNNVL AllReduce can't be enabled due to {e}.") self.mnnvl_allreduce = None else: - logger.debug(f"MNNVLAllReduce can't be enabled due to failing the is_mnnvl check.") + logger.debug( + f"MNNVLAllReduce can't be enabled due to failing the is_mnnvl check." + ) self.mnnvl_allreduce = None def forward( @@ -467,7 +510,9 @@ def forward( RESIDUAL_RMS_NORM_QUANT_NVFP4: [norm_quant_fp4, scale_factor, residual] RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: [norm, norm_quant_fp4, scale_factor, residual] """ - if self.mapping.tp_size == 1 or (all_reduce_params is not None and all_reduce_params.enable_allreduce == False): + if self.mapping.tp_size == 1 or (all_reduce_params is not None + and all_reduce_params.enable_allreduce + == False): return input input = input.contiguous() # Underlying op requires contiguous input @@ -478,7 +523,8 @@ def forward( # Try MNNVL AllReduce first if available if self.mnnvl_allreduce: - mnnvl_output = self.mnnvl_allreduce(input, all_reduce_params=all_reduce_params) + mnnvl_output = self.mnnvl_allreduce( + input, all_reduce_params=all_reduce_params) if mnnvl_output is not None: return mnnvl_output @@ -497,7 +543,8 @@ def forward( strategy=allreduce_strategy, op=all_reduce_params.fusion_op, eps=all_reduce_params.eps, - trigger_completion_at_end=all_reduce_params.trigger_completion_at_end, + trigger_completion_at_end=all_reduce_params. + trigger_completion_at_end, ) return output if len(output) > 1 else output[0] @@ -576,15 +623,15 @@ def forward( eps=all_reduce_params.eps, ) else: - assert ( - all_reduce_params.residual.shape[0] <= self.max_token - ), "Num tokens must be less than or equal to max_token" + assert (all_reduce_params.residual.shape[0] <= self.max_token + ), "Num tokens must be less than or equal to max_token" return torch.ops.trtllm.moe_finalize_allreduce( input=input, residual=all_reduce_params.residual, norm_weight=all_reduce_params.norm_weight, - expanded_idx_to_permuted_idx=all_reduce_params.expanded_idx_to_permuted_idx, + expanded_idx_to_permuted_idx=all_reduce_params. + expanded_idx_to_permuted_idx, shared_expert_output=all_reduce_params.shared_expert_output, expert_scale_factor=all_reduce_params.expert_scale_factor, workspace=self.workspace, From 5cea31c9d710426ef734a1aa354b509d97e57769 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 14 Aug 2025 09:54:21 -0700 Subject: [PATCH 07/12] Remove diffs caused by format Signed-off-by: Shiyu Li --- tensorrt_llm/_torch/distributed/ops.py | 63 +++++++++++++------------- 1 file changed, 31 insertions(+), 32 deletions(-) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index e5b65137b76..437dc5b1dd9 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -21,11 +21,11 @@ def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor: - if not hasattr(_thread_local, f"allreduce_workspaces_{mapping.pp_rank}"): - setattr(_thread_local, f"allreduce_workspaces_{mapping.pp_rank}", {}) + if not hasattr(_thread_local, f'allreduce_workspaces_{mapping.pp_rank}'): + setattr(_thread_local, f'allreduce_workspaces_{mapping.pp_rank}', {}) allreduce_workspaces = getattr(_thread_local, - f"allreduce_workspaces_{mapping.pp_rank}") + f'allreduce_workspaces_{mapping.pp_rank}') if mapping not in allreduce_workspaces: ipc_buffers, workspace = CustomAllReduceHelper.allocate_allreduce_fusion_workspace( mapping, @@ -37,7 +37,7 @@ def get_allreduce_workspace(mapping: Mapping) -> torch.LongTensor: def allocate_low_presicion_allreduce_workspace(mapping: Mapping) -> None: - if not hasattr(_thread_local, "lowprecision_allreduce_workspaces"): + if not hasattr(_thread_local, 'lowprecision_allreduce_workspaces'): _thread_local.lowprecision_allreduce_workspaces = {} lowprecision_allreduce_workspaces = _thread_local.lowprecision_allreduce_workspaces if mapping not in lowprecision_allreduce_workspaces: @@ -56,14 +56,14 @@ def get_allreduce_mnnvl_workspace( mapping: Mapping, dtype: torch.dtype ) -> Tuple[McastGPUBuffer, torch.Tensor, torch.Tensor, int]: if not hasattr(_thread_local, - f"allreduce_mnnvl_workspaces_{mapping.pp_rank}"): - setattr(_thread_local, f"allreduce_mnnvl_workspaces_{mapping.pp_rank}", + f'allreduce_mnnvl_workspaces_{mapping.pp_rank}'): + setattr(_thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}', {}) force_mn = os.environ.get("TRTLLM_FORCE_MNNVL_AR", "0") == "1" allreduce_mnnvl_workspaces = getattr( - _thread_local, f"allreduce_mnnvl_workspaces_{mapping.pp_rank}") + _thread_local, f'allreduce_mnnvl_workspaces_{mapping.pp_rank}') if mapping not in allreduce_mnnvl_workspaces: # buffer shape: [3, 2, buffer_tokens, hidden_dim] stride = 3 * 2 * dtype.itemsize @@ -116,7 +116,7 @@ def get_output_info(input: torch.Tensor, dim: int) -> List[int]: val if idx != dim else -1 for idx, val in enumerate(input.shape) ] numel_base = -math.prod(output_shape) - return {"output_shape": output_shape, "numel_base": numel_base} + return {'output_shape': output_shape, 'numel_base': numel_base} def filter_valid_input( @@ -144,7 +144,7 @@ def allgather( dim: int = -1, sizes: Optional[List[int]] = None, ) -> Union[torch.Tensor, List[torch.Tensor]]: - """ + ''' Add an operation that performs a collective all-gather. If 'sizes' is 'None', the input tensors in the different ranks must have the same shape. @@ -168,7 +168,7 @@ def allgather( sizes(Optional[List[int]]): An optional list indicating 'input.shape[dim]' in all ranks. By default None. Returns: The gathered tensor or tensor list. - """ + ''' if mapping.tp_size == 1: return input @@ -186,13 +186,13 @@ def allgather( if isinstance(input, torch.Tensor): torch_op = torch.ops.trtllm.allgather output_info = get_output_info(input, dim) - input = input.contiguous().view(-1, output_info["numel_base"]) + input = input.contiguous().view(-1, output_info['numel_base']) else: input, valid = filter_valid_input(input) torch_op = torch.ops.trtllm.allgather_list output_info = [get_output_info(val, dim) for val in input] input = [ - val.contiguous().view(-1, val_info["numel_base"]) + val.contiguous().view(-1, val_info['numel_base']) for val, val_info in zip(input, output_info) ] @@ -204,13 +204,13 @@ def allgather( def convert_output(x, x_info): if dim == 0: - x = x.view(x_info["output_shape"]) + x = x.view(x_info['output_shape']) else: if sizes is None: x_list = x.chunk(mapping.tp_size) else: x_list = x.split(sizes) - x = torch.cat([x.reshape(x_info["output_shape"]) for x in x_list], + x = torch.cat([x.reshape(x_info['output_shape']) for x in x_list], dim=dim) return x @@ -248,13 +248,13 @@ def reducescatter( def convert_input(x, x_info): # Inputs are reshaped in this way to pass necessary shape information to the reducescatter op if dim == 0: - x = x.contiguous().view(-1, x_info["numel_base"]) + x = x.contiguous().view(-1, x_info['numel_base']) else: if sizes is None: x_list = x.chunk(mapping.tp_size, dim=dim) else: x_list = x.split(sizes, dim=dim) - x = torch.cat([x.reshape(-1, x_info["numel_base"]) for x in x_list]) + x = torch.cat([x.reshape(-1, x_info['numel_base']) for x in x_list]) return x if isinstance(input, torch.Tensor): @@ -277,10 +277,10 @@ def convert_input(x, x_info): ) if isinstance(input, torch.Tensor): - output = output.view(output_info["output_shape"]) + output = output.view(output_info['output_shape']) else: output = [ - val.view(val_info["output_shape"]) + val.view(val_info['output_shape']) for val, val_info in zip(output, output_info) ] output = restore_full_output(output, valid) @@ -300,12 +300,13 @@ def __init__(self, mapping: Mapping, dtype: torch.dtype): super().__init__() self.mapping = mapping self.dtype = dtype - assert dtype in MNNVLAllReduce.get_supported_dtypes() and ( - not mapping.has_cp() + assert ( + dtype in MNNVLAllReduce.get_supported_dtypes() + and (not mapping.has_cp()) ), "MNNVL all reduce only supports dtype {MNNVLAllReduce.get_supported_dtypes()} and without cp." - self.mcast_buffer_mnnvl, self.buffer_mnnvl, self.buffer_flags_mnnvl, self.max_num_elements_mnnvl = ( - get_allreduce_mnnvl_workspace(self.mapping, dtype)) + self.mcast_buffer_mnnvl, self.buffer_mnnvl, self.buffer_flags_mnnvl, self.max_num_elements_mnnvl = get_allreduce_mnnvl_workspace( + self.mapping, dtype) @staticmethod def get_supported_dtypes(): @@ -397,14 +398,12 @@ def forward( class AllReduce(nn.Module): - def __init__( - self, - mapping: Mapping, - strategy: AllReduceStrategy = AllReduceStrategy.AUTO, - dtype: Optional[torch.dtype] = None, - ): + def __init__(self, + mapping: Mapping, + strategy: AllReduceStrategy = AllReduceStrategy.AUTO, + dtype: Optional[torch.dtype] = None): super().__init__() - """ + ''' AllReduce is a module that performs an all-reduce operation on a tensor. Args: @@ -441,7 +440,7 @@ def __init__( https://github.com/NVIDIA/TensorRT-LLM/blob/main/tests/unittest/_torch/multi_gpu/test_allreduce.py The LOWPRECISION strategy can be selected either by directly specifying it in the constructor. - """ + ''' self.mapping = mapping self.workspace = None @@ -623,8 +622,8 @@ def forward( eps=all_reduce_params.eps, ) else: - assert (all_reduce_params.residual.shape[0] <= self.max_token - ), "Num tokens must be less than or equal to max_token" + assert all_reduce_params.residual.shape[ + 0] <= self.max_token, "Num tokens must be less than or equal to max_token" return torch.ops.trtllm.moe_finalize_allreduce( input=input, From 6eb66f05c496e40f0af1b02fc91e7cd57206cf39 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Thu, 14 Aug 2025 09:55:55 -0700 Subject: [PATCH 08/12] Fix format issue Signed-off-by: Shiyu Li --- tensorrt_llm/_torch/distributed/ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 437dc5b1dd9..5b9bd6e20b8 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -403,7 +403,7 @@ def __init__(self, strategy: AllReduceStrategy = AllReduceStrategy.AUTO, dtype: Optional[torch.dtype] = None): super().__init__() - ''' + """ AllReduce is a module that performs an all-reduce operation on a tensor. Args: @@ -440,7 +440,7 @@ def __init__(self, https://github.com/NVIDIA/TensorRT-LLM/blob/main/tests/unittest/_torch/multi_gpu/test_allreduce.py The LOWPRECISION strategy can be selected either by directly specifying it in the constructor. - ''' + """ self.mapping = mapping self.workspace = None @@ -486,7 +486,7 @@ def forward( *, all_reduce_params: Optional[AllReduceParams] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: - """ + ''' The input tensors in the different ranks must have the same shape. The output tensor will have that same shape with the input tensor. The output tensor will be replicated among the TP group. @@ -508,7 +508,7 @@ def forward( RESIDUAL_RMS_NORM_OUT_QUANT_FP8: [norm, norm_quant, residual] RESIDUAL_RMS_NORM_QUANT_NVFP4: [norm_quant_fp4, scale_factor, residual] RESIDUAL_RMS_NORM_OUT_QUANT_NVFP4: [norm, norm_quant_fp4, scale_factor, residual] - """ + ''' if self.mapping.tp_size == 1 or (all_reduce_params is not None and all_reduce_params.enable_allreduce == False): From 0b1f3f777fd8e8708f5a4f8e9da3bbf1485f0b18 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Tue, 26 Aug 2025 14:01:17 -0700 Subject: [PATCH 09/12] Address review comments Signed-off-by: Shiyu Li --- tensorrt_llm/_torch/distributed/ops.py | 9 +++++++++ tensorrt_llm/_torch/models/modeling_deepseekv3.py | 15 ++++++--------- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index 5b9bd6e20b8..d9e6671b3bf 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -312,6 +312,7 @@ def __init__(self, mapping: Mapping, dtype: torch.dtype): def get_supported_dtypes(): return (torch.float16, torch.bfloat16, torch.float32) + # Check if MNNVL is supported @staticmethod def is_mnnvl(mapping: Mapping, dtype: torch.dtype) -> bool: from tensorrt_llm._mnnvl_utils import MnnvlMemory @@ -322,6 +323,14 @@ def is_mnnvl(mapping: Mapping, dtype: torch.dtype) -> bool: and not mapping.has_cp() and mapping.is_multi_node() and MnnvlMemory.supports_mnnvl() and is_on_aarch64) + # Check if MNNVL strategy is used + @staticmethod + def should_use_mnnvl(strategy: AllReduceStrategy, mapping: Mapping, + dtype: torch.dtype) -> bool: + if strategy in (AllReduceStrategy.AUTO, AllReduceStrategy.MNNVL): + return MNNVLAllReduce.is_mnnvl(mapping, dtype) + return False + def forward( self, input: torch.Tensor, diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 53304fa299a..5ae99fe0085 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -41,7 +41,7 @@ from tensorrt_llm._ipc_utils import can_access_peer from tensorrt_llm._utils import get_sm_version -from tensorrt_llm.functional import AllReduceStrategy, PositionEmbeddingType +from tensorrt_llm.functional import PositionEmbeddingType from tensorrt_llm.llmapi.utils import enable_llm_debug from tensorrt_llm.mapping import Mapping from tensorrt_llm.models.modeling_utils import QuantConfig @@ -612,7 +612,8 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], torch.cuda.Stream]): super().__init__() self.model_config = model_config - config = model_config.pretrained_config + self.config = model_config.pretrained_config + config = self.config self.hidden_size = config.hidden_size self.moe_intermediate_size = config.moe_intermediate_size @@ -745,13 +746,9 @@ def _compute_mlp_tp_size(self, intermediate_size: int, self.mapping.tp_size, ) - if tp > self.mapping.gpus_per_node and ( - self.model_config.allreduce_strategy not in ( - AllReduceStrategy.AUTO, - AllReduceStrategy.MNNVL, - ) or not MNNVLAllReduce.is_mnnvl( - self.mapping, - self.model_config.pretrained_config.torch_dtype)): + if tp > self.mapping.gpus_per_node and not MNNVLAllReduce.should_use_mnnvl( + self.model_config.allreduce_strategy, self.mapping, + self.config.torch_dtype): mlp_tp_size = math.gcd( tp, self.mapping.gpus_per_node, From 9548d4dafa951598d3c807776316a1c254e6d0c2 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Tue, 26 Aug 2025 14:09:39 -0700 Subject: [PATCH 10/12] Move the check into allreduce class. Signed-off-by: Shiyu Li --- tensorrt_llm/_torch/distributed/ops.py | 11 +++-------- tensorrt_llm/_torch/models/modeling_deepseekv3.py | 6 ++---- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/tensorrt_llm/_torch/distributed/ops.py b/tensorrt_llm/_torch/distributed/ops.py index d9e6671b3bf..b3811204dfa 100644 --- a/tensorrt_llm/_torch/distributed/ops.py +++ b/tensorrt_llm/_torch/distributed/ops.py @@ -323,14 +323,6 @@ def is_mnnvl(mapping: Mapping, dtype: torch.dtype) -> bool: and not mapping.has_cp() and mapping.is_multi_node() and MnnvlMemory.supports_mnnvl() and is_on_aarch64) - # Check if MNNVL strategy is used - @staticmethod - def should_use_mnnvl(strategy: AllReduceStrategy, mapping: Mapping, - dtype: torch.dtype) -> bool: - if strategy in (AllReduceStrategy.AUTO, AllReduceStrategy.MNNVL): - return MNNVLAllReduce.is_mnnvl(mapping, dtype) - return False - def forward( self, input: torch.Tensor, @@ -489,6 +481,9 @@ def __init__(self, ) self.mnnvl_allreduce = None + def is_mnnvl(self) -> bool: + return self.mnnvl_allreduce is not None + def forward( self, input: torch.Tensor, diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 5ae99fe0085..0767f6be29e 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -52,7 +52,6 @@ from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams from ..distributed import (AllReduce, AllReduceFusionOp, AllReduceParams, MoEAllReduce, MoEAllReduceParams, allgather) -from ..distributed.ops import MNNVLAllReduce from ..model_config import ModelConfig from ..modules.attention import MLA from ..modules.decoder_layer import DecoderLayer @@ -746,9 +745,8 @@ def _compute_mlp_tp_size(self, intermediate_size: int, self.mapping.tp_size, ) - if tp > self.mapping.gpus_per_node and not MNNVLAllReduce.should_use_mnnvl( - self.model_config.allreduce_strategy, self.mapping, - self.config.torch_dtype): + if tp > self.mapping.gpus_per_node and not self.allreduce.is_mnnvl( + ): mlp_tp_size = math.gcd( tp, self.mapping.gpus_per_node, From 82c1e9828de99fe2068ab5678a40e295d5ff39ef Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Tue, 26 Aug 2025 14:16:46 -0700 Subject: [PATCH 11/12] Move the allreduce instance creation to the top. Signed-off-by: Shiyu Li --- tensorrt_llm/_torch/models/modeling_deepseekv3.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 0767f6be29e..2e3c3b6568e 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -622,6 +622,10 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], self.mapping = model_config.mapping mapping = self.mapping + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy, + dtype=config.torch_dtype) + self.moe_allreduce = MoEAllReduce(self.mapping) self.self_attn = DeepseekV3Attention( model_config, @@ -695,10 +699,6 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], eps=config.rms_norm_eps, dtype=config.torch_dtype) self.layer_idx = layer_idx - self.allreduce = AllReduce(mapping=model_config.mapping, - strategy=model_config.allreduce_strategy, - dtype=config.torch_dtype) - self.moe_allreduce = MoEAllReduce(self.mapping) self.next_layer_layernorm: RMSNorm = None def _get_decoder_layer_quant_config( From 38e2b91cc18146a81837a1ede08f9dca32ea1358 Mon Sep 17 00:00:00 2001 From: Shiyu Li Date: Wed, 27 Aug 2025 11:19:32 -0700 Subject: [PATCH 12/12] Address CI error Signed-off-by: Shiyu Li --- tensorrt_llm/_torch/models/modeling_deepseekv3.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorrt_llm/_torch/models/modeling_deepseekv3.py b/tensorrt_llm/_torch/models/modeling_deepseekv3.py index 2e3c3b6568e..f34ff92ba93 100644 --- a/tensorrt_llm/_torch/models/modeling_deepseekv3.py +++ b/tensorrt_llm/_torch/models/modeling_deepseekv3.py @@ -622,10 +622,6 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], self.mapping = model_config.mapping mapping = self.mapping - self.allreduce = AllReduce(mapping=model_config.mapping, - strategy=model_config.allreduce_strategy, - dtype=config.torch_dtype) - self.moe_allreduce = MoEAllReduce(self.mapping) self.self_attn = DeepseekV3Attention( model_config, @@ -647,6 +643,10 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig], self.is_nvfp4 = quant_config.layer_quant_mode.has_nvfp4() has_tp = mapping.has_tp() + self.allreduce = AllReduce(mapping=model_config.mapping, + strategy=model_config.allreduce_strategy, + dtype=config.torch_dtype) + self.moe_allreduce = MoEAllReduce(self.mapping) if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace