@@ -153,18 +153,18 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
153
153
input_layouts = Replicate (),
154
154
),
155
155
"output" : col_parallel_strategy (
156
- input_layouts = Shard (0 ),
156
+ input_layouts = Shard (1 ),
157
157
output_layouts = (
158
158
Shard (- 1 )
159
159
if parallel_dims .loss_parallel_enabled
160
160
else Replicate ()
161
161
),
162
162
use_local_output = not parallel_dims .loss_parallel_enabled ,
163
163
),
164
- "norm" : SequenceParallel (sequence_dim = 0 ),
164
+ "norm" : SequenceParallel (),
165
165
"layers.0" : PrepareModuleInput (
166
166
input_layouts = (Replicate (), None ),
167
- desired_input_layouts = (Shard (0 ), None ),
167
+ desired_input_layouts = (Shard (1 ), None ),
168
168
use_local_output = True ,
169
169
),
170
170
},
@@ -174,22 +174,22 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
174
174
for layer_id , transformer_block in enumerate (model .layers ):
175
175
layer_plan = {
176
176
"attention" : PrepareModuleInput (
177
- input_layouts = (Shard (0 ), None ),
177
+ input_layouts = (Shard (1 ), None ),
178
178
desired_input_layouts = (Replicate (), None ),
179
179
),
180
180
"attention.wq" : col_parallel_strategy (),
181
181
"attention.wk" : col_parallel_strategy (),
182
182
"attention.wv" : col_parallel_strategy (),
183
- "attention.wo" : row_parallel_strategy (output_layouts = Shard (0 )),
184
- "attention_norm" : SequenceParallel (sequence_dim = 0 ),
183
+ "attention.wo" : row_parallel_strategy (output_layouts = Shard (1 )),
184
+ "attention_norm" : SequenceParallel (),
185
185
"feed_forward" : PrepareModuleInput (
186
- input_layouts = (Shard (0 ),),
186
+ input_layouts = (Shard (1 ),),
187
187
desired_input_layouts = (Replicate (),),
188
188
),
189
189
"feed_forward.w1" : col_parallel_strategy (),
190
- "feed_forward.w2" : row_parallel_strategy (output_layouts = Shard (0 )),
190
+ "feed_forward.w2" : row_parallel_strategy (output_layouts = Shard (1 )),
191
191
"feed_forward.w3" : col_parallel_strategy (),
192
- "ffn_norm" : SequenceParallel (sequence_dim = 0 ),
192
+ "ffn_norm" : SequenceParallel (),
193
193
}
194
194
195
195
# Adjust attention module to use the local number of heads
0 commit comments