8
8
9
9
from ..attention_backend import AttentionMetadata
10
10
from ..attention_backend .interface import PositionalEmbeddingParams , RopeParams
11
+ from ..distributed import AllReduceParams
11
12
from ..model_config import ModelConfig
12
13
from ..modules .decoder_layer import DecoderLayer
13
14
from ..modules .embedding import Embedding
@@ -82,6 +83,8 @@ def __init__(
82
83
model_config ,
83
84
layer_idx = layer_idx ,
84
85
)
86
+ self .mapping = model_config .mapping
87
+ self .enable_attention_dp = self .mapping .enable_attention_dp
85
88
86
89
# Qwen3 has accuracy issues with deep_gemm (see: https://nvbugspro.nvidia.com/bug/5461712
87
90
# and https://nvbugspro.nvidia.com/bug/5505402)
@@ -92,6 +95,7 @@ def __init__(
92
95
intermediate_size = config .intermediate_size ,
93
96
bias = config .mlp_bias if hasattr (config , "mlp_bias" ) else False ,
94
97
dtype = config .torch_dtype ,
98
+ overridden_tp_size = 1 if self .enable_attention_dp else None ,
95
99
config = model_config ,
96
100
disable_deep_gemm = disable_deep_gemm ,
97
101
)
@@ -102,6 +106,8 @@ def __init__(
102
106
self .post_attention_layernorm = RMSNorm (hidden_size = config .hidden_size ,
103
107
eps = config .rms_norm_eps ,
104
108
dtype = config .torch_dtype )
109
+ self .disable_allreduce = (self .mapping .tp_size == 1
110
+ or self .enable_attention_dp )
105
111
106
112
def forward (
107
113
self ,
@@ -126,13 +132,22 @@ def forward(
126
132
hidden_states = hidden_states ,
127
133
attn_metadata = attn_metadata ,
128
134
mrope_config = mrope_config ,
135
+ all_reduce_params = AllReduceParams (
136
+ enable_allreduce = not self .disable_allreduce ),
129
137
** kwargs ,
130
138
)
131
139
132
140
# Fully Connected
133
141
hidden_states , residual = self .post_attention_layernorm (
134
142
hidden_states , residual )
135
- hidden_states = self .mlp (hidden_states )
143
+ hidden_states = self .mlp (
144
+ hidden_states ,
145
+ all_rank_num_tokens = attn_metadata .all_rank_num_tokens ,
146
+ all_rank_max_num_tokens = attn_metadata .all_rank_max_num_tokens ,
147
+ final_all_reduce_params = AllReduceParams (
148
+ enable_allreduce = not self .disable_allreduce ),
149
+ cutlass_min_latency_mode = False ,
150
+ )
136
151
137
152
if spec_metadata is not None :
138
153
spec_metadata .maybe_capture_hidden_states (self .layer_idx ,
0 commit comments