@@ -530,7 +530,7 @@ def forward(self, x, src, dim, start=None, end=None, step=1):
530
530
"torch_compile" ,
531
531
inputs ,
532
532
min_block_size = 1 ,
533
- truncate_long_and_double = True ,
533
+ truncate_double = True ,
534
534
pass_through_build_failures = True ,
535
535
)
536
536
optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
@@ -593,7 +593,7 @@ def forward(self, x, src, dim, start, end, step):
593
593
"torch_compile" ,
594
594
inputs ,
595
595
min_block_size = 1 ,
596
- truncate_long_and_double = True ,
596
+ truncate_double = True ,
597
597
pass_through_build_failures = True ,
598
598
)
599
599
optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
@@ -663,7 +663,7 @@ def forward(self, x, src, dim, start, end, step):
663
663
"torch_compile" ,
664
664
inputs ,
665
665
min_block_size = 1 ,
666
- truncate_long_and_double = True ,
666
+ truncate_double = True ,
667
667
pass_through_build_failures = True ,
668
668
)
669
669
optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
@@ -679,6 +679,195 @@ def forward(self, x, src, dim, start, end, step):
679
679
f"Slice_scatter TRT outputs don't match with the original model." ,
680
680
)
681
681
682
+ def test_lowering_select_scatter_dimZero_module (self ):
683
+ class selectScatter (torch .nn .Module ):
684
+ def __init__ (self , * args , ** kwargs ) -> None :
685
+ super ().__init__ (* args , ** kwargs )
686
+
687
+ def forward (self , x , src , dim , index ):
688
+ y = torch .ops .aten .select_scatter .default (x , src , dim , index )
689
+ return y
690
+
691
+ # Operations expected to be removed in the traced graph after decompositions
692
+ expected_ops = {torch .ops .aten .scatter .src , torch .ops .aten .unsqueeze .default }
693
+ unexpected_ops = {
694
+ torch .ops .aten .select_scatter .default ,
695
+ torch .ops .aten .slice_scatter .default ,
696
+ }
697
+
698
+ inputs = [torch .zeros (2 , 2 ).cuda (), torch .ones (2 ).cuda (), 0 , 0 ]
699
+
700
+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
701
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
702
+ fx_graph ,
703
+ inputs ,
704
+ expected_ops = expected_ops ,
705
+ unexpected_ops = unexpected_ops ,
706
+ min_block_size = 1 ,
707
+ )
708
+
709
+ self .assertEqual (
710
+ len (unexpected_ops_seen ),
711
+ 0 ,
712
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
713
+ )
714
+
715
+ self .assertEqual (
716
+ len (expected_ops_unseen ),
717
+ 0 ,
718
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
719
+ )
720
+
721
+ torch ._dynamo .reset ()
722
+
723
+ # Validate that the results between Torch and Torch-TRT are similar
724
+ optimized_model = torch_tensorrt .compile (
725
+ fx_graph ,
726
+ "torch_compile" ,
727
+ inputs ,
728
+ min_block_size = 1 ,
729
+ truncate_and_double = True ,
730
+ pass_through_build_failures = True ,
731
+ )
732
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
733
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
734
+
735
+ max_diff = float (
736
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
737
+ )
738
+ self .assertAlmostEqual (
739
+ max_diff ,
740
+ 0 ,
741
+ DECIMALS_OF_AGREEMENT ,
742
+ f"Select_scatter TRT outputs don't match with the original model." ,
743
+ )
744
+
745
+ def test_lowering_select_scatter_dimOne_module (self ):
746
+ class selectScatter (torch .nn .Module ):
747
+ def __init__ (self , * args , ** kwargs ) -> None :
748
+ super ().__init__ (* args , ** kwargs )
749
+
750
+ def forward (self , x , src , dim , index ):
751
+ y = torch .ops .aten .select_scatter .default (x , src , dim , index )
752
+ return y
753
+
754
+ # Operations expected to be removed in the traced graph after decompositions
755
+ expected_ops = {torch .ops .aten .scatter .src , torch .ops .aten .unsqueeze .default }
756
+ unexpected_ops = {
757
+ torch .ops .aten .select_scatter .default ,
758
+ torch .ops .aten .slice_scatter .default ,
759
+ }
760
+
761
+ inputs = [torch .zeros (2 , 2 ).cuda (), torch .ones (2 ).cuda (), 1 , 0 ]
762
+
763
+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
764
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
765
+ fx_graph ,
766
+ inputs ,
767
+ expected_ops = expected_ops ,
768
+ unexpected_ops = unexpected_ops ,
769
+ min_block_size = 1 ,
770
+ )
771
+
772
+ self .assertEqual (
773
+ len (unexpected_ops_seen ),
774
+ 0 ,
775
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
776
+ )
777
+
778
+ self .assertEqual (
779
+ len (expected_ops_unseen ),
780
+ 0 ,
781
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
782
+ )
783
+
784
+ torch ._dynamo .reset ()
785
+
786
+ # Validate that the results between Torch and Torch-TRT are similar
787
+ optimized_model = torch_tensorrt .compile (
788
+ fx_graph ,
789
+ "torch_compile" ,
790
+ inputs ,
791
+ min_block_size = 1 ,
792
+ truncate_double = True ,
793
+ pass_through_build_failures = True ,
794
+ )
795
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
796
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
797
+
798
+ max_diff = float (
799
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
800
+ )
801
+ self .assertAlmostEqual (
802
+ max_diff ,
803
+ 0 ,
804
+ DECIMALS_OF_AGREEMENT ,
805
+ f"Select_scatter TRT outputs don't match with the original model." ,
806
+ )
807
+
808
+ def test_lowering_select_scatter_multidimension_module (self ):
809
+ class selectScatter (torch .nn .Module ):
810
+ def __init__ (self , * args , ** kwargs ) -> None :
811
+ super ().__init__ (* args , ** kwargs )
812
+
813
+ def forward (self , x , src , dim , index ):
814
+ y = torch .ops .aten .select_scatter .default (x , src , dim , index )
815
+ return y
816
+
817
+ # Operations expected to be removed in the traced graph after decompositions
818
+ expected_ops = {torch .ops .aten .scatter .src , torch .ops .aten .unsqueeze .default }
819
+ unexpected_ops = {
820
+ torch .ops .aten .select_scatter .default ,
821
+ torch .ops .aten .slice_scatter .default ,
822
+ }
823
+
824
+ inputs = [torch .zeros (2 , 3 , 4 ).cuda (), torch .ones (2 , 4 ).cuda (), 1 , 0 ]
825
+
826
+ fx_graph = torch .fx .symbolic_trace (selectScatter ())
827
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
828
+ fx_graph ,
829
+ inputs ,
830
+ expected_ops = expected_ops ,
831
+ unexpected_ops = unexpected_ops ,
832
+ min_block_size = 1 ,
833
+ )
834
+
835
+ self .assertEqual (
836
+ len (unexpected_ops_seen ),
837
+ 0 ,
838
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
839
+ )
840
+
841
+ self .assertEqual (
842
+ len (expected_ops_unseen ),
843
+ 0 ,
844
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
845
+ )
846
+
847
+ torch ._dynamo .reset ()
848
+
849
+ # Validate that the results between Torch and Torch-TRT are similar
850
+ optimized_model = torch_tensorrt .compile (
851
+ fx_graph ,
852
+ "torch_compile" ,
853
+ inputs ,
854
+ min_block_size = 1 ,
855
+ truncate_double = True ,
856
+ pass_through_build_failures = True ,
857
+ )
858
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
859
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
860
+
861
+ max_diff = float (
862
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
863
+ )
864
+ self .assertAlmostEqual (
865
+ max_diff ,
866
+ 0 ,
867
+ DECIMALS_OF_AGREEMENT ,
868
+ f"Select_scatter TRT outputs don't match with the original model." ,
869
+ )
870
+
682
871
683
872
if __name__ == "__main__" :
684
873
run_tests ()
0 commit comments