diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index eb9924f2b9..a29cee509d 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -410,7 +410,9 @@ def output(self, target, args, kwargs): if output_bool: output.dtype = trt.bool elif self.output_dtypes is not None: - output.dtype = torch_dtype_to_trt(self.output_dtypes[i]) + output.dtype = unified_dtype_converter( + self.output_dtypes[i], Frameworks.TRT + ) elif self.output_fp16 and output.dtype == trt.float32: output.dtype = trt.float16 self._output_names.append(name)