@@ -309,7 +309,7 @@ def __init__(
309
309
context_type : Type [TrainPipelineContext ] = TrainPipelineContext ,
310
310
pipeline_preproc : bool = False ,
311
311
custom_model_fwd : Optional [
312
- Callable [[In ], Tuple [torch .Tensor , List [ torch . Tensor ] ]]
312
+ Callable [[Optional [ In ]] , Tuple [torch .Tensor , Out ]]
313
313
] = None ,
314
314
) -> None :
315
315
self ._model = model
@@ -363,6 +363,10 @@ def __init__(
363
363
self ._dataloader_exhausted : bool = False
364
364
self ._context_type : Type [TrainPipelineContext ] = context_type
365
365
366
+ self ._model_fwd : Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]] = (
367
+ custom_model_fwd if custom_model_fwd else model
368
+ )
369
+
366
370
# DEPRECATED FIELDS
367
371
self ._batch_i : Optional [In ] = None
368
372
self ._batch_ip1 : Optional [In ] = None
@@ -480,9 +484,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
480
484
481
485
# forward
482
486
with record_function ("## forward ##" ):
483
- losses , output = cast (
484
- Tuple [torch .Tensor , Out ], self ._model (self .batches [0 ])
485
- )
487
+ losses , output = self ._model_fwd (self .batches [0 ])
486
488
487
489
if len (self .batches ) >= 2 :
488
490
self .wait_sparse_data_dist (self .contexts [1 ])
@@ -715,7 +717,7 @@ def __init__(
715
717
stash_gradients : bool = False ,
716
718
pipeline_preproc : bool = False ,
717
719
custom_model_fwd : Optional [
718
- Callable [[In ], Tuple [torch .Tensor , List [ torch . Tensor ] ]]
720
+ Callable [[Optional [ In ]] , Tuple [torch .Tensor , Out ]]
719
721
] = None ,
720
722
) -> None :
721
723
super ().__init__ (
@@ -726,6 +728,7 @@ def __init__(
726
728
apply_jit = apply_jit ,
727
729
context_type = EmbeddingTrainPipelineContext ,
728
730
pipeline_preproc = pipeline_preproc ,
731
+ custom_model_fwd = custom_model_fwd ,
729
732
)
730
733
self ._start_batch = start_batch
731
734
self ._stash_gradients = stash_gradients
@@ -749,9 +752,6 @@ def __init__(
749
752
self ._embedding_odd_streams : List [Optional [torch .Stream ]] = []
750
753
self ._embedding_even_streams : List [Optional [torch .Stream ]] = []
751
754
self ._gradients : Dict [str , torch .Tensor ] = {}
752
- self ._model_fwd : Union [
753
- torch .nn .Module , Callable [[In ], Tuple [torch .Tensor , List [torch .Tensor ]]]
754
- ] = (custom_model_fwd if custom_model_fwd is not None else model )
755
755
756
756
def _grad_swap (self ) -> None :
757
757
for name , param in self ._model .named_parameters ():
@@ -890,7 +890,7 @@ def _mlp_forward(
890
890
_wait_for_events (
891
891
batch , context , torch .get_device_module (self ._device ).current_stream ()
892
892
)
893
- return cast ( Tuple [ torch . Tensor , Out ], self ._model_fwd (batch ) )
893
+ return self ._model_fwd (batch )
894
894
895
895
def embedding_backward (self , context : EmbeddingTrainPipelineContext ) -> None :
896
896
default_stream = torch .get_device_module (self ._device ).current_stream ()
@@ -1017,6 +1017,10 @@ def __init__(
1017
1017
device : torch .device ,
1018
1018
execute_all_batches : bool = True ,
1019
1019
apply_jit : bool = False ,
1020
+ pipeline_preproc : bool = False ,
1021
+ custom_model_fwd : Optional [
1022
+ Callable [[Optional [In ]], Tuple [torch .Tensor , Out ]]
1023
+ ] = None ,
1020
1024
) -> None :
1021
1025
super ().__init__ (
1022
1026
model = model ,
@@ -1025,6 +1029,8 @@ def __init__(
1025
1029
execute_all_batches = execute_all_batches ,
1026
1030
apply_jit = apply_jit ,
1027
1031
context_type = PrefetchTrainPipelineContext ,
1032
+ pipeline_preproc = pipeline_preproc ,
1033
+ custom_model_fwd = custom_model_fwd ,
1028
1034
)
1029
1035
self ._context = PrefetchTrainPipelineContext (version = 0 )
1030
1036
self ._prefetch_stream : Optional [torch .Stream ] = (
@@ -1081,7 +1087,7 @@ def progress(self, dataloader_iter: Iterator[In]) -> Out:
1081
1087
self ._wait_sparse_data_dist ()
1082
1088
# forward
1083
1089
with record_function ("## forward ##" ):
1084
- losses , output = cast ( Tuple [ torch . Tensor , Out ], self ._model (self ._batch_i ) )
1090
+ losses , output = self ._model_fwd (self ._batch_i )
1085
1091
1086
1092
self ._prefetch (self ._batch_ip1 )
1087
1093
0 commit comments