Skip to content

Commit 9f24f92

Browse files
committed
test case for more output types
1 parent 853de1f commit 9f24f92

File tree

3 files changed

+17
-8
lines changed

3 files changed

+17
-8
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -259,13 +259,18 @@ def aten_ops_neg(
259259
kwargs: Dict[str, Argument],
260260
name: str,
261261
) -> Union[TRTTensor, Sequence[TRTTensor]]:
262+
input_val = args[0]
263+
if (isinstance(input_val, TRTTensor)) and (
264+
input_val.dtype == trt.int8 or input_val.dtype == trt.int32
265+
):
266+
input_val = cast_trt_tensor(network, input_val, trt.float32, name)
262267

263268
return impl.unary.neg(
264269
network,
265270
target,
266271
SourceIR.ATEN,
267272
name,
268-
args[0],
273+
input_val,
269274
)
270275

271276

py/torch_tensorrt/dynamo/test_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def run_test(
261261
atol=1e-03,
262262
precision=torch.float,
263263
check_dtype=True,
264+
output_dtypes=None,
264265
):
265266
mod.eval()
266267
mod = self.generate_graph(mod, inputs, expected_ops, unexpected_ops, None)
@@ -272,6 +273,7 @@ def run_test(
272273
interp = TRTInterpreter(
273274
mod,
274275
Input.from_tensors(inputs),
276+
output_dtypes=output_dtypes,
275277
)
276278
super().run_test(
277279
mod,

tests/py/dynamo/converters/test_neg_aten.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -11,40 +11,42 @@ class TestNegConverter(DispatchTestCase):
1111
[
1212
("2d_dim_dtype_float", (2, 2), torch.float),
1313
("3d_dim_dtype_float", (2, 2, 2), torch.float),
14-
14+
("2d_dim_dtype_half", (2, 2), torch.half),
15+
("3d_dim_dtype_half", (2, 2, 2), torch.half),
1516
]
1617
)
1718
def test_neg_float(self, _, x, type):
1819
class neg(nn.Module):
1920
def forward(self, input):
2021
return torch.neg(input)
21-
22+
2223
inputs = [torch.randn(x, dtype=type)]
2324
self.run_test(
2425
neg(),
2526
inputs,
27+
precision=type,
2628
expected_ops={torch.ops.aten.neg.default},
2729
)
2830

2931
@parameterized.expand(
3032
[
31-
("2d_dim_dtype_int", (2, 2), torch.int32, 0, 5),
32-
("3d_dim_dtype_int", (2, 2, 2), torch.int32, 0, 5),
33+
("2d_dim_dtype_int32", (2, 2), torch.int32, 0, 5),
34+
("3d_dim_dtype_int32", (2, 2, 2), torch.int32, 0, 5),
3335
]
3436
)
35-
3637
def test_neg_int(self, _, x, type, min, max):
3738
class neg(nn.Module):
3839
def forward(self, input):
3940
return torch.neg(input)
40-
41-
inputs = [torch.randint(min, max, (x), dtype=type)]
4241

42+
inputs = [torch.randint(min, max, x, dtype=type)]
4343
self.run_test(
4444
neg(),
4545
inputs,
46+
output_dtypes=[torch.int32],
4647
expected_ops={torch.ops.aten.neg.default},
4748
)
4849

50+
4951
if __name__ == "__main__":
5052
run_tests()

0 commit comments

Comments
 (0)