From 31cf883c8b08733a65a88ed1f7b7bb81837e3af2 Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Mon, 11 Sep 2023 16:45:18 -0700 Subject: [PATCH] fix: Move aten.neg test case --- py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py | 5 +++++ tests/py/dynamo/{converters => conversion}/test_neg_aten.py | 5 +++-- 2 files changed, 8 insertions(+), 2 deletions(-) rename tests/py/dynamo/{converters => conversion}/test_neg_aten.py (93%) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py index f5f6309657..a91efac621 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py @@ -393,6 +393,11 @@ def neg( name: str, input_val: TRTTensor, ) -> TRTTensor: + if (isinstance(input_val, TRTTensor)) and ( + input_val.dtype == trt.int8 or input_val.dtype == trt.int32 + ): + input_val = cast_trt_tensor(network, input_val, trt.float32, name) + return convert_unary( network, target, source_ir, name, trt.UnaryOperation.NEG, input_val ) diff --git a/tests/py/dynamo/converters/test_neg_aten.py b/tests/py/dynamo/conversion/test_neg_aten.py similarity index 93% rename from tests/py/dynamo/converters/test_neg_aten.py rename to tests/py/dynamo/conversion/test_neg_aten.py index d5d805f9c2..bcb95b4172 100644 --- a/tests/py/dynamo/converters/test_neg_aten.py +++ b/tests/py/dynamo/conversion/test_neg_aten.py @@ -3,7 +3,8 @@ from parameterized import parameterized from torch.testing._internal.common_utils import run_tests from torch_tensorrt import Input -from torch_tensorrt.dynamo.test_utils import DispatchTestCase + +from .harness import DispatchTestCase class TestNegConverter(DispatchTestCase): @@ -43,8 +44,8 @@ def forward(self, input): self.run_test( neg(), inputs, - output_dtypes=[torch.int32], expected_ops={torch.ops.aten.neg.default}, + check_dtype=False, )