Skip to content

Commit 67dffdb

Browse files
committed
Update on "[DO NOT MERGE][example] fold batch and sequence dimensions to accelerate Sequence Parallel"
Note: This PR is for showcasing purpose only and is almost a reverse of #190. At the cost of model code change, we can obtain better Sequence Parallel performance. Without folding and unfolding, all-gather and reduce-scatter are performed on dim 1 (sequence dim) instead of dim 0 (folded dim), which incurs an extra `aten.cat` after each collective. Stats from awgu: > for 8k seq len, batch size 1 on H100, these two cats take about 0.18 ms out of 3 ms of FFN compute (6%) Experiment on 8-layer `debug_model` before: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/04e5ea4b-fa9e-48e5-92be-582841cb2796"> after: <img width="1023" alt="image" src="https://github.com/pytorch/torchtitan/assets/150487191/38c39506-462d-485a-a16c-48770a28edb0"> [ghstack-poisoned]
1 parent 59773bb commit 67dffdb

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

torchtitan/models/llama/model.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ def forward(
185185
torch.Tensor: Output tensor after attention.
186186
187187
"""
188-
# dim 0 of x is a folded dimension of [bs, seqlen]
188+
# dim 0 of x is a folded dimension of (bs, seqlen)
189189
seqlen, _ = freqs_cis.shape
190190
bs_seqlen, _ = x.shape
191191
bs = bs_seqlen // seqlen
@@ -427,21 +427,27 @@ def forward(self, tokens: torch.Tensor):
427427
torch.Tensor: Output logits after applying the Transformer model.
428428
429429
"""
430-
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages
431-
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
432-
# fold batch dimension and sequence dimension
433-
# for more efficient allgather/reduce_scatter
434-
h = h.view(-1, self.model_args.dim)
430+
# passthrough for nonexistent layers, allows easy configuration of pipeline parallel stage
431+
if self.tok_embeddings:
432+
# fold batch dimension and sequence dimension
433+
# for more efficient allgather/reduce_scatter
434+
tokens = tokens.view(-1)
435+
h = self.tok_embeddings(tokens)
436+
else:
437+
h = tokens
435438

436-
freqs_cis = self.freqs_cis[0 : self.model_args.max_seq_len]
439+
seqlen = self.model_args.max_seq_len
440+
freqs_cis = self.freqs_cis[0:seqlen]
437441
for layer in self.layers.values():
438442
h = layer(h, freqs_cis)
439443

440444
h = self.norm(h) if self.norm else h
441-
# unfold batch and sequence dimension
442-
bs, seqlen = tokens.shape
443-
h = h.view(bs, seqlen, self.model_args.dim)
444-
output = self.output(h).float() if self.output else h
445+
if self.output:
446+
# unfold batch and sequence dimension
447+
h = h.view(-1, seqlen, self.model_args.dim)
448+
output = self.output(h).float()
449+
else:
450+
output = h
445451
return output
446452

447453
@classmethod

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,18 +350,14 @@ def apply_tp(model, world_mesh, parallel_dims, job_config: JobConfig):
350350
{
351351
"tok_embeddings": RowwiseParallel(
352352
input_layouts=Replicate(),
353+
output_layouts=Shard(0),
353354
),
354355
"output": col_parallel_strategy(
355356
input_layouts=Shard(0),
356357
output_layouts=Shard(-1) if loss_parallel else Replicate(),
357358
use_local_output=not loss_parallel,
358359
),
359360
"norm": SequenceParallel(sequence_dim=0),
360-
"layers.0": PrepareModuleInput(
361-
input_layouts=(Replicate(), None),
362-
desired_input_layouts=(Shard(0), None),
363-
use_local_output=True,
364-
),
365361
},
366362
)
367363

0 commit comments

Comments
 (0)