From 31cf883c8b08733a65a88ed1f7b7bb81837e3af2 Mon Sep 17 00:00:00 2001
From: gs-olive <113141689+gs-olive@users.noreply.github.com>
Date: Mon, 11 Sep 2023 16:45:18 -0700
Subject: [PATCH] fix: Move aten.neg test case

---
 py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py       | 5 +++++
 tests/py/dynamo/{converters => conversion}/test_neg_aten.py | 5 +++--
 2 files changed, 8 insertions(+), 2 deletions(-)
 rename tests/py/dynamo/{converters => conversion}/test_neg_aten.py (93%)

diff --git a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
index f5f6309657..a91efac621 100644
--- a/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
+++ b/py/torch_tensorrt/dynamo/conversion/impl/unary/ops.py
@@ -393,6 +393,11 @@ def neg(
     name: str,
     input_val: TRTTensor,
 ) -> TRTTensor:
+    if (isinstance(input_val, TRTTensor)) and (
+        input_val.dtype == trt.int8 or input_val.dtype == trt.int32
+    ):
+        input_val = cast_trt_tensor(network, input_val, trt.float32, name)
+
     return convert_unary(
         network, target, source_ir, name, trt.UnaryOperation.NEG, input_val
     )
diff --git a/tests/py/dynamo/converters/test_neg_aten.py b/tests/py/dynamo/conversion/test_neg_aten.py
similarity index 93%
rename from tests/py/dynamo/converters/test_neg_aten.py
rename to tests/py/dynamo/conversion/test_neg_aten.py
index d5d805f9c2..bcb95b4172 100644
--- a/tests/py/dynamo/converters/test_neg_aten.py
+++ b/tests/py/dynamo/conversion/test_neg_aten.py
@@ -3,7 +3,8 @@
 from parameterized import parameterized
 from torch.testing._internal.common_utils import run_tests
 from torch_tensorrt import Input
-from torch_tensorrt.dynamo.test_utils import DispatchTestCase
+
+from .harness import DispatchTestCase
 
 
 class TestNegConverter(DispatchTestCase):
@@ -43,8 +44,8 @@ def forward(self, input):
         self.run_test(
             neg(),
             inputs,
-            output_dtypes=[torch.int32],
             expected_ops={torch.ops.aten.neg.default},
+            check_dtype=False,
         )