@@ -298,10 +298,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
298
298
"""
299
299
300
300
tp_mesh = world_mesh ["tp" ]
301
+ # Parallel styles for transformer block linear weights may be different for
302
+ # float8 linears
301
303
(
302
- row_parallel_strategy ,
303
- col_parallel_strategy ,
304
- prepare_module_input ,
304
+ rowwise_parallel_weight ,
305
+ colwise_parallel_weight ,
306
+ prepare_weight_input ,
305
307
) = get_tp_parallel_strategy (job_config )
306
308
loss_parallel = parallel_dims .loss_parallel_enabled
307
309
@@ -318,7 +320,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
318
320
output_layouts = Shard (1 ),
319
321
),
320
322
"norm" : SequenceParallel (),
321
- "output" : col_parallel_strategy (
323
+ "output" : colwise_parallel_weight (
322
324
input_layouts = Shard (1 ),
323
325
output_layouts = Shard (- 1 ) if loss_parallel else Replicate (),
324
326
use_local_output = not loss_parallel ,
@@ -333,22 +335,22 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
333
335
for layer_id , transformer_block in model .layers .items ():
334
336
layer_plan = {
335
337
"attention_norm" : SequenceParallel (),
336
- "attention" : prepare_module_input (
338
+ "attention" : prepare_weight_input (
337
339
input_layouts = (Shard (1 ), None ),
338
340
desired_input_layouts = (Replicate (), None ),
339
341
),
340
- "attention.wq" : col_parallel_strategy (),
341
- "attention.wk" : col_parallel_strategy (),
342
- "attention.wv" : col_parallel_strategy (),
343
- "attention.wo" : row_parallel_strategy (output_layouts = Shard (1 )),
342
+ "attention.wq" : colwise_parallel_weight (),
343
+ "attention.wk" : colwise_parallel_weight (),
344
+ "attention.wv" : colwise_parallel_weight (),
345
+ "attention.wo" : rowwise_parallel_weight (output_layouts = Shard (1 )),
344
346
"ffn_norm" : SequenceParallel (),
345
- "feed_forward" : prepare_module_input (
347
+ "feed_forward" : prepare_weight_input (
346
348
input_layouts = (Shard (1 ),),
347
349
desired_input_layouts = (Replicate (),),
348
350
),
349
- "feed_forward.w1" : col_parallel_strategy (),
350
- "feed_forward.w2" : row_parallel_strategy (output_layouts = Shard (1 )),
351
- "feed_forward.w3" : col_parallel_strategy (),
351
+ "feed_forward.w1" : colwise_parallel_weight (),
352
+ "feed_forward.w2" : rowwise_parallel_weight (output_layouts = Shard (1 )),
353
+ "feed_forward.w3" : colwise_parallel_weight (),
352
354
}
353
355
354
356
# Adjust attention module to use the local number of heads
0 commit comments