Skip to content

Commit 014cc7e

Browse files
committed
remove folding and unfolding of sequence dim in model.py
ghstack-source-id: 5d299ad Pull Request resolved: #190
1 parent b4ab627 commit 014cc7e

File tree

2 files changed

+11
-23
lines changed

2 files changed

+11
-23
lines changed

torchtrain/models/llama/model.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -226,10 +226,7 @@ def forward(
226226
torch.Tensor: Output tensor after attention.
227227
228228
"""
229-
seqlen, _ = freqs_cis.shape
230-
bs_seqlen, _ = x.shape
231-
bsz = bs_seqlen // seqlen
232-
229+
bsz, seqlen, _ = x.shape
233230
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
234231

235232
xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
@@ -255,8 +252,7 @@ def forward(
255252
output = output.transpose(
256253
1, 2
257254
).contiguous() # (bs, seqlen, n_local_heads, head_dim)
258-
# output stay folded with batch and sequence dimension
259-
output = output.view(bsz * seqlen, -1)
255+
output = output.view(bsz, seqlen, -1)
260256
return self.wo(output)
261257

262258

@@ -487,17 +483,9 @@ def forward(self, tokens: torch.Tensor):
487483
488484
"""
489485
h, freqs_cis = self.embeddings(tokens)
490-
# fold batch and sequence dimension for more efficient allgather/reduce_scatter
491-
h = h.view(-1, self.model_args.dim)
492-
493486
for layer in self.layers:
494487
h = layer(h, freqs_cis)
495-
496488
h = self.norm(h)
497-
# unfold batch and sequence dimension
498-
bsz = tokens.shape[0]
499-
bs_seqlen = h.shape[0]
500-
h = h.view(bsz, bs_seqlen // bsz, self.model_args.dim)
501489
output = self.output(h).float()
502490
return output
503491

torchtrain/parallelisms/parallelize_llama.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,18 +153,18 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
153153
input_layouts=Replicate(),
154154
),
155155
"output": col_parallel_strategy(
156-
input_layouts=Shard(0),
156+
input_layouts=Shard(1),
157157
output_layouts=(
158158
Shard(-1)
159159
if parallel_dims.loss_parallel_enabled
160160
else Replicate()
161161
),
162162
use_local_output=not parallel_dims.loss_parallel_enabled,
163163
),
164-
"norm": SequenceParallel(sequence_dim=0),
164+
"norm": SequenceParallel(),
165165
"layers.0": PrepareModuleInput(
166166
input_layouts=(Replicate(), None),
167-
desired_input_layouts=(Shard(0), None),
167+
desired_input_layouts=(Shard(1), None),
168168
use_local_output=True,
169169
),
170170
},
@@ -174,22 +174,22 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
174174
for layer_id, transformer_block in enumerate(model.layers):
175175
layer_plan = {
176176
"attention": PrepareModuleInput(
177-
input_layouts=(Shard(0), None),
177+
input_layouts=(Shard(1), None),
178178
desired_input_layouts=(Replicate(), None),
179179
),
180180
"attention.wq": col_parallel_strategy(),
181181
"attention.wk": col_parallel_strategy(),
182182
"attention.wv": col_parallel_strategy(),
183-
"attention.wo": row_parallel_strategy(output_layouts=Shard(0)),
184-
"attention_norm": SequenceParallel(sequence_dim=0),
183+
"attention.wo": row_parallel_strategy(output_layouts=Shard(1)),
184+
"attention_norm": SequenceParallel(),
185185
"feed_forward": PrepareModuleInput(
186-
input_layouts=(Shard(0),),
186+
input_layouts=(Shard(1),),
187187
desired_input_layouts=(Replicate(),),
188188
),
189189
"feed_forward.w1": col_parallel_strategy(),
190-
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(0)),
190+
"feed_forward.w2": row_parallel_strategy(output_layouts=Shard(1)),
191191
"feed_forward.w3": col_parallel_strategy(),
192-
"ffn_norm": SequenceParallel(sequence_dim=0),
192+
"ffn_norm": SequenceParallel(),
193193
}
194194

195195
# Adjust attention module to use the local number of heads

0 commit comments

Comments
 (0)