21
21
from torchtitan .metrics import build_gpu_memory_monitor , build_metric_logger
22
22
from torchtitan .models import model_name_to_cls , model_name_to_tokenizer , models_config
23
23
from torchtitan .optimizer import build_lr_schedulers , build_optimizers
24
- from torchtitan .parallelisms import (
25
- build_pipeline_schedule ,
26
- models_parallelize_fns ,
27
- models_pipelining_fns ,
28
- ParallelDims ,
29
- )
24
+ from torchtitan .parallelisms import models_parallelize_fns , ParallelDims
30
25
from torchtitan .profiling import maybe_enable_memory_snapshot , maybe_enable_profiling
31
26
32
27
@@ -51,6 +46,9 @@ def main(job_config: JobConfig):
51
46
init_logger ()
52
47
logger .info (f"Starting job: { job_config .job .description } " )
53
48
49
+ if job_config .experimental .pipeline_parallel_degree > 1 :
50
+ raise RuntimeError ("To use Pipeline Parallelism, please run train.py" )
51
+
54
52
# used for colorful printing
55
53
color = utils .Color if job_config .metrics .enable_color_printing else utils .NoColor
56
54
@@ -82,9 +80,6 @@ def main(job_config: JobConfig):
82
80
else :
83
81
dp_degree , dp_rank = 1 , 0
84
82
85
- if parallel_dims .pp_enabled :
86
- pp_mesh = world_mesh ["pp" ]
87
-
88
83
model_name = job_config .model .name
89
84
90
85
# build tokenizer
@@ -115,17 +110,17 @@ def main(job_config: JobConfig):
115
110
116
111
logger .info (f"Building { model_name } { job_config .model .flavor } with { model_config } " )
117
112
with torch .device ("meta" ):
118
- whole_model = model_cls .from_model_args (model_config )
113
+ model = model_cls .from_model_args (model_config )
119
114
120
115
# a no-op hander if float8 is not enabled
121
116
float8_handler = Float8Handler (job_config , parallel_dims )
122
117
# swap to Float8Linear based on float8 configs
123
- float8_handler .convert_to_float8_training (whole_model )
118
+ float8_handler .convert_to_float8_training (model )
124
119
125
120
# log model size
126
- model_param_count = utils .get_num_params (whole_model )
121
+ model_param_count = utils .get_num_params (model )
127
122
num_flop_per_token = utils .get_num_flop_per_token (
128
- utils .get_num_params (whole_model , exclude_embedding = True ),
123
+ utils .get_num_params (model , exclude_embedding = True ),
129
124
model_config ,
130
125
job_config .training .seq_len ,
131
126
)
@@ -134,41 +129,10 @@ def main(job_config: JobConfig):
134
129
f"{ color .red } size: { model_param_count :,} total parameters{ color .reset } "
135
130
)
136
131
137
- if parallel_dims .pp_enabled :
138
- stages , model_parts = models_pipelining_fns [model_name ](
139
- whole_model , pp_mesh , parallel_dims , job_config , device , model_config
140
- )
141
- else :
142
- # In 1D/2D cases or PP with simple schedules, model_parts is just one item
143
- # for PP with looped schedules, each item is one stage-model-chunk
144
- # we iterate all model_parts for applying SPMD parallelism, compilation, optimizer, and checkpointing
145
- model_parts = [whole_model ]
146
-
147
132
# apply PT-D DP/TP parallelisms and activation checkpointing
148
- model_parts = [
149
- models_parallelize_fns [model_name ](m , world_mesh , parallel_dims , job_config )
150
- for m in model_parts
151
- ]
152
-
153
- init_device = "cpu" if job_config .checkpoint .create_seed_checkpoint else "cuda"
154
- for model in model_parts :
155
- model .to_empty (device = init_device )
156
-
157
- # loss fn can be shared by pipeline-parallel or non-pp execution
158
- def loss_fn (pred , labels ):
159
- return torch .nn .functional .cross_entropy (
160
- pred .flatten (0 , 1 ), labels .flatten (0 , 1 )
161
- )
162
-
163
- if parallel_dims .pp_enabled :
164
- pp_schedule = build_pipeline_schedule (
165
- job_config , parallel_dims , stages , loss_fn
166
- )
167
- else :
168
- # If PP is enabled, we can't rely on init_weights, because some layers are missing.
169
- # In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation.
170
- # allocate sharded model on GPU and initialize weights via DTensor
171
- whole_model .init_weights ()
133
+ models_parallelize_fns [model_name ](model , world_mesh , parallel_dims , job_config )
134
+ model .to_empty (device = "cuda" )
135
+ model .init_weights ()
172
136
173
137
gpu_mem_stats = gpu_memory_monitor .get_peak_stats ()
174
138
logger .info (
@@ -178,43 +142,26 @@ def loss_fn(pred, labels):
178
142
)
179
143
180
144
# build optimizer after applying parallelisms to the model
181
- optimizers = build_optimizers (model_parts , job_config )
145
+ optimizers = build_optimizers ([ model ] , job_config )
182
146
lr_schedulers = build_lr_schedulers (optimizers .optimizers , job_config )
183
147
184
148
train_state = TrainState ()
185
149
186
150
# train loop
187
- for model in model_parts :
188
- model .train ()
151
+ model .train ()
189
152
190
153
# load initial checkpoint
191
154
checkpoint = CheckpointManager (
192
155
dataloader = data_loader ,
193
- model_parts = model_parts ,
156
+ model_parts = [ model ] ,
194
157
optimizers = optimizers .optimizers ,
195
158
lr_schedulers = lr_schedulers .schedulers ,
196
159
states = {"train_state" : train_state },
197
160
job_config = job_config ,
198
161
)
199
-
200
- if job_config .checkpoint .create_seed_checkpoint :
201
- assert (
202
- world_size == 1
203
- ), "Must create seed-checkpoint using one gpu, to disable sharding"
204
- checkpoint .save (curr_step = 0 , force = True )
205
- logger .info ("Created seed checkpoint" )
206
- return
207
-
208
162
checkpoint_loaded = checkpoint .load ()
209
163
210
- if parallel_dims .pp_enabled and not checkpoint_loaded :
211
- raise RuntimeError (
212
- "Pipeline Parallelism requires meta-initialization and loading seed checkpoint. "
213
- "Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
214
- )
215
-
216
164
metric_logger = build_metric_logger (job_config , parallel_dims )
217
-
218
165
# plot losses loaded from checkpoint (if any) to TensorBoard
219
166
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
220
167
# This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq
@@ -271,43 +218,23 @@ def loss_fn(pred, labels):
271
218
labels = labels .cuda ()
272
219
optimizers .zero_grad ()
273
220
274
- if parallel_dims .pp_enabled :
275
- # pipeline parallel forward / backward inside step() call
276
- is_last_stage = pp_mesh .get_local_rank () == pp_mesh .size () - 1
277
-
278
- with train_context ():
279
- if pp_mesh .get_local_rank () == 0 :
280
- pp_schedule .step (input_ids )
281
- elif is_last_stage :
282
- losses = []
283
- pp_schedule .step (target = labels , losses = losses )
284
- else :
285
- pp_schedule .step ()
286
-
287
- # accumulate losses across pipeline microbatches
288
- loss = (
289
- torch .mean (torch .stack (losses ))
290
- if is_last_stage
291
- else torch .Tensor ([- 1.0 ])
221
+ with train_context ():
222
+ pred = model (input_ids )
223
+ loss = torch .nn .functional .cross_entropy (
224
+ pred .flatten (0 , 1 ), labels .flatten (0 , 1 )
292
225
)
293
- else :
294
- # Non-PP forward / backward
295
- with train_context ():
296
- pred = model (input_ids )
297
- loss = loss_fn (pred , labels )
298
- # pred.shape=(bs, seq_len, vocab_size)
299
- # need to free to before bwd to avoid peaking memory
300
- del pred
301
- loss .backward ()
226
+ # pred.shape=(bs, seq_len, vocab_size)
227
+ # need to free to before bwd to avoid peaking memory
228
+ del pred
229
+ loss .backward ()
302
230
303
231
# clip gradients
304
- for model in model_parts :
305
- torch .nn .utils .clip_grad_norm_ (
306
- model .parameters (), job_config .training .max_norm , foreach = True
307
- )
232
+ torch .nn .utils .clip_grad_norm_ (
233
+ model .parameters (), job_config .training .max_norm , foreach = True
234
+ )
308
235
309
236
# sync float8 amaxes and scales
310
- float8_handler .sync_float8_amax_and_scale_history (model_parts )
237
+ float8_handler .sync_float8_amax_and_scale_history (model )
311
238
312
239
# optimizer step
313
240
checkpoint .maybe_wait_for_staging ()
@@ -316,7 +243,7 @@ def loss_fn(pred, labels):
316
243
317
244
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
318
245
# it issues a single all-reduce for all parameters at once for better performance
319
- float8_handler .precompute_float8_dynamic_scale_for_fsdp (model_parts )
246
+ float8_handler .precompute_float8_dynamic_scale_for_fsdp (model )
320
247
321
248
losses_since_last_log .append (loss )
322
249
0 commit comments