Skip to content

Commit 4870151

Browse files
sarckkfacebook-github-bot
authored andcommitted
Add custom model fwd in train pipelines
Summary: Add missing pipelline_preproc and custom_moel_fwd args. Differential Revision: D61564467
1 parent b6380be commit 4870151

File tree

2 files changed

+64
-10
lines changed

2 files changed

+64
-10
lines changed

torchrec/distributed/train_pipeline/tests/test_train_pipelines.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,54 @@ def test_multi_dataloader_pipelining(self) -> None:
825825
)
826826
)
827827

828+
# pyre-ignore
829+
@unittest.skipIf(
830+
not torch.cuda.is_available(),
831+
"Not enough GPUs, this test requires at least one GPU",
832+
)
833+
def test_custom_fwd(
834+
self,
835+
) -> None:
836+
data = self._generate_data(
837+
num_batches=4,
838+
batch_size=32,
839+
)
840+
dataloader = iter(data)
841+
842+
fused_params_pipelined = {}
843+
sharding_type = ShardingType.ROW_WISE.value
844+
kernel_type = EmbeddingComputeKernel.FUSED.value
845+
sharded_model_pipelined: torch.nn.Module
846+
847+
model = self._setup_model()
848+
849+
(
850+
sharded_model_pipelined,
851+
optim_pipelined,
852+
) = self._generate_sharded_model_and_optimizer(
853+
model, sharding_type, kernel_type, fused_params_pipelined
854+
)
855+
856+
def custom_model_fwd(
857+
input: ModelInput,
858+
) -> Tuple[torch.Tensor, torch.Tensor]:
859+
loss, pred = sharded_model_pipelined(input)
860+
batch_size = pred.size(0)
861+
return loss, pred.expand(batch_size * 2, -1)
862+
863+
pipeline = TrainPipelineSparseDist(
864+
model=sharded_model_pipelined,
865+
optimizer=optim_pipelined,
866+
device=self.device,
867+
execute_all_batches=True,
868+
custom_model_fwd=custom_model_fwd,
869+
)
870+
871+
for _ in data:
872+
# Forward + backward w/ pipelining
873+
pred_pipeline = pipeline.progress(dataloader)
874+
self.assertEqual(pred_pipeline.size(0), 64)
875+
828876

829877
class TrainPipelinePreprocTest(TrainPipelineSparseDistTestBase):
830878
def setUp(self) -> None:

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def __init__(
309309
context_type: Type[TrainPipelineContext] = TrainPipelineContext,
310310
pipeline_preproc: bool = False,
311311
custom_model_fwd: Optional[
312-
Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]]
312+
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
313313
] = None,
314314
) -> None:
315315
self._model = model
@@ -363,6 +363,10 @@ def __init__(
363363
self._dataloader_exhausted: bool = False
364364
self._context_type: Type[TrainPipelineContext] = context_type
365365

366+
self._model_fwd: Callable[[Optional[In]], Tuple[torch.Tensor, Out]] = (
367+
custom_model_fwd if custom_model_fwd else model
368+
)
369+
366370
# DEPRECATED FIELDS
367371
self._batch_i: Optional[In] = None
368372
self._batch_ip1: Optional[In] = None
@@ -480,9 +484,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
480484

481485
# forward
482486
with record_function("## forward ##"):
483-
losses, output = cast(
484-
Tuple[torch.Tensor, Out], self._model(self.batches[0])
485-
)
487+
losses, output = self._model_fwd(self.batches[0])
486488

487489
if len(self.batches) >= 2:
488490
self.wait_sparse_data_dist(self.contexts[1])
@@ -715,7 +717,7 @@ def __init__(
715717
stash_gradients: bool = False,
716718
pipeline_preproc: bool = False,
717719
custom_model_fwd: Optional[
718-
Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]]
720+
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
719721
] = None,
720722
) -> None:
721723
super().__init__(
@@ -726,6 +728,7 @@ def __init__(
726728
apply_jit=apply_jit,
727729
context_type=EmbeddingTrainPipelineContext,
728730
pipeline_preproc=pipeline_preproc,
731+
custom_model_fwd=custom_model_fwd,
729732
)
730733
self._start_batch = start_batch
731734
self._stash_gradients = stash_gradients
@@ -749,9 +752,6 @@ def __init__(
749752
self._embedding_odd_streams: List[Optional[torch.Stream]] = []
750753
self._embedding_even_streams: List[Optional[torch.Stream]] = []
751754
self._gradients: Dict[str, torch.Tensor] = {}
752-
self._model_fwd: Union[
753-
torch.nn.Module, Callable[[In], Tuple[torch.Tensor, List[torch.Tensor]]]
754-
] = (custom_model_fwd if custom_model_fwd is not None else model)
755755

756756
def _grad_swap(self) -> None:
757757
for name, param in self._model.named_parameters():
@@ -890,7 +890,7 @@ def _mlp_forward(
890890
_wait_for_events(
891891
batch, context, torch.get_device_module(self._device).current_stream()
892892
)
893-
return cast(Tuple[torch.Tensor, Out], self._model_fwd(batch))
893+
return self._model_fwd(batch)
894894

895895
def embedding_backward(self, context: EmbeddingTrainPipelineContext) -> None:
896896
default_stream = torch.get_device_module(self._device).current_stream()
@@ -1017,6 +1017,10 @@ def __init__(
10171017
device: torch.device,
10181018
execute_all_batches: bool = True,
10191019
apply_jit: bool = False,
1020+
pipeline_preproc: bool = False,
1021+
custom_model_fwd: Optional[
1022+
Callable[[Optional[In]], Tuple[torch.Tensor, Out]]
1023+
] = None,
10201024
) -> None:
10211025
super().__init__(
10221026
model=model,
@@ -1025,6 +1029,8 @@ def __init__(
10251029
execute_all_batches=execute_all_batches,
10261030
apply_jit=apply_jit,
10271031
context_type=PrefetchTrainPipelineContext,
1032+
pipeline_preproc=pipeline_preproc,
1033+
custom_model_fwd=custom_model_fwd,
10281034
)
10291035
self._context = PrefetchTrainPipelineContext(version=0)
10301036
self._prefetch_stream: Optional[torch.Stream] = (
@@ -1081,7 +1087,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
10811087
self._wait_sparse_data_dist()
10821088
# forward
10831089
with record_function("## forward ##"):
1084-
losses, output = cast(Tuple[torch.Tensor, Out], self._model(self._batch_i))
1090+
losses, output = self._model_fwd(self._batch_i)
10851091

10861092
self._prefetch(self._batch_ip1)
10871093

0 commit comments

Comments
 (0)