Skip to content

Commit 855265d

Browse files
committed
scatter adding test cases for scatter.value and scatter.src
1 parent 517e8bc commit 855265d

File tree

4 files changed

+209
-64
lines changed

4 files changed

+209
-64
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -708,7 +708,7 @@ def aten_ops_scatter_value(
708708
name: str,
709709
) -> Union[TRTTensor, Sequence[TRTTensor]]:
710710
return impl.select.scatter_value(
711-
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
711+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
712712
)
713713

714714

@@ -721,19 +721,7 @@ def aten_ops_scatter_src(
721721
name: str,
722722
) -> Union[TRTTensor, Sequence[TRTTensor]]:
723723
return impl.select.scatter_src(
724-
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
725-
)
726-
727-
728-
def aten_ops_select(
729-
ctx: ConversionContext,
730-
target: Target,
731-
args: Tuple[Argument, ...],
732-
kwargs: Dict[str, Argument],
733-
name: str,
734-
) -> Union[TRTTensor, Sequence[TRTTensor]]:
735-
return impl.select.select(
736-
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2]
724+
ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3]
737725
)
738726

739727

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
99
from torch_tensorrt.dynamo.conversion.converter_utils import (
1010
broadcastable,
11+
cast_trt_tensor,
1112
get_positive_dim,
1213
get_trt_tensor,
1314
to_numpy,
@@ -19,6 +20,7 @@
1920
set_layer_name,
2021
)
2122
from torch_tensorrt.fx.types import Shape, TRTTensor
23+
from torch_tensorrt.fx.utils import Frameworks, unified_dtype_converter
2224

2325
import tensorrt as trt
2426

@@ -399,8 +401,8 @@ def scatter_value(
399401
source_ir: Optional[SourceIR],
400402
name: str,
401403
input: TRTTensor,
402-
dim: Shape,
403-
index: Shape,
404+
dim: int,
405+
index: Union[TRTTensor, np.ndarray, torch.Tensor],
404406
value: float,
405407
) -> TRTTensor:
406408
if not isinstance(input, TRTTensor):
@@ -410,26 +412,34 @@ def scatter_value(
410412
)
411413
input_shape = input.shape
412414
index_shape = index.shape
415+
index_shape_list = list(index.shape)
416+
if not (isinstance(index, TRTTensor)):
417+
index = get_trt_tensor(ctx, index, f"_index_tensor")
413418
if len(input_shape) != len(index_shape):
414419
raise RuntimeError(f"The no of dimensions of input and index should be equal")
415-
ranks = len(input_shape)
416-
dim = get_positive_dim(cast(int, dim), ranks)
420+
dim = get_positive_dim(dim, len(input_shape))
417421
dynamic_shape = has_dynamic_shape(input.shape)
418422
if dynamic_shape:
419423
# Check whether slice target dim is dynamic shape dim
420424
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
421425

422-
input_dims = len(input.shape)
426+
input_dims = len(input_shape)
423427
for i in range(0, input_dims):
424-
if index[i] >= input.shape[i]:
428+
if i != dim and (index_shape[i] >= input.shape[i]):
425429
raise RuntimeError(
426-
f"cannot have index greater than the dimension length! {input.shape[dim]}"
430+
f"cannot have index size greater than the input size along dimension {dim}"
427431
)
428-
value_tensor = value * torch.ones(index.shape)
432+
433+
value_tensor = get_trt_tensor(
434+
ctx, value * torch.ones(index_shape_list), name + "_value_tensor"
435+
)
436+
value_tensor = cast_trt_tensor(
437+
ctx, value_tensor, input.dtype, name + "_cast_value_tensor"
438+
)
429439
scatter_layer = ctx.net.add_scatter(
430-
input, index, value_tensor, trt.tensorrt.ScatterModekELEMENT
440+
input, index, value_tensor, trt.ScatterMode.ELEMENT
431441
)
432-
scatter_layer.set_axis(dim)
442+
scatter_layer.axis = dim
433443
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
434444
out = scatter_layer.get_output(0)
435445
return out
@@ -453,6 +463,8 @@ def scatter_src(
453463
input_shape = input.shape
454464
index_shape = index.shape
455465
src_shape = src.shape
466+
if not (isinstance(index, TRTTensor)):
467+
index = get_trt_tensor(ctx, index, f"_index_tensor")
456468
if len(input_shape) != len(index_shape):
457469
raise RuntimeError(f"The no of dimensions of input and index should be equal")
458470
if len(index_shape) != len(src_shape):
@@ -466,14 +478,23 @@ def scatter_src(
466478
assert input.shape[dim] != -1, "Can't scatter on negative shape dimension!"
467479

468480
for i in range(0, input_dims):
469-
if index[i] >= input.shape[i]:
481+
if i != dim and (index_shape[i] >= input.shape[i]):
470482
raise RuntimeError(
471-
f"cannot have index greater than the dimension length! {input.shape[dim]}"
483+
f"cannot have index size greater than the input size along dimension {dim}"
472484
)
485+
input_dtype = input.dtype
486+
# required for cases where src is a constant
487+
src_dtype = unified_dtype_converter(src.dtype, Frameworks.TRT)
488+
if input_dtype != src_dtype:
489+
raise RuntimeError(f"The type of input and src should be made")
490+
src_tensor = src
491+
if not (isinstance(src, TRTTensor)):
492+
src_tensor = get_trt_tensor(ctx, src, name + "_src_tensor")
493+
473494
scatter_layer = ctx.net.add_scatter(
474-
input, index, src, trt.tensorrt.ScatterModekELEMENT
495+
input, index, src_tensor, trt.ScatterMode.ELEMENT
475496
)
476-
scatter_layer.set_axis(dim)
497+
scatter_layer.axis = dim
477498
set_layer_name(scatter_layer, target, name + "_scatter_layer", source_ir)
478499
out = scatter_layer.get_output(0)
479500
return out

tests/py/dynamo/conversion/harness.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
# type: ignore
2-
1+
import copy
32
import logging
43
import time
54
import unittest
@@ -14,6 +13,9 @@
1413
# Use interpreter, input spec, and test case from fx_ts_compat to test Dynamo Converter Registry
1514
from torch_tensorrt.dynamo.conversion import TRTInterpreter
1615
from torch_tensorrt.dynamo.conversion._conversion import infer_module_output_dtypes
16+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import (
17+
DYNAMO_CONVERTERS as CONVERTERS,
18+
)
1719
from torch_tensorrt.dynamo.lowering import apply_lowering_passes
1820
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
1921

@@ -50,15 +52,17 @@ def setUp(self):
5052
def run_test(
5153
self,
5254
mod,
53-
inputs,
55+
fx_inputs,
56+
trt_interpreter_inputs,
5457
interpreter,
5558
rtol,
5659
atol,
5760
check_dtype=True,
5861
):
5962
with torch.no_grad():
6063
cuda_inputs = []
61-
for i in inputs:
64+
cuda_fx_inputs = []
65+
for i in trt_interpreter_inputs:
6266
cuda_inputs.append(i.cuda())
6367

6468
mod.eval()
@@ -72,8 +76,7 @@ def run_test(
7276
interpreter_result.output_names,
7377
)
7478

75-
mod = mod.cuda()
76-
ref_outputs = mod(*cuda_inputs)
79+
ref_outputs = mod(*fx_inputs)
7780

7881
torch.cuda.synchronize()
7982
start_event = torch.cuda.Event(enable_timing=True)
@@ -242,26 +245,35 @@ def run_test(
242245
debug=True,
243246
)
244247

245-
input_specs = [Input.from_tensor(i) for i in inputs]
246-
247-
output_dtypes = None
248-
if check_dtype:
249-
output_dtypes = infer_module_output_dtypes(
250-
mod,
251-
input_specs,
252-
compilation_settings.device,
253-
truncate_long_and_double=compilation_settings.truncate_long_and_double,
254-
)
248+
num_inputs = len(inputs)
249+
trt_inputs = inputs
250+
for num_input in range(num_inputs):
251+
input = inputs[num_input]
252+
if input.dtype in (torch.int64, torch.float64):
253+
dtype_32bit = (
254+
torch.int32 if (input.dtype == torch.int64) else torch.int64
255+
)
256+
# should we modify graph here to insert clone nodes?
257+
# ideally not required
258+
trt_inputs = (
259+
list(trt_inputs[:num_input])
260+
+ [
261+
input.to(dtype_32bit),
262+
]
263+
+ list(trt_inputs[num_input + 1 :])
264+
)
255265

256266
interp = TRTInterpreter(
257267
mod,
258-
input_specs,
268+
Input.from_tensors(trt_inputs),
259269
output_dtypes=output_dtypes,
260270
compilation_settings=compilation_settings,
261271
)
272+
262273
super().run_test(
263274
mod,
264275
inputs,
276+
trt_inputs,
265277
interp,
266278
rtol,
267279
atol,

0 commit comments

Comments
 (0)