Skip to content

Commit 81e21d0

Browse files
committed
[do not review][example] train.py without pp
ghstack-source-id: 2bc35b0 Pull Request resolved: #501
1 parent 9ac0d15 commit 81e21d0

File tree

1 file changed

+27
-100
lines changed

1 file changed

+27
-100
lines changed

train.py

Lines changed: 27 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,7 @@
2121
from torchtitan.metrics import build_gpu_memory_monitor, build_metric_logger
2222
from torchtitan.models import model_name_to_cls, model_name_to_tokenizer, models_config
2323
from torchtitan.optimizer import build_lr_schedulers, build_optimizers
24-
from torchtitan.parallelisms import (
25-
build_pipeline_schedule,
26-
models_parallelize_fns,
27-
models_pipelining_fns,
28-
ParallelDims,
29-
)
24+
from torchtitan.parallelisms import models_parallelize_fns, ParallelDims
3025
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling
3126

3227

@@ -51,6 +46,9 @@ def main(job_config: JobConfig):
5146
init_logger()
5247
logger.info(f"Starting job: {job_config.job.description}")
5348

49+
if job_config.experimental.pipeline_parallel_degree > 1:
50+
raise RuntimeError("To use Pipeline Parallelism, please run train.py")
51+
5452
# used for colorful printing
5553
color = utils.Color if job_config.metrics.enable_color_printing else utils.NoColor
5654

@@ -82,9 +80,6 @@ def main(job_config: JobConfig):
8280
else:
8381
dp_degree, dp_rank = 1, 0
8482

85-
if parallel_dims.pp_enabled:
86-
pp_mesh = world_mesh["pp"]
87-
8883
model_name = job_config.model.name
8984

9085
# build tokenizer
@@ -115,17 +110,17 @@ def main(job_config: JobConfig):
115110

116111
logger.info(f"Building {model_name} {job_config.model.flavor} with {model_config}")
117112
with torch.device("meta"):
118-
whole_model = model_cls.from_model_args(model_config)
113+
model = model_cls.from_model_args(model_config)
119114

120115
# a no-op hander if float8 is not enabled
121116
float8_handler = Float8Handler(job_config, parallel_dims)
122117
# swap to Float8Linear based on float8 configs
123-
float8_handler.convert_to_float8_training(whole_model)
118+
float8_handler.convert_to_float8_training(model)
124119

125120
# log model size
126-
model_param_count = utils.get_num_params(whole_model)
121+
model_param_count = utils.get_num_params(model)
127122
num_flop_per_token = utils.get_num_flop_per_token(
128-
utils.get_num_params(whole_model, exclude_embedding=True),
123+
utils.get_num_params(model, exclude_embedding=True),
129124
model_config,
130125
job_config.training.seq_len,
131126
)
@@ -134,41 +129,10 @@ def main(job_config: JobConfig):
134129
f"{color.red}size: {model_param_count:,} total parameters{color.reset}"
135130
)
136131

137-
if parallel_dims.pp_enabled:
138-
stages, model_parts = models_pipelining_fns[model_name](
139-
whole_model, pp_mesh, parallel_dims, job_config, device, model_config
140-
)
141-
else:
142-
# In 1D/2D cases or PP with simple schedules, model_parts is just one item
143-
# for PP with looped schedules, each item is one stage-model-chunk
144-
# we iterate all model_parts for applying SPMD parallelism, compilation, optimizer, and checkpointing
145-
model_parts = [whole_model]
146-
147132
# apply PT-D DP/TP parallelisms and activation checkpointing
148-
model_parts = [
149-
models_parallelize_fns[model_name](m, world_mesh, parallel_dims, job_config)
150-
for m in model_parts
151-
]
152-
153-
init_device = "cpu" if job_config.checkpoint.create_seed_checkpoint else "cuda"
154-
for model in model_parts:
155-
model.to_empty(device=init_device)
156-
157-
# loss fn can be shared by pipeline-parallel or non-pp execution
158-
def loss_fn(pred, labels):
159-
return torch.nn.functional.cross_entropy(
160-
pred.flatten(0, 1), labels.flatten(0, 1)
161-
)
162-
163-
if parallel_dims.pp_enabled:
164-
pp_schedule = build_pipeline_schedule(
165-
job_config, parallel_dims, stages, loss_fn
166-
)
167-
else:
168-
# If PP is enabled, we can't rely on init_weights, because some layers are missing.
169-
# In the future, we may make init_weights handle missing layers, but also have to consider RNG seed propagation.
170-
# allocate sharded model on GPU and initialize weights via DTensor
171-
whole_model.init_weights()
133+
models_parallelize_fns[model_name](model, world_mesh, parallel_dims, job_config)
134+
model.to_empty(device="cuda")
135+
model.init_weights()
172136

173137
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
174138
logger.info(
@@ -178,43 +142,26 @@ def loss_fn(pred, labels):
178142
)
179143

180144
# build optimizer after applying parallelisms to the model
181-
optimizers = build_optimizers(model_parts, job_config)
145+
optimizers = build_optimizers([model], job_config)
182146
lr_schedulers = build_lr_schedulers(optimizers.optimizers, job_config)
183147

184148
train_state = TrainState()
185149

186150
# train loop
187-
for model in model_parts:
188-
model.train()
151+
model.train()
189152

190153
# load initial checkpoint
191154
checkpoint = CheckpointManager(
192155
dataloader=data_loader,
193-
model_parts=model_parts,
156+
model_parts=[model],
194157
optimizers=optimizers.optimizers,
195158
lr_schedulers=lr_schedulers.schedulers,
196159
states={"train_state": train_state},
197160
job_config=job_config,
198161
)
199-
200-
if job_config.checkpoint.create_seed_checkpoint:
201-
assert (
202-
world_size == 1
203-
), "Must create seed-checkpoint using one gpu, to disable sharding"
204-
checkpoint.save(curr_step=0, force=True)
205-
logger.info("Created seed checkpoint")
206-
return
207-
208162
checkpoint_loaded = checkpoint.load()
209163

210-
if parallel_dims.pp_enabled and not checkpoint_loaded:
211-
raise RuntimeError(
212-
"Pipeline Parallelism requires meta-initialization and loading seed checkpoint. "
213-
"Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
214-
)
215-
216164
metric_logger = build_metric_logger(job_config, parallel_dims)
217-
218165
# plot losses loaded from checkpoint (if any) to TensorBoard
219166
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
220167
# This can be avoided by setting checkpoint.interval to be a multiple of metrics.log_freq
@@ -271,43 +218,23 @@ def loss_fn(pred, labels):
271218
labels = labels.cuda()
272219
optimizers.zero_grad()
273220

274-
if parallel_dims.pp_enabled:
275-
# pipeline parallel forward / backward inside step() call
276-
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
277-
278-
with train_context():
279-
if pp_mesh.get_local_rank() == 0:
280-
pp_schedule.step(input_ids)
281-
elif is_last_stage:
282-
losses = []
283-
pp_schedule.step(target=labels, losses=losses)
284-
else:
285-
pp_schedule.step()
286-
287-
# accumulate losses across pipeline microbatches
288-
loss = (
289-
torch.mean(torch.stack(losses))
290-
if is_last_stage
291-
else torch.Tensor([-1.0])
221+
with train_context():
222+
pred = model(input_ids)
223+
loss = torch.nn.functional.cross_entropy(
224+
pred.flatten(0, 1), labels.flatten(0, 1)
292225
)
293-
else:
294-
# Non-PP forward / backward
295-
with train_context():
296-
pred = model(input_ids)
297-
loss = loss_fn(pred, labels)
298-
# pred.shape=(bs, seq_len, vocab_size)
299-
# need to free to before bwd to avoid peaking memory
300-
del pred
301-
loss.backward()
226+
# pred.shape=(bs, seq_len, vocab_size)
227+
# need to free to before bwd to avoid peaking memory
228+
del pred
229+
loss.backward()
302230

303231
# clip gradients
304-
for model in model_parts:
305-
torch.nn.utils.clip_grad_norm_(
306-
model.parameters(), job_config.training.max_norm, foreach=True
307-
)
232+
torch.nn.utils.clip_grad_norm_(
233+
model.parameters(), job_config.training.max_norm, foreach=True
234+
)
308235

309236
# sync float8 amaxes and scales
310-
float8_handler.sync_float8_amax_and_scale_history(model_parts)
237+
float8_handler.sync_float8_amax_and_scale_history(model)
311238

312239
# optimizer step
313240
checkpoint.maybe_wait_for_staging()
@@ -316,7 +243,7 @@ def loss_fn(pred, labels):
316243

317244
# calculate float8 dynamic amax/scale for all-parameter for FSDP2
318245
# it issues a single all-reduce for all parameters at once for better performance
319-
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model_parts)
246+
float8_handler.precompute_float8_dynamic_scale_for_fsdp(model)
320247

321248
losses_since_last_log.append(loss)
322249

0 commit comments

Comments
 (0)