Skip to content

Commit 296c875

Browse files
committed
WIP integrate pippy's tracer frontend
Loss now runs and propagates to logger, but optimizer isn't working ghstack-source-id: 56b0ef0 Pull Request resolved: #161
1 parent 8f2e31a commit 296c875

File tree

4 files changed

+110
-24
lines changed

4 files changed

+110
-24
lines changed

run_llama_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
1111
# e.g.
1212
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh
1313

14-
NGPU=${NGPU:-"8"}
14+
NGPU=${NGPU:-"2"}
1515

1616
# by default log just rank 0 output,
1717
LOG_RANK=${LOG_RANK:-0}

torchtrain/parallelisms/parallelize_llama.py

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Tuple
99

1010
import torch
11-
11+
from pippy import annotate_split_points, Pipe, PipeSplitWrapper
1212
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
1313
from torch.distributed._tensor import Replicate, Shard
1414
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
@@ -134,7 +134,31 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
134134
the model must fit on GPU or CPU memory.
135135
"""
136136
if parallel_dims.pp_enabled:
137-
raise NotImplementedError("PP not implemented yet.")
137+
pp_mesh = world_mesh["pp"]
138+
stage_idx = pp_mesh.get_local_rank()
139+
layers_per_rank = len(model.layers) // parallel_dims.pp
140+
for i in range(1, parallel_dims.pp):
141+
annotate_split_points(
142+
model,
143+
{
144+
f"layers.{i * layers_per_rank}": PipeSplitWrapper.SplitPoint.BEGINNING
145+
},
146+
)
147+
148+
# Get example input
149+
label_shape = input_shape = (8, 2048) # TODO
150+
input_ids = torch.randint(
151+
model.vocab_size, input_shape, dtype=torch.int64, device="meta"
152+
)
153+
labels = torch.randint(
154+
model.vocab_size, label_shape, dtype=torch.int64, device="meta"
155+
)
156+
print("input_ids: ", input_ids.shape, input_ids.dtype)
157+
print("labels: ", labels.shape, labels.dtype)
158+
159+
# Create a pipeline representation from the model
160+
pipe = Pipe.from_tracing(model, parallel_dims.pp, example_args=(input_ids,))
161+
model = pipe.get_stage_module(stage_idx)
138162

139163
if parallel_dims.tp_enabled:
140164
tp_mesh = world_mesh["tp"]
@@ -233,4 +257,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
233257
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
234258
logger.info("Applied FSDP to the model")
235259

260+
if parallel_dims.pp_enabled:
261+
setattr(pipe.split_gm, f"submod_{stage_idx}", model)
262+
return pipe
263+
236264
return model

train.py

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import torch
1717
import torch.nn.functional as F
18+
from pippy.PipelineSchedule import PipelineScheduleGPipe
19+
from pippy.PipelineStage import PipelineStage
1820
from torch.distributed.elastic.multiprocessing.errors import record
1921
from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
2022
from torch.distributed.tensor.parallel import loss_parallel
@@ -129,7 +131,9 @@ def main(job_config: JobConfig):
129131
world_size=world_size,
130132
enable_loss_parallel=job_config.training.enable_loss_parallel,
131133
)
132-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
134+
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
135+
# torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
136+
torch.cuda.set_device(device)
133137
init_distributed(job_config)
134138

135139
world_mesh = parallel_dims.build_mesh(device_type="cuda")
@@ -148,6 +152,14 @@ def main(job_config: JobConfig):
148152
dp_rank = dp_mesh.get_local_rank()
149153
else:
150154
dp_degree, dp_rank = 1, 0
155+
156+
if parallel_dims.pp_enabled:
157+
pp_mesh = world_mesh["pp"]
158+
pp_degree = pp_mesh.size()
159+
pp_rank = pp_mesh.get_local_rank()
160+
else:
161+
pp_degree, pp_rank = 1, 0
162+
151163
data_loader = build_dataloader_fn(
152164
job_config.training.dataset,
153165
job_config.training.dataset_path,
@@ -197,18 +209,54 @@ def main(job_config: JobConfig):
197209
model = models_parallelize_fns[model_name](
198210
model, world_mesh, parallel_dims, job_config
199211
)
200-
# allocate sharded model on GPU and initialize weights via DTensor
201-
model.to_empty(device="cuda")
202-
model.init_weights()
203-
204-
# build optimizer after applying parallelisms to the model
205-
optimizer = build_optimizer(model, job_config)
206-
scheduler = get_lr_scheduler(optimizer, job_config)
212+
if parallel_dims.pp_enabled:
213+
pipe_meta = model
214+
model = pipe_meta.get_stage_module(pp_rank)
207215

208216
# build grad scaler which is effective only when mixed precision training
209217
# is enabled with fp16 param dtype under FSDP
210218
scaler = build_grad_scaler(model)
211219

220+
def loss_fn(pred, labels):
221+
with (
222+
loss_parallel()
223+
if parallel_dims.loss_parallel_enabled
224+
else contextlib.nullcontext()
225+
):
226+
loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
227+
228+
# backward on scaled loss to create scaled gradients
229+
scaler.scale(loss)
230+
return loss
231+
232+
# TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
233+
# there are virtual stages
234+
if parallel_dims.pp_enabled:
235+
stage = PipelineStage(
236+
pipe=pipe_meta,
237+
stage_index=pp_rank,
238+
device=device,
239+
group=pp_mesh.get_group(),
240+
)
241+
pp_schedule = PipelineScheduleGPipe(
242+
stage,
243+
n_microbatches=parallel_dims.pp,
244+
loss_fn=loss_fn,
245+
)
246+
model.to_empty(device="cuda")
247+
else:
248+
# if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
249+
# and loading it to get initialization values. This is becuase the init_weights functions are written assuming
250+
# the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
251+
# becuase it can't find "embedding" layer, for example.
252+
253+
# allocate sharded model on GPU and initialize weights via DTensor
254+
model.to_empty(device="cuda")
255+
model.init_weights()
256+
257+
# build optimizer after applying parallelisms to the model
258+
optimizer = build_optimizer(model, job_config)
259+
scheduler = get_lr_scheduler(optimizer, job_config)
212260
metric_logger = build_metric_logger(job_config)
213261

214262
# torch.compile model for improved performance
@@ -278,21 +326,32 @@ def main(job_config: JobConfig):
278326

279327
input_ids = input_ids.cuda()
280328
labels = labels.cuda()
281-
282329
optimizer.zero_grad()
283330

284-
# forward
285-
pred = model(input_ids)
331+
if parallel_dims.pp_enabled:
332+
# pipeline F/Loss/B
333+
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
286334

287-
with (
288-
loss_parallel()
289-
if parallel_dims.loss_parallel_enabled
290-
else contextlib.nullcontext()
291-
):
292-
loss = F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
335+
if pp_mesh.get_local_rank() == 0:
336+
pp_schedule.step(input_ids)
337+
elif is_last_stage:
338+
losses = []
339+
pp_schedule.step(target=labels, losses=losses)
340+
else:
341+
schedule.step()
342+
343+
# accumulate losses across pipeline microbatches
344+
current_loss = (
345+
torch.mean(torch.stack(losses)).item() if is_last_stage else -1.0
346+
)
347+
else:
348+
# non-pipeline F/Loss/B
349+
pred = model(input_ids)
350+
351+
loss = loss_fn(pred, labels)
352+
loss.backward()
293353

294-
# backward on scaled loss to create scaled gradients
295-
scaler.scale(loss).backward()
354+
current_loss = loss.item()
296355

297356
# clip gradients (after unscaling gradients of the optimizer's params)
298357
scaler.unscale_(optimizer)
@@ -309,7 +368,6 @@ def main(job_config: JobConfig):
309368
# updates the scale for next iteration
310369
scaler.update()
311370

312-
current_loss = loss.item()
313371
losses_since_last_log.append(current_loss)
314372

315373
# log metrics

train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ max_norm = 1.0 # grad norm clipping
3838
steps = 10
3939
data_parallel_degree = -1
4040
tensor_parallel_degree = 1
41-
pipeline_parallel_degree = 1
41+
pipeline_parallel_degree = 2
4242
fp8_linear = ""
4343
compile = false
4444
dataset = "alpaca" # supported datasets = alpaca (52K), minipile (1M), c4 (177M)

0 commit comments

Comments
 (0)