Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions cpp/tensorrt_llm/thop/allreduceOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,14 +621,12 @@ class AllreduceOp

AllReduceStrategyType getRuntimeStrategy(size_t seq_len, size_t size)
{
static char* force_nccl_all_reduce_strategy_char = std::getenv("FORCE_NCCL_ALL_REDUCE_STRATEGY");
bool force_nccl_all_reduce_strategy = (force_nccl_all_reduce_strategy_char != nullptr);
AllReduceStrategyType runtime_strategy;
if (mStrategy == AllReduceStrategyType::UB)
{
runtime_strategy = AllReduceStrategyType::UB;
}
else if (force_nccl_all_reduce_strategy || mStrategy == AllReduceStrategyType::NCCL)
else if (mStrategy == AllReduceStrategyType::NCCL)
{
runtime_strategy = AllReduceStrategyType::NCCL;
}
Expand Down Expand Up @@ -936,10 +934,7 @@ class AllreduceOp

bool isUsingLowPrecision(size_t message_size) const noexcept
{
static char* force_low_precision_allreduce_strategy_char
= std::getenv("FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY");
bool force_low_precision = (force_low_precision_allreduce_strategy_char != nullptr)
|| (mStrategy == AllReduceStrategyType::LOWPRECISION);
bool force_low_precision = mStrategy == AllReduceStrategyType::LOWPRECISION;

#ifdef ENABLE_FP8
// Use LowPrecision if PCIe and p2p support and message size is larger than 2MB
Expand Down
14 changes: 5 additions & 9 deletions docs/source/advanced/lowprecision-pcie-allreduce.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,12 @@ The Low-Precision-AllReduce algorithm can be enabled in two ways:
```
AllReduce allreduce(mapping=mapping, strategy=AllReduceStrategy.LOWPRECISION);
```
2. **Environment variable control** with AUTO strategy:

2. Enable by LlmArgs
```
// In your code
AllReduce allreduce(mapping=mapping, strategy=AllReduceStrategy.AUTO);
// Set environment variable before running
export FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY=1
Set allreduce_strategy field in LlmArgs.
Candidates of strategies are "AUTO", "NCCL", "UB", "MINLATENCY", "ONESHOT", "TWOSHOT", "LOWPRECISION" and "MNNVL".
If no strategy is set, AUTO will be set.
```

## Performance and Accuracy Considerations
Expand All @@ -58,8 +58,4 @@ Low-Precision-AllReduce reduces communication volume by using FP8 data format fo

Users should evaluate the precision impact on their specific models and workloads.

## Environment Variables

- `FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY`: When set to `1`, forces the use of low-precision algorithm with AUTO strategy. If the algorithm determines it cannot provide performance benefits, it will automatically fall back to other strategies.

**Note**: When compiling TensorRT-LLM without enabling the `ENABLE_FP8` option, setting Low Precision allreduce will not take effect.
34 changes: 16 additions & 18 deletions examples/pytorch/out_of_tree_example/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,24 +64,22 @@ def __init__(
config.hidden_size,
elementwise_affine=config.layer_norm_elementwise_affine,
dtype=config.torch_dtype)
self.fc1 = Linear(
config.hidden_size,
config.ffn_dim,
bias=config.enable_bias,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=model_config.get_quant_config(),
)
self.fc2 = Linear(
config.ffn_dim,
config.hidden_size,
bias=config.enable_bias,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=model_config.get_quant_config(),
)
self.fc1 = Linear(config.hidden_size,
config.ffn_dim,
bias=config.enable_bias,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=model_config.get_quant_config(),
allreduce_strategy=model_config.allreduce_strategy)
self.fc2 = Linear(config.ffn_dim,
config.hidden_size,
bias=config.enable_bias,
dtype=config.torch_dtype,
mapping=model_config.mapping,
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=model_config.get_quant_config(),
allreduce_strategy=model_config.allreduce_strategy)
self.final_layer_norm = LayerNorm(
config.hidden_size,
elementwise_affine=config.layer_norm_elementwise_affine,
Expand Down
4 changes: 2 additions & 2 deletions tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
try:
from ....mapping import Mapping
from ...distributed import AllReduce, allgather
from ...modules.linear import AllReduceFusionOp, AllReduceParams
from ...modules.linear import AllReduceFusionOp, AllReduceParams, AllReduceStrategy

def trtllm_allgather(tensor, dim, sizes=None):
rank, world_size = get_rank_world_size()
Expand All @@ -17,7 +17,7 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None):
rank, world_size = get_rank_world_size()
assert op == ReduceOp.SUM, "TRT-LLM all reduce only supports SUM op."
p_config = Mapping(world_size=world_size, tp_size=world_size, rank=rank)
torch_op = AllReduce(p_config)
torch_op = AllReduce(mapping=p_config, strategy=AllReduceStrategy.AUTO)
return torch_op(tensor, all_reduce_params=all_reduce_params)

@torch.library.custom_op(
Expand Down
35 changes: 19 additions & 16 deletions tensorrt_llm/_torch/distributed/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,14 +307,17 @@ def __init__(self, mapping: Mapping, dtype: torch.dtype):
super().__init__()
self.mapping = mapping
self.dtype = dtype
self.enable_mnnvl = (os.environ.get("TRTLLM_MNNVL_AR_ENABLED",
"0") == "1"
and dtype in [torch.bfloat16, torch.float32]
and (not mapping.has_cp()))
assert (
dtype in MNNVLAllReduce.get_supported_dtypes()
and (not mapping.has_cp())
), "MNNVL all reduce only supports dtype {MNNVLAllReduce.get_supported_dtypes()} and without cp."

if self.enable_mnnvl:
self.mcast_buffer_mnnvl, self.buffer_mnnvl, self.buffer_flags_mnnvl, self.max_num_elements_mnnvl = get_allreduce_mnnvl_workspace(
self.mapping, dtype)
self.mcast_buffer_mnnvl, self.buffer_mnnvl, self.buffer_flags_mnnvl, self.max_num_elements_mnnvl = get_allreduce_mnnvl_workspace(
self.mapping, dtype)

@staticmethod
def get_supported_dtypes():
return (torch.bfloat16, torch.float32)

def forward(
self,
Expand All @@ -330,7 +333,7 @@ def forward(
Returns:
Union[torch.Tensor, Tuple[torch.Tensor, ...]]: Reduced tensor(s)
"""
if not self.enable_mnnvl or input.numel() > self.max_num_elements_mnnvl:
if input.numel() > self.max_num_elements_mnnvl:
return None

fusion_op = all_reduce_params.fusion_op
Expand Down Expand Up @@ -411,27 +414,27 @@ def __init__(self,
For the reference implementation for each pattern, please refer to the following unit test:
https://github.com/NVIDIA/TensorRT-LLM/blob/main/tests/unittest/_torch/multi_gpu/test_allreduce.py

The LOWPRECISION strategy can be selected either by directly specifying it in the constructor
or by setting the environment variable FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY when using
the AUTO strategy.
The LOWPRECISION strategy can be selected either by directly specifying it in the constructor.
"""

self.mapping = mapping
self.workspace = None
self.strategy = strategy
self.mnnvl_allreduce = None

self.force_low_precision_env = os.environ.get(
"FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY")
if self.mapping.tp_size > 1:
# When Strategy is UB, it is guaranteed that the workspace is not used.
if self.strategy != AllReduceStrategy.UB:
if self.strategy == AllReduceStrategy.LOWPRECISION or self.force_low_precision_env is not None:
if self.strategy == AllReduceStrategy.LOWPRECISION:
allocate_low_presicion_allreduce_workspace(self.mapping)
self.workspace = get_allreduce_workspace(self.mapping)

# Initialize MNNVL AllReduce if needed
self.mnnvl_allreduce = MNNVLAllReduce(mapping,
dtype) if dtype else None
if self.strategy == AllReduceStrategy.MNNVL and (
dtype and dtype in MNNVLAllReduce.get_supported_dtypes()
) and (not self.mapping.has_cp()):
self.mnnvl_allreduce = MNNVLAllReduce(self.mapping,
dtype) if dtype else None

def forward(
self,
Expand Down
21 changes: 21 additions & 0 deletions tensorrt_llm/_torch/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from tensorrt_llm import logger
from tensorrt_llm._utils import torch_dtype_to_binding
from tensorrt_llm.functional import AllReduceStrategy
from tensorrt_llm.logger import logger
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.modeling_utils import QuantConfig
from tensorrt_llm.quantization.mode import QuantAlgo
Expand Down Expand Up @@ -77,6 +79,7 @@ class ModelConfig(Generic[TConfig]):

attn_backend: str = 'TRTLLM'
moe_backend: str = 'CUTLASS' # options can be CUTLASS, TRTLLM
allreduce_strategy: AllReduceStrategy = AllReduceStrategy.AUTO

# If true, enable min-latency mode. Currently only used for Llama4.
enable_min_latency: bool = False
Expand Down Expand Up @@ -106,6 +109,24 @@ def __post_init__(self):
self.is_generation = self.is_generation_model(
self.pretrained_config.architectures)

def get_all_reduce_strategy(strategy: str = "AUTO"):
maps = {
"AUTO": AllReduceStrategy.AUTO,
"NCCL": AllReduceStrategy.NCCL,
"UB": AllReduceStrategy.UB,
"MINLATENCY": AllReduceStrategy.MIN_LATENCY,
"ONESHOT": AllReduceStrategy.ONESHOT,
"TWOSHOT": AllReduceStrategy.TWOSHOT,
"LOWPRECISION": AllReduceStrategy.LOWPRECISION,
"MNNVL": AllReduceStrategy.MNNVL
}
key = strategy.upper()
return maps[key] if key in maps else AllReduceStrategy.AUTO

if isinstance(self.allreduce_strategy, str):
self.allreduce_strategy = get_all_reduce_strategy(
self.allreduce_strategy)

@property
def fuse_pos_embd(self):
if self.attn_backend == 'TRTLLM':
Expand Down
7 changes: 5 additions & 2 deletions tensorrt_llm/_torch/models/modeling_deepseekv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ def __init__(self,
overridden_tp_size=shared_tp_size,
reduce_output=False)

self.allreduce = AllReduce(self.mapping)
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
self.aux_stream = aux_stream_dict[AuxStreamType.MoeShared]
self.event_dict = {
key: torch.cuda.Event()
Expand Down Expand Up @@ -628,7 +629,9 @@ def __init__(self, model_config: ModelConfig[PretrainedConfig],
eps=config.rms_norm_eps,
dtype=config.torch_dtype)
self.layer_idx = layer_idx
self.allreduce = AllReduce(self.mapping, dtype=config.torch_dtype)
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy,
dtype=config.torch_dtype)
self.moe_allreduce = MoEAllReduce(self.mapping)
self.next_layer_layernorm: RMSNorm = None

Expand Down
10 changes: 7 additions & 3 deletions tensorrt_llm/_torch/models/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,10 @@ def __init__(
quant_config=None)

self.mapping = model_config.mapping
self.all_reduce = AllReduce(self.mapping)
self.all_reduce = AllReduce(
mapping=model_config.mapping,
strategy=model_config.allreduce_strategy,
)
self.moe_event = [torch.cuda.Event(), torch.cuda.Event()]
self.aux_stream = aux_stream

Expand Down Expand Up @@ -414,7 +417,8 @@ def __init__(
dtype=config.torch_dtype)

self.mapping = model_config.mapping
self.all_reduce = AllReduce(self.mapping)
self.all_reduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
self.next_layer_layernorm: RMSNorm = None
self.next_attn: LlamaAttention = None

Expand Down Expand Up @@ -625,7 +629,7 @@ def __init__(
quant_config=model_config.get_quant_config(),
skip_create_weights_in_init=model_config.
skip_create_weights_in_init,
)
allreduce_strategy=model_config.allreduce_strategy)


class Eagle3LlamaDecoderLayer(DecoderLayer):
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/models/modeling_nemotron_nas.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _create_linear_from_configs(model_config: ModelConfig[PretrainedConfig],
gather_output=True,
quant_config=model_config.get_quant_config(),
skip_create_weights_in_init=model_config.skip_create_weights_in_init,
)
allreduce_strategy=model_config.allreduce_strategy)


class NemotronNASAttention(Attention):
Expand Down
6 changes: 4 additions & 2 deletions tensorrt_llm/_torch/models/modeling_qwen3_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def __init__(
self.top_k = config.num_experts_per_tok
self.enable_attention_dp = model_config.mapping.enable_attention_dp
self.mapping = model_config.mapping
self.allreduce = AllReduce(self.mapping)
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
self.enable_alltoall = Qwen3MoE.should_enable_alltoall(
model_config, self.top_k)
if self.enable_alltoall:
Expand Down Expand Up @@ -202,7 +203,8 @@ def __init__(self, model_config: ModelConfig[Qwen3MoeConfig],
dtype=config.torch_dtype)
self.layer_idx = layer_idx

self.allreduce = AllReduce(self.mapping)
self.allreduce = AllReduce(mapping=model_config.mapping,
strategy=model_config.allreduce_strategy)
self.next_layer_layernorm: RMSNorm = None

self.fusion_config = EagerFusionConfig()
Expand Down
14 changes: 8 additions & 6 deletions tensorrt_llm/_torch/modules/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def __init__(
weight_mode=WeightMode.FUSED_QKV_LINEAR),
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
)
allreduce_strategy=config.allreduce_strategy)
self.o_lora = LoraLayer([LoraModuleType.ATTENTION_DENSE],
[self.hidden_size])

Expand All @@ -140,7 +140,7 @@ def __init__(
quant_config=config.get_quant_config(),
skip_create_weights_in_init=config.skip_create_weights_in_init,
lora=self.o_lora,
)
allreduce_strategy=config.allreduce_strategy)

self.quant_config = config.get_quant_config()
self.attn_backend = config.attn_backend
Expand Down Expand Up @@ -481,7 +481,8 @@ def __init__(
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init)
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy)
else:
self.fused_a = Linear(
hidden_size,
Expand All @@ -501,7 +502,7 @@ def __init__(
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
)
allreduce_strategy=config.allreduce_strategy)
self.q_b_proj = self.q_proj

self.kv_a_layernorm = RMSNorm(hidden_size=kv_lora_rank,
Expand All @@ -517,7 +518,8 @@ def __init__(
mapping=mapping,
tensor_parallel_mode=TensorParallelMode.COLUMN,
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init)
skip_create_weights_in_init=config.skip_create_weights_in_init,
allreduce_strategy=config.allreduce_strategy)
# This parameter will view into self.kv_b_proj.weight after loading weights.
# For dummy weight initialization, this parameter is initialized with empty tensor.
# Used in forward_generation only
Expand All @@ -538,7 +540,7 @@ def __init__(
tensor_parallel_mode=TensorParallelMode.ROW,
quant_config=quant_config,
skip_create_weights_in_init=config.skip_create_weights_in_init,
)
allreduce_strategy=config.allreduce_strategy)

def yarn_get_mscale(scale=1, mscale=1):
if scale <= 1:
Expand Down
Loading