File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed
tensorrt_llm/_torch/models Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change 11
11
from transformers .models .llama4 .modeling_llama4 import Llama4MultiModalProjector
12
12
13
13
from tensorrt_llm ._torch .distributed import (AllReduce , AllReduceFusionOp ,
14
- AllReduceParams , MoEAllReduce )
14
+ AllReduceParams , AllReduceStrategy ,
15
+ MoEAllReduce )
15
16
from tensorrt_llm ._torch .models .checkpoints .base_weight_mapper import \
16
17
BaseWeightMapper
17
18
from tensorrt_llm ._utils import get_sm_version
@@ -652,7 +653,12 @@ def __init__(
652
653
eps = config .rms_norm_eps ,
653
654
dtype = config .torch_dtype )
654
655
655
- self .all_reduce = AllReduce (mapping = model_config .mapping )
656
+ # TODO: This is a temporary fix to disable oneshot kernel for pre-Blackwell arch to avoid perf regressions
657
+ self .all_reduce = AllReduce (
658
+ strategy = model_config .allreduce_strategy
659
+ if get_sm_version () >= 100 else AllReduceStrategy .NCCL ,
660
+ mapping = model_config .mapping ,
661
+ )
656
662
657
663
self .next_layer_layernorm : RMSNorm = None
658
664
self .next_attn : LlamaAttention = None
You can’t perform that action at this time.
0 commit comments