From f38e9c836ad6e6aec47c7e1176d914083fe20a9c Mon Sep 17 00:00:00 2001 From: gs-olive <113141689+gs-olive@users.noreply.github.com> Date: Wed, 7 Jun 2023 20:33:30 -0700 Subject: [PATCH] feat: Add support for output data types in Interpreter - Add argument for specification of output data types of TRT engines in the interpreter, to avoid type mismatches at runtime - Add support for output data type provision in the Dynamo compile path, which simultaneously tests the feature via the backend testing and e2e frameworks --- .../dynamo/backend/conversion.py | 10 +++++++++ .../dynamo/fx_ts_compat/fx2trt.py | 22 ++++++++++++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/py/torch_tensorrt/dynamo/backend/conversion.py b/py/torch_tensorrt/dynamo/backend/conversion.py index f359020bfb..f2631f0c87 100644 --- a/py/torch_tensorrt/dynamo/backend/conversion.py +++ b/py/torch_tensorrt/dynamo/backend/conversion.py @@ -24,11 +24,21 @@ def convert_module( Returns: TRTModule or TRTModuleNext """ + # Specify module output data types to ensure TRT output types agree with + # that of the equivalent Torch module + module_outputs = module(*inputs) + + if not isinstance(module_outputs, (list, tuple)): + module_outputs = [module_outputs] + + output_dtypes = list(output.dtype for output in module_outputs) + interpreter = TRTInterpreter( module, InputTensorSpec.from_tensors(inputs), explicit_batch_dimension=True, logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING), + output_dtypes=output_dtypes, ) interpreter_result = interpreter.run( diff --git a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py index e4298600cb..444efc0f4e 100644 --- a/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py +++ b/py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py @@ -1,6 +1,7 @@ import logging import warnings from datetime import datetime +from packaging import version from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence import numpy @@ -40,6 +41,7 @@ def __init__( explicit_batch_dimension: bool = True, explicit_precision: bool = False, logger_level=None, + output_dtypes=None, ): super().__init__(module) @@ -78,6 +80,9 @@ def __init__( trt.tensorrt.ITensor, TensorMetadata ] = dict() + # Data types for TRT Module output Tensors + self.output_dtypes = output_dtypes + def validate_input_specs(self): for shape, _, _, shape_ranges, has_batch_dim in self.input_specs: if not self.network.has_implicit_batch_dimension: @@ -178,13 +183,17 @@ def run( algorithm_selector: set up algorithm selection for certain layer timing_cache: enable timing cache for TensorRT profiling_verbosity: TensorRT logging level + max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine + version_compatible: Provide version forward-compatibility for engine plan files + optimization_level: Builder optimization 0-5, higher levels imply longer build time, + searching for more optimization options. TRT defaults to 3 Return: TRTInterpreterResult """ TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module) # For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and - # force_fp32_output=False. + # force_fp32_output=False. Overriden by specifying output_dtypes self.output_fp16 = ( not force_fp32_output and lower_precision == LowerPrecision.FP16 ) @@ -224,14 +233,14 @@ def run( cache = builder_config.create_timing_cache(b"") builder_config.set_timing_cache(cache, False) - if trt.__version__ >= "8.2": + if version.parse(trt.__version__) >= version.parse("8.2"): builder_config.profiling_verbosity = ( profiling_verbosity if profiling_verbosity else trt.ProfilingVerbosity.LAYER_NAMES_ONLY ) - if trt.__version__ >= "8.6": + if version.parse(trt.__version__) >= version.parse("8.6"): if max_aux_streams is not None: _LOGGER.info(f"Setting max aux streams to {max_aux_streams}") builder_config.max_aux_streams = max_aux_streams @@ -372,6 +381,11 @@ def output(self, target, args, kwargs): if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs): raise RuntimeError("TensorRT requires all outputs to be Tensor!") + if self.output_dtypes is not None and len(self.output_dtypes) != len(outputs): + raise RuntimeError( + f"Specified output dtypes ({len(self.output_dtypes)}) differ from number of outputs ({len(outputs)})" + ) + for i, output in enumerate(outputs): if any( op_name in output.name.split("_") @@ -396,6 +410,8 @@ def output(self, target, args, kwargs): self.network.mark_output(output) if output_bool: output.dtype = trt.bool + elif self.output_dtypes is not None: + output.dtype = torch_dtype_to_trt(self.output_dtypes[i]) elif self.output_fp16 and output.dtype == trt.float32: output.dtype = trt.float16 self._output_names.append(name)