diff --git a/.gitignore b/.gitignore index e39990d72..21e3a56b4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ build outputs dist/* .vscode +slurm-*.out # data data diff --git a/torchtitan/config_manager.py b/torchtitan/config_manager.py index 4300c3bb8..c2a686605 100644 --- a/torchtitan/config_manager.py +++ b/torchtitan/config_manager.py @@ -353,6 +353,9 @@ class Parallelism: - 'alltoall' means to all-to-all shuffle the kv shards. The default value is 'allgather'. """ + + enable_tp2ep: bool = False + """Whether to use expert parallelism instead of tensor parallelism for shared experts.""" @dataclass diff --git a/torchtitan/experiments/kernels/moe/token_dispatcher.py b/torchtitan/experiments/kernels/moe/token_dispatcher.py new file mode 100644 index 000000000..289f3107b --- /dev/null +++ b/torchtitan/experiments/kernels/moe/token_dispatcher.py @@ -0,0 +1,128 @@ +from typing import Tuple +import torch +import torch.distributed as dist +from torch.distributed._functional_collectives import all_to_all_single_autograd + + +class DefaultTokenDispatcher: + + def __init__(self, num_experts: int, ep_size: int = 1): + self.num_experts = num_experts + self.ep_size = ep_size + self.experts_per_rank = num_experts // ep_size + self.ep_group = None + + def token_permutation( + self, + routed_input: torch.Tensor, + top_scores: torch.Tensor, + num_local_tokens_per_expert: torch.Tensor, + training: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, + torch.Tensor | None]: + return routed_input, top_scores, num_local_tokens_per_expert, None, None + + def token_unpermutation( + self, + routed_output: torch.Tensor, + input_splits: torch.Tensor | None = None, + output_splits: torch.Tensor | None = None, + training: bool = True, + ) -> torch.Tensor: + return routed_output + + +class TorchAllToAllTokenDispatcher(DefaultTokenDispatcher): + + def __init__( + self, + num_experts: int, + ep_size: int, + ep_group: torch.distributed.ProcessGroup, + ): + super().__init__(num_experts, ep_size) + self.ep_group = ep_group + + def token_permutation( + self, + routed_input: torch.Tensor, + top_scores: torch.Tensor, + num_local_tokens_per_expert: torch.Tensor, + training: bool = True, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, + torch.Tensor | None]: + dim = routed_input.shape[-1] + with torch.no_grad(): + tokens_per_expert_group = num_local_tokens_per_expert.new_empty( + num_local_tokens_per_expert.shape[0]) + dist.all_to_all_single(tokens_per_expert_group, + num_local_tokens_per_expert, + group=self.ep_group) + input_splits = num_local_tokens_per_expert.view( + self.ep_size, -1).sum(dim=1) + output_splits = tokens_per_expert_group.view( + self.ep_size, -1).sum(dim=1) + if training: + gathered_tokens = all_to_all_single_autograd( + routed_input, + output_splits.tolist(), + input_splits.tolist(), + self.ep_group, + ) + gathered_top_scores = all_to_all_single_autograd( + top_scores, + output_splits.tolist(), + input_splits.tolist(), + self.ep_group, + ) + else: + # TODO: unify with all_to_all_single_autograd after + # https://github.com/pytorch/pytorch/issues/154370 is resolved + gathered_num_tokens = output_splits.sum() + gathered_tokens = routed_input.new_empty( + (gathered_num_tokens, dim)) + dist.all_to_all_single( + gathered_tokens, + routed_input, + output_splits.tolist(), + input_splits.tolist(), + group=self.ep_group, + ) + gathered_top_scores = top_scores.new_empty(gathered_num_tokens, ) + dist.all_to_all_single( + gathered_top_scores, + top_scores, + output_splits.tolist(), + input_splits.tolist(), + group=self.ep_group, + ) + return gathered_tokens, gathered_top_scores, tokens_per_expert_group, input_splits, output_splits + + def token_unpermutation( + self, + routed_output: torch.Tensor, + input_splits: torch.Tensor | None = None, + output_splits: torch.Tensor | None = None, + training: bool = True, + ) -> torch.Tensor: + dim = routed_output.shape[-1] + if training: + returned_tokens = all_to_all_single_autograd( + routed_output, + input_splits.tolist(), + output_splits.tolist(), + self.ep_group, + ) + else: + # TODO: unify with all_to_all_single_autograd after + # https://github.com/pytorch/pytorch/issues/154370 is resolved + returned_tokens = routed_output.new_empty( + (input_splits.sum(), dim)) + dist.all_to_all_single( + returned_tokens, + routed_output, + input_splits.tolist(), + output_splits.tolist(), + group=self.ep_group, + ) + return returned_tokens diff --git a/torchtitan/experiments/llama4/infra/expert_parallel.py b/torchtitan/experiments/llama4/infra/expert_parallel.py index 68f2e7a75..64023b3d1 100644 --- a/torchtitan/experiments/llama4/infra/expert_parallel.py +++ b/torchtitan/experiments/llama4/infra/expert_parallel.py @@ -18,7 +18,10 @@ Replicate, Shard, ) -from torch.distributed.tensor.parallel import ParallelStyle +from torch.distributed.tensor.parallel import ( + ParallelStyle, + PrepareModuleInputOutput, +) from torch.distributed.tensor.placement_types import Placement @@ -141,3 +144,39 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ), partial(self._prepare_output_fn, self.output_layout, self.use_local_output), ) + + +class ExpertParallel(ParallelStyle): + + def __init__(self, ): + super().__init__() + + @staticmethod + def _prepare_input_fn(mod, inputs, device_mesh): + for inp in inputs: + if isinstance(inp, torch.Tensor): + assert not isinstance( + inp, DTensor), "ExpertParallel expects local tensor inputs." + return inputs + + def _partition_fn(self, name, module, device_mesh: DeviceMesh): + # shard on the expert dimension + for name, param in module.named_parameters(recurse=False): + dist_param = nn.Parameter( + distribute_tensor(param, device_mesh, [Shard(0)])) + module.register_parameter(name, dist_param) + + @staticmethod + def _prepare_output_fn(mod, outputs, device_mesh): + assert not isinstance( + outputs, DTensor), "ExpertParallel expects local tensor outputs." + return outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + self._partition_fn, + self._prepare_input_fn, + self._prepare_output_fn, + ) diff --git a/torchtitan/experiments/llama4/infra/parallelize_llama.py b/torchtitan/experiments/llama4/infra/parallelize_llama.py index 785d9d8a5..49b3bf48c 100644 --- a/torchtitan/experiments/llama4/infra/parallelize_llama.py +++ b/torchtitan/experiments/llama4/infra/parallelize_llama.py @@ -64,7 +64,11 @@ def parallelize_llama( enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, ) - apply_moe_tp(model, world_mesh["tp"]) + apply_moe_tp( + model, + world_mesh["tp"], + enable_tp2ep=job_config.parallelism.enable_tp2ep, + ) if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) @@ -145,6 +149,7 @@ def _sync_tokens_per_expert(module, *_): def apply_moe_tp( model: nn.Module, tp_mesh: DeviceMesh, + enable_tp2ep: bool = False, ): from torch.distributed.tensor import Partial, Replicate, Shard from torch.distributed.tensor.parallel import ( @@ -152,25 +157,62 @@ def apply_moe_tp( PrepareModuleInputOutput, ) - from .expert_parallel import NoParallel, TensorParallel + from .expert_parallel import ( + NoParallel, + TensorParallel, + ExpertParallel, + ) for transformer_block in model.layers.values(): - moe_layer_plan = { - # input / output sharding on the seqlen dim - # all-gather for input, reduce-scatter for output - "moe": PrepareModuleInputOutput( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - use_local_input=True, - output_layouts=(Partial(),), - desired_output_layouts=(Shard(1),), - ), - # replicate computation for the router - "moe.router.gate": NoParallel(), - # input Replicate, output Partial - "moe.experts": TensorParallel(output_layout=Partial()), - "moe.shared_expert": TensorParallel(output_layout=Partial()), - } + if enable_tp2ep: + moe_layer_plan = { + # input / output sharding on the seqlen dim + "moe": + PrepareModuleInputOutput( + input_layouts=(Shard(1), ), + desired_input_layouts=(Shard(1), ), + use_local_input=True, + output_layouts=(Shard(1), ), + desired_output_layouts=(Shard(1), ), + ), + # FIXME: The input is reshaped after sharded along + # the seqlen dimension. Should we use local tensors + # instead of Replicate? + "moe.router.gate": + NoParallel(), + # Given the tokens are not splitted evenly, + # we need to use local tensors for both input / output. + # After the manual all-to-all gather, the result is + # sharded along the seqlen dim. + "moe.experts": + ExpertParallel(), + "moe.shared_expert": + TensorParallel( + input_layouts=(Shard(1), None), + output_layout=Shard(1), + ), + } + else: + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": + PrepareModuleInputOutput( + input_layouts=(Shard(1), ), + desired_input_layouts=(Replicate(), ), + use_local_input=True, + output_layouts=(Partial(), ), + desired_output_layouts=(Shard(1), ), + ), + # replicate computation for the router + "moe.router.gate": + NoParallel(), + # input Replicate, output Partial + "moe.experts": + TensorParallel(output_layout=Partial()), + "moe.shared_expert": + TensorParallel(output_layout=Partial()), + } parallelize_module( module=transformer_block, device_mesh=tp_mesh, diff --git a/torchtitan/experiments/llama4/model/moe.py b/torchtitan/experiments/llama4/model/moe.py index 0dad02d25..875293168 100644 --- a/torchtitan/experiments/llama4/model/moe.py +++ b/torchtitan/experiments/llama4/model/moe.py @@ -7,6 +7,11 @@ import torch import torch.nn.functional as F from torch import nn +from torch.distributed.tensor import DTensor, Shard +from torchtitan.experiments.kernels.moe.token_dispatcher import ( + DefaultTokenDispatcher, + TorchAllToAllTokenDispatcher, +) from .args import TransformerModelArgs @@ -31,6 +36,20 @@ def forward( x: torch.Tensor, num_local_tokens_per_expert: torch.Tensor | list[int] | None = None, ) -> torch.Tensor: + if isinstance(self.w1, DTensor) and self.w1.placements == ( + Shard(0), ) and self.w1.device_mesh.size() > 1: + # expert parallel enabled + w1 = self.w1.to_local() + w2 = self.w2.to_local() + w3 = self.w3.to_local() + experts_per_rank = self.num_experts // self.w1.device_mesh.size() + else: + # expert parallel disabled + w1 = self.w1 + w2 = self.w2 + w3 = self.w3 + experts_per_rank = self.num_experts + # TODO: keeping this for loop implementation for comparison # and readability, will remove later if not self.use_grouped_mm: @@ -44,24 +63,27 @@ def forward( ) out_experts_splits = [] for expert_idx, x_expert in enumerate(x): - w1, w2, w3 = ( - self.w1[expert_idx], - self.w2[expert_idx], - self.w3[expert_idx], + expert_idx = expert_idx % experts_per_rank + current_w1, current_w2, current_w3 = ( + w1[expert_idx], + w2[expert_idx], + w3[expert_idx], ) - h = F.silu(torch.matmul(x_expert, w1)) - h = h * torch.matmul(x_expert, w3) - h = torch.matmul(h, w2) + h = F.silu(torch.matmul(x_expert, current_w1)) + h = h * torch.matmul(x_expert, current_w3) + h = torch.matmul(h, current_w2) # h shape (tokens_per_expert(varying), dim) out_experts_splits.append(h) out = torch.cat(out_experts_splits, dim=0) else: + bs, slen, dim = x.shape + x = x.reshape(1, bs * slen, dim) # x shape (num_experts, tokens_per_expert, dim) - h = F.silu(torch.bmm(x, self.w1)) - h = h * torch.bmm(x, self.w3) + h = F.silu(torch.bmm(x, w1)) + h = h * torch.bmm(x, w3) # out shape (num_experts, tokens_per_expert, dim) - out = torch.bmm(h, self.w2) - + out = torch.bmm(h, w2) + out = out.reshape(bs, slen, dim) return out # grouped mm implementation @@ -76,15 +98,15 @@ def forward( assert x.dim() == 2 else: offsets = None + bs, slen, dim = x.shape # fall back to regular bmm between 3D tensors - assert x.dim() == 3 + x = x.reshape(1, bs * slen, dim) - assert ( - x.dtype == self.w1.dtype == self.w2.dtype == self.w3.dtype == torch.bfloat16 - ), "torch._grouped_mm only supports bf16 dtypes" - h = F.silu(torch._grouped_mm(x, self.w1, offs=offsets)) - h = h * torch._grouped_mm(x, self.w3, offs=offsets) - out = torch._grouped_mm(h, self.w2, offs=offsets) + assert (x.dtype == w1.dtype == w2.dtype == w3.dtype == + torch.bfloat16), "torch._grouped_mm only supports bf16 dtypes" + h = F.silu(torch._grouped_mm(x, w1, offs=offsets)) + h = h * torch._grouped_mm(x, w3, offs=offsets) + out = torch._grouped_mm(h, w2, offs=offsets) return out @@ -172,8 +194,14 @@ def init_weights(self, init_std: float): class MoE(nn.Module): - def __init__(self, model_args: TransformerModelArgs): + def __init__( + self, + model_args: TransformerModelArgs, + scoring_before_experts: bool = True, + ): super().__init__() + # compatibility with DeepSeek MoE + self.scoring_before_experts = scoring_before_experts dim = model_args.dim hidden_dim = 4 * model_args.dim ffn_dim_multiplier = model_args.ffn_dim_multiplier @@ -181,7 +209,7 @@ def __init__(self, model_args: TransformerModelArgs): if ffn_dim_multiplier is not None: hidden_dim = int(ffn_dim_multiplier * hidden_dim) - num_experts = model_args.num_experts + self.num_experts = model_args.num_experts hidden_dim_denom = 1 if model_args.auto_scale_hidden_dim: @@ -195,11 +223,11 @@ def __init__(self, model_args: TransformerModelArgs): self.experts = GroupedExperts( dim=dim, hidden_dim=hidden_dim, - num_experts=num_experts, + num_experts=self.num_experts, use_grouped_mm=self.use_grouped_mm, ) self.router = TokenChoiceTopKRouter( - dim=dim, num_experts=num_experts, top_k=model_args.top_k + dim=dim, num_experts=self.num_experts, top_k=model_args.top_k ) self.shared_expert = ( GroupedExperts( @@ -212,34 +240,38 @@ def __init__(self, model_args: TransformerModelArgs): else None ) + self.token_dispatcher = DefaultTokenDispatcher(self.num_experts) + # auxiliary-loss-free load balancing self.load_balance_coeff = model_args.load_balance_coeff + self.expert_bias_enabled = self.load_balance_coeff is not None and self.load_balance_coeff > 0 # the fields below are defined even when load_balance_coeff is None # to make initialization and checkpointing code simpler self.register_buffer( "expert_bias", - torch.zeros(num_experts, dtype=torch.float32), + torch.zeros(self.num_experts, dtype=torch.float32), persistent=True, ) self.register_buffer( "tokens_per_expert", - torch.zeros(num_experts, dtype=torch.float32), + torch.zeros(self.num_experts, dtype=torch.float32), persistent=True, ) # NOTE: forward hook, forward pre hook, or backward pre hook # would conflict with activation checkpointing - if self.load_balance_coeff is not None and self.load_balance_coeff > 0: + if self.expert_bias_enabled: self.register_full_backward_hook(self._update_expert_bias) def _update_expert_bias(self, *_): - expert_bias_delta = self.load_balance_coeff * torch.sign( - self.tokens_per_expert.mean() - self.tokens_per_expert - ) - expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() - self.expert_bias.add_(expert_bias_delta) + with torch.no_grad(): + expert_bias_delta = self.load_balance_coeff * torch.sign( + self.tokens_per_expert.mean() - self.tokens_per_expert + ) + expert_bias_delta = expert_bias_delta - expert_bias_delta.mean() + self.expert_bias.add_(expert_bias_delta) - self.tokens_per_expert.zero_() + self.tokens_per_expert.zero_() def forward(self, x: torch.Tensor) -> torch.Tensor: """ @@ -259,8 +291,29 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: num_local_tokens_per_expert, ) = self.router(x.reshape(bs * slen, dim), self.expert_bias) - # will be used to update the expert bias for load balancing - self.tokens_per_expert += num_local_tokens_per_expert + # TODO: Find a better place to initialize the token dispatcher. + # I tried putting it in PrepareModuleInputOutputWithParams._apply, + # but caused torch compiling issues + if (isinstance(self.experts.w1, DTensor) + and self.experts.w1.placements == (Shard(0), )): + self.token_dispatcher = TorchAllToAllTokenDispatcher( + num_experts=self.num_experts, + ep_size=self.experts.w1.device_mesh.size(), + ep_group=self.experts.w1.device_mesh.get_group(), + ) + + # Prevent extra local tokens accumulation on evaluation or activation recomputation + if self.expert_bias_enabled and torch.is_grad_enabled(): + with torch.no_grad(): + num_local_tokens_per_expert_detached = num_local_tokens_per_expert.detach().clone() + if self.token_dispatcher.ep_group is not None: + # sum all num_local_tokens_per_expert from ep_mesh + torch.distributed.all_reduce( + num_local_tokens_per_expert_detached, + group=self.token_dispatcher.ep_group, + ) + # will be used to update the expert bias for load balancing + self.tokens_per_expert.add_(num_local_tokens_per_expert_detached) # shape (bs*slen*top_k, dim) token_indices = token_indices.reshape(-1, 1).expand(-1, dim) @@ -271,10 +324,24 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dim=0, index=token_indices, ) - routed_input = (routed_input.to(torch.float32) * top_scores.reshape(-1, 1)).to( - x.dtype + + ( + gathered_tokens, + gathered_top_scores, + tokens_per_expert_group, + input_splits, + output_splits, + ) = self.token_dispatcher.token_permutation( + routed_input, + top_scores, + num_local_tokens_per_expert, + self.training, ) + if self.scoring_before_experts: + gathered_tokens = (gathered_tokens.to(torch.float32) * + gathered_top_scores.reshape(-1, 1)).to(x.dtype) + if self.use_grouped_mm: # NOTE: In order to use torch._grouped_mm, we need to make sure # the number of tokens each expert gets is a multiple of 16. @@ -289,36 +356,48 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: with torch.no_grad(): ( permuted_indices, - num_local_tokens_per_expert, + tokens_per_expert_group, _, ) = generate_permute_indices( - num_local_tokens_per_expert, - self.experts.num_experts, - 1, + tokens_per_expert_group, + self.token_dispatcher.experts_per_rank, + self.token_dispatcher.ep_size, ALIGN_SIZE_M, ) - token_indices = torch.vstack( - (token_indices, token_indices.new_zeros((dim))) - ) - token_indices = token_indices[permuted_indices, :] - routed_input = torch.vstack((routed_input, routed_input.new_zeros((dim)))) - routed_input = routed_input[permuted_indices, :] + gathered_tokens_buffer = torch.vstack( + (gathered_tokens, gathered_tokens.new_zeros((dim)))) + buffer_shape = gathered_tokens_buffer.shape + gathered_tokens = gathered_tokens_buffer[permuted_indices, :] + + gathered_top_scores = torch.cat( + (gathered_top_scores, gathered_top_scores.new_zeros(1))) + gathered_top_scores = gathered_top_scores[permuted_indices] else: # NOTE: this would incur a synchronization between device and host - num_local_tokens_per_expert = num_local_tokens_per_expert.tolist() + if tokens_per_expert_group is not None: + tokens_per_expert_group = tokens_per_expert_group.tolist() # shape (bs*slen*top_k, dim) - routed_output = self.experts(routed_input, num_local_tokens_per_expert) + routed_output = self.experts(gathered_tokens, tokens_per_expert_group) + if not self.scoring_before_experts: + routed_output = (routed_output.to(torch.float32) * + gathered_top_scores.reshape(-1, 1)).to(x.dtype) + + if self.use_grouped_mm: + gathered_tokens_buffer = routed_output.new_empty(buffer_shape) + gathered_tokens_buffer[permuted_indices, :] = routed_output + routed_output = gathered_tokens_buffer[:(buffer_shape[0] - 1), :] + + returned_tokens = self.token_dispatcher.token_unpermutation( + routed_output, input_splits, output_splits, self.training) # shared expert if self.shared_expert is not None: - out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape( - bs * slen, dim - ) + out = self.shared_expert(x).reshape(bs * slen, dim) else: - out = torch.zeros_like(x.reshape(bs * slen, dim)) + out = x.new_zeros((bs * slen, dim)) - out = out.scatter_add(dim=0, index=token_indices, src=routed_output) + out = out.scatter_add(dim=0, index=token_indices, src=returned_tokens) out = out.reshape(bs, slen, dim) return out diff --git a/torchtitan/experiments/llama4/train_configs/debug_model.toml b/torchtitan/experiments/llama4/train_configs/debug_model.toml index bc48d3809..a1083f507 100644 --- a/torchtitan/experiments/llama4/train_configs/debug_model.toml +++ b/torchtitan/experiments/llama4/train_configs/debug_model.toml @@ -52,6 +52,7 @@ tensor_parallel_degree = 1 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 context_parallel_degree = 1 +enable_tp2ep = false [checkpoint] enable_checkpoint = false diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml index f508968c8..1c6fd25db 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx128e.toml @@ -46,6 +46,7 @@ pipeline_parallel_degree = 4 # pipeline_parallel_schedule = "interleaved1f1b" # pipeline_parallel_microbatches = 2 context_parallel_degree = 1 +enable_tp2ep = false [checkpoint] enable_checkpoint = false diff --git a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml index c899dd508..ee76deafb 100644 --- a/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml +++ b/torchtitan/experiments/llama4/train_configs/llama4_17bx16e.toml @@ -44,6 +44,7 @@ tensor_parallel_degree = 8 enable_async_tensor_parallel = false pipeline_parallel_degree = 1 context_parallel_degree = 1 +enable_tp2ep = false [checkpoint] enable_checkpoint = false