Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchrec/distributed/embedding_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,9 @@ def purge(self) -> None:


class CommOpGradientScaling(torch.autograd.Function):
# user override: inline autograd.Function is safe to trace since only tensor mutations / no global state
_compiled_autograd_should_lift = False

@staticmethod
# pyre-ignore
def forward(
Expand Down
1 change: 1 addition & 0 deletions torchrec/distributed/train_pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
TrainPipelineBase, # noqa
TrainPipelinePT2, # noqa
TrainPipelineSparseDist, # noqa
TrainPipelineSparseDistCompAutograd, # noqa
)
from torchrec.distributed.train_pipeline.utils import ( # noqa
_override_input_dist_forwards, # noqa
Expand Down
43 changes: 33 additions & 10 deletions torchrec/distributed/train_pipeline/tests/test_train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from hypothesis import given, settings, strategies as st, Verbosity
from torch import nn, optim
from torch._dynamo.testing import reduce_to_scalar_loss
from torch._dynamo.utils import counters
from torchrec.distributed import DistributedModelParallel
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
Expand Down Expand Up @@ -53,6 +54,7 @@
TrainPipelinePT2,
TrainPipelineSemiSync,
TrainPipelineSparseDist,
TrainPipelineSparseDistCompAutograd,
)
from torchrec.distributed.train_pipeline.utils import (
DataLoadingThread,
Expand Down Expand Up @@ -393,7 +395,7 @@ def forward(self, model_input) -> Tuple[torch.Tensor, torch.Tensor]:
sharded_sparse_arch_pipeline.parameters(), lr=0.1
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
sharded_sparse_arch_pipeline,
optimizer_pipeline,
self.device,
Expand Down Expand Up @@ -441,7 +443,7 @@ def _setup_pipeline(
dict(in_backward_optimizer_filter(distributed_model.named_parameters())),
lambda params: optim.SGD(params, lr=0.1),
)
return TrainPipelineSparseDist(
return self.pipeline_class(
model=distributed_model,
optimizer=optimizer_distributed,
device=self.device,
Expand Down Expand Up @@ -508,7 +510,7 @@ def test_equal_to_non_pipelined(
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -621,7 +623,7 @@ def test_model_detach_during_train(self) -> None:
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -719,7 +721,7 @@ def test_model_detach_after_train(self) -> None:
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -862,7 +864,7 @@ def _check_output_equal(
sharded_model.state_dict(), sharded_model_pipelined.state_dict()
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -1116,7 +1118,7 @@ def test_pipeline_invalid_preproc_inputs_has_trainable_params(self) -> None:
model, self.sharding_type, self.kernel_type, self.fused_params
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -1171,7 +1173,7 @@ def test_pipeline_invalid_preproc_trainable_params_recursive(
model, self.sharding_type, self.kernel_type, self.fused_params
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -1217,7 +1219,7 @@ def test_pipeline_invalid_preproc_inputs_modify_kjt_recursive(self) -> None:
model, self.sharding_type, self.kernel_type, self.fused_params
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -1280,7 +1282,7 @@ def test_pipeline_preproc_fwd_values_cached(self) -> None:
model, self.sharding_type, self.kernel_type, self.fused_params
)

pipeline = TrainPipelineSparseDist(
pipeline = self.pipeline_class(
model=sharded_model_pipelined,
optimizer=optim_pipelined,
device=self.device,
Expand Down Expand Up @@ -2100,3 +2102,24 @@ def gpu_preproc(x: StageOut) -> StageOut:
self.assertEqual(len(pipelined_out), len(non_pipelined_outputs))
for out, ref_out in zip(pipelined_out, non_pipelined_outputs):
torch.testing.assert_close(out, ref_out)


class TrainPipelineSparseDistCompAutogradTest(TrainPipelineSparseDistTest):
def setUp(self) -> None:
super().setUp()
self.pipeline_class = TrainPipelineSparseDistCompAutograd
torch._dynamo.reset()
counters["compiled_autograd"].clear()
# Compiled Autograd don't work with Anomaly Mode
torch.autograd.set_detect_anomaly(False)

def tearDown(self) -> None:
# Every single test has two captures, one for forward and one for backward
self.assertEqual(counters["compiled_autograd"]["captures"], 2)
return super().tearDown()

@unittest.skip("Dynamo only supports FSDP with use_orig_params=True")
# pyre-ignore[56]
@given(execute_all_batches=st.booleans())
def test_pipelining_fsdp_pre_trace(self, execute_all_batches: bool) -> None:
super().test_pipelining_fsdp_pre_trace()
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TestEBCSharder,
TestSparseNN,
)
from torchrec.distributed.train_pipeline.train_pipelines import TrainPipelineSparseDist
from torchrec.distributed.types import ModuleSharder, ShardingEnv
from torchrec.modules.embedding_configs import DataType, EmbeddingBagConfig
from torchrec.test_utils import get_free_port, init_distributed_single_host
Expand Down Expand Up @@ -59,6 +60,7 @@ def setUp(self) -> None:
]

self.device = torch.device("cuda:0")
self.pipeline_class = TrainPipelineSparseDist

def tearDown(self) -> None:
super().tearDown()
Expand Down
90 changes: 89 additions & 1 deletion torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,16 @@
# pyre-strict

import abc
import contextlib
import logging
from collections import deque
from contextlib import contextmanager
from dataclasses import dataclass
from typing import (
Any,
Callable,
cast,
ContextManager,
Deque,
Dict,
Generic,
Expand All @@ -27,6 +30,7 @@
)

import torch
import torchrec.distributed.comm_ops
from torch.autograd.profiler import record_function
from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
from torchrec.distributed.model_parallel import ShardedModule
Expand Down Expand Up @@ -59,7 +63,6 @@
from torchrec.pt2.checks import is_torchdynamo_compiling
from torchrec.pt2.utils import default_pipeline_input_transformer
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
from torchrec.streamable import Multistreamable

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1506,3 +1509,88 @@ def progress(
return self.progress(dataloader_iter)

return out


class TrainPipelineSparseDistCompAutograd(TrainPipelineSparseDist[In, Out]):
"""
This pipeline clone the TrainPipelineSparseDist, but execute the progress
method within compiled autograd context.
"""

def __init__(
self,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
device: torch.device,
execute_all_batches: bool = True,
apply_jit: bool = False,
context_type: Type[TrainPipelineContext] = TrainPipelineContext,
pipeline_preproc: bool = False,
custom_model_fwd: Optional[
Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]]
] = None,
) -> None:
super().__init__(
model,
optimizer,
device,
execute_all_batches,
apply_jit,
context_type,
pipeline_preproc,
custom_model_fwd,
)

# it will check this path on model to inject configuration other than
# the default one.
self.compiled_autograd_options: Dict[str, Union[str, bool]] = getattr(
model,
"_compiled_autograd_options",
{
"backend": "inductor",
"dynamic": True,
"fullgraph": True,
},
)

torch._dynamo.config.optimize_ddp = "python_reducer"
torch._dynamo.config.inline_inbuilt_nn_modules = True
torch._dynamo.config.skip_fsdp_hooks = False
torch._functorch.config.recompute_views = True
torch._functorch.config.cse = False
torch._inductor.config.reorder_for_compute_comm_overlap = True
torch._inductor.config.reorder_for_compute_comm_overlap_passes = [
"sink_waits",
"raise_comms",
"reorder_compute_for_overlap",
]
self.initialized = False

def get_compiled_autograd_ctx(
self,
) -> ContextManager:
# this allows for pipelining
# to avoid doing a sum on None
# when the pipeline is empty
if not self.initialized:
self.initialized = True
return contextlib.nullcontext()

return torch._dynamo.compiled_autograd.enable(
# pyre-ignore
torch.compile(**self.compiled_autograd_options)
)

@contextmanager
def sync_collectives_ctx(self) -> Iterator[None]:
try:
if is_torchdynamo_compiling():
torchrec.distributed.comm_ops.set_use_sync_collectives(True)
yield
finally:
torchrec.distributed.comm_ops.set_use_sync_collectives(False)

def progress(self, dataloader_iter: Iterator[In]) -> Out:

with self.get_compiled_autograd_ctx(), self.sync_collectives_ctx():
return super().progress(dataloader_iter)
Loading