diff --git a/core/ir/Input.cpp b/core/ir/Input.cpp index 8c0ccbe90a..7fb5105a22 100644 --- a/core/ir/Input.cpp +++ b/core/ir/Input.cpp @@ -69,11 +69,16 @@ bool valid_input_dtype(nvinfer1::DataType dtype) { } } +bool valid_input_domain(std::vector domain) { + return (domain.size() == 2) && (domain[0] < domain[1]); +} + Input::Input( std::vector shape, at::ScalarType dtype, nvinfer1::TensorFormat format, - bool dtype_is_user_defined) { + bool dtype_is_user_defined, + std::vector tensor_domain) { if (shape.size() > 5) { LOG_WARNING("Verify that this dim size is accepted"); } @@ -93,6 +98,11 @@ Input::Input( << "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported"); this->format = format; this->dtype_is_user_defined = dtype_is_user_defined; + + TORCHTRT_CHECK( + valid_input_domain(tensor_domain), + "Unsupported tensor domain: [" << tensor_domain[0] << ", " << tensor_domain[1] << ")"); + this->tensor_domain = tensor_domain; } Input::Input( @@ -101,7 +111,8 @@ Input::Input( std::vector max_shape, at::ScalarType dtype, nvinfer1::TensorFormat format, - bool dtype_is_user_defined) { + bool dtype_is_user_defined, + std::vector tensor_domain) { if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) { LOG_WARNING("Verify that this dim size is accepted"); } @@ -146,6 +157,10 @@ Input::Input( << "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported"); this->format = format; this->dtype_is_user_defined = dtype_is_user_defined; + TORCHTRT_CHECK( + valid_input_domain(tensor_domain), + "Unsupported tensor domain: [" << tensor_domain[0] << ", " << tensor_domain[1] << ")"); + this->tensor_domain = tensor_domain; } std::ostream& operator<<(std::ostream& os, const Input& input) { diff --git a/core/ir/ir.h b/core/ir/ir.h index cb5a157a87..75317327f2 100644 --- a/core/ir/ir.h +++ b/core/ir/ir.h @@ -31,19 +31,22 @@ struct Input : torch::CustomClassHolder { std::vector shape, at::ScalarType dtype = at::kFloat, nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR, - bool dtype_is_user_defined = false); + bool dtype_is_user_defined = false, + std::vector tensor_domain = std::vector{0, 2}); Input( std::vector min_shape, std::vector opt_shape, std::vector max_shape, at::ScalarType dtype = at::kFloat, nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR, - bool dtype_is_used_defined = false); + bool dtype_is_user_defined = false, + std::vector tensor_domain = std::vector{0, 2}); friend std::ostream& operator<<(std::ostream& os, const Input& input); bool input_is_dynamic = false; bool dtype_is_user_defined = false; + std::vector tensor_domain; nvinfer1::Dims input_shape; nvinfer1::Dims min; nvinfer1::Dims max; diff --git a/core/partitioning/shape_analysis.cpp b/core/partitioning/shape_analysis.cpp index 5fc2b3fcdf..387c9d27bd 100644 --- a/core/partitioning/shape_analysis.cpp +++ b/core/partitioning/shape_analysis.cpp @@ -26,8 +26,8 @@ at::Tensor generateSingleInput( } // Initialize min and max ranges for random number selection - int LoValIncl = 0; - int HiValExcl = 2; + double LoValIncl = input.tensor_domain[0]; + double HiValExcl = input.tensor_domain[1]; auto type = at::kFloat; if (type_opt) { @@ -36,6 +36,10 @@ at::Tensor generateSingleInput( LOG_WARNING("Input type for doing shape analysis could not be determined, defaulting to F32"); } + LOG_DEBUG( + "Using the Range: [" << LoValIncl << ", " << HiValExcl + << ") as a random range for shape analysis on input with data type " << type); + // Make the value range for input tensor a uniform (float) distribution // over [LoValIncl, HiValExcl), then cast to the desired dtype auto in = ((HiValExcl - LoValIncl) * at::rand(util::toVec(input_shape), {at::kCUDA}) + LoValIncl).to(type); diff --git a/cpp/include/torch_tensorrt/torch_tensorrt.h b/cpp/include/torch_tensorrt/torch_tensorrt.h index fb0b945012..dead16f6d9 100644 --- a/cpp/include/torch_tensorrt/torch_tensorrt.h +++ b/cpp/include/torch_tensorrt/torch_tensorrt.h @@ -381,6 +381,8 @@ struct Input : torch::CustomClassHolder { DataType dtype; /// Expected tensor format for the input TensorFormat format; + /// Expected allowed domain for tensor input + std::vector tensor_domain; Input() {} /** @@ -394,6 +396,22 @@ struct Input : torch::CustomClassHolder { */ TORCHTRT_API Input(std::vector shape, TensorFormat format = TensorFormat::kContiguous); + /** + * @brief Construct a new Input spec object for static input size from + * c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments + * allow the user to configure expected input shape tensor format + * dtype (Expected data type for the input) defaults to PyTorch + * / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8) + * + * @param shape Input tensor shape + * @param tensor_domain Allowed range for tensor inputs [low, high) + * @param format Expected tensor format for the input (Defaults to contiguous) + */ + TORCHTRT_API Input( + std::vector shape, + std::vector tensor_domain, + TensorFormat format = TensorFormat::kContiguous); + /** * @brief Construct a new Input spec object for static input size from * vector, optional arguments allow the user to configure expected input shape @@ -406,6 +424,23 @@ struct Input : torch::CustomClassHolder { */ TORCHTRT_API Input(std::vector shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous); + /** + * @brief Construct a new Input spec object for static input size from + * vector, optional arguments allow the user to configure expected input shape + * tensor format + * + * @param shape Input tensor shape + * @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor + * calculation if detectable else Float32) + * @param tensor_domain Allowed range for tensor inputs [low, high) + * @param format Expected tensor format for the input (Defaults to contiguous) + */ + TORCHTRT_API Input( + std::vector shape, + DataType dtype, + std::vector tensor_domain, + TensorFormat format = TensorFormat::kContiguous); + /** * @brief Construct a new Input spec object for static input size from * c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments @@ -418,6 +453,22 @@ struct Input : torch::CustomClassHolder { */ TORCHTRT_API Input(c10::ArrayRef shape, TensorFormat format = TensorFormat::kContiguous); + /** + * @brief Construct a new Input spec object for static input size from + * c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments + * allow the user to configure expected input shape tensor format + * dtype (Expected data type for the input) defaults to PyTorch + * / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8) + * + * @param shape Input tensor shape + * @param tensor_domain Allowed range for tensor inputs [low, high) + * @param format Expected tensor format for the input (Defaults to contiguous) + */ + TORCHTRT_API Input( + c10::ArrayRef shape, + std::vector tensor_domain, + TensorFormat format = TensorFormat::kContiguous); + /** * @brief Construct a new Input spec object for static input size from * c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments @@ -430,6 +481,23 @@ struct Input : torch::CustomClassHolder { */ TORCHTRT_API Input(c10::ArrayRef shape, DataType dtype, TensorFormat format = TensorFormat::kContiguous); + /** + * @brief Construct a new Input spec object for static input size from + * c10::ArrayRef (the type produced by tensor.sizes()), vector, optional arguments + * allow the user to configure expected input shape tensor format + * + * @param shape Input tensor shape + * @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor + * calculation if detectable else Float32) + * @param tensor_domain Allowed range for tensor inputs [low, high) + * @param format Expected tensor format for the input (Defaults to contiguous) + */ + TORCHTRT_API Input( + c10::ArrayRef shape, + DataType dtype, + std::vector tensor_domain, + TensorFormat format = TensorFormat::kContiguous); + /** * @brief Construct a new Input spec object dynamic input size from * c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max @@ -446,6 +514,24 @@ struct Input : torch::CustomClassHolder { std::vector opt_shape, std::vector max_shape, TensorFormat format = TensorFormat::kContiguous); + /** + * @brief Construct a new Input spec object dynamic input size from + * c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max + * supported sizes. dtype (Expected data type for the input) defaults to PyTorch + * / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8) + * + * @param min_shape Minimum shape for input tensor + * @param opt_shape Target optimization shape for input tensor + * @param max_shape Maximum acceptible shape for input tensor + * @param tensor_domain Allowed range for tensor inputs [low, high) + * @param format Expected tensor format for the input (Defaults to contiguous) + */ + TORCHTRT_API Input( + std::vector min_shape, + std::vector opt_shape, + std::vector max_shape, + std::vector tensor_domain, + TensorFormat format = TensorFormat::kContiguous); /** * @brief Construct a new Input spec object for a dynamic input size from vectors @@ -466,6 +552,44 @@ struct Input : torch::CustomClassHolder { DataType dtype, TensorFormat format = TensorFormat::kContiguous); + /** + * @brief Construct a new Input spec object for a dynamic input size from vectors + * for minimum shape, optimal shape, and max shape supported sizes optional arguments + * allow the user to configure expected input shape tensor format + * + * @param min_shape Minimum shape for input tensor + * @param opt_shape Target optimization shape for input tensor + * @param max_shape Maximum acceptible shape for input tensor + * @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor + * calculation if detectable else Float32) + * @param tensor_domain Allowed range for tensor inputs [low, high) + * @param format Expected tensor format for the input (Defaults to contiguous) + */ + TORCHTRT_API Input( + std::vector min_shape, + std::vector opt_shape, + std::vector max_shape, + DataType dtype, + std::vector tensor_domain, + TensorFormat format = TensorFormat::kContiguous); + + /** + * @brief Construct a new Input spec object dynamic input size from + * c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max + * supported sizes. dtype (Expected data type for the input) defaults to PyTorch + * / traditional TRT convection (FP32 for FP32 only, FP16 for FP32 and FP16, FP32 for Int8) + * + * @param min_shape Minimum shape for input tensor + * @param opt_shape Target optimization shape for input tensor + * @param max_shape Maximum acceptible shape for input tensor + * @param format Expected tensor format for the input (Defaults to contiguous) + */ + TORCHTRT_API Input( + c10::ArrayRef min_shape, + c10::ArrayRef opt_shape, + c10::ArrayRef max_shape, + TensorFormat format = TensorFormat::kContiguous); + /** * @brief Construct a new Input spec object dynamic input size from * c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max @@ -475,12 +599,33 @@ struct Input : torch::CustomClassHolder { * @param min_shape Minimum shape for input tensor * @param opt_shape Target optimization shape for input tensor * @param max_shape Maximum acceptible shape for input tensor + * @param tensor_domain Allowed range for tensor inputs [low, high) * @param format Expected tensor format for the input (Defaults to contiguous) */ TORCHTRT_API Input( c10::ArrayRef min_shape, c10::ArrayRef opt_shape, c10::ArrayRef max_shape, + std::vector tensor_domain, + TensorFormat format = TensorFormat::kContiguous); + + /** + * @brief Construct a new Input spec object dynamic input size from + * c10::ArrayRef (the type produced by tensor.sizes()) for min, opt, and max + * supported sizes + * + * @param min_shape Minimum shape for input tensor + * @param opt_shape Target optimization shape for input tensor + * @param max_shape Maximum acceptible shape for input tensor + * @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor + * calculation if detectable else Float32) + * @param format Expected tensor format for the input (Defaults to contiguous) + */ + TORCHTRT_API Input( + c10::ArrayRef min_shape, + c10::ArrayRef opt_shape, + c10::ArrayRef max_shape, + DataType dtype, TensorFormat format = TensorFormat::kContiguous); /** @@ -493,6 +638,7 @@ struct Input : torch::CustomClassHolder { * @param max_shape Maximum acceptible shape for input tensor * @param dtype Expected data type for the input (Defaults to the type of the weights in the first tensor * calculation if detectable else Float32) + * @param tensor_domain Allowed range for tensor inputs [low, high) * @param format Expected tensor format for the input (Defaults to contiguous) */ TORCHTRT_API Input( @@ -500,6 +646,7 @@ struct Input : torch::CustomClassHolder { c10::ArrayRef opt_shape, c10::ArrayRef max_shape, DataType dtype, + std::vector tensor_domain, TensorFormat format = TensorFormat::kContiguous); /** diff --git a/cpp/src/types.cpp b/cpp/src/types.cpp index 2d3c271694..2be7fea338 100644 --- a/cpp/src/types.cpp +++ b/cpp/src/types.cpp @@ -173,6 +173,18 @@ Input::Input(std::vector shape, TensorFormat format) { this->dtype = DataType::kUnknown; this->format = format; this->input_is_dynamic = false; + this->tensor_domain = std::vector{0, 2}; +} + +Input::Input(std::vector shape, std::vector tensor_domain, TensorFormat format) { + this->opt_shape = shape; + this->min_shape = shape; + this->max_shape = shape; + this->shape = shape; + this->dtype = DataType::kUnknown; + this->format = format; + this->input_is_dynamic = false; + this->tensor_domain = tensor_domain; } Input::Input(std::vector shape, DataType dtype, TensorFormat format) { @@ -183,6 +195,18 @@ Input::Input(std::vector shape, DataType dtype, TensorFormat format) { this->dtype = dtype; this->format = format; this->input_is_dynamic = false; + this->tensor_domain = std::vector{0, 2}; +} + +Input::Input(std::vector shape, DataType dtype, std::vector tensor_domain, TensorFormat format) { + this->opt_shape = shape; + this->min_shape = shape; + this->max_shape = shape; + this->shape = shape; + this->dtype = dtype; + this->format = format; + this->input_is_dynamic = false; + this->tensor_domain = tensor_domain; } Input::Input(c10::IntArrayRef shape, TensorFormat format) { @@ -193,6 +217,18 @@ Input::Input(c10::IntArrayRef shape, TensorFormat format) { this->dtype = DataType::kUnknown; this->format = format; this->input_is_dynamic = false; + this->tensor_domain = std::vector{0, 2}; +} + +Input::Input(c10::IntArrayRef shape, std::vector tensor_domain, TensorFormat format) { + this->opt_shape = torch_tensorrt::core::util::toVec(shape); + this->min_shape = torch_tensorrt::core::util::toVec(shape); + this->max_shape = torch_tensorrt::core::util::toVec(shape); + this->shape = torch_tensorrt::core::util::toVec(shape); + this->dtype = DataType::kUnknown; + this->format = format; + this->input_is_dynamic = false; + this->tensor_domain = tensor_domain; } Input::Input(c10::IntArrayRef shape, DataType dtype, TensorFormat format) { @@ -203,12 +239,41 @@ Input::Input(c10::IntArrayRef shape, DataType dtype, TensorFormat format) { this->dtype = dtype; this->format = format; this->input_is_dynamic = false; + this->tensor_domain = std::vector{0, 2}; +} + +Input::Input(c10::IntArrayRef shape, DataType dtype, std::vector tensor_domain, TensorFormat format) { + this->opt_shape = torch_tensorrt::core::util::toVec(shape); + this->min_shape = torch_tensorrt::core::util::toVec(shape); + this->max_shape = torch_tensorrt::core::util::toVec(shape); + this->shape = torch_tensorrt::core::util::toVec(shape); + this->dtype = dtype; + this->format = format; + this->input_is_dynamic = false; + this->tensor_domain = tensor_domain; +} + +Input::Input( + std::vector min_shape, + std::vector opt_shape, + std::vector max_shape, + TensorFormat format) { + this->opt_shape = opt_shape; + this->min_shape = min_shape; + this->max_shape = max_shape; + this->shape = torch_tensorrt::core::util::toVec( + torch_tensorrt::core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape); + this->dtype = DataType::kUnknown; + this->format = format; + this->input_is_dynamic = true; + this->tensor_domain = std::vector{0, 2}; } Input::Input( std::vector min_shape, std::vector opt_shape, std::vector max_shape, + std::vector tensor_domain, TensorFormat format) { this->opt_shape = opt_shape; this->min_shape = min_shape; @@ -218,6 +283,7 @@ Input::Input( this->dtype = DataType::kUnknown; this->format = format; this->input_is_dynamic = true; + this->tensor_domain = tensor_domain; } Input::Input( @@ -234,6 +300,25 @@ Input::Input( this->dtype = dtype; this->format = format; this->input_is_dynamic = true; + this->tensor_domain = std::vector{0, 2}; +} + +Input::Input( + std::vector min_shape, + std::vector opt_shape, + std::vector max_shape, + DataType dtype, + std::vector tensor_domain, + TensorFormat format) { + this->opt_shape = opt_shape; + this->min_shape = min_shape; + this->max_shape = max_shape; + this->shape = torch_tensorrt::core::util::toVec( + torch_tensorrt::core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape); + this->dtype = dtype; + this->format = format; + this->input_is_dynamic = true; + this->tensor_domain = tensor_domain; } Input::Input(c10::IntArrayRef min_shape, c10::IntArrayRef opt_shape, c10::IntArrayRef max_shape, TensorFormat format) { @@ -245,6 +330,41 @@ Input::Input(c10::IntArrayRef min_shape, c10::IntArrayRef opt_shape, c10::IntArr this->dtype = DataType::kUnknown; this->format = format; this->input_is_dynamic = true; + this->tensor_domain = std::vector{0, 2}; +} + +Input::Input( + c10::IntArrayRef min_shape, + c10::IntArrayRef opt_shape, + c10::IntArrayRef max_shape, + std::vector tensor_domain, + TensorFormat format) { + this->opt_shape = torch_tensorrt::core::util::toVec(opt_shape); + this->min_shape = torch_tensorrt::core::util::toVec(min_shape); + this->max_shape = torch_tensorrt::core::util::toVec(max_shape); + this->shape = torch_tensorrt::core::util::toVec( + torch_tensorrt::core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape); + this->dtype = DataType::kUnknown; + this->format = format; + this->input_is_dynamic = true; + this->tensor_domain = tensor_domain; +} + +Input::Input( + c10::IntArrayRef min_shape, + c10::IntArrayRef opt_shape, + c10::IntArrayRef max_shape, + DataType dtype, + TensorFormat format) { + this->opt_shape = torch_tensorrt::core::util::toVec(opt_shape); + this->min_shape = torch_tensorrt::core::util::toVec(min_shape); + this->max_shape = torch_tensorrt::core::util::toVec(max_shape); + this->shape = torch_tensorrt::core::util::toVec( + torch_tensorrt::core::ir::Input(this->min_shape, this->opt_shape, this->max_shape).input_shape); + this->dtype = dtype; + this->format = format; + this->input_is_dynamic = true; + this->tensor_domain = std::vector{0, 2}; } Input::Input( @@ -252,6 +372,7 @@ Input::Input( c10::IntArrayRef opt_shape, c10::IntArrayRef max_shape, DataType dtype, + std::vector tensor_domain, TensorFormat format) { this->opt_shape = torch_tensorrt::core::util::toVec(opt_shape); this->min_shape = torch_tensorrt::core::util::toVec(min_shape); @@ -261,6 +382,7 @@ Input::Input( this->dtype = dtype; this->format = format; this->input_is_dynamic = true; + this->tensor_domain = tensor_domain; } Input::Input(at::Tensor tensor) { @@ -280,6 +402,7 @@ Input::Input(at::Tensor tensor) { } this->format = frmt; this->input_is_dynamic = false; + this->tensor_domain = std::vector{0, 2}; } /* ==========================================*/ @@ -291,7 +414,8 @@ torch_tensorrt::core::ir::Input to_internal_input(Input& i) { i.max_shape, toAtenDataType(i.dtype), toTRTTensorFormat(i.format), - !(i.dtype == DataType::kUnknown)); + !(i.dtype == DataType::kUnknown), + i.tensor_domain); } std::vector to_vec_internal_inputs(std::vector& external) { diff --git a/py/torch_tensorrt/_Input.py b/py/torch_tensorrt/_Input.py index 8780d4db91..324c385fab 100644 --- a/py/torch_tensorrt/_Input.py +++ b/py/torch_tensorrt/_Input.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import List, Dict, Any +from typing import List, Dict, Any, Tuple, Optional import torch @@ -38,6 +38,10 @@ class _ShapeMode(Enum): _enums.TensorFormat.contiguous ) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW) + DOMAIN_OFFSET = 2.0 + low_tensor_domain_incl = 0.0 + high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET + def __init__(self, *args, **kwargs): """__init__ Method for torch_tensorrt.Input @@ -56,6 +60,8 @@ def __init__(self, *args, **kwargs): Note: All three of min_shape, opt_shape, max_shape must be provided, there must be no positional arguments, shape must not be defined and implictly this sets Input's shape_mode to DYNAMIC dtype (torch.dtype or torch_tensorrt.dtype): Expected data type for input tensor (default: torch_tensorrt.dtype.float32) format (torch.memory_format or torch_tensorrt.TensorFormat): The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW) + tensor_domain (Tuple(float, float), optional): The domain of allowed values for the tensor, as interval notation: [tensor_domain[0], tensor_domain[1]). + Note: Entering "None" (or not specifying) will set the bound to [0, 2) Examples: - Input([1,3,32,32], dtype=torch.float32, format=torch.channel_last) @@ -138,18 +144,31 @@ def __init__(self, *args, **kwargs): if "format" in kwargs: self.format = Input._parse_format(kwargs["format"]) + if "tensor_domain" in kwargs: + domain = kwargs["tensor_domain"] + else: + domain = None + + self.tensor_domain = Input._parse_tensor_domain(domain) + def __str__(self) -> str: if self.shape_mode == Input._ShapeMode.STATIC: - return "Input(shape={}, dtype={}, format={})".format( - self.shape, str(self.dtype), str(self.format) + return "Input(shape={}, dtype={}, format={}, domain=[{}, {}))".format( + self.shape, + str(self.dtype), + str(self.format), + str(self.tensor_domain[0]), + str(self.tensor_domain[1]), ) elif self.shape_mode == Input._ShapeMode.DYNAMIC: - return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={})".format( + return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={}, domain=[{}, {}))".format( self.shape["min_shape"], self.shape["opt_shape"], self.shape["max_shape"], str(self.dtype), str(self.format), + str(self.tensor_domain[0]), + str(self.tensor_domain[1]), ) else: raise RuntimeError("Unknown input shape mode") @@ -203,6 +222,8 @@ def _to_internal(self) -> _C.Input: internal_in.dtype = Input._parse_dtype(self.dtype) internal_in._explicit_set_dtype = self._explicit_set_dtype internal_in.format = Input._parse_format(self.format) + + internal_in.tensor_domain = Input._parse_tensor_domain(self.tensor_domain) return internal_in @staticmethod @@ -267,6 +288,51 @@ def _parse_format(format: Any) -> _enums.TensorFormat: "Tensor format needs to be specified with either torch.memory_format or torch_tensorrt.TensorFormat" ) + @staticmethod + def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple: + """ + Produce a tuple of integers which specifies a tensor domain in the interval format: [lo, hi) + + Args: + domain (Tuple[int, int]): A tuple of integers (or NoneTypes) to verify + + Returns: + A tuple of two int32_t-valid integers + """ + if domain is None: + result_domain = ( + Input.low_tensor_domain_incl, + Input.high_tensor_domain_excl, + ) + elif len(domain) == 2: + domain_lo, domain_hi = domain + + # Validate type and provided values for domain + valid_type_lo = isinstance(domain_lo, int) or isinstance(domain_lo, float) + valid_type_hi = isinstance(domain_hi, int) or isinstance(domain_hi, float) + + if not valid_type_lo: + raise ValueError( + f"Expected value for tensor domain low specifier, got {domain_lo}" + ) + elif not valid_type_hi: + raise ValueError( + f"Expected value for tensor domain high specifier, got {domain_hi}" + ) + + if domain_hi <= domain_lo: + raise ValueError( + "Expected provided integer range to have low tensor domain value " + + f"< high tensor domain value, got invalid range [{domain_lo}, {domain_hi})" + ) + result_domain = (float(domain_lo), float(domain_hi)) + else: + raise ValueError( + f"Expected 2 values for domain, got {len(domain)}: {domain}" + ) + + return result_domain + @classmethod def from_tensor(cls, t: torch.Tensor) -> "Input": """ diff --git a/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp b/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp index ba072504d9..528ffde23f 100644 --- a/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/register_tensorrt_classes.cpp @@ -20,6 +20,7 @@ void RegisterTRTCompileSpec() { ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, max); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, dtype); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, format); + ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, tensor_domain); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, input_is_dynamic); ADD_FIELD_GET_SET_REGISTRATION(TRTInputRangeTSRegistration, torch_tensorrt::pyapi::Input, explicit_set_dtype); diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.cpp b/py/torch_tensorrt/csrc/tensorrt_classes.cpp index 9cf18b4eb9..a312832628 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.cpp +++ b/py/torch_tensorrt/csrc/tensorrt_classes.cpp @@ -95,9 +95,10 @@ std::string to_str(TensorFormat value) { core::ir::Input Input::toInternalInput() { if (!input_is_dynamic) { - return core::ir::Input(opt, toAtenDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); + return core::ir::Input(opt, toAtenDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype, tensor_domain); } else { - return core::ir::Input(min, opt, max, toAtenDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype); + return core::ir::Input( + min, opt, max, toAtenDataType(dtype), toTRTTensorFormat(format), explicit_set_dtype, tensor_domain); } } @@ -112,6 +113,12 @@ std::string Input::to_str() { return ss.str(); }; + auto domain_to_str = [](std::vector domain) -> std::string { + std::stringstream ss; + ss << "[" << domain[0] << ", " << domain[1] << ")"; + return ss.str(); + }; + std::stringstream ss; ss << "Input("; @@ -124,7 +131,8 @@ std::string Input::to_str() { } ss << "dtype=" << pyapi::to_str(dtype) << ", "; - ss << "format=" << pyapi::to_str(format) << ')'; + ss << "format=" << pyapi::to_str(format) << ", "; + ss << "tensor_domain=" << domain_to_str(tensor_domain) << ")"; return ss.str(); } diff --git a/py/torch_tensorrt/csrc/tensorrt_classes.h b/py/torch_tensorrt/csrc/tensorrt_classes.h index 3470944c72..0b42b68729 100644 --- a/py/torch_tensorrt/csrc/tensorrt_classes.h +++ b/py/torch_tensorrt/csrc/tensorrt_classes.h @@ -40,6 +40,7 @@ struct Input : torch::CustomClassHolder { std::vector min; std::vector opt; std::vector max; + std::vector tensor_domain; bool input_is_dynamic; bool explicit_set_dtype; @@ -49,6 +50,7 @@ struct Input : torch::CustomClassHolder { ADD_FIELD_GET_SET(min, std::vector); ADD_FIELD_GET_SET(opt, std::vector); ADD_FIELD_GET_SET(max, std::vector); + ADD_FIELD_GET_SET(tensor_domain, std::vector); ADD_FIELD_GET_SET(input_is_dynamic, bool); ADD_FIELD_GET_SET(explicit_set_dtype, bool); ADD_ENUM_GET_SET(dtype, DataType, static_cast(DataType::kUnknown)); diff --git a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp index 7341ba9281..142a316c05 100644 --- a/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp +++ b/py/torch_tensorrt/csrc/torch_tensorrt_py.cpp @@ -223,6 +223,7 @@ PYBIND11_MODULE(_C, m) { .def_readwrite("input_is_dynamic", &Input::input_is_dynamic) .def_readwrite("_explicit_set_dtype", &Input::explicit_set_dtype) .def_readwrite("dtype", &Input::dtype) + .def_readwrite("tensor_domain", &Input::tensor_domain) .def_readwrite("format", &Input::format); py::class_(m, "InputSignature") diff --git a/py/torch_tensorrt/ts/_compile_spec.py b/py/torch_tensorrt/ts/_compile_spec.py index d76d259e29..50ba9b4ed5 100644 --- a/py/torch_tensorrt/ts/_compile_spec.py +++ b/py/torch_tensorrt/ts/_compile_spec.py @@ -17,6 +17,7 @@ def _internal_input_to_torch_class_input(i: _C.Input) -> torch.classes.tensorrt. clone._set_opt(i.opt) clone._set_max(i.max) clone._set_dtype(i.dtype) + clone._set_tensor_domain(i.tensor_domain) clone._set_format(i.format) clone._set_input_is_dynamic(i.input_is_dynamic) clone._set_explicit_set_dtype(i._explicit_set_dtype) diff --git a/tests/cpp/test_collections.cpp b/tests/cpp/test_collections.cpp index 982562923d..cbca9c7b98 100644 --- a/tests/cpp/test_collections.cpp +++ b/tests/cpp/test_collections.cpp @@ -89,6 +89,51 @@ TEST(CppAPITests, TestCollectionStandardTensorInputLongDtype) { out.toTensor().to(torch::kFloat), trt_out.toTensor().to(torch::kFloat))); } +TEST(CppAPITests, TestSpecifyDomainStandardTensorInput) { + std::string path = "tests/modules/tuple_input_scripted.jit.pt"; + torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); + + torch::jit::Module mod; + try { + // Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(path); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + } + mod.eval(); + mod.to(torch::kCUDA); + + std::vector complex_inputs, complex_inputs_list; + std::tuple input_tuple(in0, in0); + + complex_inputs.push_back(input_tuple); + + auto out = mod.forward(complex_inputs); + + // Specify input tensor domain argument + auto tensor_domain = std::vector{35, 377}; + auto input_shape = torch_tensorrt::Input(in0.sizes(), torch_tensorrt::DataType::kHalf, tensor_domain); + + auto input_shape_ivalue = torch::jit::IValue(std::move(c10::make_intrusive(input_shape))); + + std::tuple input_shape_tuple(input_shape_ivalue, input_shape_ivalue); + + torch::jit::IValue complex_input_shape(input_shape_tuple); + std::tuple input_tuple2(complex_input_shape); + torch::jit::IValue complex_input_shape2(input_tuple2); + + auto compile_settings = torch_tensorrt::ts::CompileSpec(complex_input_shape2); + compile_settings.min_block_size = 1; + + // // FP16 execution + compile_settings.enabled_precisions = {torch::kHalf}; + // // Compile module + auto trt_mod = torch_tensorrt::torchscript::compile(mod, compile_settings); + auto trt_out = trt_mod.forward(complex_inputs); + + ASSERT_TRUE(torch_tensorrt::tests::util::cosineSimEqual(out.toTensor(), trt_out.toTensor())); +} + TEST(CppAPITests, TestCollectionTupleInput) { std::string path = "tests/modules/tuple_input_scripted.jit.pt"; torch::Tensor in0 = torch::randn({1, 3, 512, 512}, torch::kCUDA).to(torch::kHalf); diff --git a/tests/py/api/test_collections.py b/tests/py/api/test_collections.py index 33b4608194..12c1ac9f50 100644 --- a/tests/py/api/test_collections.py +++ b/tests/py/api/test_collections.py @@ -80,6 +80,35 @@ def test_compile(self): ) +class TestStandardTensorInputDomain(unittest.TestCase): + def test_compile(self): + + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + self.model = ( + torch.jit.load(MODULE_DIR + "/standard_tensor_input_scripted.jit.pt") + .eval() + .to("cuda") + ) + + compile_spec = { + "inputs": [ + torchtrt.Input(self.input.shape, tensor_domain=(70.8, 800)), + torchtrt.Input(self.input.shape, tensor_domain=(-20, -17.9)), + ], + "device": torchtrt.Device("gpu:0"), + "enabled_precisions": {torch.float}, + } + + trt_mod = torchtrt.ts.compile(self.model, **compile_spec) + cos_sim = cosine_similarity( + self.model(self.input, self.input), trt_mod(self.input, self.input) + ) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"standard_tensor_input_scripted with tensor domain specified TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + + class TestTupleInput(unittest.TestCase): def test_compile(self): diff --git a/tests/py/api/test_operator_fallback.py b/tests/py/api/test_operator_fallback.py index 302a663e24..19ba891514 100644 --- a/tests/py/api/test_operator_fallback.py +++ b/tests/py/api/test_operator_fallback.py @@ -31,6 +31,32 @@ def test_fallback_resnet18(self): msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) + def test_fallback_resnet18_with_tensor_domain(self): + self.model = models.resnet18(pretrained=True).eval().to("cuda") + self.input = torch.randn((1, 3, 224, 224)).to("cuda") + compile_spec = { + "inputs": [ + torchtrt.Input( + self.input.shape, + dtype=torch.float, + format=torch.contiguous_format, + tensor_domain=(-0.5, 0.5), + ) + ], + "device": { + "device_type": torchtrt.DeviceType.GPU, + "gpu_id": 0, + }, + "enabled_precisions": {torch.float}, + "torch_executed_ops": ["aten::add"], + } + trt_mod = torchtrt.compile(self.model, **compile_spec) + cos_sim = cosine_similarity(self.model(self.input), trt_mod(self.input)) + self.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) + def test_fallback_mobilenet_v2(self): self.model = models.mobilenet_v2(pretrained=True).eval().to("cuda") self.input = torch.randn((1, 3, 224, 224)).to("cuda")