Skip to content

Commit 975ce31

Browse files
committed
scatter adding test cases for scatter.value and scatter.src
1 parent a4e506d commit 975ce31

File tree

4 files changed

+210
-51
lines changed

4 files changed

+210
-51
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,7 @@ def aten_ops_scatter_value(
706706
name: str,
707707
) -> Union[TRTTensor, Sequence[TRTTensor]]:
708708
return impl.select.scatter_value(
709-
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
709+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
710710
)
711711

712712

@@ -719,19 +719,7 @@ def aten_ops_scatter_src(
719719
name: str,
720720
) -> Union[TRTTensor, Sequence[TRTTensor]]:
721721
return impl.select.scatter_src(
722-
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
723-
)
724-
725-
726-
def aten_ops_select(
727-
ctx: ConversionContext,
728-
target: Target,
729-
args: Tuple[Argument, ...],
730-
kwargs: Dict[str, Argument],
731-
name: str,
732-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
733-
return impl.select.select(
734-
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
722+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
735723
)
736724

737725

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
1010
from torch_tensorrt.dynamo.conversion.converter_utils import (
1111
broadcastable,
12+
cast_trt_tensor,
1213
get_positive_dim,
1314
get_trt_tensor,
1415
to_numpy,
@@ -20,6 +21,7 @@
2021
set_layer_name,
2122
)
2223
from torch_tensorrt.fx.types import Shape, TRTTensor
24+
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
2325

2426
_LOGGER: logging.Logger = logging.getLogger(__name__)
2527

@@ -378,8 +380,8 @@ def scatter_value(
378380
source_ir: Optional[SourceIR],
379381
name: str,
380382
input: TRTTensor,
381-
dim: Shape,
382-
index: Shape,
383+
dim: int,
384+
index: Union[TRTTensor, np.ndarray, torch.Tensor],
383385
value: float,
384386
) -> TRTTensor:
385387
if not isinstance(input, TRTTensor):
@@ -389,26 +391,34 @@ def scatter_value(
389391
)
390392
input_shape = input.shape
391393
index_shape = index.shape
394+
index_shape_list = list(index.shape)
395+
if not (isinstance(index, TRTTensor)):
396+
index = get_trt_tensor(ctx, index, f"_index_tensor")
392397
if len(input_shape) != len(index_shape):
393398
raise RuntimeError(f"The no of dimensions of input and index should be equal")
394-
ranks = len(input_shape)
395-
dim = get_positive_dim(cast(int, dim), ranks)
399+
dim = get_positive_dim(dim, len(input_shape))
396400
dynamic_shape = has_dynamic_shape(input.shape)
397401
if dynamic_shape:
398402
# Check whether slice target dim is dynamic shape dim
399403
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
400404

401-
input_dims = len(input.shape)
405+
input_dims = len(input_shape)
402406
for i in range(0, input_dims):
403-
if index[i] >= input.shape[i]:
407+
if i != dim and (index_shape[i] >= input.shape[i]):
404408
raise RuntimeError(
405-
f"cannot have index greater than the dimension length! {input.shape[dim]}"
409+
f"cannot have index size greater than the input size along dimension {dim}"
406410
)
407-
value_tensor = value * torch.ones(index.shape)
411+
412+
value_tensor = get_trt_tensor(
413+
ctx, value * torch.ones(index_shape_list), name + "_value_tensor"
414+
)
415+
value_tensor = cast_trt_tensor(
416+
ctx, value_tensor, input.dtype, name + "_cast_value_tensor"
417+
)
408418
scatter_layer = ctx.net.add_scatter(
409-
input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT
419+
input, index, value_tensor, trt.ScatterMode.ELEMENT
410420
)
411-
scatter_layer.set_axis(dim)
421+
scatter_layer.axis = dim
412422
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
413423
out = scatter_layer.get_output(0)
414424
return out
@@ -432,6 +442,8 @@ def scatter_src(
432442
input_shape = input.shape
433443
index_shape = index.shape
434444
src_shape = src.shape
445+
if not (isinstance(index, TRTTensor)):
446+
index = get_trt_tensor(ctx, index, f"_index_tensor")
435447
if len(input_shape) != len(index_shape):
436448
raise RuntimeError(f"The no of dimensions of input and index should be equal")
437449
if len(index_shape) != len(src_shape):
@@ -445,14 +457,23 @@ def scatter_src(
445457
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
446458

447459
for i in range(0, input_dims):
448-
if index[i] >= input.shape[i]:
460+
if i != dim and (index_shape[i] >= input.shape[i]):
449461
raise RuntimeError(
450-
f"cannot have index greater than the dimension length! {input.shape[dim]}"
462+
f"cannot have index size greater than the input size along dimension {dim}"
451463
)
464+
input_dtype = input.dtype
465+
# required for cases where src is a constant
466+
src_dtype = unified_dtype_converter(src.dtype, Frameworks.TRT)
467+
if input_dtype != src_dtype:
468+
raise RuntimeError(f"The type of input and src should be made")
469+
src_tensor = src
470+
if not (isinstance(src, TRTTensor)):
471+
src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor")
472+
452473
scatter_layer = ctx.net.add_scatter(
453-
input, index, src, trt.tensorrt.ScatterModekELEMENT
474+
input, index, src_tensor, trt.ScatterMode.ELEMENT
454475
)
455-
scatter_layer.set_axis(dim)
476+
scatter_layer.axis = dim
456477
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
457478
out = scatter_layer.get_output(0)
458479
return out

tests/py/dynamo/conversion/harness.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import copy
12
import logging
23
import time
34
import unittest
@@ -10,6 +11,9 @@
1011

1112
# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry
1213
from torch_tensorrt.dynamo.conversion import TRTInterpreter
14+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
15+
DYNAMO_CONVERTERS as CONVERTERS,
16+
)
1317
from torch_tensorrt.dynamo.lowering import apply_lowering_passes
1418
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
1519

@@ -46,15 +50,17 @@ def setUp(self):
4650
def run_test(
4751
self,
4852
mod,
49-
inputs,
53+
fx_inputs,
54+
trt_interpreter_inputs,
5055
interpreter,
5156
rtol,
5257
atol,
5358
check_dtype=True,
5459
):
5560
with torch.no_grad():
5661
cuda_inputs = []
57-
for i in inputs:
62+
cuda_fx_inputs = []
63+
for i in trt_interpreter_inputs:
5864
cuda_inputs.append(i.cuda())
5965

6066
mod.eval()
@@ -68,7 +74,7 @@ def run_test(
6874
interpreter_result.output_names,
6975
)
7076

71-
ref_outputs = mod(*inputs)
77+
ref_outputs = mod(*fx_inputs)
7278

7379
torch.cuda.synchronize()
7480
start_event = torch.cuda.Event(enable_timing=True)
@@ -237,15 +243,35 @@ def run_test(
237243
precision=precision, truncate_long_and_double=True
238244
)
239245

246+
num_inputs = len(inputs)
247+
trt_inputs = inputs
248+
for num_input in range(num_inputs):
249+
input = inputs[num_input]
250+
if input.dtype in (torch.int64, torch.float64):
251+
dtype_32bit = (
252+
torch.int32 if (input.dtype == torch.int64) else torch.int64
253+
)
254+
# should we modify graph here to insert clone nodes?
255+
# ideally not required
256+
trt_inputs = (
257+
list(trt_inputs[:num_input])
258+
+ [
259+
input.to(dtype_32bit),
260+
]
261+
+ list(trt_inputs[num_input + 1 :])
262+
)
263+
240264
interp = TRTInterpreter(
241265
mod,
242-
Input.from_tensors(inputs),
266+
Input.from_tensors(trt_inputs),
243267
output_dtypes=output_dtypes,
244268
compilation_settings=compilation_settings,
245269
)
270+
246271
super().run_test(
247272
mod,
248273
inputs,
274+
trt_inputs,
249275
interp,
250276
rtol,
251277
atol,

0 commit comments

Comments
 (0)