Skip to content

Commit c181683

Browse files
author
Andrew Gu
committed
Renamed parallel styles for transformer block weights
ghstack-source-id: 4e41e66 Pull Request resolved: #448
1 parent 181b8ca commit c181683

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,12 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
298298
"""
299299

300300
tp_mesh = world_mesh["tp"]
301+
# Parallel styles for transformer block linear weights may be different for
302+
# float8 linears
301303
(
302-
row_parallel_strategy,
303-
col_parallel_strategy,
304-
prepare_module_input,
304+
rowwise_parallel_weight,
305+
colwise_parallel_weight,
306+
prepare_weight_input,
305307
) = get_tp_parallel_strategy(job_config)
306308
loss_parallel = parallel_dims.loss_parallel_enabled
307309

@@ -318,7 +320,7 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
318320
output_layouts=Shard(1),
319321
),
320322
"norm": SequenceParallel(),
321-
"output": col_parallel_strategy(
323+
"output": colwise_parallel_weight(
322324
input_layouts=Shard(1),
323325
output_layouts=Shard(-1) if loss_parallel else Replicate(),
324326
use_local_output=not loss_parallel,
@@ -333,22 +335,22 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
333335
for layer_id, transformer_block in model.layers.items():
334336
layer_plan = {
335337
"attention_norm": SequenceParallel(),
336-
"attention": prepare_module_input(
338+
"attention": prepare_weight_input(
337339
input_layouts=(Shard(1), None),
338340
desired_input_layouts=(Replicate(), None),
339341
),
340-
"attention.wq": col_parallel_strategy(),
341-
"attention.wk": col_parallel_strategy(),
342-
"attention.wv": col_parallel_strategy(),
343-
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
342+
"attention.wq": colwise_parallel_weight(),
343+
"attention.wk": colwise_parallel_weight(),
344+
"attention.wv": colwise_parallel_weight(),
345+
"attention.wo": rowwise_parallel_weight(output_layouts=Shard(1)),
344346
"ffn_norm": SequenceParallel(),
345-
"feed_forward": prepare_module_input(
347+
"feed_forward": prepare_weight_input(
346348
input_layouts=(Shard(1),),
347349
desired_input_layouts=(Replicate(),),
348350
),
349-
"feed_forward.w1": col_parallel_strategy(),
350-
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
351-
"feed_forward.w3": col_parallel_strategy(),
351+
"feed_forward.w1": colwise_parallel_weight(),
352+
"feed_forward.w2": rowwise_parallel_weight(output_layouts=Shard(1)),
353+
"feed_forward.w3": colwise_parallel_weight(),
352354
}
353355

354356
# Adjust attention module to use the local number of heads

0 commit comments

Comments
 (0)