Skip to content

Commit aea9d94

Browse files
committed
addressing review comments and changing test names
1 parent 855265d commit aea9d94

File tree

3 files changed

+26
-94
lines changed

3 files changed

+26
-94
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -699,28 +699,21 @@ def aten_ops_clamp(
699699
)
700700

701701

702-
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
703-
def aten_ops_scatter_value(
704-
ctx: ConversionContext,
705-
target: Target,
706-
args: Tuple[Argument, ...],
707-
kwargs: Dict[str, Argument],
708-
name: str,
709-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
710-
return impl.select.scatter_value(
711-
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
712-
)
713-
714-
702+
@enforce_tensor_types(
703+
{
704+
0: (TRTTensor,),
705+
}
706+
)
715707
@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
716-
def aten_ops_scatter_src(
708+
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
709+
def aten_ops_scatter(
717710
ctx: ConversionContext,
718711
target: Target,
719712
args: Tuple[Argument, ...],
720713
kwargs: Dict[str, Argument],
721714
name: str,
722715
) -> Union[TRTTensor, Sequence[TRTTensor]]:
723-
return impl.select.scatter_src(
716+
return impl.select.scatter(
724717
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
725718
)
726719

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 12 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -395,100 +395,38 @@ def index_select(
395395
return gather_layer.get_output(0)
396396

397397

398-
def scatter_value(
398+
def scatter(
399399
ctx: ConversionContext,
400400
target: Target,
401401
source_ir: Optional[SourceIR],
402402
name: str,
403403
input: TRTTensor,
404404
dim: int,
405405
index: Union[TRTTensor, np.ndarray, torch.Tensor],
406-
value: float,
406+
src: Union[TRTTensor, int, float],
407407
) -> TRTTensor:
408-
if not isinstance(input, TRTTensor):
409-
raise RuntimeError(
410-
f"scatter_tensor received input {input} that is not part "
411-
"of the TensorRT region!"
412-
)
413408
input_shape = input.shape
414409
index_shape = index.shape
415410
index_shape_list = list(index.shape)
416411
if not (isinstance(index, TRTTensor)):
417412
index = get_trt_tensor(ctx, index, f"_index_tensor")
418-
if len(input_shape) != len(index_shape):
419-
raise RuntimeError(f"The no of dimensions of input and index should be equal")
420413
dim = get_positive_dim(dim, len(input_shape))
421414
dynamic_shape = has_dynamic_shape(input.shape)
422415
if dynamic_shape:
423416
# Check whether slice target dim is dynamic shape dim
424417
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
425418

426-
input_dims = len(input_shape)
427-
for i in range(0, input_dims):
428-
if i != dim and (index_shape[i] >= input.shape[i]):
429-
raise RuntimeError(
430-
f"cannot have index size greater than the input size along dimension {dim}"
431-
)
432-
433-
value_tensor = get_trt_tensor(
434-
ctx, value * torch.ones(index_shape_list), name + "_value_tensor"
435-
)
436-
value_tensor = cast_trt_tensor(
437-
ctx, value_tensor, input.dtype, name + "_cast_value_tensor"
438-
)
439-
scatter_layer = ctx.net.add_scatter(
440-
input, index, value_tensor, trt.ScatterMode.ELEMENT
441-
)
442-
scatter_layer.axis = dim
443-
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
444-
out = scatter_layer.get_output(0)
445-
return out
446-
447-
448-
def scatter_src(
449-
ctx: ConversionContext,
450-
target: Target,
451-
source_ir: Optional[SourceIR],
452-
name: str,
453-
input: TRTTensor,
454-
dim: Shape,
455-
index: Shape,
456-
src: TRTTensor,
457-
) -> TRTTensor:
458-
if not isinstance(input, TRTTensor):
459-
raise RuntimeError(
460-
f"scatter_tensor received input {input} that is not part "
461-
"of the TensorRT region!"
462-
)
463-
input_shape = input.shape
464-
index_shape = index.shape
465-
src_shape = src.shape
466-
if not (isinstance(index, TRTTensor)):
467-
index = get_trt_tensor(ctx, index, f"_index_tensor")
468-
if len(input_shape) != len(index_shape):
469-
raise RuntimeError(f"The no of dimensions of input and index should be equal")
470-
if len(index_shape) != len(src_shape):
471-
raise RuntimeError(f"The no of dimensions of src and index should be equal")
472-
473-
input_dims = len(input_shape)
474-
dim = get_positive_dim(cast(int, dim), input_dims)
475-
dynamic_shape = has_dynamic_shape(input.shape)
476-
if dynamic_shape:
477-
# Check whether slice target dim is dynamic shape dim
478-
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
479-
480-
for i in range(0, input_dims):
481-
if i != dim and (index_shape[i] >= input.shape[i]):
482-
raise RuntimeError(
483-
f"cannot have index size greater than the input size along dimension {dim}"
484-
)
485-
input_dtype = input.dtype
486-
# required for cases where src is a constant
487-
src_dtype = unified_dtype_converter(src.dtype, Frameworks.TRT)
488-
if input_dtype != src_dtype:
489-
raise RuntimeError(f"The type of input and src should be made")
490419
src_tensor = src
491-
if not (isinstance(src, TRTTensor)):
420+
# scatter.value
421+
if isinstance(src, int) or isinstance(src, float):
422+
src_tensor = get_trt_tensor(
423+
ctx, src * torch.ones(index_shape_list), name + "_value_tensor"
424+
)
425+
src_tensor = cast_trt_tensor(
426+
ctx, src_tensor, input.dtype, name + "_cast_value_tensor"
427+
)
428+
# scatter.src
429+
elif not (isinstance(src, TRTTensor)):
492430
src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor")
493431

494432
scatter_layer = ctx.net.add_scatter(

tests/py/dynamo/conversion/test_scatter_aten.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import torch
2-
from harness import DispatchTestCase
32
from parameterized import parameterized
43
from torch.testing._internal.common_utils import run_tests
54
from torch_tensorrt import Input
65

6+
from .harness import DispatchTestCase
7+
78

89
class TestScatterValueConverter(DispatchTestCase):
910
@parameterized.expand(
@@ -87,25 +88,25 @@ class TestScatterSrcConverter(DispatchTestCase):
8788
@parameterized.expand(
8889
[
8990
(
90-
"scatter_zero_dim_indexOne_constant_src",
91+
"scatter_zero_dim_indexOne_src",
9192
0,
9293
torch.tensor([[0, 1, 2, 0]]),
9394
torch.tensor([[1, 2, 3, 4]], dtype=torch.int32),
9495
),
9596
(
96-
"scatter_zero_dim_indexTwo_constant_src",
97+
"scatter_zero_dim_indexTwo_src",
9798
0,
9899
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
99100
torch.tensor([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=torch.int32),
100101
),
101102
(
102-
"scatter_one_dim_indexOne_constant_src",
103+
"scatter_one_dim_indexOne_src",
103104
1,
104105
torch.tensor([[0, 1, 2, 0]]),
105106
torch.tensor([[1, 2, 3, 1]], dtype=torch.int32),
106107
),
107108
(
108-
"scatter_one_dim_indexTwo_constant_src",
109+
"scatter_one_dim_indexTwo_src",
109110
1,
110111
torch.tensor([[0, 1, 2, 0], [1, 2, 1, 1]]),
111112
torch.tensor([[1, 2, 3, 1], [5, 6, 5, 5]], dtype=torch.int32),

0 commit comments

Comments
 (0)