Skip to content

Commit 73f1158

Browse files
committed
Grid test changes
1 parent 09ffab2 commit 73f1158

File tree

3 files changed

+58
-28
lines changed

3 files changed

+58
-28
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,20 +245,37 @@ def aten_ops_fmod(
245245
return impl.elementwise.fmod(ctx, target, SourceIR.ATEN, name, args[0], args[1])
246246

247247

248-
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler.out)
249-
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_backward.out)
248+
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler)
250249
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d.out)
251250
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_2d_backward.out)
252251
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d.out)
253252
@dynamo_tensorrt_converter(torch.ops.aten.grid_sampler_3d_backward.out)
253+
@enforce_tensor_types(
254+
{
255+
0: (TRTTensor,),
256+
1: (TRTTensor,),
257+
}
258+
) # type: ignore[misc]
254259
def aten_ops_grid(
255260
ctx: ConversionContext,
256261
target: Target,
257262
args: Tuple[Argument, ...],
258263
kwargs: Dict[str, Argument],
259264
name: str,
260265
) -> Union[TRTTensor, Sequence[TRTTensor]]:
261-
return impl.grid.grid(ctx, target, SourceIR.ATEN, name, args[0], args[1], args[2], args[3], args[4])
266+
return impl.grid.grid(
267+
ctx,
268+
target,
269+
SourceIR.ATEN,
270+
name,
271+
input=args[0],
272+
grid=args[1],
273+
interpolation_mode=args[2],
274+
padding_mode=args[3],
275+
align_corners=args_bounds_check(args, 4, True),
276+
output_mask=args_bounds_check(args, 5, None),
277+
278+
)
262279

263280

264281
@dynamo_tensorrt_converter(torch.ops.aten.relu.default)
Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
1-
from typing import Optional
1+
from typing import Optional, Sequence
22

33
import torch
4+
import tensorrt as trt
45
from torch.fx.node import Target
56
from torch_tensorrt.dynamo._SourceIR import SourceIR
6-
from torch_tensorrt.dynamo.conversion.converter_utils import GridSamplerInterpolation, GridSamplerPadding
7+
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
8+
from torch_tensorrt.dynamo.conversion.converter_utils import cast_trt_tensor, GridSamplerInterpolation, GridSamplerSampling
79
from torch_tensorrt.fx.converters.converter_utils import set_layer_name
810
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
911

1012
def grid(
11-
network: TRTNetwork,
13+
ctx: ConversionContext,
1214
target: Target,
1315
source_ir: Optional[SourceIR],
1416
name: str,
@@ -17,10 +19,21 @@ def grid(
1719
interpolation_mode: int,
1820
padding_mode: int,
1921
align_corners: bool,
22+
output_mask: Optional[Sequence[bool]] = None,
2023
) -> TRTTensor:
21-
grid_layer = network.add_grid_sample(input, grid)
22-
grid_layer.interpolation_mode = GridSamplerInterpolation(interpolation_mode)
23-
grid_layer.padding_mode = GridSamplerPadding(padding_mode)
24+
grid_layer = ctx.net.add_grid_sample(input, grid)
25+
interpolation_mode_trt = GridSamplerInterpolation()
26+
grid_layer.interpolation_mode = interpolation_mode_trt(interpolation_mode)
27+
sample_mode_trt = GridSamplerSampling()
28+
grid_layer.sample_mode = sample_mode_trt(padding_mode)
2429
grid_layer.align_corners = align_corners
2530
set_layer_name(grid_layer, target, name + "_grid_layer", source_ir)
26-
return grid_layer.get_output(0)
31+
if(output_mask is None):
32+
return grid_layer.get_output(0)
33+
else:
34+
if(output_mask[0] and output_mask[1]):
35+
return (grid_layer.get_output(0), None)
36+
elif(output_mask[0]):
37+
return grid_layer.get_output(0)
38+
else:
39+
return None

tests/py/dynamo/conversion/test_grid_aten.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,32 +4,32 @@
44
from torch.testing._internal.common_utils import run_tests
55
from torch_tensorrt import Input
66
from parameterized import parameterized
7-
from .harness import DispatchTestCase
7+
from harness import DispatchTestCase
88

99
class TestGridConverter(DispatchTestCase):
1010
@parameterized.expand(
1111
[
12-
("input_grid_interpolation_nearest_sample_fill", [5,5], [5,2], 0, 0),
13-
("input_grid_interpolation_nearest_sample_clamp", [5,5], [5,2], 0, 1),
14-
("input_grid_interpolation_nearest_sample_reflect", [5,5], [5,2], 0, 2),
15-
("input_grid_interpolation_linear_sample_fill", [5,5], [5,2], 1, 0),
16-
("input_grid_interpolation_linear_sample_clamp", [5,5], [5,2], 1, 1),
17-
("input_grid_interpolation_linear_sample_reflect", [5,5], [5,2], 1, 2),
18-
("input_grid_interpolation_cubic_sample_fill", [5,5], [5,2], 2, 0),
19-
("input_grid_interpolation_cubic_sample_clamp", [5,5], [5,2], 2, 1),
20-
("input_grid_interpolation_cubic_sample_reflect", [5,5], [5,2], 2, 2),
12+
("input_grid_interpolation_nearest_sample_fill", [1,1,5,5], [1,5,2,2], 0, 0),
13+
("input_grid_interpolation_nearest_sample_clamp", [1,1,5,5], [1,5,2,2], 0, 1),
14+
("input_grid_interpolation_nearest_sample_reflect", [1,1,5,5], [1,5,2,2], 0, 2),
15+
("input_grid_interpolation_linear_sample_fill", [1,1,5,5], [1,5,2,2], 1, 0),
16+
("input_grid_interpolation_linear_sample_clamp", [1,1,5,5], [1,5,2,2], 1, 1),
17+
("input_grid_interpolation_linear_sample_reflect", [1,1,5,5], [1,5,2,2], 1, 2),
18+
("input_grid_interpolation_cubic_sample_fill", [1,1,5,5], [1,5,2,2], 2, 0),
19+
("input_grid_interpolation_cubic_sample_clamp", [1,1,5,5], [1,5,2,2], 2, 1),
20+
("input_grid_interpolation_cubic_sample_reflect", [1,1,5,5], [1,5,2,2], 2, 2),
2121
]
2222
)
23-
def test_grid(self,_, input_shape, dim_shape, interpolation, sample):
23+
def test_grid(self, _, input_shape, dim_shape, interpolation, sample):
2424
class TestModule(nn.Module):
25-
def forward(self, x):
26-
input = torch.randn(10).reshape(input_shape)
27-
grid = torch.randint(-1, 1, dim_shape)
28-
return nn.functional.grid(input, grid, interpolation, sample)
29-
30-
inputs = [torch.randn(1, 10)]
31-
self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out})
25+
def forward(self, x):
26+
grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32)
27+
return torch.ops.aten.grid_sampler(x, grid, interpolation, sample, True)
28+
inputs = [torch.randn(input_shape, dtype = torch.float32)]
29+
self.run_test(TestModule(), inputs)
3230

31+
if __name__ == "__main__":
32+
run_tests()
3333

3434

3535

0 commit comments

Comments
 (0)