|
4 | 4 | from torch.testing._internal.common_utils import run_tests
|
5 | 5 | from torch_tensorrt import Input
|
6 | 6 | from parameterized import parameterized
|
7 |
| -from .harness import DispatchTestCase |
| 7 | +from harness import DispatchTestCase |
8 | 8 |
|
9 | 9 | class TestGridConverter(DispatchTestCase):
|
10 | 10 | @parameterized.expand(
|
11 | 11 | [
|
12 |
| - ("input_grid_interpolation_nearest_sample_fill", [5,5], [5,2], 0, 0), |
13 |
| - ("input_grid_interpolation_nearest_sample_clamp", [5,5], [5,2], 0, 1), |
14 |
| - ("input_grid_interpolation_nearest_sample_reflect", [5,5], [5,2], 0, 2), |
15 |
| - ("input_grid_interpolation_linear_sample_fill", [5,5], [5,2], 1, 0), |
16 |
| - ("input_grid_interpolation_linear_sample_clamp", [5,5], [5,2], 1, 1), |
17 |
| - ("input_grid_interpolation_linear_sample_reflect", [5,5], [5,2], 1, 2), |
18 |
| - ("input_grid_interpolation_cubic_sample_fill", [5,5], [5,2], 2, 0), |
19 |
| - ("input_grid_interpolation_cubic_sample_clamp", [5,5], [5,2], 2, 1), |
20 |
| - ("input_grid_interpolation_cubic_sample_reflect", [5,5], [5,2], 2, 2), |
| 12 | + ("input_grid_interpolation_nearest_sample_fill", [1,1,5,5], [1,5,2,2], 0, 0), |
| 13 | + ("input_grid_interpolation_nearest_sample_clamp", [1,1,5,5], [1,5,2,2], 0, 1), |
| 14 | + ("input_grid_interpolation_nearest_sample_reflect", [1,1,5,5], [1,5,2,2], 0, 2), |
| 15 | + ("input_grid_interpolation_linear_sample_fill", [1,1,5,5], [1,5,2,2], 1, 0), |
| 16 | + ("input_grid_interpolation_linear_sample_clamp", [1,1,5,5], [1,5,2,2], 1, 1), |
| 17 | + ("input_grid_interpolation_linear_sample_reflect", [1,1,5,5], [1,5,2,2], 1, 2), |
| 18 | + ("input_grid_interpolation_cubic_sample_fill", [1,1,5,5], [1,5,2,2], 2, 0), |
| 19 | + ("input_grid_interpolation_cubic_sample_clamp", [1,1,5,5], [1,5,2,2], 2, 1), |
| 20 | + ("input_grid_interpolation_cubic_sample_reflect", [1,1,5,5], [1,5,2,2], 2, 2), |
21 | 21 | ]
|
22 | 22 | )
|
23 |
| - def test_grid(self,_, input_shape, dim_shape, interpolation, sample): |
| 23 | + def test_grid(self, _, input_shape, dim_shape, interpolation, sample): |
24 | 24 | class TestModule(nn.Module):
|
25 |
| - def forward(self, x): |
26 |
| - input = torch.randn(10).reshape(input_shape) |
27 |
| - grid = torch.randint(-1, 1, dim_shape) |
28 |
| - return nn.functional.grid(input, grid, interpolation, sample) |
29 |
| - |
30 |
| - inputs = [torch.randn(1, 10)] |
31 |
| - self.run_test(TestModule(), inputs, expected_ops={torch.ops.aten.grid_sampler.out}) |
| 25 | + def forward(self, x): |
| 26 | + grid = torch.randint(-1, 1, dim_shape, dtype=torch.float32) |
| 27 | + return torch.ops.aten.grid_sampler(x, grid, interpolation, sample, True) |
| 28 | + inputs = [torch.randn(input_shape, dtype = torch.float32)] |
| 29 | + self.run_test(TestModule(), inputs) |
32 | 30 |
|
| 31 | +if __name__ == "__main__": |
| 32 | + run_tests() |
33 | 33 |
|
34 | 34 |
|
35 | 35 |
|
|
0 commit comments