Skip to content

Commit 2d2d214

Browse files
committed
Addressing review comments
1 parent 0318b33 commit 2d2d214

File tree

3 files changed

+7
-19
lines changed

3 files changed

+7
-19
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -691,13 +691,14 @@ def aten_ops_clamp(
691691
)
692692

693693

694+
@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
695+
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
694696
@enforce_tensor_types(
695697
{
696698
0: (TRTTensor,),
699+
2: (TRTTensor,),
697700
}
698701
)
699-
@dynamo_tensorrt_converter(torch.ops.aten.scatter.src)
700-
@dynamo_tensorrt_converter(torch.ops.aten.scatter.value)
701702
def aten_ops_scatter(
702703
ctx: ConversionContext,
703704
target: Target,

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 4 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
set_layer_name,
2222
)
2323
from torch_tensorrt.fx.types import Shape, TRTTensor
24-
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
2524

2625
_LOGGER: logging.Logger = logging.getLogger(__name__)
2726

@@ -405,26 +404,15 @@ def scatter(
405404
) -> TRTTensor:
406405
input_shape = input.shape
407406
index_shape = index.shape
408-
index_shape_list = list(index.shape)
409-
if not (isinstance(index, TRTTensor)):
410-
if isinstance(index, torch.Tensor):
411-
if index.dtype == torch.int64:
412-
index = index.to(torch.int32)
413-
elif isinstance(index, np.ndarray):
414-
if index.dtype == np.int64:
415-
index = index.astype(np.int32)
416-
index = get_trt_tensor(ctx, index, f"_index_tensor")
407+
index_shape_list = list(index_shape)
408+
if index.dtype == trt.int64:
409+
index = cast_trt_tensor(ctx, index, trt.int32, name + "_cast_index_tensor")
417410
dim = get_positive_dim(dim, len(input_shape))
418-
dynamic_shape = has_dynamic_shape(input.shape)
419-
if dynamic_shape:
420-
# Check whether slice target dim is dynamic shape dim
421-
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
422-
423411
src_tensor = src
424412
# scatter.value
425413
if isinstance(src, int) or isinstance(src, float):
426414
src_tensor = get_trt_tensor(
427-
ctx, src * torch.ones(index_shape_list), name + "_value_tensor"
415+
ctx, src * np.ones(index_shape_list), name + "_value_tensor"
428416
)
429417
src_tensor = cast_trt_tensor(
430418
ctx, src_tensor, input.dtype, name + "_cast_value_tensor"

tests/py/dynamo/conversion/harness.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import copy
21
import logging
32
import time
43
import unittest

0 commit comments

Comments
 (0)