Skip to content

Commit ee8330a

Browse files
committed
select_scatter decomp
Changing lowering of select_scatter select_scatter changes select_scatter changes Test case for select_scatter removing assertion adding select_scatter decomp lowering ops in test implement select_scatter using slice_scatter adding test case linting commit fix
1 parent 6152607 commit ee8330a

File tree

2 files changed

+205
-3
lines changed

2 files changed

+205
-3
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,19 @@ def slice_scatter_decomposition(
213213
return output_tensor
214214

215215

216+
@register_torch_trt_decomposition(
217+
torch.ops.aten.slice_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS
218+
)
219+
def select_scatter_decomposition(
220+
input_tensor: torch.Tensor,
221+
src_tensor: torch.Tensor,
222+
dim: int,
223+
index: int,
224+
) -> torch.Tensor:
225+
src_tensor = torch.unsqueeze(src_tensor, dim)
226+
return torch.slice_scatter(input_tensor, src_tensor, dim, index, index + 1, 1)
227+
228+
216229
def get_decompositions(
217230
enable_experimental_decompositions: bool = False,
218231
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 192 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -530,7 +530,7 @@ def forward(self, x, src, dim, start=None, end=None, step=1):
530530
"torch_compile",
531531
inputs,
532532
min_block_size=1,
533-
truncate_long_and_double=True,
533+
truncate_double=True,
534534
pass_through_build_failures=True,
535535
)
536536
optimized_model_results = optimized_model(*inputs).detach().cpu()
@@ -593,7 +593,7 @@ def forward(self, x, src, dim, start, end, step):
593593
"torch_compile",
594594
inputs,
595595
min_block_size=1,
596-
truncate_long_and_double=True,
596+
truncate_double=True,
597597
pass_through_build_failures=True,
598598
)
599599
optimized_model_results = optimized_model(*inputs).detach().cpu()
@@ -663,7 +663,7 @@ def forward(self, x, src, dim, start, end, step):
663663
"torch_compile",
664664
inputs,
665665
min_block_size=1,
666-
truncate_long_and_double=True,
666+
truncate_double=True,
667667
pass_through_build_failures=True,
668668
)
669669
optimized_model_results = optimized_model(*inputs).detach().cpu()
@@ -679,6 +679,195 @@ def forward(self, x, src, dim, start, end, step):
679679
f"Slice_scatter TRT outputs don't match with the original model.",
680680
)
681681

682+
def test_lowering_select_scatter_dimZero_module(self):
683+
class selectScatter(torch.nn.Module):
684+
def __init__(self, *args, **kwargs) -> None:
685+
super().__init__(*args, **kwargs)
686+
687+
def forward(self, x, src, dim, index):
688+
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
689+
return y
690+
691+
# Operations expected to be removed in the traced graph after decompositions
692+
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
693+
unexpected_ops = {
694+
torch.ops.aten.select_scatter.default,
695+
torch.ops.aten.slice_scatter.default,
696+
}
697+
698+
inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 0, 0]
699+
700+
fx_graph = torch.fx.symbolic_trace(selectScatter())
701+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
702+
fx_graph,
703+
inputs,
704+
expected_ops=expected_ops,
705+
unexpected_ops=unexpected_ops,
706+
min_block_size=1,
707+
)
708+
709+
self.assertEqual(
710+
len(unexpected_ops_seen),
711+
0,
712+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
713+
)
714+
715+
self.assertEqual(
716+
len(expected_ops_unseen),
717+
0,
718+
f"The following expected ops were not encountered: {expected_ops_unseen}",
719+
)
720+
721+
torch._dynamo.reset()
722+
723+
# Validate that the results between Torch and Torch-TRT are similar
724+
optimized_model = torch_tensorrt.compile(
725+
fx_graph,
726+
"torch_compile",
727+
inputs,
728+
min_block_size=1,
729+
truncate_and_double=True,
730+
pass_through_build_failures=True,
731+
)
732+
optimized_model_results = optimized_model(*inputs).detach().cpu()
733+
torch_model_results = fx_graph(*inputs).detach().cpu()
734+
735+
max_diff = float(
736+
torch.max(torch.abs(optimized_model_results - torch_model_results))
737+
)
738+
self.assertAlmostEqual(
739+
max_diff,
740+
0,
741+
DECIMALS_OF_AGREEMENT,
742+
f"Select_scatter TRT outputs don't match with the original model.",
743+
)
744+
745+
def test_lowering_select_scatter_dimOne_module(self):
746+
class selectScatter(torch.nn.Module):
747+
def __init__(self, *args, **kwargs) -> None:
748+
super().__init__(*args, **kwargs)
749+
750+
def forward(self, x, src, dim, index):
751+
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
752+
return y
753+
754+
# Operations expected to be removed in the traced graph after decompositions
755+
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
756+
unexpected_ops = {
757+
torch.ops.aten.select_scatter.default,
758+
torch.ops.aten.slice_scatter.default,
759+
}
760+
761+
inputs = [torch.zeros(2, 2).cuda(), torch.ones(2).cuda(), 1, 0]
762+
763+
fx_graph = torch.fx.symbolic_trace(selectScatter())
764+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
765+
fx_graph,
766+
inputs,
767+
expected_ops=expected_ops,
768+
unexpected_ops=unexpected_ops,
769+
min_block_size=1,
770+
)
771+
772+
self.assertEqual(
773+
len(unexpected_ops_seen),
774+
0,
775+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
776+
)
777+
778+
self.assertEqual(
779+
len(expected_ops_unseen),
780+
0,
781+
f"The following expected ops were not encountered: {expected_ops_unseen}",
782+
)
783+
784+
torch._dynamo.reset()
785+
786+
# Validate that the results between Torch and Torch-TRT are similar
787+
optimized_model = torch_tensorrt.compile(
788+
fx_graph,
789+
"torch_compile",
790+
inputs,
791+
min_block_size=1,
792+
truncate_double=True,
793+
pass_through_build_failures=True,
794+
)
795+
optimized_model_results = optimized_model(*inputs).detach().cpu()
796+
torch_model_results = fx_graph(*inputs).detach().cpu()
797+
798+
max_diff = float(
799+
torch.max(torch.abs(optimized_model_results - torch_model_results))
800+
)
801+
self.assertAlmostEqual(
802+
max_diff,
803+
0,
804+
DECIMALS_OF_AGREEMENT,
805+
f"Select_scatter TRT outputs don't match with the original model.",
806+
)
807+
808+
def test_lowering_select_scatter_multidimension_module(self):
809+
class selectScatter(torch.nn.Module):
810+
def __init__(self, *args, **kwargs) -> None:
811+
super().__init__(*args, **kwargs)
812+
813+
def forward(self, x, src, dim, index):
814+
y = torch.ops.aten.select_scatter.default(x, src, dim, index)
815+
return y
816+
817+
# Operations expected to be removed in the traced graph after decompositions
818+
expected_ops = {torch.ops.aten.scatter.src, torch.ops.aten.unsqueeze.default}
819+
unexpected_ops = {
820+
torch.ops.aten.select_scatter.default,
821+
torch.ops.aten.slice_scatter.default,
822+
}
823+
824+
inputs = [torch.zeros(2, 3, 4).cuda(), torch.ones(2, 4).cuda(), 1, 0]
825+
826+
fx_graph = torch.fx.symbolic_trace(selectScatter())
827+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
828+
fx_graph,
829+
inputs,
830+
expected_ops=expected_ops,
831+
unexpected_ops=unexpected_ops,
832+
min_block_size=1,
833+
)
834+
835+
self.assertEqual(
836+
len(unexpected_ops_seen),
837+
0,
838+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
839+
)
840+
841+
self.assertEqual(
842+
len(expected_ops_unseen),
843+
0,
844+
f"The following expected ops were not encountered: {expected_ops_unseen}",
845+
)
846+
847+
torch._dynamo.reset()
848+
849+
# Validate that the results between Torch and Torch-TRT are similar
850+
optimized_model = torch_tensorrt.compile(
851+
fx_graph,
852+
"torch_compile",
853+
inputs,
854+
min_block_size=1,
855+
truncate_double=True,
856+
pass_through_build_failures=True,
857+
)
858+
optimized_model_results = optimized_model(*inputs).detach().cpu()
859+
torch_model_results = fx_graph(*inputs).detach().cpu()
860+
861+
max_diff = float(
862+
torch.max(torch.abs(optimized_model_results - torch_model_results))
863+
)
864+
self.assertAlmostEqual(
865+
max_diff,
866+
0,
867+
DECIMALS_OF_AGREEMENT,
868+
f"Select_scatter TRT outputs don't match with the original model.",
869+
)
870+
682871

683872
if __name__ == "__main__":
684873
run_tests()

0 commit comments

Comments
 (0)