Skip to content

Commit 6152607

Browse files
authored
slice_scatter decomposition (#2519)
1 parent 6cc61b4 commit 6152607

File tree

2 files changed

+234
-0
lines changed

2 files changed

+234
-0
lines changed

py/torch_tensorrt/dynamo/lowering/_decompositions.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import torch
55
from torch._decomp import register_decomposition
66
from torch._ops import OpOverload
7+
from torch_tensorrt.dynamo.conversion.converter_utils import get_positive_dim
78

89
from ._decomposition_groups import (
910
ENABLED_TORCH_DECOMPOSITIONS,
@@ -174,6 +175,44 @@ def empty_permuted_decomposition(*args, **kwargs) -> torch.Tensor:
174175
return torch.empty([empty_size[l] for l in empty_permute], **kwargs).permute(perm)
175176

176177

178+
@register_torch_trt_decomposition(
179+
torch.ops.aten.slice_scatter.default, registry=TORCH_TRT_DECOMPOSITIONS
180+
)
181+
def slice_scatter_decomposition(
182+
input_tensor: torch.Tensor,
183+
src_tensor: torch.Tensor,
184+
dim: int,
185+
start: Optional[int] = None,
186+
end: Optional[int] = None,
187+
step: Optional[int] = None,
188+
):
189+
dim_size = input_tensor.shape[dim]
190+
start = get_positive_dim(start, input_tensor.shape[dim])
191+
if end is None:
192+
end = dim_size
193+
end = get_positive_dim(end, input_tensor.shape[dim])
194+
if step is None:
195+
step = 1
196+
197+
src_dim = src_tensor.shape
198+
# step == 0 is not a valid torch case
199+
# also src_dim should be equal to slice dimension
200+
201+
if start == 0 and end == dim_size and step == 1:
202+
return src_tensor
203+
204+
cat_tensors = []
205+
index_tensor_shape = []
206+
for i, src_each_dim in enumerate(list(src_dim)):
207+
if i != dim:
208+
index_tensor_shape.append(src_each_dim)
209+
for index in range(start, end, step):
210+
cat_tensors.append(index * torch.ones(index_tensor_shape, dtype=torch.long))
211+
index_tensor = torch.stack(cat_tensors, dim).cuda()
212+
output_tensor = torch.scatter(input_tensor, dim, index_tensor, src_tensor)
213+
return output_tensor
214+
215+
177216
def get_decompositions(
178217
enable_experimental_decompositions: bool = False,
179218
) -> Dict[OpOverload, Callable[[Any], Any]]:

tests/py/dynamo/lowering/test_decompositions.py

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -484,6 +484,201 @@ def forward(self, x):
484484
f"The optimized model results shape and torch model results shape should be equal in empty_like",
485485
)
486486

487+
def test_lowering_slice_scatter_dimOne_module(self):
488+
class sliceScatter(torch.nn.Module):
489+
def __init__(self, *args, **kwargs) -> None:
490+
super().__init__(*args, **kwargs)
491+
492+
def forward(self, x, src, dim, start=None, end=None, step=1):
493+
y = torch.ops.aten.slice_scatter(x, src, dim, start, end, step)
494+
return y
495+
496+
# Operations expected to be removed in the traced graph after decompositions
497+
expected_ops = {
498+
torch.ops.aten.scatter.src,
499+
}
500+
unexpected_ops = {torch.ops.aten.select_scatter}
501+
502+
inputs = [torch.zeros(8, 8).cuda(), torch.ones(8, 2).cuda(), 1, 6, None, 1]
503+
504+
fx_graph = torch.fx.symbolic_trace(sliceScatter())
505+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
506+
fx_graph,
507+
inputs,
508+
expected_ops=expected_ops,
509+
unexpected_ops=unexpected_ops,
510+
min_block_size=1,
511+
)
512+
513+
self.assertEqual(
514+
len(unexpected_ops_seen),
515+
0,
516+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
517+
)
518+
519+
self.assertEqual(
520+
len(expected_ops_unseen),
521+
0,
522+
f"The following expected ops were not encountered: {expected_ops_unseen}",
523+
)
524+
525+
torch._dynamo.reset()
526+
527+
# Validate that the results between Torch and Torch-TRT are similar
528+
optimized_model = torch_tensorrt.compile(
529+
fx_graph,
530+
"torch_compile",
531+
inputs,
532+
min_block_size=1,
533+
truncate_long_and_double=True,
534+
pass_through_build_failures=True,
535+
)
536+
optimized_model_results = optimized_model(*inputs).detach().cpu()
537+
torch_model_results = fx_graph(*inputs).detach().cpu()
538+
539+
max_diff = float(
540+
torch.max(torch.abs(optimized_model_results - torch_model_results))
541+
)
542+
self.assertAlmostEqual(
543+
max_diff,
544+
0,
545+
DECIMALS_OF_AGREEMENT,
546+
f"Slice_scatter TRT outputs don't match with the original model.",
547+
)
548+
549+
def test_lowering_slice_scatter_dimZero_StepTwo_module(self):
550+
class sliceScatter(torch.nn.Module):
551+
def __init__(self, *args, **kwargs) -> None:
552+
super().__init__(*args, **kwargs)
553+
554+
def forward(self, x, src, dim, start, end, step):
555+
y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step)
556+
return y
557+
558+
# Operations expected to be removed in the traced graph after decompositions
559+
expected_ops = {
560+
torch.ops.aten.scatter.src,
561+
}
562+
unexpected_ops = {torch.ops.aten.slice_scatter}
563+
564+
inputs = [torch.zeros(8, 8).cuda(), torch.ones(2, 8).cuda(), 0, 2, 6, 2]
565+
566+
fx_graph = torch.fx.symbolic_trace(sliceScatter())
567+
568+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
569+
fx_graph,
570+
inputs,
571+
expected_ops=expected_ops,
572+
unexpected_ops=unexpected_ops,
573+
min_block_size=1,
574+
)
575+
576+
self.assertEqual(
577+
len(unexpected_ops_seen),
578+
0,
579+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
580+
)
581+
582+
self.assertEqual(
583+
len(expected_ops_unseen),
584+
0,
585+
f"The following expected ops were not encountered: {expected_ops_unseen}",
586+
)
587+
588+
torch._dynamo.reset()
589+
590+
# Validate that the results between Torch and Torch-TRT are similar
591+
optimized_model = torch_tensorrt.compile(
592+
fx_graph,
593+
"torch_compile",
594+
inputs,
595+
min_block_size=1,
596+
truncate_long_and_double=True,
597+
pass_through_build_failures=True,
598+
)
599+
optimized_model_results = optimized_model(*inputs).detach().cpu()
600+
torch_model_results = fx_graph(*inputs).detach().cpu()
601+
602+
max_diff = float(
603+
torch.max(torch.abs(optimized_model_results - torch_model_results))
604+
)
605+
self.assertAlmostEqual(
606+
max_diff,
607+
0,
608+
DECIMALS_OF_AGREEMENT,
609+
f"Slice_scatter TRT outputs don't match with the original model.",
610+
)
611+
612+
def test_lowering_slice_scatter_dimOne_3d_module(self):
613+
class sliceScatter(torch.nn.Module):
614+
def __init__(self, *args, **kwargs) -> None:
615+
super().__init__(*args, **kwargs)
616+
617+
def forward(self, x, src, dim, start, end, step):
618+
y = torch.ops.aten.slice_scatter.default(x, src, dim, start, end, step)
619+
return y
620+
621+
# Operations expected to be removed in the traced graph after decompositions
622+
expected_ops = {
623+
torch.ops.aten.scatter.src,
624+
}
625+
unexpected_ops = {torch.ops.aten.slice_scatter}
626+
627+
inputs = [
628+
torch.zeros(8, 8, 8).cuda(),
629+
torch.ones(8, 2, 8).cuda(),
630+
1,
631+
6,
632+
None,
633+
1,
634+
]
635+
636+
fx_graph = torch.fx.symbolic_trace(sliceScatter())
637+
638+
unexpected_ops_seen, expected_ops_unseen = lower_graph_testing(
639+
fx_graph,
640+
inputs,
641+
expected_ops=expected_ops,
642+
unexpected_ops=unexpected_ops,
643+
min_block_size=1,
644+
)
645+
646+
self.assertEqual(
647+
len(unexpected_ops_seen),
648+
0,
649+
f"The following unexpected ops were encountered: {unexpected_ops_seen}",
650+
)
651+
652+
self.assertEqual(
653+
len(expected_ops_unseen),
654+
0,
655+
f"The following expected ops were not encountered: {expected_ops_unseen}",
656+
)
657+
658+
torch._dynamo.reset()
659+
660+
# Validate that the results between Torch and Torch-TRT are similar
661+
optimized_model = torch_tensorrt.compile(
662+
fx_graph,
663+
"torch_compile",
664+
inputs,
665+
min_block_size=1,
666+
truncate_long_and_double=True,
667+
pass_through_build_failures=True,
668+
)
669+
optimized_model_results = optimized_model(*inputs).detach().cpu()
670+
torch_model_results = fx_graph(*inputs).detach().cpu()
671+
672+
max_diff = float(
673+
torch.max(torch.abs(optimized_model_results - torch_model_results))
674+
)
675+
self.assertAlmostEqual(
676+
max_diff,
677+
0,
678+
DECIMALS_OF_AGREEMENT,
679+
f"Slice_scatter TRT outputs don't match with the original model.",
680+
)
681+
487682

488683
if __name__ == "__main__":
489684
run_tests()

0 commit comments

Comments
 (0)