Skip to content

Commit 30f5094

Browse files
peri044narendasan
authored andcommitted
feat: cherry-pick of Selectively enable different frontends (#2693) (#2761)
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]> Co-authored-by: Naren Dasan <[email protected]>
1 parent a5079ad commit 30f5094

File tree

18 files changed

+56
-81
lines changed

18 files changed

+56
-81
lines changed

.github/workflows/build-test.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,6 @@ jobs:
264264
pre-script: ${{ matrix.pre-script }}
265265
script: |
266266
export USE_HOST_DEPS=1
267-
export LD_LIBRARY_PATH=/opt/torch-tensorrt-builds/TensorRT-10.0.0.6/lib:$LD_LIBRARY_PATH
268267
pushd .
269268
cd tests/py/core
270269
${CONDA_RUN} python -m pip install --pre pytest-xdist timm transformers parameterized expecttest==0.1.6 --use-deprecated=legacy-resolver

py/torch_tensorrt/_Device.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,11 @@
99
else:
1010
from typing_extensions import Self
1111

12+
import tensorrt as trt
1213
import torch
1314
from torch_tensorrt._enums import DeviceType
1415
from torch_tensorrt._features import ENABLED_FEATURES
1516

16-
import tensorrt as trt
17-
1817

1918
class Device(object):
2019
"""

py/torch_tensorrt/_compile.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77

88
import torch
99
import torch.fx
10-
import torch_tensorrt.dynamo
11-
import torch_tensorrt.ts
1210
from torch_tensorrt._enums import dtype
1311
from torch_tensorrt._features import ENABLED_FEATURES
1412
from torch_tensorrt._Input import Input
@@ -343,18 +341,8 @@ def convert_method_to_trt_engine(
343341
"convert_method_to_trt_engine call is not supported for ir=fx"
344342
)
345343
elif target_ir == _IRType.dynamo:
346-
# Prepare torch and torchtrt inputs
347-
from torch_tensorrt.dynamo.utils import prepare_inputs
348-
349-
if not isinstance(inputs, collections.abc.Sequence):
350-
inputs = [inputs]
351-
352-
# Export the module
353-
torchtrt_inputs = prepare_inputs(inputs)
354-
exp_program = torch_tensorrt.dynamo.trace(module, torchtrt_inputs, **kwargs)
355-
356344
return dynamo_convert_module_to_trt_engine( # type: ignore[no-any-return]
357-
exp_program,
345+
module,
358346
inputs=inputs,
359347
enabled_precisions=enabled_precisions_set,
360348
**kwargs,

py/torch_tensorrt/_enums.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def _from(
107107
return dtype.f16
108108
elif t == trt.float32:
109109
return dtype.f32
110-
elif t == trt.bool:
110+
elif trt.__version__ >= "7.0" and t == trt.bool:
111111
return dtype.b
112112
else:
113113
raise TypeError(

py/torch_tensorrt/dynamo/_compiler.py

Lines changed: 10 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,33 +10,7 @@
1010
from torch_tensorrt._Device import Device
1111
from torch_tensorrt._enums import EngineCapability, dtype
1212
from torch_tensorrt._Input import Input
13-
from torch_tensorrt.dynamo import partitioning
14-
from torch_tensorrt.dynamo._defaults import (
15-
DEBUG,
16-
DEVICE,
17-
DISABLE_TF32,
18-
DLA_GLOBAL_DRAM_SIZE,
19-
DLA_LOCAL_DRAM_SIZE,
20-
DLA_SRAM_SIZE,
21-
DRYRUN,
22-
ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
23-
ENGINE_CAPABILITY,
24-
HARDWARE_COMPATIBLE,
25-
MAX_AUX_STREAMS,
26-
MIN_BLOCK_SIZE,
27-
NUM_AVG_TIMING_ITERS,
28-
OPTIMIZATION_LEVEL,
29-
PASS_THROUGH_BUILD_FAILURES,
30-
PRECISION,
31-
REFIT,
32-
REQUIRE_FULL_COMPILATION,
33-
SPARSE_WEIGHTS,
34-
TRUNCATE_LONG_AND_DOUBLE,
35-
USE_FAST_PARTITIONER,
36-
USE_PYTHON_RUNTIME,
37-
VERSION_COMPATIBLE,
38-
WORKSPACE_SIZE,
39-
)
13+
from torch_tensorrt.dynamo import _defaults, partitioning
4014
from torch_tensorrt.dynamo._DryRunTracker import (
4115
DryRunTracker,
4216
PerSubgraphData,
@@ -89,15 +63,15 @@ def compile(
8963
min_block_size: int = _defaults.MIN_BLOCK_SIZE,
9064
torch_executed_ops: Optional[Collection[Target]] = None,
9165
torch_executed_modules: Optional[List[str]] = None,
92-
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
93-
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
94-
version_compatible: bool = VERSION_COMPATIBLE,
95-
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
96-
use_python_runtime: bool = USE_PYTHON_RUNTIME,
97-
use_fast_partitioner: bool = USE_FAST_PARTITIONER,
98-
enable_experimental_decompositions: bool = ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
99-
dryrun: bool = DRYRUN,
100-
hardware_compatible: bool = HARDWARE_COMPATIBLE,
66+
pass_through_build_failures: bool = _defaults.PASS_THROUGH_BUILD_FAILURES,
67+
max_aux_streams: Optional[int] = _defaults.MAX_AUX_STREAMS,
68+
version_compatible: bool = _defaults.VERSION_COMPATIBLE,
69+
optimization_level: Optional[int] = _defaults.OPTIMIZATION_LEVEL,
70+
use_python_runtime: bool = _defaults.USE_PYTHON_RUNTIME,
71+
use_fast_partitioner: bool = _defaults.USE_FAST_PARTITIONER,
72+
enable_experimental_decompositions: bool = _defaults.ENABLE_EXPERIMENTAL_DECOMPOSITIONS,
73+
dryrun: bool = _defaults.DRYRUN,
74+
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
10175
**kwargs: Any,
10276
) -> torch.fx.GraphModule:
10377
"""Compile a TorchScript module for NVIDIA GPUs using TensorRT

py/torch_tensorrt/dynamo/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
REQUIRE_FULL_COMPILATION = False
2727
DRYRUN = False
2828
HARDWARE_COMPATIBLE = False
29+
SUPPORTED_KERNEL_PRECISIONS = {dtype.f32, dtype.f16, dtype.i8}
2930

3031

3132
def default_device() -> Device:

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -313,7 +313,7 @@ def run(
313313
)
314314
timing_cache = self._create_timing_cache(builder_config, existing_cache)
315315

316-
engine = self.builder.build_serialized_network(self.ctx.net, builder_config)
316+
engine = self.builder.build_engine(self.ctx.net, builder_config)
317317
assert engine
318318

319319
serialized_cache = (

py/torch_tensorrt/dynamo/conversion/_conversion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def infer_module_output_dtypes(
3838
# such as aten.sum - such outputs can be truncated
3939
output_dtypes = []
4040
for output in module_outputs:
41+
if not isinstance(output, torch.Tensor):
42+
output = torch.tensor(output)
4143
if truncate_long_and_double and output.dtype == dtype.float64:
4244
output_dtypes.append(dtype.float32)
4345
elif truncate_long_and_double and output.dtype == dtype.int64:

py/torch_tensorrt/dynamo/conversion/impl/cast.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional, Union
33

44
import numpy as np
5+
import tensorrt as trt
56
import torch
67
from torch.fx.node import Target
78
from torch_tensorrt import _enums

py/torch_tensorrt/dynamo/conversion/impl/matmul.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional
22

3+
import tensorrt as trt
34
import torch
45
from torch.fx.node import Target
56
from torch_tensorrt import _enums
@@ -9,8 +10,6 @@
910
from torch_tensorrt.fx.converters.converter_utils import broadcast, set_layer_name
1011
from torch_tensorrt.fx.types import TRTTensor
1112

12-
import tensorrt as trt
13-
1413

1514
def matrix_multiply(
1615
ctx: ConversionContext,

0 commit comments

Comments
 (0)