diff --git a/torchrec/distributed/embedding_lookup.py b/torchrec/distributed/embedding_lookup.py index 22d7c52b8..f7e5846fd 100644 --- a/torchrec/distributed/embedding_lookup.py +++ b/torchrec/distributed/embedding_lookup.py @@ -355,6 +355,9 @@ def purge(self) -> None: class CommOpGradientScaling(torch.autograd.Function): + # user override: inline autograd.Function is safe to trace since only tensor mutations / no global state + _compiled_autograd_should_lift = False + @staticmethod # pyre-ignore def forward( diff --git a/torchrec/distributed/train_pipeline/__init__.py b/torchrec/distributed/train_pipeline/__init__.py index 9e9d3bd73..e49f7ff98 100644 --- a/torchrec/distributed/train_pipeline/__init__.py +++ b/torchrec/distributed/train_pipeline/__init__.py @@ -16,6 +16,7 @@ TrainPipelineBase, # noqa TrainPipelinePT2, # noqa TrainPipelineSparseDist, # noqa + TrainPipelineSparseDistCompAutograd, # noqa ) from torchrec.distributed.train_pipeline.utils import ( # noqa _override_input_dist_forwards, # noqa diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py index 77b5c4ff3..6c7878189 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines.py @@ -19,6 +19,7 @@ from hypothesis import given, settings, strategies as st, Verbosity from torch import nn, optim from torch._dynamo.testing import reduce_to_scalar_loss +from torch._dynamo.utils import counters from torchrec.distributed import DistributedModelParallel from torchrec.distributed.embedding_types import EmbeddingComputeKernel from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder @@ -53,6 +54,7 @@ TrainPipelinePT2, TrainPipelineSemiSync, TrainPipelineSparseDist, + TrainPipelineSparseDistCompAutograd, ) from torchrec.distributed.train_pipeline.utils import ( DataLoadingThread, @@ -393,7 +395,7 @@ def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]: sharded_sparse_arch_pipeline.parameters(), lr=0.1 ) - pipeline = TrainPipelineSparseDist( + pipeline = self.pipeline_class( sharded_sparse_arch_pipeline, optimizer_pipeline, self.device, @@ -441,7 +443,7 @@ def _setup_pipeline( dict(in_backward_optimizer_filter(distributed_model.named_parameters())), lambda params: optim.SGD(params, lr=0.1), ) - return TrainPipelineSparseDist( + return self.pipeline_class( model=distributed_model, optimizer=optimizer_distributed, device=self.device, @@ -508,7 +510,7 @@ def test_equal_to_non_pipelined( sharded_model.state_dict(), sharded_model_pipelined.state_dict() ) - pipeline = TrainPipelineSparseDist( + pipeline = self.pipeline_class( model=sharded_model_pipelined, optimizer=optim_pipelined, device=self.device, @@ -621,7 +623,7 @@ def test_model_detach_during_train(self) -> None: sharded_model.state_dict(), sharded_model_pipelined.state_dict() ) - pipeline = TrainPipelineSparseDist( + pipeline = self.pipeline_class( model=sharded_model_pipelined, optimizer=optim_pipelined, device=self.device, @@ -719,7 +721,7 @@ def test_model_detach_after_train(self) -> None: sharded_model.state_dict(), sharded_model_pipelined.state_dict() ) - pipeline = TrainPipelineSparseDist( + pipeline = self.pipeline_class( model=sharded_model_pipelined, optimizer=optim_pipelined, device=self.device, @@ -862,7 +864,7 @@ def _check_output_equal( sharded_model.state_dict(), sharded_model_pipelined.state_dict() ) - pipeline = TrainPipelineSparseDist( + pipeline = self.pipeline_class( model=sharded_model_pipelined, optimizer=optim_pipelined, device=self.device, @@ -1116,7 +1118,7 @@ def test_pipeline_invalid_preproc_inputs_has_trainable_params(self) -> None: model, self.sharding_type, self.kernel_type, self.fused_params ) - pipeline = TrainPipelineSparseDist( + pipeline = self.pipeline_class( model=sharded_model_pipelined, optimizer=optim_pipelined, device=self.device, @@ -1171,7 +1173,7 @@ def test_pipeline_invalid_preproc_trainable_params_recursive( model, self.sharding_type, self.kernel_type, self.fused_params ) - pipeline = TrainPipelineSparseDist( + pipeline = self.pipeline_class( model=sharded_model_pipelined, optimizer=optim_pipelined, device=self.device, @@ -1217,7 +1219,7 @@ def test_pipeline_invalid_preproc_inputs_modify_kjt_recursive(self) -> None: model, self.sharding_type, self.kernel_type, self.fused_params ) - pipeline = TrainPipelineSparseDist( + pipeline = self.pipeline_class( model=sharded_model_pipelined, optimizer=optim_pipelined, device=self.device, @@ -1280,7 +1282,7 @@ def test_pipeline_preproc_fwd_values_cached(self) -> None: model, self.sharding_type, self.kernel_type, self.fused_params ) - pipeline = TrainPipelineSparseDist( + pipeline = self.pipeline_class( model=sharded_model_pipelined, optimizer=optim_pipelined, device=self.device, @@ -2100,3 +2102,24 @@ def gpu_preproc(x: StageOut) -> StageOut: self.assertEqual(len(pipelined_out), len(non_pipelined_outputs)) for out, ref_out in zip(pipelined_out, non_pipelined_outputs): torch.testing.assert_close(out, ref_out) + + +class TrainPipelineSparseDistCompAutogradTest(TrainPipelineSparseDistTest): + def setUp(self) -> None: + super().setUp() + self.pipeline_class = TrainPipelineSparseDistCompAutograd + torch._dynamo.reset() + counters["compiled_autograd"].clear() + # Compiled Autograd don't work with Anomaly Mode + torch.autograd.set_detect_anomaly(False) + + def tearDown(self) -> None: + # Every single test has two captures, one for forward and one for backward + self.assertEqual(counters["compiled_autograd"]["captures"], 2) + return super().tearDown() + + @unittest.skip("Dynamo only supports FSDP with use_orig_params=True") + # pyre-ignore[56] + @given(execute_all_batches=st.booleans()) + def test_pipelining_fsdp_pre_trace(self, execute_all_batches: bool) -> None: + super().test_pipelining_fsdp_pre_trace() diff --git a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py index 6ca45371a..47ac23d8c 100644 --- a/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py +++ b/torchrec/distributed/train_pipeline/tests/test_train_pipelines_base.py @@ -23,6 +23,7 @@ TestEBCSharder, TestSparseNN, ) +from torchrec.distributed.train_pipeline.train_pipelines import TrainPipelineSparseDist from torchrec.distributed.types import ModuleSharder, ShardingEnv from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig from torchrec.test_utils import get_free_port, init_distributed_single_host @@ -59,6 +60,7 @@ def setUp(self) -> None: ] self.device = torch.device("cuda:0") + self.pipeline_class = TrainPipelineSparseDist def tearDown(self) -> None: super().tearDown() diff --git a/torchrec/distributed/train_pipeline/train_pipelines.py b/torchrec/distributed/train_pipeline/train_pipelines.py index a59ad72f8..ae0aecd69 100644 --- a/torchrec/distributed/train_pipeline/train_pipelines.py +++ b/torchrec/distributed/train_pipeline/train_pipelines.py @@ -8,13 +8,16 @@ # pyre-strict import abc +import contextlib import logging from collections import deque +from contextlib import contextmanager from dataclasses import dataclass from typing import ( Any, Callable, cast, + ContextManager, Deque, Dict, Generic, @@ -27,6 +30,7 @@ ) import torch +import torchrec.distributed.comm_ops from torch.autograd.profiler import record_function from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable from torchrec.distributed.model_parallel import ShardedModule @@ -59,7 +63,6 @@ from torchrec.pt2.checks import is_torchdynamo_compiling from torchrec.pt2.utils import default_pipeline_input_transformer from torchrec.sparse.jagged_tensor import KeyedJaggedTensor -from torchrec.streamable import Multistreamable logger: logging.Logger = logging.getLogger(__name__) @@ -1506,3 +1509,88 @@ def progress( return self.progress(dataloader_iter) return out + + +class TrainPipelineSparseDistCompAutograd(TrainPipelineSparseDist[In, Out]): + """ + This pipeline clone the TrainPipelineSparseDist, but execute the progress + method within compiled autograd context. + """ + + def __init__( + self, + model: torch.nn.Module, + optimizer: torch.optim.Optimizer, + device: torch.device, + execute_all_batches: bool = True, + apply_jit: bool = False, + context_type: Type[TrainPipelineContext] = TrainPipelineContext, + pipeline_preproc: bool = False, + custom_model_fwd: Optional[ + Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]] + ] = None, + ) -> None: + super().__init__( + model, + optimizer, + device, + execute_all_batches, + apply_jit, + context_type, + pipeline_preproc, + custom_model_fwd, + ) + + # it will check this path on model to inject configuration other than + # the default one. + self.compiled_autograd_options: Dict[str, Union[str, bool]] = getattr( + model, + "_compiled_autograd_options", + { + "backend": "inductor", + "dynamic": True, + "fullgraph": True, + }, + ) + + torch._dynamo.config.optimize_ddp = "python_reducer" + torch._dynamo.config.inline_inbuilt_nn_modules = True + torch._dynamo.config.skip_fsdp_hooks = False + torch._functorch.config.recompute_views = True + torch._functorch.config.cse = False + torch._inductor.config.reorder_for_compute_comm_overlap = True + torch._inductor.config.reorder_for_compute_comm_overlap_passes = [ + "sink_waits", + "raise_comms", + "reorder_compute_for_overlap", + ] + self.initialized = False + + def get_compiled_autograd_ctx( + self, + ) -> ContextManager: + # this allows for pipelining + # to avoid doing a sum on None + # when the pipeline is empty + if not self.initialized: + self.initialized = True + return contextlib.nullcontext() + + return torch._dynamo.compiled_autograd.enable( + # pyre-ignore + torch.compile(**self.compiled_autograd_options) + ) + + @contextmanager + def sync_collectives_ctx(self) -> Iterator[None]: + try: + if is_torchdynamo_compiling(): + torchrec.distributed.comm_ops.set_use_sync_collectives(True) + yield + finally: + torchrec.distributed.comm_ops.set_use_sync_collectives(False) + + def progress(self, dataloader_iter: Iterator[In]) -> Out: + + with self.get_compiled_autograd_ctx(), self.sync_collectives_ctx(): + return super().progress(dataloader_iter)