Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def estimate_memory(job_config: JobConfig):

parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
Expand Down
6 changes: 6 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,12 @@ def __init__(self):
action="store_true",
help="Enable CompiledAutograd to compile the backward.",
)
self.parser.add_argument(
"--experimental.context_parallel_degree",
type=int,
default=1,
help="Context parallelism degree. 1 means disabled.",
)
self.parser.add_argument(
"--training.mixed_precision_param",
type=str,
Expand Down
2 changes: 1 addition & 1 deletion torchtitan/models/llama/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,7 @@ def _precompute_freqs_cis(self) -> torch.Tensor:
self.model_args.dim // self.model_args.n_heads,
# Need to compute until at least the max token limit for generation
# (use 2x max sequence length to be safe)
self.model_args.max_seq_len * 2,
self.model_args.max_seq_len,
self.model_args.rope_theta,
)

Expand Down
23 changes: 16 additions & 7 deletions torchtitan/parallelisms/parallel_dims.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
@dataclass
class ParallelDims:
dp: int
cp: int
tp: int
pp: int
world_size: int
Expand All @@ -25,34 +26,41 @@ def __post_init__(self):
self._validate()

def _validate(self):
dp, tp, pp = self.dp, self.tp, self.pp
dp, cp, tp, pp = self.dp, self.cp, self.tp, self.pp
if dp == -1:
self.dp = dp = self.world_size // (tp * pp)
self.dp = dp = self.world_size // (cp * tp * pp)
assert dp >= 1, dp
assert cp >= 1, cp
assert tp >= 1, tp
assert pp >= 1, pp
assert (
dp * tp * pp == self.world_size
), f"Invalid parallel dims: dp({dp}) * tp({tp}) * pp({pp}) != WORLD_SIZE({self.world_size})"
assert dp * cp * tp * pp == self.world_size, (
f"Invalid parallel dims: dp({dp}) * cp ({cp}) * tp({tp}) * pp({pp}) "
f"!= WORLD_SIZE({self.world_size})"
)
assert self.dp_type in ("fsdp", "ddp")

def build_mesh(self, device_type):
dims = []
names = []
for d, name in zip(
[self.pp, self.dp, self.tp], ["pp", "dp", "tp"], strict=True
[self.pp, self.dp, self.cp, self.tp], ["pp", "dp", "cp", "tp"], strict=True
):
if d > 1:
dims.append(d)
names.append(name)
logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}")
names = tuple(names)
return init_device_mesh(device_type, dims, mesh_dim_names=names)
world_mesh = init_device_mesh(device_type, dims, mesh_dim_names=names)
return world_mesh

@property
def dp_enabled(self):
return self.dp > 1

@property
def cp_enabled(self):
return self.cp > 1

@property
def tp_enabled(self):
return self.tp > 1
Expand All @@ -68,3 +76,4 @@ def loss_parallel_enabled(self):
@cached_property
def model_parallel_size(self):
return self.tp * self.pp

17 changes: 14 additions & 3 deletions torchtitan/parallelisms/parallelize_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@
from torch.distributed._composable.fsdp import fully_shard, MixedPrecisionPolicy
from torch.distributed._composable.replicate import replicate
from torch.distributed._tensor import Replicate, Shard

try:
from torch.distributed._tensor.experimental.attention import enable_context_parallel
except ImportError:
print("The PyTorch version does not include the experimental CP APIs.")
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper as ptd_checkpoint_wrapper,
)
Expand Down Expand Up @@ -72,10 +77,16 @@ def parallelize_llama(
)
apply_compile(model)

if parallel_dims.dp_enabled:
if parallel_dims.dp_enabled or parallel_dims.cp_enabled:
if parallel_dims.dp_type == "fsdp":
dp_mesh = world_mesh["dp"] if world_mesh.ndim > 1 else world_mesh
assert dp_mesh.mesh_dim_names == ("dp",), dp_mesh.mesh_dim_names
if parallel_dims.cp_enabled:
# Temporary solution to enable FSDP + CP
if parallel_dims.dp_enabled:
dp_mesh = world_mesh["dp", "cp"]._flatten()
else:
dp_mesh = world_mesh["cp"]
else:
dp_mesh = world_mesh["dp"]

apply_fsdp(
model,
Expand Down
68 changes: 50 additions & 18 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from datetime import timedelta

import torch
from torch.distributed.device_mesh import DeviceMesh
from typing import List, Optional, Set
from functools import partial
from torch.distributed._tensor.experimental.attention import context_parallel
from torch.distributed.elastic.multiprocessing.errors import record
from torch.fx import GraphModule

Expand All @@ -30,16 +34,39 @@
from torchtitan.profiling import maybe_enable_memory_snapshot, maybe_enable_profiling


def get_train_context(enable_loss_parallel: bool, enable_compiled_autograd: bool):
def get_train_context(
enable_loss_parallel: bool,
enable_compiled_autograd: bool,
cp_mesh: Optional[DeviceMesh],
):
if cp_mesh is not None:
context_parallel_ctx = partial(context_parallel, mesh=cp_mesh)
else:
context_parallel_ctx = partial(context_parallel, mesh=None)

@contextlib.contextmanager
def context():
def context(
cp_buffers: List[torch.Tensor],
cp_seq_dims: List[int],
cp_no_restore_buffers: Set[torch.Tensor],
):
with contextlib.ExitStack() as stack:
if enable_loss_parallel:
stack.enter_context(torch.distributed.tensor.parallel.loss_parallel())

if enable_compiled_autograd:
stack.enter_context(
torch._dynamo.utils.maybe_enable_compiled_autograd(True)
)

buffers = stack.enter_context(
context_parallel_ctx(
buffers=cp_buffers,
buffer_seq_dims=cp_seq_dims,
no_restore_buffers=cp_no_restore_buffers,
)
)

yield

return context
Expand All @@ -61,6 +88,7 @@ def main(job_config: JobConfig):
world_size = int(os.environ["WORLD_SIZE"])
parallel_dims = ParallelDims(
dp=job_config.training.data_parallel_degree,
cp=job_config.experimental.context_parallel_degree,
tp=job_config.training.tensor_parallel_degree,
pp=job_config.experimental.pipeline_parallel_degree,
world_size=world_size,
Expand Down Expand Up @@ -233,6 +261,7 @@ def loss_fn(pred, labels):
train_context = get_train_context(
parallel_dims.loss_parallel_enabled,
job_config.experimental.enable_compiled_autograd,
world_mesh["cp"] if parallel_dims.cp_enabled else None,
)

# variables used to keep info for metrics logging
Expand Down Expand Up @@ -266,18 +295,23 @@ def loss_fn(pred, labels):
data_load_start = time.perf_counter()
batch = next(data_iterator)
input_ids, labels = batch
ntokens_since_last_log += labels.numel()
ntokens_since_last_log += labels.numel() // parallel_dim.cp
data_loading_times.append(time.perf_counter() - data_load_start)

input_ids = input_ids.cuda()
labels = labels.cuda()
optimizers.zero_grad()

if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1
with train_context(
cp_buffers=[input_ids, labels, model.freqs_cis],
cp_seq_dims=[1, 1, 0],
cp_no_restore_buffers={input_ids, labels},
) as cp_buffers:
input_ids = input_ids.cuda()
labels = labels.cuda()

if parallel_dims.pp_enabled:
# Pipeline Parallel forward / backward inside step() call
is_last_stage = pp_mesh.get_local_rank() == pp_mesh.size() - 1

with train_context():
if pp_mesh.get_local_rank() == 0:
pp_schedule.step(input_ids)
elif is_last_stage:
Expand All @@ -286,15 +320,13 @@ def loss_fn(pred, labels):
else:
pp_schedule.step()

# accumulate losses across pipeline microbatches
loss = (
torch.mean(torch.stack(losses))
if is_last_stage
else torch.Tensor([-1.0])
)
else:
# Non-PP forward / backward
with train_context():
loss = (
torch.mean(torch.stack(losses))
if is_last_stage
else torch.Tensor([-1.0])
)
else:
# Non-PP forward / backward
pred = model(input_ids)
loss = loss_fn(pred, labels)
# pred.shape=(bs, seq_len, vocab_size)
Expand Down