@@ -484,6 +484,201 @@ def forward(self, x):
484
484
f"The optimized model results shape and torch model results shape should be equal in empty_like" ,
485
485
)
486
486
487
+ def test_lowering_slice_scatter_dimOne_module (self ):
488
+ class sliceScatter (torch .nn .Module ):
489
+ def __init__ (self , * args , ** kwargs ) -> None :
490
+ super ().__init__ (* args , ** kwargs )
491
+
492
+ def forward (self , x , src , dim , start = None , end = None , step = 1 ):
493
+ y = torch .ops .aten .slice_scatter (x , src , dim , start , end , step )
494
+ return y
495
+
496
+ # Operations expected to be removed in the traced graph after decompositions
497
+ expected_ops = {
498
+ torch .ops .aten .scatter .src ,
499
+ }
500
+ unexpected_ops = {torch .ops .aten .select_scatter }
501
+
502
+ inputs = [torch .zeros (8 , 8 ).cuda (), torch .ones (8 , 2 ).cuda (), 1 , 6 , None , 1 ]
503
+
504
+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
505
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
506
+ fx_graph ,
507
+ inputs ,
508
+ expected_ops = expected_ops ,
509
+ unexpected_ops = unexpected_ops ,
510
+ min_block_size = 1 ,
511
+ )
512
+
513
+ self .assertEqual (
514
+ len (unexpected_ops_seen ),
515
+ 0 ,
516
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
517
+ )
518
+
519
+ self .assertEqual (
520
+ len (expected_ops_unseen ),
521
+ 0 ,
522
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
523
+ )
524
+
525
+ torch ._dynamo .reset ()
526
+
527
+ # Validate that the results between Torch and Torch-TRT are similar
528
+ optimized_model = torch_tensorrt .compile (
529
+ fx_graph ,
530
+ "torch_compile" ,
531
+ inputs ,
532
+ min_block_size = 1 ,
533
+ truncate_long_and_double = True ,
534
+ pass_through_build_failures = True ,
535
+ )
536
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
537
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
538
+
539
+ max_diff = float (
540
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
541
+ )
542
+ self .assertAlmostEqual (
543
+ max_diff ,
544
+ 0 ,
545
+ DECIMALS_OF_AGREEMENT ,
546
+ f"Slice_scatter TRT outputs don't match with the original model." ,
547
+ )
548
+
549
+ def test_lowering_slice_scatter_dimZero_StepTwo_module (self ):
550
+ class sliceScatter (torch .nn .Module ):
551
+ def __init__ (self , * args , ** kwargs ) -> None :
552
+ super ().__init__ (* args , ** kwargs )
553
+
554
+ def forward (self , x , src , dim , start , end , step ):
555
+ y = torch .ops .aten .slice_scatter .default (x , src , dim , start , end , step )
556
+ return y
557
+
558
+ # Operations expected to be removed in the traced graph after decompositions
559
+ expected_ops = {
560
+ torch .ops .aten .scatter .src ,
561
+ }
562
+ unexpected_ops = {torch .ops .aten .slice_scatter }
563
+
564
+ inputs = [torch .zeros (8 , 8 ).cuda (), torch .ones (2 , 8 ).cuda (), 0 , 2 , 6 , 2 ]
565
+
566
+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
567
+
568
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
569
+ fx_graph ,
570
+ inputs ,
571
+ expected_ops = expected_ops ,
572
+ unexpected_ops = unexpected_ops ,
573
+ min_block_size = 1 ,
574
+ )
575
+
576
+ self .assertEqual (
577
+ len (unexpected_ops_seen ),
578
+ 0 ,
579
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
580
+ )
581
+
582
+ self .assertEqual (
583
+ len (expected_ops_unseen ),
584
+ 0 ,
585
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
586
+ )
587
+
588
+ torch ._dynamo .reset ()
589
+
590
+ # Validate that the results between Torch and Torch-TRT are similar
591
+ optimized_model = torch_tensorrt .compile (
592
+ fx_graph ,
593
+ "torch_compile" ,
594
+ inputs ,
595
+ min_block_size = 1 ,
596
+ truncate_long_and_double = True ,
597
+ pass_through_build_failures = True ,
598
+ )
599
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
600
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
601
+
602
+ max_diff = float (
603
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
604
+ )
605
+ self .assertAlmostEqual (
606
+ max_diff ,
607
+ 0 ,
608
+ DECIMALS_OF_AGREEMENT ,
609
+ f"Slice_scatter TRT outputs don't match with the original model." ,
610
+ )
611
+
612
+ def test_lowering_slice_scatter_dimOne_3d_module (self ):
613
+ class sliceScatter (torch .nn .Module ):
614
+ def __init__ (self , * args , ** kwargs ) -> None :
615
+ super ().__init__ (* args , ** kwargs )
616
+
617
+ def forward (self , x , src , dim , start , end , step ):
618
+ y = torch .ops .aten .slice_scatter .default (x , src , dim , start , end , step )
619
+ return y
620
+
621
+ # Operations expected to be removed in the traced graph after decompositions
622
+ expected_ops = {
623
+ torch .ops .aten .scatter .src ,
624
+ }
625
+ unexpected_ops = {torch .ops .aten .slice_scatter }
626
+
627
+ inputs = [
628
+ torch .zeros (8 , 8 , 8 ).cuda (),
629
+ torch .ones (8 , 2 , 8 ).cuda (),
630
+ 1 ,
631
+ 6 ,
632
+ None ,
633
+ 1 ,
634
+ ]
635
+
636
+ fx_graph = torch .fx .symbolic_trace (sliceScatter ())
637
+
638
+ unexpected_ops_seen , expected_ops_unseen = lower_graph_testing (
639
+ fx_graph ,
640
+ inputs ,
641
+ expected_ops = expected_ops ,
642
+ unexpected_ops = unexpected_ops ,
643
+ min_block_size = 1 ,
644
+ )
645
+
646
+ self .assertEqual (
647
+ len (unexpected_ops_seen ),
648
+ 0 ,
649
+ f"The following unexpected ops were encountered: { unexpected_ops_seen } " ,
650
+ )
651
+
652
+ self .assertEqual (
653
+ len (expected_ops_unseen ),
654
+ 0 ,
655
+ f"The following expected ops were not encountered: { expected_ops_unseen } " ,
656
+ )
657
+
658
+ torch ._dynamo .reset ()
659
+
660
+ # Validate that the results between Torch and Torch-TRT are similar
661
+ optimized_model = torch_tensorrt .compile (
662
+ fx_graph ,
663
+ "torch_compile" ,
664
+ inputs ,
665
+ min_block_size = 1 ,
666
+ truncate_long_and_double = True ,
667
+ pass_through_build_failures = True ,
668
+ )
669
+ optimized_model_results = optimized_model (* inputs ).detach ().cpu ()
670
+ torch_model_results = fx_graph (* inputs ).detach ().cpu ()
671
+
672
+ max_diff = float (
673
+ torch .max (torch .abs (optimized_model_results - torch_model_results ))
674
+ )
675
+ self .assertAlmostEqual (
676
+ max_diff ,
677
+ 0 ,
678
+ DECIMALS_OF_AGREEMENT ,
679
+ f"Slice_scatter TRT outputs don't match with the original model." ,
680
+ )
681
+
487
682
488
683
if __name__ == "__main__" :
489
684
run_tests ()
0 commit comments