Skip to content

Commit 0ebaf8c

Browse files
committed
Enable CP
This PR adds experimental flags and functions to enable context parallelism. We currently support only FSDP + CP and CP only. CP + TP is being tested. ghstack-source-id: 20b8844 Pull Request resolved: #433
1 parent f2fae18 commit 0ebaf8c

File tree

5 files changed

+108
-37
lines changed

5 files changed

+108
-37
lines changed

estimation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def estimate_memory(job_config: JobConfig):
6767

6868
parallel_dims = ParallelDims(
6969
dp=job_config.training.data_parallel_degree,
70+
cp=job_config.experimental.context_parallel_degree,
7071
tp=job_config.training.tensor_parallel_degree,
7172
pp=job_config.experimental.pipeline_parallel_degree,
7273
world_size=world_size,

torchtitan/config_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,12 @@ def __init__(self):
323323
action="store_true",
324324
help="Enable CompiledAutograd to compile the backward.",
325325
)
326+
self.parser.add_argument(
327+
"--experimental.context_parallel_degree",
328+
type=int,
329+
default=1,
330+
help="Context parallelism degree. 1 means disabled.",
331+
)
326332
self.parser.add_argument(
327333
"--training.mixed_precision_param",
328334
type=str,

torchtitan/parallelisms/__init__.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
@dataclass
2525
class ParallelDims:
2626
dp: int
27+
cp: int
2728
tp: int
2829
pp: int
2930
world_size: int
@@ -35,22 +36,24 @@ def __post_init__(self):
3536
self._validate()
3637

3738
def _validate(self):
38-
dp, tp, pp = self.dp, self.tp, self.pp
39+
dp, cp, tp, pp = self.dp, self.cp, self.tp, self.pp
3940
if dp == -1:
40-
self.dp = dp = self.world_size // (tp * pp)
41+
self.dp = dp = self.world_size // (cp * tp * pp)
4142
assert dp >= 1, dp
43+
assert cp >= 1, cp
4244
assert tp >= 1, tp
4345
assert pp >= 1, pp
44-
assert (
45-
dp * tp * pp == self.world_size
46-
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
46+
assert dp * cp * tp * pp == self.world_size, (
47+
f"Invalid parallel dims: dp({dp}) * cp ({cp}) * tp({tp}) * pp({pp}) "
48+
f"!= WORLD_SIZE({self.world_size})"
49+
)
4750
assert self.dp_type in ("fsdp", "ddp")
4851

4952
def build_mesh(self, device_type):
5053
dims = []
5154
names = []
5255
for d, name in zip(
53-
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
56+
[self.pp, self.dp, self.cp, self.tp], ["pp", "dp", "cp", "tp"], strict=True
5457
):
5558
if d > 1:
5659
dims.append(d)
@@ -63,6 +66,10 @@ def build_mesh(self, device_type):
6366
def dp_enabled(self):
6467
return self.dp > 1
6568

69+
@property
70+
def cp_enabled(self):
71+
return self.cp > 1
72+
6673
@property
6774
def tp_enabled(self):
6875
return self.tp > 1

torchtitan/parallelisms/parallelize_llama.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,15 @@
1919

2020
from torch.distributed._composable.replicate import replicate
2121
from torch.distributed._tensor import Replicate, Shard
22+
23+
try:
24+
from torch.distributed._tensor.experimental.attention import enable_context_parallel
25+
except ImportError:
26+
print("The PyTorch version does not include the experimental CP APIs.")
2227
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
2328
checkpoint_wrapper as ptd_checkpoint_wrapper,
2429
)
30+
from torch.distributed.device_mesh import init_device_mesh
2531
from torch.distributed.pipelining import pipeline, PipelineStage, SplitPoint
2632
from torch.distributed.tensor.parallel import (
2733
ColwiseParallel,
@@ -455,17 +461,43 @@ def apply_compile(model: nn.Module, job_config: JobConfig):
455461
return model
456462

457463

464+
def apply_cp(model, world_mesh, parallel_dims, job_config: JobConfig):
465+
"""
466+
Apply context parallelism to the model. This is an experimental feature.
467+
"""
468+
if parallel_dims.tp_enabled or parallel_dims.pp_enabled:
469+
raise NotImplementedError("CP + TP or CP + PP are not supported yet.")
470+
cp_mesh = world_mesh["cp"]
471+
callers = []
472+
for layer_id, transformer_block in model.layers.items():
473+
callers.append(transformer_block.attention)
474+
enable_context_parallel(seq_dim=2, callers=callers, device_mesh=cp_mesh)
475+
logger.info("Applied CP to the model")
476+
477+
return model
478+
479+
458480
def apply_fsdp(
459481
model: nn.Module,
460482
world_mesh: DeviceMesh,
461483
parallel_dims: "ParallelDims",
462484
job_config: JobConfig,
463485
):
486+
464487
"""
465488
Apply data parallelism to the model. FSDP2 is used here.
466489
"""
467490

468-
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
491+
if parallel_dims.cp_enabled:
492+
# Temporary solution to enable FSDP + CP
493+
dp_mesh = init_device_mesh(
494+
world_mesh.device_type,
495+
(parallel_dims.dp * parallel_dims.cp,),
496+
mesh_dim_names=["dp"],
497+
)
498+
else:
499+
dp_mesh = world_mesh["dp"]
500+
469501
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
470502

471503
mp_policy = MixedPrecisionPolicy(
@@ -542,6 +574,9 @@ def parallelize_llama(
542574
if job_config.training.compile:
543575
model = apply_compile(model, job_config)
544576

577+
if parallel_dims.cp_enabled:
578+
model = apply_cp(model, world_mesh, parallel_dims, job_config)
579+
545580
if parallel_dims.dp_enabled:
546581
if parallel_dims.dp_type == "fsdp":
547582
model = apply_fsdp(model, world_mesh, parallel_dims, job_config)

train.py

Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from dataclasses import dataclass, field
1313
from datetime import timedelta
14+
from functools import partial
1415
from io import BytesIO
1516
from timeit import default_timer as timer
1617
from typing import Any, Dict, List
@@ -20,6 +21,7 @@
2021
import torch
2122
import torch.nn.functional as F
2223
from torch.distributed import destroy_process_group
24+
from torch.distributed._tensor.experimental.attention import context_parallel_buffers
2325
from torch.distributed.checkpoint.stateful import Stateful
2426
from torch.distributed.elastic.multiprocessing.errors import record
2527
from torch.distributed.tensor.parallel import loss_parallel
@@ -172,6 +174,7 @@ def main(job_config: JobConfig):
172174
world_size = int(os.environ["WORLD_SIZE"])
173175
parallel_dims = ParallelDims(
174176
dp=job_config.training.data_parallel_degree,
177+
cp=job_config.experimental.context_parallel_degree,
175178
tp=job_config.training.tensor_parallel_degree,
176179
pp=job_config.experimental.pipeline_parallel_degree,
177180
world_size=world_size,
@@ -216,6 +219,20 @@ def main(job_config: JobConfig):
216219
job_config.experimental.enable_compiled_autograd,
217220
)
218221

222+
if parallel_dims.cp_enabled:
223+
cp_mesh = world_mesh["cp"]
224+
context_parallel_ctx = partial(
225+
context_parallel_buffers,
226+
cp_rank=cp_mesh.get_local_rank(),
227+
cp_world_size=cp_mesh.size(),
228+
)
229+
else:
230+
context_parallel_ctx = partial(
231+
context_parallel_buffers,
232+
cp_rank=0,
233+
cp_world_size=1,
234+
)
235+
219236
# loss fn can be shared by pipeline-parallel or non-pp execution
220237
def loss_fn(pred, labels):
221238
return F.cross_entropy(pred.flatten(0, 1), labels.flatten(0, 1))
@@ -373,38 +390,43 @@ def loss_fn(pred, labels):
373390
ntokens_since_last_log += labels.numel()
374391
data_loading_times.append(timer() - data_load_start)
375392

376-
input_ids = input_ids.cuda()
377-
labels = labels.cuda()
378393
optimizers.zero_grad()
379394

380-
if parallel_dims.pp_enabled:
381-
# pipeline parallel forward / backward inside step() call
382-
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
383-
384-
with train_context():
385-
if pp_mesh.get_local_rank() == 0:
386-
pp_schedule.step(input_ids)
387-
elif is_last_stage:
388-
losses = []
389-
pp_schedule.step(target=labels, losses=losses)
390-
else:
391-
pp_schedule.step()
392-
393-
# accumulate losses across pipeline microbatches
394-
loss = (
395-
torch.mean(torch.stack(losses))
396-
if is_last_stage
397-
else torch.Tensor([-1.0])
398-
)
399-
else:
400-
# Non-PP forward / backward
401-
with train_context():
402-
pred = model(input_ids)
403-
loss = loss_fn(pred, labels)
404-
# pred.shape=(bs, seq_len, vocab_size)
405-
# need to free to before bwd to avoid peaking memory
406-
del pred
407-
loss.backward()
395+
with context_parallel_ctx(
396+
buffers=[input_ids, labels, model.freqs_cis],
397+
seq_dims=[1, 1, 0],
398+
keep_orig_buffers=[False, False, True],
399+
):
400+
input_ids = input_ids.cuda()
401+
labels = labels.cuda()
402+
if parallel_dims.pp_enabled:
403+
# pipeline parallel forward / backward inside step() call
404+
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
405+
406+
with train_context():
407+
if pp_mesh.get_local_rank() == 0:
408+
pp_schedule.step(input_ids)
409+
elif is_last_stage:
410+
losses = []
411+
pp_schedule.step(target=labels, losses=losses)
412+
else:
413+
pp_schedule.step()
414+
415+
# accumulate losses across pipeline microbatches
416+
loss = (
417+
torch.mean(torch.stack(losses))
418+
if is_last_stage
419+
else torch.Tensor([-1.0])
420+
)
421+
else:
422+
# Non-PP forward / backward
423+
with train_context():
424+
pred = model(input_ids)
425+
loss = loss_fn(pred, labels)
426+
# pred.shape=(bs, seq_len, vocab_size)
427+
# need to free to before bwd to avoid peaking memory
428+
del pred
429+
loss.backward()
408430

409431
# clip gradients
410432
for model in model_parts:

0 commit comments

Comments
 (0)