Skip to content

Commit ab26d21

Browse files
hyuknchzblych
authored andcommitted
[https://nvbugs/5517023][fix] Pass allreduce strategy and force NCCL on pre-Blackwell arch (#7768)
Signed-off-by: Yukun He <[email protected]> Signed-off-by: Wangshanshan <[email protected]>
1 parent edbe270 commit ab26d21

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

tensorrt_llm/_torch/models/modeling_llama.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
from transformers.models.llama4.modeling_llama4 import Llama4MultiModalProjector
1212

1313
from tensorrt_llm._torch.distributed import (AllReduce, AllReduceFusionOp,
14-
AllReduceParams, MoEAllReduce)
14+
AllReduceParams, AllReduceStrategy,
15+
MoEAllReduce)
1516
from tensorrt_llm._torch.models.checkpoints.base_weight_mapper import \
1617
BaseWeightMapper
1718
from tensorrt_llm._utils import get_sm_version
@@ -652,7 +653,12 @@ def __init__(
652653
eps=config.rms_norm_eps,
653654
dtype=config.torch_dtype)
654655

655-
self.all_reduce = AllReduce(mapping=model_config.mapping)
656+
# TODO: This is a temporary fix to disable oneshot kernel for pre-Blackwell arch to avoid perf regressions
657+
self.all_reduce = AllReduce(
658+
strategy=model_config.allreduce_strategy
659+
if get_sm_version() >= 100 else AllReduceStrategy.NCCL,
660+
mapping=model_config.mapping,
661+
)
656662

657663
self.next_layer_layernorm: RMSNorm = None
658664
self.next_attn: LlamaAttention = None

0 commit comments

Comments
 (0)