From 732f0cab5fddc9b54d1cf6421260a2b0b6a5f934 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 2 Oct 2025 12:32:34 -0700 Subject: [PATCH 01/11] Fork SimpleFSDP --- torchtitan/experiments/__init__.py | 2 +- .../experiments/joint_graph_runner/README.md | 11 ++ .../joint_graph_runner/llama3/__init__.py | 33 +++++ .../joint_graph_runner/llama3/parallelize.py | 127 ++++++++++++++++++ 4 files changed, 172 insertions(+), 1 deletion(-) create mode 100644 torchtitan/experiments/joint_graph_runner/README.md create mode 100644 torchtitan/experiments/joint_graph_runner/llama3/__init__.py create mode 100644 torchtitan/experiments/joint_graph_runner/llama3/parallelize.py diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 000823803..2b99a9051 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -5,5 +5,5 @@ # LICENSE file in the root directory of this source tree. _supported_experiments = frozenset( - ["flux", "llama4", "qwen3", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"] + ["flux", "llama4", "qwen3", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm", "joint_graph_runner.llama3"] ) diff --git a/torchtitan/experiments/joint_graph_runner/README.md b/torchtitan/experiments/joint_graph_runner/README.md new file mode 100644 index 000000000..833342a0b --- /dev/null +++ b/torchtitan/experiments/joint_graph_runner/README.md @@ -0,0 +1,11 @@ +## Joint Graph Runner + +Exploring toolkit-style use of the compiler stack for authoring parallel models. + +Joint Graph based Training Prototype: + +Llama3 +- User code: SimpleFSDP + TP +- Trace joint +- Apply passes to the joint +- Run using the Joint Graph Runner diff --git a/torchtitan/experiments/joint_graph_runner/llama3/__init__.py b/torchtitan/experiments/joint_graph_runner/llama3/__init__.py new file mode 100644 index 000000000..99c350cde --- /dev/null +++ b/torchtitan/experiments/joint_graph_runner/llama3/__init__.py @@ -0,0 +1,33 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.models.llama3 import llama3_configs, pipeline_llama +from torchtitan.protocols.train_spec import TrainSpec + +from torchtitan.experiments.simple_fsdp.llama3.model import SimpleFSDPTransformer +from torchtitan.experiments.joint_graph_runner.llama3.parallelize import parallelize_llama + + +def get_train_spec() -> TrainSpec: + return TrainSpec( + name="joint_graph_runner.llama3", + model_cls=SimpleFSDPTransformer, + model_args=llama3_configs, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llama, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) diff --git a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py new file mode 100644 index 000000000..b94e9238b --- /dev/null +++ b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py @@ -0,0 +1,127 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp +from torchtitan.models.llama3.infra.parallelize import apply_tp +from torchtitan.tools.logging import logger + +from torchtitan.experiments.simple_fsdp.simple_fsdp import data_parallel, MixedPrecisionPolicy + + +# for selective op activation checkpointing +_op_sac_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, +} + + +def parallelize_llama( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + if parallel_dims.tp_enabled: + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + # For now, float8 all-gather with TP is only supported for tensorwise + # float8 scaling recipes. For rowwise recipes, we use regular TP and + # all-gather happens in high precision. + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + + tp_mesh = parallel_dims.world_mesh["tp"] + apply_tp( + model, + tp_mesh, + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + ) + maybe_enable_async_tp(job_config, tp_mesh) + + if job_config.activation_checkpoint.mode != "none": + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled=model_compile_enabled, + use_flex_attn=use_flex_attn, + op_sac_save_list=_op_sac_save_list, + ) + + # apply data parallel + if ( + parallel_dims.dp_replicate_enabled + or parallel_dims.dp_shard_enabled + or parallel_dims.cp_enabled + ): + if parallel_dims.dp_replicate_enabled: + if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mode = "hybrid_shard" + else: + dp_mesh_dim_names = ("dp_replicate",) + dp_mode = "replicate" + else: + dp_mesh_dim_names = ("dp_shard_cp",) + dp_mode = "fully_shard" + + mp_policy = MixedPrecisionPolicy( + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + ) + + model = data_parallel( + model, + parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], + mode=dp_mode, + ac_mode=job_config.activation_checkpoint.mode, + mp_policy=mp_policy, + ) + logger.info( + "Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode + ) + + if job_config.compile.enable and "model" in job_config.compile.components: + torch._inductor.config.reorder_for_peak_memory = False + model = torch.compile(model, fullgraph=True) + + return model From 2d047a6ca85ab28ee3157a711d1db79b351d4bd0 Mon Sep 17 00:00:00 2001 From: Simon Fan Date: Thu, 2 Oct 2025 12:58:37 -0700 Subject: [PATCH 02/11] Hijack the execution flow for a single training loop --- .../experiments/joint_graph_runner/README.md | 2 + .../joint_graph_runner/llama3/parallelize.py | 82 ++++++++++++++++++- 2 files changed, 82 insertions(+), 2 deletions(-) diff --git a/torchtitan/experiments/joint_graph_runner/README.md b/torchtitan/experiments/joint_graph_runner/README.md index 833342a0b..30ae59594 100644 --- a/torchtitan/experiments/joint_graph_runner/README.md +++ b/torchtitan/experiments/joint_graph_runner/README.md @@ -9,3 +9,5 @@ Llama3 - Trace joint - Apply passes to the joint - Run using the Joint Graph Runner + +Run with: NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" wp ./run_train.sh --model.name joint_graph_runner.llama3 --compile.enable diff --git a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py index b94e9238b..a49e16282 100644 --- a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py +++ b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py @@ -121,7 +121,85 @@ def parallelize_llama( ) if job_config.compile.enable and "model" in job_config.compile.components: - torch._inductor.config.reorder_for_peak_memory = False - model = torch.compile(model, fullgraph=True) + # torch._inductor.config.reorder_for_peak_memory = False + # model = torch.compile(model, fullgraph=True) + model = HijackWrapper(model) return model + +# Just to bootstrap our experiment. NOT the final API. +class HijackWrapper(torch.nn.Module): + def __init__(self, inner: torch.nn.Module, **overrides): + super().__init__() + self.inner = inner # register as submodule + self._overrides = overrides # for custom hooks + + def __getattr__(self, name): + # check overrides + if "_overrides" in self.__dict__ and name in self._overrides: + return self._overrides[name] + try: + # let nn.Module handle registered stuff + return super().__getattr__(name) + except AttributeError: + # fallback to inner model + return getattr(self.inner, name) + + def __setattr__(self, name, value): + if "_overrides" in self.__dict__ and name in self._overrides: + self._overrides[name] = value + else: + super().__setattr__(name, value) + + def __delattr__(self, name): + if "_overrides" in self.__dict__ and name in self._overrides: + del self._overrides[name] + else: + super().__delattr__(name) + + def forward(self, *args, **kwargs): + assert "forward" not in self._overrides, "forward cannot be overridden" + # EDIT ME + joint_graph_runner(self.inner, *args, **kwargs) + # calling the line below returns control to torchtitan's runner + # letting it call the backward, and optimizer. + return self.inner(*args, **kwargs) + +# Think of this as a "main" function. +def joint_graph_runner(model, *inputs, **kwargs): + from contextlib import ExitStack + from torchtitan.experiments.simple_fsdp.llama3.model import SimpleFSDPTransformer + from torch._functorch.aot_autograd import ( + aot_compile_joint_with_descriptors, + aot_export_joint_with_descriptors, + boxed_nop_preserve_node_meta, + ) + from torch._logging import trace_structured + + assert isinstance(model, SimpleFSDPTransformer) + assert isinstance(inputs, tuple) + assert not kwargs + + stack = ExitStack() + joint_with_descriptors = aot_export_joint_with_descriptors( + stack, + model, + inputs, + decompositions=None, + fw_compiler=boxed_nop_preserve_node_meta, + bw_compiler=boxed_nop_preserve_node_meta, + ) + gm = joint_with_descriptors.graph_module + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "aot_export_joint_with_descriptors", + "encoding": "string", + }, + payload_fn=lambda: gm.print_readable( + print_output=False, include_stride=True, include_device=True + ), + ) + + exit(123) # just manually force the exit for now From 4f8a0ba78629e3c986e30ef23315cd3af3fed79e Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Thu, 2 Oct 2025 20:15:30 -0700 Subject: [PATCH 03/11] Introduce Joint Graph Runner --- .../experiments/joint_graph_runner/README.md | 2 +- .../joint_graph_runner/llama3/parallelize.py | 157 ++++++++++++++---- 2 files changed, 127 insertions(+), 32 deletions(-) diff --git a/torchtitan/experiments/joint_graph_runner/README.md b/torchtitan/experiments/joint_graph_runner/README.md index 30ae59594..562aa98bb 100644 --- a/torchtitan/experiments/joint_graph_runner/README.md +++ b/torchtitan/experiments/joint_graph_runner/README.md @@ -10,4 +10,4 @@ Llama3 - Apply passes to the joint - Run using the Joint Graph Runner -Run with: NGPU=1 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" wp ./run_train.sh --model.name joint_graph_runner.llama3 --compile.enable +Run with: NGPU=8 CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" with-proxy ./run_train.sh --model.name joint_graph_runner.llama3 --compile.enable --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 diff --git a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py index a49e16282..afbc8eb62 100644 --- a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py +++ b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py @@ -6,6 +6,7 @@ import torch import torch.nn as nn +from torch.distributed.tensor import DTensor from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims @@ -16,7 +17,25 @@ from torchtitan.experiments.simple_fsdp.simple_fsdp import data_parallel, MixedPrecisionPolicy +from torch._functorch.aot_autograd import aot_export_joint_with_descriptors +from torch._functorch.partitioners import min_cut_rematerialization_partition +from torch._dynamo.functional_export import _dynamo_graph_capture_for_export + +from torch._functorch._aot_autograd.aot_eager_runner import ( + get_num_mutate_inputs, + get_num_user_outputs, + JointGraphModule, + RunMode, +) +import contextlib +from torchtitan.experiments.simple_fsdp.llama3.model import SimpleFSDPTransformer +from torch._functorch.aot_autograd import ( + aot_export_joint_with_descriptors, + boxed_nop_preserve_node_meta, +) +from torch._logging import trace_structured + # for selective op activation checkpointing _op_sac_save_list = { torch.ops.aten.mm.default, @@ -30,6 +49,29 @@ torch._higher_order_ops.flex_attention, } +def print_if_rank0(msg): + if torch.distributed.get_rank() == 0: + print(msg) + +def graph_capture_and_aot_export_joint_with_descriptors(model, inputs): + assert isinstance(inputs, tuple) + with torch._dynamo.config.patch(install_free_tensors=True): + # TODO: switch to use the official graph_capture API once it is ready + gm = _dynamo_graph_capture_for_export(model)(*inputs) + return aot_export_joint_with_descriptors_alone(gm, inputs) + + +def aot_export_joint_with_descriptors_alone(model, inputs): + assert isinstance(inputs, tuple) + with contextlib.ExitStack() as stack: + joint_with_descriptors = aot_export_joint_with_descriptors( + stack, + model, + inputs, + ) + return joint_with_descriptors + + def parallelize_llama( model: nn.Module, @@ -132,6 +174,8 @@ class HijackWrapper(torch.nn.Module): def __init__(self, inner: torch.nn.Module, **overrides): super().__init__() self.inner = inner # register as submodule + + self.joint_graph_module = None self._overrides = overrides # for custom hooks def __getattr__(self, name): @@ -159,47 +203,98 @@ def __delattr__(self, name): def forward(self, *args, **kwargs): assert "forward" not in self._overrides, "forward cannot be overridden" - # EDIT ME - joint_graph_runner(self.inner, *args, **kwargs) + + # HACK: doing graph capture on the fly, we should do it AOT + if self.joint_graph_module is None: + # first time, we need to initialize the runner + self.joint_graph_module = joint_graph_runner(self.inner, *args, **kwargs) + # calling the line below returns control to torchtitan's runner # letting it call the backward, and optimizer. + # return self.joint_graph_module(*args, **kwargs) return self.inner(*args, **kwargs) # Think of this as a "main" function. def joint_graph_runner(model, *inputs, **kwargs): - from contextlib import ExitStack - from torchtitan.experiments.simple_fsdp.llama3.model import SimpleFSDPTransformer - from torch._functorch.aot_autograd import ( - aot_compile_joint_with_descriptors, - aot_export_joint_with_descriptors, - boxed_nop_preserve_node_meta, - ) - from torch._logging import trace_structured - assert isinstance(model, SimpleFSDPTransformer) assert isinstance(inputs, tuple) assert not kwargs - stack = ExitStack() - joint_with_descriptors = aot_export_joint_with_descriptors( - stack, - model, - inputs, - decompositions=None, - fw_compiler=boxed_nop_preserve_node_meta, - bw_compiler=boxed_nop_preserve_node_meta, + # convert args and kwargs to DTensor + # local_inputs = tuple(DTensor.from_local(input, ) for input in inputs) + + # get fw/bw graphs + joint_with_descriptors = graph_capture_and_aot_export_joint_with_descriptors(model, inputs) + + # Now partition the joint grapg + joint_gm = joint_with_descriptors.graph_module + aot_state = joint_with_descriptors._aot_state + aot_graph_capture = joint_with_descriptors._aot_graph_capture + + # Get the joint graph module + joint_inputs = aot_graph_capture.updated_flat_args + fw_metadata = aot_state.fw_metadata + + num_user_outputs = get_num_user_outputs(fw_metadata) + num_mutate_inputs = get_num_mutate_inputs(fw_metadata) + num_inner_fwd_outputs = num_mutate_inputs + num_user_outputs + + fw_gm, bw_gm = min_cut_rematerialization_partition( + joint_gm, + joint_inputs, + num_fwd_outputs=num_inner_fwd_outputs, + static_lifetime_input_indices=fw_metadata.static_input_indices or [], ) - gm = joint_with_descriptors.graph_module - - trace_structured( - "artifact", - metadata_fn=lambda: { - "name": "aot_export_joint_with_descriptors", - "encoding": "string", - }, - payload_fn=lambda: gm.print_readable( - print_output=False, include_stride=True, include_device=True - ), + + # print_if_rank0(f"fw_gm:") + # print_if_rank0(fw_gm.print_readable(print_output=False)) + + # print_if_rank0(f"bw_gm:") + # print_if_rank0(bw_gm.print_readable(print_output=False)) + + # Run graph passes here + ## Apply bucketing here + ## Apply Flex Attention compilation here + + # Codgen Autograd.Function Wrappers + + # Get the model parameters and buffers - the partitioned graphs expect these as arguments + local_params = [] + for p in model.parameters(): + if isinstance(p, DTensor): + local_params.append(p.to_local()) + else: + local_params.append(p) + + local_buffers = [] + for b in model.buffers(): + if isinstance(b, DTensor): + local_buffers.append(b.to_local()) + else: + local_buffers.append(b) + + + # local_inputs = tuple(input.clone().detach() for input in inputs) + + joint_graph_module = JointGraphModule( + local_params, local_buffers, fw_metadata, fw_gm, bw_gm, RunMode.CODEGEN ) - exit(123) # just manually force the exit for now + return joint_graph_module + + # Run forward pass through the custom function + # outputs = joint_graph_module(local_inputs) + + # trace_structured( + # "artifact", + # metadata_fn=lambda: { + # "name": "aot_export_joint_with_descriptors", + # "encoding": "string", + # }, + # payload_fn=lambda: gm.print_readable( + # print_output=False, include_stride=True, include_device=True + # ), + # ) + + + # exit(123) # just manually force the exit for now From 3919e7878b38c2d5b48aa0f8f330fb113c51eeaa Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Thu, 2 Oct 2025 21:12:35 -0700 Subject: [PATCH 04/11] convert inputs into DTensor --- .../joint_graph_runner/llama3/parallelize.py | 31 +++++++++++++------ 1 file changed, 22 insertions(+), 9 deletions(-) diff --git a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py index afbc8eb62..b9747b45f 100644 --- a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py +++ b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py @@ -6,7 +6,8 @@ import torch import torch.nn as nn -from torch.distributed.tensor import DTensor +from torch.distributed.tensor import DTensor, Shard, Replicate + from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims @@ -165,15 +166,16 @@ def parallelize_llama( if job_config.compile.enable and "model" in job_config.compile.components: # torch._inductor.config.reorder_for_peak_memory = False # model = torch.compile(model, fullgraph=True) - model = HijackWrapper(model) + model = HijackWrapper(model, parallel_dims) return model # Just to bootstrap our experiment. NOT the final API. class HijackWrapper(torch.nn.Module): - def __init__(self, inner: torch.nn.Module, **overrides): + def __init__(self, inner: torch.nn.Module, parallel_dims, **overrides): super().__init__() self.inner = inner # register as submodule + self.parallel_dims = parallel_dims self.joint_graph_module = None self._overrides = overrides # for custom hooks @@ -204,25 +206,36 @@ def __delattr__(self, name): def forward(self, *args, **kwargs): assert "forward" not in self._overrides, "forward cannot be overridden" + # print_if_rank0(self.parallel_dims.world_mesh) + # 2-D device mesh with ['dp_shard', 'tp'], [2, 4] + + # Hack: convert args and kwargs to DTensor. This should be fixed at data loader. + dt_args = tuple(DTensor.from_local(arg, self.parallel_dims.world_mesh["tp"], [Replicate()]) for arg in args) + + # RuntimeError('Sharding propagation failed for Op(op=aten.embedding.default, args_schema=Spec(S(0) on (2048, 256)), Spec((Shard(dim=0), Replicate()) on (16, 2048)) @ mesh: (2, 4))') + # dt_args = tuple(DTensor.from_local(arg, self.parallel_dims.world_mesh, [Shard(0), Replicate()]) for arg in args) + + # RuntimeError('Sharding propagation failed for Op(op=aten.embedding.default, args_schema=Spec(S(0) on (2048, 256)), Spec(S(0) on (16, 2048)) @ mesh: (2,))') + # dt_args = tuple(DTensor.from_local(arg, self.parallel_dims.world_mesh["dp_shard"], [Shard(0)]) for arg in args) + # HACK: doing graph capture on the fly, we should do it AOT if self.joint_graph_module is None: # first time, we need to initialize the runner - self.joint_graph_module = joint_graph_runner(self.inner, *args, **kwargs) + self.joint_graph_module = joint_graph_builder(self.inner, *dt_args, **kwargs) # calling the line below returns control to torchtitan's runner # letting it call the backward, and optimizer. # return self.joint_graph_module(*args, **kwargs) return self.inner(*args, **kwargs) -# Think of this as a "main" function. -def joint_graph_runner(model, *inputs, **kwargs): + +def joint_graph_builder(model, *inputs, **kwargs): assert isinstance(model, SimpleFSDPTransformer) assert isinstance(inputs, tuple) + for input in inputs: + assert isinstance(input, DTensor) assert not kwargs - # convert args and kwargs to DTensor - # local_inputs = tuple(DTensor.from_local(input, ) for input in inputs) - # get fw/bw graphs joint_with_descriptors = graph_capture_and_aot_export_joint_with_descriptors(model, inputs) From 35ce6a0407fead525015fa8c69ff93f9faa49010 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Fri, 3 Oct 2025 00:14:43 -0700 Subject: [PATCH 05/11] fixes --- .../joint_graph_runner/llama3/parallelize.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py index b9747b45f..ae0895781 100644 --- a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py +++ b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py @@ -7,6 +7,7 @@ import torch import torch.nn as nn from torch.distributed.tensor import DTensor, Shard, Replicate +from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing from torchtitan.config import JobConfig, TORCH_DTYPE_MAP @@ -164,8 +165,6 @@ def parallelize_llama( ) if job_config.compile.enable and "model" in job_config.compile.components: - # torch._inductor.config.reorder_for_peak_memory = False - # model = torch.compile(model, fullgraph=True) model = HijackWrapper(model, parallel_dims) return model @@ -210,6 +209,7 @@ def forward(self, *args, **kwargs): # 2-D device mesh with ['dp_shard', 'tp'], [2, 4] # Hack: convert args and kwargs to DTensor. This should be fixed at data loader. + # This works, but kinda cheating? dt_args = tuple(DTensor.from_local(arg, self.parallel_dims.world_mesh["tp"], [Replicate()]) for arg in args) # RuntimeError('Sharding propagation failed for Op(op=aten.embedding.default, args_schema=Spec(S(0) on (2048, 256)), Spec((Shard(dim=0), Replicate()) on (16, 2048)) @ mesh: (2, 4))') @@ -220,13 +220,17 @@ def forward(self, *args, **kwargs): # HACK: doing graph capture on the fly, we should do it AOT if self.joint_graph_module is None: + # needed to avoid having fwd_rng_state in the fw_gm inp + # this doesn't work! + # torch._functorch.config.graphsafe_rng_functionalization = False + # first time, we need to initialize the runner self.joint_graph_module = joint_graph_builder(self.inner, *dt_args, **kwargs) # calling the line below returns control to torchtitan's runner # letting it call the backward, and optimizer. # return self.joint_graph_module(*args, **kwargs) - return self.inner(*args, **kwargs) + return self.joint_graph_module(args) def joint_graph_builder(model, *inputs, **kwargs): @@ -266,8 +270,12 @@ def joint_graph_builder(model, *inputs, **kwargs): # print_if_rank0(bw_gm.print_readable(print_output=False)) # Run graph passes here - ## Apply bucketing here - ## Apply Flex Attention compilation here + + ## Apply bucketing + # schedule_overlap_bucketing(fw_gm) + # schedule_overlap_bucketing(bw_gm) + + ## TODO: Apply Flex Attention compilation here # Codgen Autograd.Function Wrappers @@ -287,16 +295,12 @@ def joint_graph_builder(model, *inputs, **kwargs): local_buffers.append(b) - # local_inputs = tuple(input.clone().detach() for input in inputs) - joint_graph_module = JointGraphModule( - local_params, local_buffers, fw_metadata, fw_gm, bw_gm, RunMode.CODEGEN + local_params, local_buffers, fw_metadata, fw_gm, bw_gm, RunMode.CODEGEN_AUTOGRAD, f"rank{torch.distributed.get_rank()}" ) return joint_graph_module - # Run forward pass through the custom function - # outputs = joint_graph_module(local_inputs) # trace_structured( # "artifact", @@ -308,6 +312,3 @@ def joint_graph_builder(model, *inputs, **kwargs): # print_output=False, include_stride=True, include_device=True # ), # ) - - - # exit(123) # just manually force the exit for now From 51c1538a1ef422737ee4d1de06b22bc703d68f3a Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Wed, 8 Oct 2025 15:35:50 -0700 Subject: [PATCH 06/11] use aot_compile_joint_with_descriptors --- .../joint_graph_runner/llama3/parallelize.py | 278 ++++++++++++++---- 1 file changed, 220 insertions(+), 58 deletions(-) diff --git a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py index ae0895781..360fde87a 100644 --- a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py +++ b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from torch._functorch.aot_autograd import aot_compile_joint_with_descriptors import torch import torch.nn as nn from torch.distributed.tensor import DTensor, Shard, Replicate @@ -24,12 +25,12 @@ from torch._dynamo.functional_export import _dynamo_graph_capture_for_export -from torch._functorch._aot_autograd.aot_eager_runner import ( - get_num_mutate_inputs, - get_num_user_outputs, - JointGraphModule, - RunMode, -) +# from torch._functorch._aot_autograd.aot_eager_runner import ( +# get_num_mutate_inputs, +# get_num_user_outputs, +# JointGraphModule, +# RunMode, +# ) import contextlib from torchtitan.experiments.simple_fsdp.llama3.model import SimpleFSDPTransformer from torch._functorch.aot_autograd import ( @@ -60,6 +61,13 @@ def graph_capture_and_aot_export_joint_with_descriptors(model, inputs): with torch._dynamo.config.patch(install_free_tensors=True): # TODO: switch to use the official graph_capture API once it is ready gm = _dynamo_graph_capture_for_export(model)(*inputs) + + # Restore the state dict to match the original module + _restore_state_dict(model, gm) + + # Validate that state dict and ordering match + _validate_state_dict_match(model, gm) + return aot_export_joint_with_descriptors_alone(gm, inputs) @@ -75,6 +83,158 @@ def aot_export_joint_with_descriptors_alone(model, inputs): +def _assign_attr( + from_obj: torch.Tensor | torch.nn.Parameter, + to_module: torch.nn.Module, + target: str, + is_buffer: bool = False, +) -> None: + """ + Assign attribute 'from_obj' to the qualified name 'target' on 'to_module'. + This installs empty Modules where none exist yet if they are subpaths of target. + Based on _assign_attr from torch.export.unflatten. + """ + *prefix, field = target.split(".") + + # Generate all submodules along the path + for item in prefix: + if not hasattr(to_module, item): + setattr(to_module, item, torch.nn.Module()) + to_module = getattr(to_module, item) + + # Assign the actual parameter or buffer + if is_buffer: + to_module.register_buffer(field, from_obj) + else: + to_module.register_parameter(field, from_obj) + + +def _clear_traced_params_buffers(traced_module: torch.fx.GraphModule) -> None: + """Remove all parameters and buffers from traced module before restoring.""" + # Remove all parameters + for name in list(traced_module._parameters.keys()): + delattr(traced_module, name) + + # Remove all buffers + for name in list(traced_module._buffers.keys()): + delattr(traced_module, name) + + +def _restore_state_dict( + original_module: torch.nn.Module, traced_module: torch.fx.GraphModule +) -> None: + """ + Restores the state dict of the traced module to match the original module exactly. + Preserves the original FQNs with dots, creating intermediate empty modules as needed. + Ensures that the ordering of parameters/buffers matches the original module. + """ + # Build ID-based lookups for traced module params/buffers + traced_params: dict[int, tuple[str, torch.nn.Parameter]] = {} + for name, param in traced_module.named_parameters(remove_duplicate=False): + traced_params[id(param)] = (name, param) + + traced_buffers: dict[int, tuple[str, torch.Tensor]] = {} + for name, buffer in traced_module.named_buffers(remove_duplicate=False): + traced_buffers[id(buffer)] = (name, buffer) + + # Build mapping from old names to new names for graph node updates + name_mapping: dict[str, str] = {} + + # Clear existing parameters and buffers from traced module + _clear_traced_params_buffers(traced_module) + + # Restore parameters in the order they appear in original module + for orig_name, orig_param in original_module.named_parameters( + remove_duplicate=False + ): + if id(orig_param) in traced_params: + # This param exists in traced module - restore it with original FQN + traced_name, traced_param = traced_params[id(orig_param)] + _assign_attr(traced_param, traced_module, orig_name, is_buffer=False) + name_mapping[traced_name] = orig_name + else: + # This param doesn't exist in traced module - add it + _assign_attr(orig_param, traced_module, orig_name, is_buffer=False) + + # Restore buffers in the order they appear in original module + for orig_name, orig_buffer in original_module.named_buffers(remove_duplicate=False): + if id(orig_buffer) in traced_buffers: + # This buffer exists in traced module - restore it with original FQN + traced_name, traced_buffer = traced_buffers[id(orig_buffer)] + _assign_attr(traced_buffer, traced_module, orig_name, is_buffer=True) + name_mapping[traced_name] = orig_name + else: + # This buffer doesn't exist in traced module - add it + _assign_attr(orig_buffer, traced_module, orig_name, is_buffer=True) + + # Update get_attr nodes in the graph to use the correct FQNs + for node in traced_module.graph.nodes: + if node.op == "get_attr" and node.target in name_mapping: + node.target = name_mapping[node.target] + + traced_module.recompile() + + +def _validate_state_dict_match( + original_module: torch.nn.Module, traced_module: torch.fx.GraphModule +) -> None: + """ + Validates that the traced module's state dict matches the original module. + Checks that parameter/buffer names, ordering, and tensor values match. + + Raises: + AssertionError: If any validation check fails. + """ + # Check 1: Verify parameter names and ordering match + orig_param_names = list( + dict(original_module.named_parameters(remove_duplicate=False)).keys() + ) + gm_param_names = list( + dict(traced_module.named_parameters(remove_duplicate=False)).keys() + ) + assert orig_param_names == gm_param_names, ( + f"Parameter names or ordering mismatch!\n" + f"Original: {orig_param_names}\n" + f"Traced: {gm_param_names}" + ) + + # Check 2: Verify buffer names and ordering match + orig_buffer_names = list( + dict(original_module.named_buffers(remove_duplicate=False)).keys() + ) + gm_buffer_names = list( + dict(traced_module.named_buffers(remove_duplicate=False)).keys() + ) + assert orig_buffer_names == gm_buffer_names, ( + f"Buffer names or ordering mismatch!\n" + f"Original: {orig_buffer_names}\n" + f"Traced: {gm_buffer_names}" + ) + + # Check 3: Verify parameter tensors match by identity or value + for (orig_name, orig_param), (gm_name, gm_param) in zip( + original_module.named_parameters(remove_duplicate=False), + traced_module.named_parameters(remove_duplicate=False), + ): + assert ( + orig_name == gm_name + ), f"Parameter name mismatch: {orig_name} != {gm_name}" + assert id(orig_param) == id(gm_param) or torch.equal( + orig_param, gm_param + ), f"Parameter tensor mismatch for {orig_name}" + + # Check 4: Verify buffer tensors match by identity or value + for (orig_name, orig_buffer), (gm_name, gm_buffer) in zip( + original_module.named_buffers(remove_duplicate=False), + traced_module.named_buffers(remove_duplicate=False), + ): + assert orig_name == gm_name, f"Buffer name mismatch: {orig_name} != {gm_name}" + assert id(orig_buffer) == id(gm_buffer) or torch.equal( + orig_buffer, gm_buffer + ), f"Buffer tensor mismatch for {orig_name}" + + + def parallelize_llama( model: nn.Module, parallel_dims: ParallelDims, @@ -243,72 +403,74 @@ def joint_graph_builder(model, *inputs, **kwargs): # get fw/bw graphs joint_with_descriptors = graph_capture_and_aot_export_joint_with_descriptors(model, inputs) - # Now partition the joint grapg - joint_gm = joint_with_descriptors.graph_module - aot_state = joint_with_descriptors._aot_state - aot_graph_capture = joint_with_descriptors._aot_graph_capture + fn = aot_compile_joint_with_descriptors(joint_with_descriptors) - # Get the joint graph module - joint_inputs = aot_graph_capture.updated_flat_args - fw_metadata = aot_state.fw_metadata + def wrapper_fn(args): + input = [ + *model.parameters(), + *model.buffers(), + *args, + ] + return fn(*input) + + return wrapper_fn - num_user_outputs = get_num_user_outputs(fw_metadata) - num_mutate_inputs = get_num_mutate_inputs(fw_metadata) - num_inner_fwd_outputs = num_mutate_inputs + num_user_outputs - fw_gm, bw_gm = min_cut_rematerialization_partition( - joint_gm, - joint_inputs, - num_fwd_outputs=num_inner_fwd_outputs, - static_lifetime_input_indices=fw_metadata.static_input_indices or [], - ) - # print_if_rank0(f"fw_gm:") - # print_if_rank0(fw_gm.print_readable(print_output=False)) + # # Now partition the joint grapg + # joint_gm = joint_with_descriptors.graph_module + # aot_state = joint_with_descriptors._aot_state + # aot_graph_capture = joint_with_descriptors._aot_graph_capture - # print_if_rank0(f"bw_gm:") - # print_if_rank0(bw_gm.print_readable(print_output=False)) + # # Get the joint graph module + # joint_inputs = aot_graph_capture.updated_flat_args + # fw_metadata = aot_state.fw_metadata - # Run graph passes here + # num_user_outputs = get_num_user_outputs(fw_metadata) + # num_mutate_inputs = get_num_mutate_inputs(fw_metadata) + # num_inner_fwd_outputs = num_mutate_inputs + num_user_outputs - ## Apply bucketing - # schedule_overlap_bucketing(fw_gm) - # schedule_overlap_bucketing(bw_gm) + # fw_gm, bw_gm = min_cut_rematerialization_partition( + # joint_gm, + # joint_inputs, + # num_fwd_outputs=num_inner_fwd_outputs, + # static_lifetime_input_indices=fw_metadata.static_input_indices or [], + # ) - ## TODO: Apply Flex Attention compilation here + # # print_if_rank0(f"fw_gm:") + # # print_if_rank0(fw_gm.print_readable(print_output=False)) - # Codgen Autograd.Function Wrappers + # # print_if_rank0(f"bw_gm:") + # # print_if_rank0(bw_gm.print_readable(print_output=False)) - # Get the model parameters and buffers - the partitioned graphs expect these as arguments - local_params = [] - for p in model.parameters(): - if isinstance(p, DTensor): - local_params.append(p.to_local()) - else: - local_params.append(p) + # # Run graph passes here - local_buffers = [] - for b in model.buffers(): - if isinstance(b, DTensor): - local_buffers.append(b.to_local()) - else: - local_buffers.append(b) + # ## Apply bucketing + # # schedule_overlap_bucketing(fw_gm) + # # schedule_overlap_bucketing(bw_gm) + # ## TODO: Apply Flex Attention compilation here - joint_graph_module = JointGraphModule( - local_params, local_buffers, fw_metadata, fw_gm, bw_gm, RunMode.CODEGEN_AUTOGRAD, f"rank{torch.distributed.get_rank()}" - ) + # # Codgen Autograd.Function Wrappers + + # # Get the model parameters and buffers - the partitioned graphs expect these as arguments + # local_params = [] + # for p in model.parameters(): + # if isinstance(p, DTensor): + # local_params.append(p.to_local()) + # else: + # local_params.append(p) - return joint_graph_module + # local_buffers = [] + # for b in model.buffers(): + # if isinstance(b, DTensor): + # local_buffers.append(b.to_local()) + # else: + # local_buffers.append(b) - # trace_structured( - # "artifact", - # metadata_fn=lambda: { - # "name": "aot_export_joint_with_descriptors", - # "encoding": "string", - # }, - # payload_fn=lambda: gm.print_readable( - # print_output=False, include_stride=True, include_device=True - # ), + # joint_graph_module = JointGraphModule( + # local_params, local_buffers, fw_metadata, fw_gm, bw_gm, RunMode.CODEGEN_AUTOGRAD, f"rank{torch.distributed.get_rank()}" # ) + + # return joint_graph_module From b2684c00574cde3f3f94dc5d326b87c7d7499c41 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Wed, 8 Oct 2025 16:46:57 -0700 Subject: [PATCH 07/11] apply fw/bw compiler --- .../experiments/joint_graph_runner/llama3/parallelize.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py index 360fde87a..ef4107a6a 100644 --- a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py +++ b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py @@ -4,7 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from torch._functorch.aot_autograd import aot_compile_joint_with_descriptors import torch import torch.nn as nn from torch.distributed.tensor import DTensor, Shard, Replicate @@ -20,7 +19,7 @@ from torchtitan.experiments.simple_fsdp.simple_fsdp import data_parallel, MixedPrecisionPolicy -from torch._functorch.aot_autograd import aot_export_joint_with_descriptors +from torch._functorch.aot_autograd import aot_export_joint_with_descriptors, aot_compile_joint_with_descriptors from torch._functorch.partitioners import min_cut_rematerialization_partition from torch._dynamo.functional_export import _dynamo_graph_capture_for_export @@ -402,8 +401,12 @@ def joint_graph_builder(model, *inputs, **kwargs): # get fw/bw graphs joint_with_descriptors = graph_capture_and_aot_export_joint_with_descriptors(model, inputs) + + def compiler(gm: torch.fx.GraphModule, example_inputs): + print_if_rank0(gm.print_readable(print_output=False)) + return gm - fn = aot_compile_joint_with_descriptors(joint_with_descriptors) + fn = aot_compile_joint_with_descriptors(joint_with_descriptors, fw_compiler=compiler, bw_compiler=compiler) def wrapper_fn(args): input = [ From 1cdeb61e3d99483630e4e4c8c5ec77b78fb5104f Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Wed, 8 Oct 2025 17:28:27 -0700 Subject: [PATCH 08/11] apply schedule_overlap_bucketing --- .../experiments/joint_graph_runner/llama3/parallelize.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py index ef4107a6a..6f6bf2247 100644 --- a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py +++ b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py @@ -403,6 +403,12 @@ def joint_graph_builder(model, *inputs, **kwargs): joint_with_descriptors = graph_capture_and_aot_export_joint_with_descriptors(model, inputs) def compiler(gm: torch.fx.GraphModule, example_inputs): + print_if_rank0("Before compiler:") + print_if_rank0(gm.print_readable(print_output=False)) + + gm = schedule_overlap_bucketing(gm) + + print_if_rank0("After compiler:") print_if_rank0(gm.print_readable(print_output=False)) return gm From cf845c05a18af2c33360ba3861da9e5c01733363 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Wed, 8 Oct 2025 17:30:52 -0700 Subject: [PATCH 09/11] Clean up --- .../joint_graph_runner/llama3/parallelize.py | 68 ------------------- 1 file changed, 68 deletions(-) diff --git a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py index 6f6bf2247..7265d8224 100644 --- a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py +++ b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py @@ -9,7 +9,6 @@ from torch.distributed.tensor import DTensor, Shard, Replicate from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing - from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac @@ -24,19 +23,12 @@ from torch._dynamo.functional_export import _dynamo_graph_capture_for_export -# from torch._functorch._aot_autograd.aot_eager_runner import ( -# get_num_mutate_inputs, -# get_num_user_outputs, -# JointGraphModule, -# RunMode, -# ) import contextlib from torchtitan.experiments.simple_fsdp.llama3.model import SimpleFSDPTransformer from torch._functorch.aot_autograd import ( aot_export_joint_with_descriptors, boxed_nop_preserve_node_meta, ) -from torch._logging import trace_structured # for selective op activation checkpointing _op_sac_save_list = { @@ -423,63 +415,3 @@ def wrapper_fn(args): return fn(*input) return wrapper_fn - - - - # # Now partition the joint grapg - # joint_gm = joint_with_descriptors.graph_module - # aot_state = joint_with_descriptors._aot_state - # aot_graph_capture = joint_with_descriptors._aot_graph_capture - - # # Get the joint graph module - # joint_inputs = aot_graph_capture.updated_flat_args - # fw_metadata = aot_state.fw_metadata - - # num_user_outputs = get_num_user_outputs(fw_metadata) - # num_mutate_inputs = get_num_mutate_inputs(fw_metadata) - # num_inner_fwd_outputs = num_mutate_inputs + num_user_outputs - - # fw_gm, bw_gm = min_cut_rematerialization_partition( - # joint_gm, - # joint_inputs, - # num_fwd_outputs=num_inner_fwd_outputs, - # static_lifetime_input_indices=fw_metadata.static_input_indices or [], - # ) - - # # print_if_rank0(f"fw_gm:") - # # print_if_rank0(fw_gm.print_readable(print_output=False)) - - # # print_if_rank0(f"bw_gm:") - # # print_if_rank0(bw_gm.print_readable(print_output=False)) - - # # Run graph passes here - - # ## Apply bucketing - # # schedule_overlap_bucketing(fw_gm) - # # schedule_overlap_bucketing(bw_gm) - - # ## TODO: Apply Flex Attention compilation here - - # # Codgen Autograd.Function Wrappers - - # # Get the model parameters and buffers - the partitioned graphs expect these as arguments - # local_params = [] - # for p in model.parameters(): - # if isinstance(p, DTensor): - # local_params.append(p.to_local()) - # else: - # local_params.append(p) - - # local_buffers = [] - # for b in model.buffers(): - # if isinstance(b, DTensor): - # local_buffers.append(b.to_local()) - # else: - # local_buffers.append(b) - - - # joint_graph_module = JointGraphModule( - # local_params, local_buffers, fw_metadata, fw_gm, bw_gm, RunMode.CODEGEN_AUTOGRAD, f"rank{torch.distributed.get_rank()}" - # ) - - # return joint_graph_module From b262d9aa23b08e4ca906b011132d83fb6eec18f5 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Wed, 8 Oct 2025 17:36:35 -0700 Subject: [PATCH 10/11] lint --- torchtitan/experiments/__init__.py | 10 ++- .../joint_graph_runner/llama3/__init__.py | 8 ++- .../joint_graph_runner/llama3/parallelize.py | 63 +++++++++++-------- 3 files changed, 52 insertions(+), 29 deletions(-) diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 2b99a9051..d8e020e0d 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -5,5 +5,13 @@ # LICENSE file in the root directory of this source tree. _supported_experiments = frozenset( - ["flux", "llama4", "qwen3", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm", "joint_graph_runner.llama3"] + [ + "flux", + "llama4", + "qwen3", + "simple_fsdp.llama3", + "simple_fsdp.deepseek_v3", + "vlm", + "joint_graph_runner.llama3", + ] ) diff --git a/torchtitan/experiments/joint_graph_runner/llama3/__init__.py b/torchtitan/experiments/joint_graph_runner/llama3/__init__.py index 99c350cde..9f21ea56b 100644 --- a/torchtitan/experiments/joint_graph_runner/llama3/__init__.py +++ b/torchtitan/experiments/joint_graph_runner/llama3/__init__.py @@ -11,11 +11,13 @@ from torchtitan.components.optimizer import build_optimizers from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.datasets.hf_datasets import build_hf_dataloader -from torchtitan.models.llama3 import llama3_configs, pipeline_llama -from torchtitan.protocols.train_spec import TrainSpec +from torchtitan.experiments.joint_graph_runner.llama3.parallelize import ( + parallelize_llama, +) from torchtitan.experiments.simple_fsdp.llama3.model import SimpleFSDPTransformer -from torchtitan.experiments.joint_graph_runner.llama3.parallelize import parallelize_llama +from torchtitan.models.llama3 import llama3_configs, pipeline_llama +from torchtitan.protocols.train_spec import TrainSpec def get_train_spec() -> TrainSpec: diff --git a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py index 7265d8224..1f34cf518 100644 --- a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py +++ b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py @@ -4,32 +4,35 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib + import torch import torch.nn as nn -from torch.distributed.tensor import DTensor, Shard, Replicate + +from torch._dynamo.functional_export import _dynamo_graph_capture_for_export + +from torch._functorch.aot_autograd import ( + aot_compile_joint_with_descriptors, + aot_export_joint_with_descriptors, + boxed_nop_preserve_node_meta, +) +from torch._functorch.partitioners import min_cut_rematerialization_partition from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing +from torch.distributed.tensor import DTensor, Replicate, Shard from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp +from torchtitan.experiments.simple_fsdp.llama3.model import SimpleFSDPTransformer + +from torchtitan.experiments.simple_fsdp.simple_fsdp import ( + data_parallel, + MixedPrecisionPolicy, +) from torchtitan.models.llama3.infra.parallelize import apply_tp from torchtitan.tools.logging import logger -from torchtitan.experiments.simple_fsdp.simple_fsdp import data_parallel, MixedPrecisionPolicy - -from torch._functorch.aot_autograd import aot_export_joint_with_descriptors, aot_compile_joint_with_descriptors -from torch._functorch.partitioners import min_cut_rematerialization_partition - -from torch._dynamo.functional_export import _dynamo_graph_capture_for_export - -import contextlib -from torchtitan.experiments.simple_fsdp.llama3.model import SimpleFSDPTransformer -from torch._functorch.aot_autograd import ( - aot_export_joint_with_descriptors, - boxed_nop_preserve_node_meta, -) - # for selective op activation checkpointing _op_sac_save_list = { torch.ops.aten.mm.default, @@ -43,10 +46,12 @@ torch._higher_order_ops.flex_attention, } + def print_if_rank0(msg): if torch.distributed.get_rank() == 0: print(msg) + def graph_capture_and_aot_export_joint_with_descriptors(model, inputs): assert isinstance(inputs, tuple) with torch._dynamo.config.patch(install_free_tensors=True): @@ -73,7 +78,6 @@ def aot_export_joint_with_descriptors_alone(model, inputs): return joint_with_descriptors - def _assign_attr( from_obj: torch.Tensor | torch.nn.Parameter, to_module: torch.nn.Module, @@ -225,7 +229,6 @@ def _validate_state_dict_match( ), f"Buffer tensor mismatch for {orig_name}" - def parallelize_llama( model: nn.Module, parallel_dims: ParallelDims, @@ -320,11 +323,12 @@ def parallelize_llama( return model + # Just to bootstrap our experiment. NOT the final API. class HijackWrapper(torch.nn.Module): def __init__(self, inner: torch.nn.Module, parallel_dims, **overrides): super().__init__() - self.inner = inner # register as submodule + self.inner = inner # register as submodule self.parallel_dims = parallel_dims self.joint_graph_module = None @@ -359,15 +363,18 @@ def forward(self, *args, **kwargs): # print_if_rank0(self.parallel_dims.world_mesh) # 2-D device mesh with ['dp_shard', 'tp'], [2, 4] - # Hack: convert args and kwargs to DTensor. This should be fixed at data loader. + # Hack: convert args and kwargs to DTensor. This should be fixed at data loader. # This works, but kinda cheating? - dt_args = tuple(DTensor.from_local(arg, self.parallel_dims.world_mesh["tp"], [Replicate()]) for arg in args) + dt_args = tuple( + DTensor.from_local(arg, self.parallel_dims.world_mesh["tp"], [Replicate()]) + for arg in args + ) # RuntimeError('Sharding propagation failed for Op(op=aten.embedding.default, args_schema=Spec(S(0) on (2048, 256)), Spec((Shard(dim=0), Replicate()) on (16, 2048)) @ mesh: (2, 4))') # dt_args = tuple(DTensor.from_local(arg, self.parallel_dims.world_mesh, [Shard(0), Replicate()]) for arg in args) # RuntimeError('Sharding propagation failed for Op(op=aten.embedding.default, args_schema=Spec(S(0) on (2048, 256)), Spec(S(0) on (16, 2048)) @ mesh: (2,))') - # dt_args = tuple(DTensor.from_local(arg, self.parallel_dims.world_mesh["dp_shard"], [Shard(0)]) for arg in args) + # dt_args = tuple(DTensor.from_local(arg, self.parallel_dims.world_mesh["dp_shard"], [Shard(0)]) for arg in args) # HACK: doing graph capture on the fly, we should do it AOT if self.joint_graph_module is None: @@ -376,7 +383,9 @@ def forward(self, *args, **kwargs): # torch._functorch.config.graphsafe_rng_functionalization = False # first time, we need to initialize the runner - self.joint_graph_module = joint_graph_builder(self.inner, *dt_args, **kwargs) + self.joint_graph_module = joint_graph_builder( + self.inner, *dt_args, **kwargs + ) # calling the line below returns control to torchtitan's runner # letting it call the backward, and optimizer. @@ -392,8 +401,10 @@ def joint_graph_builder(model, *inputs, **kwargs): assert not kwargs # get fw/bw graphs - joint_with_descriptors = graph_capture_and_aot_export_joint_with_descriptors(model, inputs) - + joint_with_descriptors = graph_capture_and_aot_export_joint_with_descriptors( + model, inputs + ) + def compiler(gm: torch.fx.GraphModule, example_inputs): print_if_rank0("Before compiler:") print_if_rank0(gm.print_readable(print_output=False)) @@ -404,7 +415,9 @@ def compiler(gm: torch.fx.GraphModule, example_inputs): print_if_rank0(gm.print_readable(print_output=False)) return gm - fn = aot_compile_joint_with_descriptors(joint_with_descriptors, fw_compiler=compiler, bw_compiler=compiler) + fn = aot_compile_joint_with_descriptors( + joint_with_descriptors, fw_compiler=compiler, bw_compiler=compiler + ) def wrapper_fn(args): input = [ From 7535a4a47066cf89fffff523c6355703ec1bf4f1 Mon Sep 17 00:00:00 2001 From: Sherlock Huang Date: Thu, 9 Oct 2025 20:30:37 -0700 Subject: [PATCH 11/11] patch _restore_state_dict --- .../joint_graph_runner/llama3/parallelize.py | 163 ++++++------------ torchtitan/models/attention.py | 5 +- 2 files changed, 53 insertions(+), 115 deletions(-) diff --git a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py index 1f34cf518..ba5702e96 100644 --- a/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py +++ b/torchtitan/experiments/joint_graph_runner/llama3/parallelize.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import contextlib +from torch._guards import tracing, TracingContext import torch import torch.nn as nn @@ -33,6 +34,9 @@ from torchtitan.models.llama3.infra.parallelize import apply_tp from torchtitan.tools.logging import logger + +from torch.fx.passes.regional_inductor import compile_fx_annotated_nodes_with_inductor + # for selective op activation checkpointing _op_sac_save_list = { torch.ops.aten.mm.default, @@ -54,17 +58,20 @@ def print_if_rank0(msg): def graph_capture_and_aot_export_joint_with_descriptors(model, inputs): assert isinstance(inputs, tuple) - with torch._dynamo.config.patch(install_free_tensors=True): + with torch._dynamo.config.patch(install_free_tensors=True), torch.fx.traceback.preserve_node_meta(): # TODO: switch to use the official graph_capture API once it is ready gm = _dynamo_graph_capture_for_export(model)(*inputs) # Restore the state dict to match the original module _restore_state_dict(model, gm) - # Validate that state dict and ordering match - _validate_state_dict_match(model, gm) + print_if_rank0("Dynamo gm:") + print_if_rank0(gm.print_readable(print_output=False)) - return aot_export_joint_with_descriptors_alone(gm, inputs) + fake_mode = gm.meta.get("fake_mode", None) + + with tracing(TracingContext(fake_mode)): + return aot_export_joint_with_descriptors_alone(gm, inputs), fake_mode def aot_export_joint_with_descriptors_alone(model, inputs): @@ -78,47 +85,24 @@ def aot_export_joint_with_descriptors_alone(model, inputs): return joint_with_descriptors -def _assign_attr( - from_obj: torch.Tensor | torch.nn.Parameter, - to_module: torch.nn.Module, - target: str, - is_buffer: bool = False, +def _clear_traced_params_buffers( + traced_module: torch.fx.GraphModule, const_keys: list[str] ) -> None: - """ - Assign attribute 'from_obj' to the qualified name 'target' on 'to_module'. - This installs empty Modules where none exist yet if they are subpaths of target. - Based on _assign_attr from torch.export.unflatten. - """ - *prefix, field = target.split(".") - - # Generate all submodules along the path - for item in prefix: - if not hasattr(to_module, item): - setattr(to_module, item, torch.nn.Module()) - to_module = getattr(to_module, item) - - # Assign the actual parameter or buffer - if is_buffer: - to_module.register_buffer(field, from_obj) - else: - to_module.register_parameter(field, from_obj) - - -def _clear_traced_params_buffers(traced_module: torch.fx.GraphModule) -> None: """Remove all parameters and buffers from traced module before restoring.""" - # Remove all parameters - for name in list(traced_module._parameters.keys()): - delattr(traced_module, name) - - # Remove all buffers - for name in list(traced_module._buffers.keys()): - delattr(traced_module, name) + for key in const_keys: + assert key in traced_module._buffers.keys() + # We don't want constants to show up as a buffer in the state dict. + # Instead they should just be a direct attribute. + buffer = getattr(traced_module, key) + torch.fx.graph_module._del_attr(traced_module, key) + setattr(traced_module, key, buffer) def _restore_state_dict( original_module: torch.nn.Module, traced_module: torch.fx.GraphModule ) -> None: """ + TODO: move this into torch.export Restores the state dict of the traced module to match the original module exactly. Preserves the original FQNs with dots, creating intermediate empty modules as needed. Ensures that the ordering of parameters/buffers matches the original module. @@ -135,9 +119,6 @@ def _restore_state_dict( # Build mapping from old names to new names for graph node updates name_mapping: dict[str, str] = {} - # Clear existing parameters and buffers from traced module - _clear_traced_params_buffers(traced_module) - # Restore parameters in the order they appear in original module for orig_name, orig_param in original_module.named_parameters( remove_duplicate=False @@ -145,22 +126,30 @@ def _restore_state_dict( if id(orig_param) in traced_params: # This param exists in traced module - restore it with original FQN traced_name, traced_param = traced_params[id(orig_param)] - _assign_attr(traced_param, traced_module, orig_name, is_buffer=False) + torch.fx.graph_module._assign_attr(traced_param, traced_module, orig_name) + torch.fx.graph_module._del_attr(traced_module, traced_name) name_mapping[traced_name] = orig_name else: # This param doesn't exist in traced module - add it - _assign_attr(orig_param, traced_module, orig_name, is_buffer=False) + torch.fx.graph_module._assign_attr(orig_param, traced_module, orig_name) # Restore buffers in the order they appear in original module for orig_name, orig_buffer in original_module.named_buffers(remove_duplicate=False): if id(orig_buffer) in traced_buffers: # This buffer exists in traced module - restore it with original FQN traced_name, traced_buffer = traced_buffers[id(orig_buffer)] - _assign_attr(traced_buffer, traced_module, orig_name, is_buffer=True) + torch.fx.graph_module._assign_attr(orig_buffer, traced_module, orig_name) name_mapping[traced_name] = orig_name + torch.fx.graph_module._del_attr(traced_module, traced_name) else: # This buffer doesn't exist in traced module - add it - _assign_attr(orig_buffer, traced_module, orig_name, is_buffer=True) + torch.fx.graph_module._assign_attr(orig_buffer, traced_module, orig_name) + + param_names = [v[0] for v in traced_params.values()] + buffer_names = [v[0] for v in traced_buffers.values()] + const_keys = set(param_names + buffer_names).difference(set(name_mapping.keys())) + + _clear_traced_params_buffers(traced_module, const_keys) # Update get_attr nodes in the graph to use the correct FQNs for node in traced_module.graph.nodes: @@ -170,65 +159,6 @@ def _restore_state_dict( traced_module.recompile() -def _validate_state_dict_match( - original_module: torch.nn.Module, traced_module: torch.fx.GraphModule -) -> None: - """ - Validates that the traced module's state dict matches the original module. - Checks that parameter/buffer names, ordering, and tensor values match. - - Raises: - AssertionError: If any validation check fails. - """ - # Check 1: Verify parameter names and ordering match - orig_param_names = list( - dict(original_module.named_parameters(remove_duplicate=False)).keys() - ) - gm_param_names = list( - dict(traced_module.named_parameters(remove_duplicate=False)).keys() - ) - assert orig_param_names == gm_param_names, ( - f"Parameter names or ordering mismatch!\n" - f"Original: {orig_param_names}\n" - f"Traced: {gm_param_names}" - ) - - # Check 2: Verify buffer names and ordering match - orig_buffer_names = list( - dict(original_module.named_buffers(remove_duplicate=False)).keys() - ) - gm_buffer_names = list( - dict(traced_module.named_buffers(remove_duplicate=False)).keys() - ) - assert orig_buffer_names == gm_buffer_names, ( - f"Buffer names or ordering mismatch!\n" - f"Original: {orig_buffer_names}\n" - f"Traced: {gm_buffer_names}" - ) - - # Check 3: Verify parameter tensors match by identity or value - for (orig_name, orig_param), (gm_name, gm_param) in zip( - original_module.named_parameters(remove_duplicate=False), - traced_module.named_parameters(remove_duplicate=False), - ): - assert ( - orig_name == gm_name - ), f"Parameter name mismatch: {orig_name} != {gm_name}" - assert id(orig_param) == id(gm_param) or torch.equal( - orig_param, gm_param - ), f"Parameter tensor mismatch for {orig_name}" - - # Check 4: Verify buffer tensors match by identity or value - for (orig_name, orig_buffer), (gm_name, gm_buffer) in zip( - original_module.named_buffers(remove_duplicate=False), - traced_module.named_buffers(remove_duplicate=False), - ): - assert orig_name == gm_name, f"Buffer name mismatch: {orig_name} != {gm_name}" - assert id(orig_buffer) == id(gm_buffer) or torch.equal( - orig_buffer, gm_buffer - ), f"Buffer tensor mismatch for {orig_name}" - - def parallelize_llama( model: nn.Module, parallel_dims: ParallelDims, @@ -378,10 +308,6 @@ def forward(self, *args, **kwargs): # HACK: doing graph capture on the fly, we should do it AOT if self.joint_graph_module is None: - # needed to avoid having fwd_rng_state in the fw_gm inp - # this doesn't work! - # torch._functorch.config.graphsafe_rng_functionalization = False - # first time, we need to initialize the runner self.joint_graph_module = joint_graph_builder( self.inner, *dt_args, **kwargs @@ -400,24 +326,35 @@ def joint_graph_builder(model, *inputs, **kwargs): assert isinstance(input, DTensor) assert not kwargs - # get fw/bw graphs - joint_with_descriptors = graph_capture_and_aot_export_joint_with_descriptors( + joint_with_descriptors, fake_mode = graph_capture_and_aot_export_joint_with_descriptors( model, inputs ) + # verify user annotation show up in the graph + for node in joint_with_descriptors.graph_module.graph.nodes: + if node.target in {torch.ops.higher_order.flex_attention, torch.ops.higher_order.flex_attention_backward}: + if "custom" not in node.meta: + # this is currently failing, as backward nodes are missing the annotation + # raise RuntimeError(f"node {node} is not annotated with custom metadata, seeing node.meta: {node.meta}") + pass + def compiler(gm: torch.fx.GraphModule, example_inputs): print_if_rank0("Before compiler:") print_if_rank0(gm.print_readable(print_output=False)) - gm = schedule_overlap_bucketing(gm) + # gm = schedule_overlap_bucketing(gm) + + # RuntimeError: Cannot access data pointer of Tensor that doesn't have storage P1985731155 + # gm = compile_fx_annotated_nodes_with_inductor(gm, example_inputs) print_if_rank0("After compiler:") print_if_rank0(gm.print_readable(print_output=False)) return gm - fn = aot_compile_joint_with_descriptors( - joint_with_descriptors, fw_compiler=compiler, bw_compiler=compiler - ) + with tracing(TracingContext(fake_mode)): + fn = aot_compile_joint_with_descriptors( + joint_with_descriptors, fw_compiler=compiler, bw_compiler=compiler + ) def wrapper_fn(args): input = [ diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index f66361a6d..44018386c 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -85,8 +85,9 @@ def forward( v: torch.Tensor, scale: float | None = None, ) -> torch.Tensor: - block_mask = FlexAttention.block_masks[self.mask_key] - return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) + with torch.fx.traceback.annotate({"compile_with_inductor": "flex_attention"}): + block_mask = FlexAttention.block_masks[self.mask_key] + return FlexAttention.flex_attn(q, k, v, block_mask=block_mask, scale=scale) @staticmethod def _get_causal_mask_mod() -> _mask_mod_signature: