Skip to content

Commit 05fb65f

Browse files
committed
Add Pipeline Parallel support
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now - supports 1D parallelism currently. WIP: support 2D/3D parallel and clean up seed-checkpoint ux ghstack-source-id: 7055ffe Pull Request resolved: #161
1 parent 7286124 commit 05fb65f

File tree

4 files changed

+92
-14
lines changed

4 files changed

+92
-14
lines changed

run_llama_train.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ TRAINER_DIR=${1:-/home/$USER/local/torchtrain}
1111
# e.g.
1212
# LOG_RANK=0,1 NGPU=4 ./run_llama_train.sh
1313

14-
NGPU=${NGPU:-"8"}
14+
NGPU=${NGPU:-"2"}
1515

1616
# by default log just rank 0 output,
1717
LOG_RANK=${LOG_RANK:-0}

torchtrain/parallelisms/parallelize_llama.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from typing import Tuple
99

1010
import torch
11-
11+
from pippy import annotate_split_points, pipeline, SplitPoint
1212
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
1313
from torch.distributed._tensor import Replicate, Shard
1414
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
@@ -134,7 +134,29 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
134134
the model must fit on GPU or CPU memory.
135135
"""
136136
if parallel_dims.pp_enabled:
137-
raise NotImplementedError("PP not implemented yet.")
137+
pp_mesh = world_mesh["pp"]
138+
stage_idx = pp_mesh.get_local_rank()
139+
layers_per_rank = len(model.layers) // parallel_dims.pp
140+
for i in range(1, parallel_dims.pp):
141+
annotate_split_points(
142+
model,
143+
{f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING},
144+
)
145+
146+
# Get example input
147+
label_shape = input_shape = (8, 2048) # TODO
148+
input_ids = torch.randint(
149+
model.vocab_size, input_shape, dtype=torch.int64, device="meta"
150+
)
151+
labels = torch.randint(
152+
model.vocab_size, label_shape, dtype=torch.int64, device="meta"
153+
)
154+
print("input_ids: ", input_ids.shape, input_ids.dtype)
155+
print("labels: ", labels.shape, labels.dtype)
156+
157+
# Create a pipeline representation from the model
158+
pipe = pipeline(model, parallel_dims.pp, example_args=(input_ids,))
159+
model = pipe.get_stage_module(stage_idx)
138160

139161
if parallel_dims.tp_enabled:
140162
tp_mesh = world_mesh["tp"]
@@ -233,4 +255,8 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
233255
logger.info(f"Applied {ac_mode} activation checkpointing to the model")
234256
logger.info("Applied FSDP to the model")
235257

258+
if parallel_dims.pp_enabled:
259+
setattr(pipe.split_gm, f"submod_{stage_idx}", model)
260+
return pipe
261+
236262
return model

train.py

Lines changed: 62 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515

1616
import torch
1717
import torch.nn.functional as F
18+
from pippy.PipelineSchedule import PipelineScheduleGPipe
19+
from pippy.PipelineStage import PipelineStage
1820
from torch.distributed.elastic.multiprocessing.errors import record
1921
from torch.distributed.tensor.parallel import loss_parallel
2022

@@ -120,7 +122,9 @@ def main(job_config: JobConfig):
120122
world_size=world_size,
121123
enable_loss_parallel=job_config.training.enable_loss_parallel,
122124
)
123-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
125+
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
126+
# torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
127+
torch.cuda.set_device(device)
124128
init_distributed(job_config)
125129

126130
world_mesh = parallel_dims.build_mesh(device_type="cuda")
@@ -139,6 +143,14 @@ def main(job_config: JobConfig):
139143
dp_rank = dp_mesh.get_local_rank()
140144
else:
141145
dp_degree, dp_rank = 1, 0
146+
147+
if parallel_dims.pp_enabled:
148+
pp_mesh = world_mesh["pp"]
149+
pp_degree = pp_mesh.size()
150+
pp_rank = pp_mesh.get_local_rank()
151+
else:
152+
pp_degree, pp_rank = 1, 0
153+
142154
data_loader = build_dataloader_fn(
143155
job_config.training.dataset,
144156
job_config.training.dataset_path,
@@ -197,14 +209,38 @@ def loss_fn(pred, labels):
197209
model = models_parallelize_fns[model_name](
198210
model, world_mesh, parallel_dims, job_config
199211
)
200-
# allocate sharded model on GPU and initialize weights via DTensor
212+
if parallel_dims.pp_enabled:
213+
pipe_meta = model
214+
model = pipe_meta.get_stage_module(pp_rank)
215+
201216
model.to_empty(device="cuda")
202-
model.init_weights()
217+
218+
# TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
219+
# there are virtual stages
220+
if parallel_dims.pp_enabled:
221+
stage = PipelineStage(
222+
pipe=pipe_meta,
223+
stage_index=pp_rank,
224+
device=device,
225+
group=pp_mesh.get_group(),
226+
)
227+
pp_schedule = PipelineScheduleGPipe(
228+
stage,
229+
n_microbatches=parallel_dims.pp,
230+
loss_fn=loss_fn,
231+
)
232+
else:
233+
# if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
234+
# and loading it to get initialization values. This is becuase the init_weights functions are written assuming
235+
# the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
236+
# becuase it can't find "embedding" layer, for example.
237+
238+
# allocate sharded model on GPU and initialize weights via DTensor
239+
model.init_weights()
203240

204241
# build optimizer after applying parallelisms to the model
205242
optimizer = build_optimizer(model, job_config)
206243
scheduler = get_lr_scheduler(optimizer, job_config)
207-
208244
metric_logger = build_metric_logger(job_config)
209245

210246
# torch.compile model for improved performance
@@ -274,13 +310,30 @@ def loss_fn(pred, labels):
274310

275311
input_ids = input_ids.cuda()
276312
labels = labels.cuda()
277-
278313
optimizer.zero_grad()
279314

280-
# forward / backward
281-
pred = model(input_ids)
282-
loss = loss_fn(pred, labels)
283-
loss.backward()
315+
if parallel_dims.pp_enabled:
316+
# pipeline parallel forward / backward inside step() call
317+
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
318+
319+
if pp_mesh.get_local_rank() == 0:
320+
pp_schedule.step(input_ids)
321+
elif is_last_stage:
322+
losses = []
323+
pp_schedule.step(target=labels, losses=losses)
324+
else:
325+
schedule.step()
326+
327+
# accumulate losses across pipeline microbatches
328+
current_loss = (
329+
torch.mean(torch.stack(losses)).item() if is_last_stage else -1.0
330+
)
331+
else:
332+
# forward / backward
333+
pred = model(input_ids)
334+
loss = loss_fn(pred, labels)
335+
loss.backward()
336+
current_loss = loss.item()
284337

285338
# clip gradients
286339
torch.nn.utils.clip_grad_norm_(
@@ -291,7 +344,6 @@ def loss_fn(pred, labels):
291344
optimizer.step()
292345
scheduler.step()
293346

294-
current_loss = loss.item()
295347
losses_since_last_log.append(current_loss)
296348

297349
# log metrics

train_configs/debug_model.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ max_norm = 1.0 # grad norm clipping
3838
steps = 10
3939
data_parallel_degree = -1
4040
tensor_parallel_degree = 1
41-
pipeline_parallel_degree = 1
41+
pipeline_parallel_degree = 2
4242
fp8_linear = ""
4343
compile = false
4444
dataset = "alpaca" # supported datasets = alpaca (52K), minipile (1M), c4 (177M)

0 commit comments

Comments
 (0)