Skip to content

Commit 65a5be7

Browse files
committed
Add Pipeline Parallel (and 2D PP+FSDP) support
- uses pipeline tracer frontend to extract a graph and partition it into chunks per stage - hardcodes one schedule (1F1B) for now (need to expose option to switch schedule and test other schedules) - supports 2D parallelism currently, 3D (TP) is work in progress ghstack-source-id: feb45e1 Pull Request resolved: #161
1 parent f72a2a0 commit 65a5be7

File tree

4 files changed

+202
-27
lines changed

4 files changed

+202
-27
lines changed

.github/workflows/unit_test_4gpu.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ jobs:
3636
pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121
3737
python -m pip install -r requirements.txt
3838
python -m pip install -r dev-requirements.txt
39+
python -m pip install git+https://github.com/pytorch/pippy
3940
- name: Run test_runner.py
4041
run: python ./test_runner.py
4142
- name: Upload Coverage to Codecov

test_runner.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class OverrideDefinitions:
2626

2727
override_args: Sequence[Sequence[str]] = tuple(tuple(" "))
2828
test_descr: str = "default"
29+
requires_seed_checkpoint: bool = False
30+
ngpu: int = 4
2931

3032

3133
CONFIG_DIR = "./train_configs"
@@ -85,25 +87,72 @@ class OverrideDefinitions:
8587
],
8688
"Checkpoint Integration Test - Save Model Weights Only bf16",
8789
),
90+
OverrideDefinitions(
91+
[
92+
[
93+
"--checkpoint.enable_checkpoint",
94+
f"--checkpoint.folder {test_checkpoint_dir}_pp",
95+
"--training.pipeline_parallel_degree 2",
96+
"--training.data_parallel_degree 1",
97+
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
98+
],
99+
],
100+
"PP 1D test",
101+
requires_seed_checkpoint=True,
102+
ngpu=2,
103+
),
104+
OverrideDefinitions(
105+
[
106+
[
107+
"--checkpoint.enable_checkpoint",
108+
f"--checkpoint.folder {test_checkpoint_dir}_pp_dp",
109+
"--training.pipeline_parallel_degree 2",
110+
"--training.data_parallel_degree 2",
111+
"--model.norm_type rmsnorm", # TODO fix fused_rmsnorm issue
112+
],
113+
],
114+
"PP+DP 2D test",
115+
requires_seed_checkpoint=True,
116+
),
88117
]
89118

90119

120+
def _run_cmd(cmd):
121+
return subprocess.run(
122+
[cmd],
123+
stdout=subprocess.PIPE,
124+
stderr=subprocess.STDOUT,
125+
text=True,
126+
shell=True,
127+
)
128+
129+
91130
def run_test(test_flavor: OverrideDefinitions, full_path: str):
92131
# run_test supports sequence of tests.
93132
for override_arg in test_flavor.override_args:
94-
cmd = f"CONFIG_FILE={full_path} NGPU=4 LOG_RANK=0,1,2,3 ./run_llama_train.sh"
133+
134+
cmd = f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK=0,1,2,3 ./run_llama_train.sh"
95135
if override_arg:
96136
cmd += " " + " ".join(override_arg)
97137
print(
98138
f"=====Integration test, flavor : {test_flavor.test_descr}, command : {cmd}====="
99139
)
100-
result = subprocess.run(
101-
[cmd],
102-
stdout=subprocess.PIPE,
103-
stderr=subprocess.STDOUT,
104-
text=True,
105-
shell=True,
106-
)
140+
141+
if test_flavor.requires_seed_checkpoint:
142+
checkpoint_folder_arg = None
143+
for arg in override_arg:
144+
if "--checkpoint.folder" in arg:
145+
checkpoint_folder_arg = arg
146+
assert (
147+
checkpoint_folder_arg is not None
148+
), "Can't use seed checkpoint if folder is not specified"
149+
print("Creating seed checkpoint")
150+
result = _run_cmd(
151+
f"CONFIG_FILE={full_path} ./create_seed_checkpoint.sh {checkpoint_folder_arg}"
152+
)
153+
print(result.stdout)
154+
155+
result = _run_cmd(cmd)
107156
print(result.stdout)
108157
if result.returncode != 0:
109158
raise Exception(

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 56 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,15 @@
1212

1313
import torch
1414

15+
# TODO(whc) this can be removed after pippy migration into pytorch core is complete.
16+
try:
17+
from pippy import pipeline, SplitPoint
18+
except ImportError as exc:
19+
raise ImportError(
20+
"pippy is not installed. Please install it to use pipeline parallelism. "
21+
"`pip install git+https://github.com/pytorch/pippy`"
22+
) from exc
23+
1524
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
1625
from torch.distributed._tensor import Replicate, Shard
1726
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
@@ -129,15 +138,48 @@ def get_tp_parallel_strategy(
129138
return RowwiseParallel, ColwiseParallel
130139

131140

141+
def apply_pipeline_parallelism(model, world_mesh, parallel_dims, job_config: JobConfig):
142+
assert (
143+
parallel_dims.pp_enabled
144+
), "can't apply pipeline parallelism if it is not enabled"
145+
146+
if job_config.model.norm_type == "fused_rmsnorm":
147+
# TODO(whc) - torch._dynamo.exc.Unsupported: Illegal getattr invocation stride in strict mode
148+
# coming from ` if dy.stride(-1) != 1:` in fused_rmsnorm
149+
raise NotImplementedError(
150+
"fused_rmsnorm not yet compatible with Pipeline Tracer (strides error). Please use layernorm or rmsnorm."
151+
)
152+
pp_mesh = world_mesh["pp"]
153+
stage_idx = pp_mesh.get_local_rank()
154+
layers_per_rank = len(model.layers) // parallel_dims.pp
155+
split_spec = {
156+
f"layers.{i * layers_per_rank}": SplitPoint.BEGINNING
157+
for i in range(1, parallel_dims.pp)
158+
}
159+
# Get example input
160+
label_shape = input_shape = (8, 2048) # TODO
161+
input_ids = torch.randint(
162+
model.vocab_size, input_shape, dtype=torch.int64, device="meta"
163+
)
164+
labels = torch.randint(
165+
model.vocab_size, label_shape, dtype=torch.int64, device="meta"
166+
)
167+
168+
# Create a pipeline representation from the model
169+
pipe = pipeline(
170+
model, parallel_dims.pp, example_args=(input_ids,), split_spec=split_spec
171+
)
172+
model = pipe.get_stage_module(stage_idx)
173+
return model, pipe.pipe_info
174+
175+
132176
def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
133177
"""
134-
Apply parallelisms and activation checkpointing to the model.
178+
Apply SPMD parallelisms and activation checkpointing to the model.
135179
136180
NOTE: The passed-in model preferably should be on meta device. Otherwise,
137181
the model must fit on GPU or CPU memory.
138182
"""
139-
if parallel_dims.pp_enabled:
140-
raise NotImplementedError("PP not implemented yet.")
141183

142184
if parallel_dims.tp_enabled:
143185
if job_config.model.norm_type == "fused_rmsnorm":
@@ -215,24 +257,31 @@ def parallelize_llama(model, world_mesh, parallel_dims, job_config: JobConfig):
215257
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
216258
# TODO: Expose `reduce_dtype` as a config option.
217259
mp_policy = MixedPrecisionPolicy(
218-
param_dtype=torch.bfloat16, reduce_dtype=torch.float32
260+
# TODO(whc) need to fix PP + FSDP-mixed-precision
261+
# tracer for PP assumes f32 and is caught off guard when runtime FSDP interacts using bf16 inputs
262+
# param_dtype=torch.bfloat16, reduce_dtype=torch.float32
263+
param_dtype=torch.float32,
264+
reduce_dtype=torch.float32,
219265
)
220266
ac_mode = job_config.activation_checkpoint.mode
221267
fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy}
222-
for layer_id, transformer_block in enumerate(model.layers):
268+
for layer_name, transformer_block in model.layers.named_children():
223269
if job_config.activation_checkpoint.mode in ("full", "selective"):
224270
transformer_block = checkpoint_wrapper(
225271
transformer_block, job_config.activation_checkpoint
226272
)
227273
# As an optimization, do not reshard after forward for the last
228274
# transformer block since FSDP would prefetch it immediately
229-
reshard_after_forward = layer_id < len(model.layers) - 1
275+
# reshard_after_forward = layer_id < len(model.layers) - 1
276+
# TODO(whc) need to fix correctly handle layer-ids on pp-split module
277+
reshard_after_forward = True
230278
fully_shard(
231279
transformer_block,
232280
**fsdp_config,
233281
reshard_after_forward=reshard_after_forward,
234282
)
235-
model.layers[layer_id] = transformer_block
283+
model.layers.add_module(layer_name, transformer_block)
284+
236285
model = fully_shard(model, **fsdp_config)
237286
if ac_mode in ("full", "selective"):
238287
logger.info(f"Applied {ac_mode} activation checkpointing to the model")

train.py

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,17 @@
1919

2020
import torch
2121
import torch.nn.functional as F
22+
23+
# TODO(whc) this can be removed after pippy migration into pytorch core is complete.
24+
try:
25+
from pippy import ScheduleGPipe
26+
from pippy.PipelineStage import _PipelineStage
27+
except ImportError as exc:
28+
raise ImportError(
29+
"pippy is not installed. Please install it to use pipeline parallelism. "
30+
"`pip install git+https://github.com/pytorch/pippy`"
31+
) from exc
32+
2233
from torch.distributed import destroy_process_group
2334
from torch.distributed.checkpoint.stateful import Stateful
2435
from torch.distributed.elastic.multiprocessing.errors import record
@@ -126,7 +137,8 @@ def main(job_config: JobConfig):
126137
world_size=world_size,
127138
enable_loss_parallel=job_config.training.enable_loss_parallel,
128139
)
129-
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
140+
device = torch.device(f"cuda:{int(os.environ['LOCAL_RANK'])}")
141+
torch.cuda.set_device(device)
130142
init_distributed(job_config)
131143

132144
world_mesh = parallel_dims.build_mesh(device_type="cuda")
@@ -144,6 +156,15 @@ def main(job_config: JobConfig):
144156
dp_rank = dp_mesh.get_local_rank()
145157
else:
146158
dp_degree, dp_rank = 1, 0
159+
160+
if parallel_dims.pp_enabled:
161+
pp_mesh = world_mesh["pp"]
162+
pp_degree = pp_mesh.size()
163+
pp_rank = pp_mesh.get_local_rank()
164+
165+
else:
166+
pp_degree, pp_rank = 1, 0
167+
147168
data_loader = build_hf_data_loader(
148169
job_config.training.dataset,
149170
job_config.training.dataset_path,
@@ -201,13 +222,44 @@ def loss_fn(pred, labels):
201222
# obtain the peak flops of bf16 type for MFU calculation
202223
gpu_peak_flops = get_peak_flops(gpu_memory_monitor.device_name)
203224

204-
# apply PT-D parallelisms and activation checkpointing
225+
if parallel_dims.pp_enabled:
226+
# TODO(whc) now i need to figure out how to align this with the `model_parallelize_fns[model_name] pattern`
227+
from torchtitan.parallelisms.parallelize_llama import apply_pipeline_parallelism
228+
229+
model, pipe_info = apply_pipeline_parallelism(
230+
model, world_mesh, parallel_dims, job_config
231+
)
232+
233+
# apply PT-D DP/TP parallelisms and activation checkpointing
205234
model = models_parallelize_fns[model_name](
206235
model, world_mesh, parallel_dims, job_config
207236
)
208-
# allocate sharded model on GPU and initialize weights via DTensor
237+
209238
model.to_empty(device="cuda")
210-
model.init_weights()
239+
240+
# TODO(whc) everything below needs to become a function that can be applied to each 'virtual stage' of PP, if
241+
# there are virtual stages
242+
if parallel_dims.pp_enabled:
243+
stage = _PipelineStage(
244+
stage_module=model,
245+
stage_index=pp_rank,
246+
pipe_info=pipe_info,
247+
device=device,
248+
group=pp_mesh.get_group(),
249+
)
250+
pp_schedule = ScheduleGPipe(
251+
stage,
252+
n_microbatches=parallel_dims.pp,
253+
loss_fn=loss_fn,
254+
)
255+
else:
256+
# if PP is enabled, we can't use init_weights. instead, we have to rely on offline creating an initial checkpoint
257+
# and loading it to get initialization values. This is becuase the init_weights functions are written assuming
258+
# the whole model (all its weights, or FQNs) exist on one rank. In PP, the init_weights on stage1 might crash
259+
# becuase it can't find "embedding" layer, for example.
260+
261+
# allocate sharded model on GPU and initialize weights via DTensor
262+
model.init_weights()
211263

212264
gpu_mem_stats = gpu_memory_monitor.get_peak_stats()
213265
logger.info(
@@ -219,7 +271,6 @@ def loss_fn(pred, labels):
219271
# build optimizer after applying parallelisms to the model
220272
optimizer = build_optimizer(model, job_config)
221273
scheduler = get_lr_scheduler(optimizer, job_config)
222-
223274
metric_logger = build_metric_logger(job_config)
224275

225276
# torch.compile model for improved performance
@@ -257,7 +308,13 @@ def loss_fn(pred, labels):
257308
logger.info("Created seed checkpoint")
258309
return
259310

260-
checkpoint.load()
311+
checkpoint_loaded = checkpoint.load()
312+
313+
if parallel_dims.pp_enabled and not checkpoint_loaded:
314+
raise RuntimeError(
315+
"Pipeline Parallelism requires meta-initialization and loading seed checkpoint. "
316+
"Please run `./create_seed_checkpoint.sh` and rerun training with `--checkpoint.enable_checkpoint`"
317+
)
261318

262319
# plot losses loaded from checkpoint (if any) to TensorBoard
263320
# NOTE: Loss info after the last log step before checkpoint saving will not be ploted.
@@ -299,14 +356,33 @@ def loss_fn(pred, labels):
299356

300357
input_ids = input_ids.cuda()
301358
labels = labels.cuda()
302-
303359
optimizer.zero_grad()
304360

305-
# forward / backward
306-
with loss_parallel_ctx():
307-
pred = model(input_ids)
308-
loss = loss_fn(pred, labels)
309-
loss.backward()
361+
if parallel_dims.pp_enabled:
362+
# pipeline parallel forward / backward inside step() call
363+
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
364+
365+
with loss_parallel_ctx():
366+
if pp_mesh.get_local_rank() == 0:
367+
pp_schedule.step(input_ids)
368+
elif is_last_stage:
369+
losses = []
370+
pp_schedule.step(target=labels, losses=losses)
371+
else:
372+
schedule.step()
373+
374+
# accumulate losses across pipeline microbatches
375+
loss = (
376+
torch.mean(torch.stack(losses))
377+
if is_last_stage
378+
else torch.Tensor([-1.0])
379+
)
380+
else:
381+
# Non-PP forward / backward
382+
with loss_parallel_ctx():
383+
pred = model(input_ids)
384+
loss = loss_fn(pred, labels)
385+
loss.backward()
310386

311387
# clip gradients
312388
torch.nn.utils.clip_grad_norm_(

0 commit comments

Comments
 (0)