Skip to content

feat: Add support for output data types in TRTInterpreter [2 / x] #2004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions py/torch_tensorrt/dynamo/backend/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,21 @@ def convert_module(
Returns:
TRTModule or TRTModuleNext
"""
# Specify module output data types to ensure TRT output types agree with
# that of the equivalent Torch module
module_outputs = module(*inputs)

if not isinstance(module_outputs, (list, tuple)):
module_outputs = [module_outputs]

output_dtypes = list(output.dtype for output in module_outputs)

interpreter = TRTInterpreter(
module,
InputTensorSpec.from_tensors(inputs),
explicit_batch_dimension=True,
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
output_dtypes=output_dtypes,
)

interpreter_result = interpreter.run(
Expand Down
22 changes: 19 additions & 3 deletions py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import warnings
from datetime import datetime
from packaging import version
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence

import numpy
Expand Down Expand Up @@ -40,6 +41,7 @@ def __init__(
explicit_batch_dimension: bool = True,
explicit_precision: bool = False,
logger_level=None,
output_dtypes=None,
):
super().__init__(module)

Expand Down Expand Up @@ -78,6 +80,9 @@ def __init__(
trt.tensorrt.ITensor, TensorMetadata
] = dict()

# Data types for TRT Module output Tensors
self.output_dtypes = output_dtypes

def validate_input_specs(self):
for shape, _, _, shape_ranges, has_batch_dim in self.input_specs:
if not self.network.has_implicit_batch_dimension:
Expand Down Expand Up @@ -178,13 +183,17 @@ def run(
algorithm_selector: set up algorithm selection for certain layer
timing_cache: enable timing cache for TensorRT
profiling_verbosity: TensorRT logging level
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
version_compatible: Provide version forward-compatibility for engine plan files
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
searching for more optimization options. TRT defaults to 3
Return:
TRTInterpreterResult
"""
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)

# For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and
# force_fp32_output=False.
# force_fp32_output=False. Overriden by specifying output_dtypes
self.output_fp16 = (
not force_fp32_output and lower_precision == LowerPrecision.FP16
)
Expand Down Expand Up @@ -224,14 +233,14 @@ def run(
cache = builder_config.create_timing_cache(b"")
builder_config.set_timing_cache(cache, False)

if trt.__version__ >= "8.2":
if version.parse(trt.__version__) >= version.parse("8.2"):
builder_config.profiling_verbosity = (
profiling_verbosity
if profiling_verbosity
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
)

if trt.__version__ >= "8.6":
if version.parse(trt.__version__) >= version.parse("8.6"):
if max_aux_streams is not None:
_LOGGER.info(f"Setting max aux streams to {max_aux_streams}")
builder_config.max_aux_streams = max_aux_streams
Expand Down Expand Up @@ -372,6 +381,11 @@ def output(self, target, args, kwargs):
if not all(isinstance(output, trt.tensorrt.ITensor) for output in outputs):
raise RuntimeError("TensorRT requires all outputs to be Tensor!")

if self.output_dtypes is not None and len(self.output_dtypes) != len(outputs):
raise RuntimeError(
f"Specified output dtypes ({len(self.output_dtypes)}) differ from number of outputs ({len(outputs)})"
)

for i, output in enumerate(outputs):
if any(
op_name in output.name.split("_")
Expand All @@ -396,6 +410,8 @@ def output(self, target, args, kwargs):
self.network.mark_output(output)
if output_bool:
output.dtype = trt.bool
elif self.output_dtypes is not None:
output.dtype = torch_dtype_to_trt(self.output_dtypes[i])
elif self.output_fp16 and output.dtype == trt.float32:
output.dtype = trt.float16
self._output_names.append(name)