Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit ac7c95c

Browse files
committedOct 13, 2023
Grid test changes
1 parent 09ffab2 commit ac7c95c

File tree

4 files changed

+139
-49
lines changed

4 files changed

+139
-49
lines changed
 

‎py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -245,20 +245,37 @@ def aten_ops_fmod(
245245
return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1])
246246

247247

248-
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.out)
249-
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_backward.out)
248+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler)
250249
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.out)
251250
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d_backward.out)
252251
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d.out)
253252
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d_backward.out)
253+
@enforce_tensor_types(
254+
{
255+
0: (TRTTensor,),
256+
1: (TRTTensor,),
257+
}
258+
) # type: ignore[misc]
254259
def aten_ops_grid(
255260
ctx: ConversionContext,
256261
target: Target,
257262
args: Tuple[Argument, ...],
258263
kwargs: Dict[str, Argument],
259264
name: str,
260265
) -> Union[TRTTensor, Sequence[TRTTensor]]:
261-
return impl.grid.grid(ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3], args[4])
266+
return impl.grid.grid(
267+
ctx,
268+
target,
269+
SourceIR.ATEN,
270+
name,
271+
input=args[0],
272+
grid=args[1],
273+
interpolation_mode=args[2],
274+
padding_mode=args[3],
275+
align_corners=args_bounds_check(args, 4, True),
276+
output_mask=args_bounds_check(args, 5, None),
277+
278+
)
262279

263280

264281
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)

‎py/torch_tensorrt/dynamo/conversion/converter_utils.py

+21-17
Original file line numberDiff line numberDiff line change
@@ -23,32 +23,36 @@
2323

2424
_LOGGER: logging.Logger = logging.getLogger(__name__)
2525

26-
#nearesr, linear, cubc
26+
27+
# nearest, linear, cubic
2728
class GridSamplerInterpolation:
2829
def __init__(self):
2930
self.interpolator_mode = None
30-
def __call__(self, interpolator_int):
31-
if(interpolator_int == 0) :
31+
32+
def __call__(self, interpolator_int):
33+
if interpolator_int == 0:
3234
self.interpolator_mode = trt.InterpolationMode.NEAREST
33-
elif(interpolator_int == 1) :
35+
elif interpolator_int == 1:
3436
self.interpolator_mode = trt.InterpolationMode.LINEAR
35-
elif(interpolator_int == 2) :
37+
elif interpolator_int == 2:
3638
self.interpolator_mode = trt.InterpolationMode.CUBIC
3739
return self.interpolator_mode
38-
3940

40-
#zeros, border, reflection
41-
class GridSamplerPadding:
41+
42+
# zeros, border, reflection
43+
class GridSamplerSampling:
4244
def __init__(self):
43-
self.padding_mode = None
44-
def __call__(self, padding_int):
45-
if(padding_int == 0) :
46-
self.padding_mode = trt.SampleMode.kFILL
47-
elif(padding_int == 1) :
48-
self.padding_mode = trt.SampleMode.kCLAMP
49-
elif(padding_int == 2) :
50-
self.padding_mode = trt.SampleMode.kREFLECT
51-
return self.padding_mode
45+
self.sample_mode = None
46+
47+
def __call__(self, sample_int):
48+
if sample_int == 0:
49+
self.sample_mode = trt.SampleMode.FILL
50+
elif sample_int == 1:
51+
self.sample_mode = trt.SampleMode.CLAMP
52+
elif sample_int == 2:
53+
self.sample_mode = trt.SampleMode.REFLECT
54+
return self.sample_mode
55+
5256

5357
def get_node_name(node: torch.fx.Node) -> str:
5458
# nn_module_stack preserves the call stack of pytorch nn.modules
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
1-
from typing import Optional
1+
from typing import Optional, Sequence
22

3+
import tensorrt as trt
34
import torch
45
from torch.fx.node import Target
56
from torch_tensorrt.dynamo._SourceIR import SourceIR
6-
from torch_tensorrt.dynamo.conversion.converter_utils import GridSamplerInterpolation, GridSamplerPadding
7+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8+
from torch_tensorrt.dynamo.conversion.converter_utils import (
9+
GridSamplerInterpolation,
10+
GridSamplerSampling,
11+
cast_trt_tensor,
12+
)
713
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
814
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
915

16+
1017
def grid(
11-
network: TRTNetwork,
18+
ctx: ConversionContext,
1219
target: Target,
1320
source_ir: Optional[SourceIR],
1421
name: str,
@@ -17,10 +24,21 @@ def grid(
1724
interpolation_mode: int,
1825
padding_mode: int,
1926
align_corners: bool,
27+
output_mask: Optional[Sequence[bool]] = None,
2028
) -> TRTTensor:
21-
grid_layer = network.add_grid_sample(input, grid)
22-
grid_layer.interpolation_mode = GridSamplerInterpolation(interpolation_mode)
23-
grid_layer.padding_mode = GridSamplerPadding(padding_mode)
29+
grid_layer = ctx.net.add_grid_sample(input, grid)
30+
interpolation_mode_trt = GridSamplerInterpolation()
31+
grid_layer.interpolation_mode = interpolation_mode_trt(interpolation_mode)
32+
sample_mode_trt = GridSamplerSampling()
33+
grid_layer.sample_mode = sample_mode_trt(padding_mode)
2434
grid_layer.align_corners = align_corners
2535
set_layer_name(grid_layer, target, name + "_grid_layer", source_ir)
26-
return grid_layer.get_output(0)
36+
if output_mask is None:
37+
return grid_layer.get_output(0)
38+
else:
39+
if output_mask[0] and output_mask[1]:
40+
return (grid_layer.get_output(0), None)
41+
elif output_mask[0]:
42+
return grid_layer.get_output(0)
43+
else:
44+
return None
+73-22
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,89 @@
11
import pytest
22
import torch
33
import torch.nn as nn
4+
from .harness import DispatchTestCase
5+
from parameterized import parameterized
46
from torch.testing._internal.common_utils import run_tests
57
from torch_tensorrt import Input
6-
from parameterized import parameterized
7-
from .harness import DispatchTestCase
8+
89

910
class TestGridConverter(DispatchTestCase):
1011
@parameterized.expand(
1112
[
12-
("input_grid_interpolation_nearest_sample_fill", [5,5], [5,2], 0, 0),
13-
("input_grid_interpolation_nearest_sample_clamp", [5,5], [5,2], 0, 1),
14-
("input_grid_interpolation_nearest_sample_reflect", [5,5], [5,2], 0, 2),
15-
("input_grid_interpolation_linear_sample_fill", [5,5], [5,2], 1, 0),
16-
("input_grid_interpolation_linear_sample_clamp", [5,5], [5,2], 1, 1),
17-
("input_grid_interpolation_linear_sample_reflect", [5,5], [5,2], 1, 2),
18-
("input_grid_interpolation_cubic_sample_fill", [5,5], [5,2], 2, 0),
19-
("input_grid_interpolation_cubic_sample_clamp", [5,5], [5,2], 2, 1),
20-
("input_grid_interpolation_cubic_sample_reflect", [5,5], [5,2], 2, 2),
13+
(
14+
"input_grid_interpolation_nearest_sample_fill",
15+
[1, 1, 5, 5],
16+
[1, 5, 2, 2],
17+
0,
18+
0,
19+
),
20+
(
21+
"input_grid_interpolation_nearest_sample_clamp",
22+
[1, 1, 5, 5],
23+
[1, 5, 2, 2],
24+
0,
25+
1,
26+
),
27+
(
28+
"input_grid_interpolation_nearest_sample_reflect",
29+
[1, 1, 5, 5],
30+
[1, 5, 2, 2],
31+
0,
32+
2,
33+
),
34+
(
35+
"input_grid_interpolation_linear_sample_fill",
36+
[1, 1, 5, 5],
37+
[1, 5, 2, 2],
38+
1,
39+
0,
40+
),
41+
(
42+
"input_grid_interpolation_linear_sample_clamp",
43+
[1, 1, 5, 5],
44+
[1, 5, 2, 2],
45+
1,
46+
1,
47+
),
48+
(
49+
"input_grid_interpolation_linear_sample_reflect",
50+
[1, 1, 5, 5],
51+
[1, 5, 2, 2],
52+
1,
53+
2,
54+
),
55+
(
56+
"input_grid_interpolation_cubic_sample_fill",
57+
[1, 1, 5, 5],
58+
[1, 5, 2, 2],
59+
2,
60+
0,
61+
),
62+
(
63+
"input_grid_interpolation_cubic_sample_clamp",
64+
[1, 1, 5, 5],
65+
[1, 5, 2, 2],
66+
2,
67+
1,
68+
),
69+
(
70+
"input_grid_interpolation_cubic_sample_reflect",
71+
[1, 1, 5, 5],
72+
[1, 5, 2, 2],
73+
2,
74+
2,
75+
),
2176
]
2277
)
23-
def test_grid(self,_, input_shape, dim_shape, interpolation, sample):
78+
def test_grid(self, _, input_shape, dim_shape, interpolation, sample):
2479
class TestModule(nn.Module):
2580
def forward(self, x):
26-
input = torch.randn(10).reshape(input_shape)
27-
grid = torch.randint(-1, 1, dim_shape)
28-
return nn.functional.grid(input, grid, interpolation, sample)
29-
30-
inputs = [torch.randn(1, 10)]
31-
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out})
32-
33-
81+
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
82+
return torch.ops.aten.grid_sampler(x, grid, interpolation, sample, True)
3483

84+
inputs = [torch.randn(input_shape, dtype=torch.float32)]
85+
self.run_test(TestModule(), inputs)
3586

36-
3787

38-
88+
if __name__ == "__main__":
89+
run_tests()

0 commit comments

Comments
 (0)
Please sign in to comment.