Skip to content

Commit 96f11b1

Browse files
authored
[None][feat] support attention dp for qwen3 dense model (#7618)
Signed-off-by: Nekofish-L <[email protected]>
1 parent 44d5ccf commit 96f11b1

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

tensorrt_llm/_torch/models/modeling_qwen3.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from ..attention_backend import AttentionMetadata
1010
from ..attention_backend.interface import PositionalEmbeddingParams, RopeParams
11+
from ..distributed import AllReduceParams
1112
from ..model_config import ModelConfig
1213
from ..modules.decoder_layer import DecoderLayer
1314
from ..modules.embedding import Embedding
@@ -82,6 +83,8 @@ def __init__(
8283
model_config,
8384
layer_idx=layer_idx,
8485
)
86+
self.mapping = model_config.mapping
87+
self.enable_attention_dp = self.mapping.enable_attention_dp
8588

8689
# Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712
8790
# and https://nvbugspro.nvidia.com/bug/5505402)
@@ -92,6 +95,7 @@ def __init__(
9295
intermediate_size=config.intermediate_size,
9396
bias=config.mlp_bias if hasattr(config, "mlp_bias") else False,
9497
dtype=config.torch_dtype,
98+
overridden_tp_size=1 if self.enable_attention_dp else None,
9599
config=model_config,
96100
disable_deep_gemm=disable_deep_gemm,
97101
)
@@ -102,6 +106,8 @@ def __init__(
102106
self.post_attention_layernorm = RMSNorm(hidden_size=config.hidden_size,
103107
eps=config.rms_norm_eps,
104108
dtype=config.torch_dtype)
109+
self.disable_allreduce = (self.mapping.tp_size == 1
110+
or self.enable_attention_dp)
105111

106112
def forward(
107113
self,
@@ -126,13 +132,22 @@ def forward(
126132
hidden_states=hidden_states,
127133
attn_metadata=attn_metadata,
128134
mrope_config=mrope_config,
135+
all_reduce_params=AllReduceParams(
136+
enable_allreduce=not self.disable_allreduce),
129137
**kwargs,
130138
)
131139

132140
# Fully Connected
133141
hidden_states, residual = self.post_attention_layernorm(
134142
hidden_states, residual)
135-
hidden_states = self.mlp(hidden_states)
143+
hidden_states = self.mlp(
144+
hidden_states,
145+
all_rank_num_tokens=attn_metadata.all_rank_num_tokens,
146+
all_rank_max_num_tokens=attn_metadata.all_rank_max_num_tokens,
147+
final_all_reduce_params=AllReduceParams(
148+
enable_allreduce=not self.disable_allreduce),
149+
cutlass_min_latency_mode=False,
150+
)
136151

137152
if spec_metadata is not None:
138153
spec_metadata.maybe_capture_hidden_states(self.layer_idx,

0 commit comments

Comments
 (0)