diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 375eed00b..df384f7ff 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -399,6 +399,24 @@ class Parallelism: Note that this is still an experimental feature. """ + expert_parallel_a2a_dispatch_impl: Literal["default", "mxfp8"] = "default" + """ + All-to-all implementation to use for the token dispatch step in expert parallelism. + - "default": Directly uses all_to_all_single with inputs/outputs in original precision. + - "mxfp8": Reduces network bandwidth utilization by quantizing inputs to MXFP8, + using all_to_all_single on the quantized data and scales, then dequantizing + the outputs back to original precision. + """ + + expert_parallel_a2a_combine_impl: Literal["default", "mxfp8"] = "default" + """ + All-to-all implementation to use for the token combine step in expert parallelism. + - "default": Directly uses all_to_all_single with inputs/outputs in original precision. + - "mxfp8": Reduces network bandwidth utilization by quantizing inputs to MXFP8, + using all_to_all_single on the quantized data and scales, then dequantizing + the outputs back to original precision. + """ + @dataclass class Checkpoint: diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 12512bfac..8158dbddd 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -81,16 +81,40 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: class ExpertParallel(ParallelStyle): - def __init__(self): + """ + ExpertParallel is a parallel style for MoE, where each experts + are distributed across ranks along a given axis of the device mesh. + + Args: + a2a_impl (str): The implementation of all-to-all. Default is "default". Options are ["default","mxfp8"]. + """ + + def __init__( + self, a2a_dispatch_impl: str = "default", a2a_combine_impl: str = "default" + ): super().__init__() self.input_splits = None self.output_splits = None + self.a2a_dispatch_func = self._get_a2a_func(a2a_dispatch_impl) + self.a2a_combine_func = self._get_a2a_func(a2a_combine_impl) + + def _get_a2a_func(self, a2a_impl: str): + if a2a_impl == "default": + return all_to_all_single_autograd + elif a2a_impl == "mxfp8": + from torchao.prototype.moe_training.kernels.mxfp8.comms import ( + to_mxfp8_a2a_dequant, + ) + + return to_mxfp8_a2a_dequant + else: + raise ValueError(f"Unknown a2a_impl: {a2a_impl}") # performing all-to-all dispatch on the input def _token_dispatch(self, mod, inputs, device_mesh): # annotate module input placements/sharding with input_layouts routed_input, num_tokens_per_expert = inputs - ep_size = device_mesh.shape[0] + ep_size = device_mesh.size(0) # generate the input splits and output splits for all-to-all with torch.no_grad(): @@ -119,8 +143,7 @@ def _token_dispatch(self, mod, inputs, device_mesh): self.input_splits = input_splits.tolist() self.output_splits = output_splits.tolist() - # perform all-to-all - routed_input = all_to_all_single_autograd( + routed_input = self.a2a_dispatch_func( routed_input, self.output_splits, self.input_splits, @@ -148,7 +171,8 @@ def _partition_fn(name, mod, device_mesh): # performing all-to-all combine on the output def _token_combine(self, mod, routed_output, device_mesh): - routed_output = all_to_all_single_autograd( + # For a2a combine, input splits and output splits are opposite of a2a dispatch. + routed_output = self.a2a_combine_func( routed_output, self.input_splits, self.output_splits, diff --git a/torchtitan/experiments/llama4/infra/parallelize.py b/torchtitan/experiments/llama4/infra/parallelize.py index 0c5bd9e78..22190ee86 100644 --- a/torchtitan/experiments/llama4/infra/parallelize.py +++ b/torchtitan/experiments/llama4/infra/parallelize.py @@ -107,6 +107,8 @@ def parallelize_llama( else None ), etp_enabled=parallel_dims.etp_enabled, + a2a_dispatch_impl=job_config.parallelism.expert_parallel_a2a_dispatch_impl, + a2a_combine_impl=job_config.parallelism.expert_parallel_a2a_combine_impl, ) model_compile_enabled = ( @@ -438,7 +440,11 @@ def apply_moe_ep_tp( ep_mesh: DeviceMesh | None, ep_tp_mesh: DeviceMesh | None, etp_enabled: bool, + a2a_dispatch_impl: str = "default", + a2a_combine_impl: str = "default", ): + logger.info(f"Using all-to-all dispatch: {a2a_dispatch_impl}") + logger.info(f"Using all-to-all combine: {a2a_combine_impl}") for transformer_block in model.layers.values(): if not transformer_block.moe_enabled: continue @@ -487,13 +493,19 @@ def apply_moe_ep_tp( elif tp_mesh is None: experts_mesh = ep_mesh # input / output sharding on the batch / tokens dim - experts_plan = ExpertParallel() + experts_plan = ExpertParallel( + a2a_dispatch_impl=a2a_dispatch_impl, + a2a_combine_impl=a2a_combine_impl, + ) elif etp_enabled: experts_mesh = ep_tp_mesh experts_plan = ExpertTensorParallel(tp_mesh=tp_mesh, ep_mesh=ep_mesh) else: experts_mesh = ep_mesh - experts_plan = ExpertParallel() + experts_plan = ExpertParallel( + a2a_dispatch_impl=a2a_dispatch_impl, + a2a_combine_impl=a2a_combine_impl, + ) parallelize_module( module=transformer_block.moe.experts,