diff --git a/core/runtime/execute_engine.cpp b/core/runtime/execute_engine.cpp index f0ec653219..c4b34cb218 100644 --- a/core/runtime/execute_engine.cpp +++ b/core/runtime/execute_engine.cpp @@ -136,7 +136,7 @@ std::vector execute_engine(std::vector inputs, c10::intr TORCHTRT_CHECK( inputs[i].dtype() == expected_type, "Expected input tensors to have type " << expected_type << ", found type " << inputs[i].dtype()); - auto dims = core::util::toDimsPad(inputs[i].sizes(), 1); + auto dims = core::util::toDims(inputs[i].sizes()); auto shape = core::util::toVec(dims); LOG_DEBUG("Input Name: " << name << " Shape: " << dims); compiled_engine->exec_ctx->setInputShape(name.c_str(), dims); diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 796b8c6253..92afefbb92 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -32,11 +32,11 @@ class _ShapeMode(Enum): shape: Optional[ Tuple[int, ...] | Dict[str, Tuple[int, ...]] ] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }`` - dtype: _enums.dtype = ( # type: ignore[name-defined] + dtype: _enums.dtype = ( _enums.dtype.unknown ) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32) _explicit_set_dtype: bool = False - format: _enums.TensorFormat = ( # type: ignore[name-defined] + format: _enums.TensorFormat = ( _enums.TensorFormat.contiguous ) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW) @@ -208,7 +208,7 @@ def _supported_input_size_type(input_size: Any) -> bool: return False @staticmethod - def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined] + def _parse_dtype(dtype: Any) -> _enums.dtype: if isinstance(dtype, torch.dtype): if dtype == torch.long: return _enums.dtype.long @@ -236,7 +236,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined] ) @staticmethod - def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: # type: ignore[name-defined] + def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype: if dtype == _enums.dtype.long: return torch.long elif dtype == _enums.dtype.int32: @@ -255,7 +255,7 @@ def is_trt_dtype(self) -> bool: return bool(self.dtype != _enums.dtype.long) @staticmethod - def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defined] + def _parse_format(format: Any) -> _enums.TensorFormat: if isinstance(format, torch.memory_format): if format == torch.contiguous_format: return _enums.TensorFormat.contiguous @@ -337,9 +337,9 @@ def from_tensor( A Input object. """ if not ( - t.is_contiguous(memory_format=torch.contiguous_format) + disable_memory_format_check + or t.is_contiguous(memory_format=torch.contiguous_format) or t.is_contiguous(memory_format=torch.channels_last) - or disable_memory_format_check ): raise ValueError( "Tensor does not have a supported memory format, supported formats are contiguous or channel_last" @@ -347,8 +347,8 @@ def from_tensor( frmt = ( torch.contiguous_format if ( - t.is_contiguous(memory_format=torch.contiguous_format) - or disable_memory_format_check + disable_memory_format_check + or t.is_contiguous(memory_format=torch.contiguous_format) ) else torch.channels_last ) diff --git a/tests/py/dynamo/backend/test_specialized_models.py b/tests/py/dynamo/backend/test_specialized_models.py index ed9fc35a59..edaa538363 100644 --- a/tests/py/dynamo/backend/test_specialized_models.py +++ b/tests/py/dynamo/backend/test_specialized_models.py @@ -1,7 +1,7 @@ -from utils import lower_graph_testing -from torch.testing._internal.common_utils import run_tests, TestCase import torch import torch_tensorrt +from torch.testing._internal.common_utils import TestCase, run_tests +from utils import lower_graph_testing class TestFakeTensors(TestCase): @@ -118,5 +118,43 @@ def forward(self, x): torch._dynamo.reset() +class Test0DTensors(TestCase): + def test_0D_input(self): + class Tensor0DInput(torch.nn.Module): + def forward(self, x): + return x * 7 + + inputs = [ + torch.tensor( + 3, + ) + .cuda() + .int(), + ] + + fx_graph = torch.fx.symbolic_trace(Tensor0DInput()) + + # Validate that the results between Torch and Torch-TRT are similar + optimized_model = torch_tensorrt.compile( + fx_graph, + "torch_compile", + inputs, + min_block_size=1, + pass_through_build_failures=True, + ) + optimized_model_results = optimized_model(*inputs).detach().cpu() + torch_model_results = fx_graph(*inputs).detach().cpu() + + max_diff = float( + torch.max(torch.abs(optimized_model_results - torch_model_results)) + ) + self.assertAlmostEqual( + max_diff, + 0, + msg=f"0D-Tensor TRT outputs don't match with the original model.", + ) + torch._dynamo.reset() + + if __name__ == "__main__": run_tests()