Skip to content

Commit 9e2502a

Browse files
committed
fix: Add truncate_long_and_double to Dynamo
- Add pass-through ability for feature to Dynamo compile frontend
1 parent 1a2fe99 commit 9e2502a

File tree

4 files changed

+18
-13
lines changed

4 files changed

+18
-13
lines changed

py/torch_tensorrt/dynamo/backend/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
DEBUG,
1717
MAX_WORKSPACE_SIZE,
1818
MIN_BLOCK_SIZE,
19+
TRUNCATE_LONG_AND_DOUBLE,
1920
)
2021

2122

@@ -39,7 +40,7 @@ def compile(
3940
dla_local_dram_size=1073741824,
4041
dla_global_dram_size=536870912,
4142
calibrator=None,
42-
truncate_long_and_double=False,
43+
truncate_long_and_double=TRUNCATE_LONG_AND_DOUBLE,
4344
require_full_compilation=False,
4445
min_block_size=MIN_BLOCK_SIZE,
4546
torch_executed_ops=[],
@@ -50,7 +51,8 @@ def compile(
5051
logger.warn(
5152
"The Dynamo backend is an experimental feature, for which only the "
5253
+ "following arguments are supported: "
53-
+ "{enabled_precisions, debug, workspace_size, min_block_size, torch_executed_ops}"
54+
+ "{enabled_precisions, debug, workspace_size, "
55+
+ "truncate_long_and_double, min_block_size, torch_executed_ops}"
5456
)
5557

5658
if not isinstance(inputs, collections.abc.Sequence):
@@ -82,6 +84,7 @@ def compile(
8284
workspace_size=workspace_size,
8385
min_block_size=min_block_size,
8486
torch_executed_ops=torch_executed_ops,
87+
truncate_long_and_double=truncate_long_and_double,
8588
**kwargs,
8689
)
8790

@@ -104,6 +107,7 @@ def create_backend(
104107
workspace_size: int = MAX_WORKSPACE_SIZE,
105108
min_block_size: int = MIN_BLOCK_SIZE,
106109
torch_executed_ops: Sequence[str] = set(),
110+
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE,
107111
**kwargs,
108112
):
109113
"""Create torch.compile backend given specified arguments
@@ -122,6 +126,7 @@ def create_backend(
122126
workspace_size=workspace_size,
123127
min_block_size=min_block_size,
124128
torch_executed_ops=torch_executed_ops,
129+
truncate_long_and_double=truncate_long_and_double,
125130
)
126131

127132
return partial(

py/torch_tensorrt/dynamo/backend/_defaults.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
DEBUG = False
66
MAX_WORKSPACE_SIZE = 20 << 30
77
MIN_BLOCK_SIZE = 5
8+
TRUNCATE_LONG_AND_DOUBLE = False

py/torch_tensorrt/dynamo/backend/_settings.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
DEBUG,
88
MAX_WORKSPACE_SIZE,
99
MIN_BLOCK_SIZE,
10+
TRUNCATE_LONG_AND_DOUBLE,
1011
)
1112

1213

@@ -17,3 +18,4 @@ class CompilationSettings:
1718
workspace_size: int = MAX_WORKSPACE_SIZE
1819
min_block_size: int = MIN_BLOCK_SIZE
1920
torch_executed_ops: Sequence[str] = field(default_factory=set)
21+
truncate_long_and_double: bool = TRUNCATE_LONG_AND_DOUBLE

py/torch_tensorrt/dynamo/backend/conversion.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,45 +2,42 @@
22
import torch
33
from torch_tensorrt.fx.trt_module import TRTModule
44
from torch_tensorrt import TRTModuleNext
5+
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
56
from torch_tensorrt.fx.fx2trt import (
67
InputTensorSpec,
78
TRTInterpreter,
89
)
9-
from torch_tensorrt.fx.utils import LowerPrecision
1010

1111
import tensorrt as trt
1212

1313

1414
def convert_module(
1515
module: torch.fx.GraphModule,
1616
inputs: Sequence[torch.Tensor],
17-
debug: bool = False,
18-
workspace_size: int = 20 << 30,
19-
precision: LowerPrecision = LowerPrecision.FP32,
17+
settings: CompilationSettings = CompilationSettings(),
2018
) -> Union[TRTModuleNext, TRTModule]:
2119
"""Convert an FX module to a TRT module
2220
Args:
2321
module: FX GraphModule to convert
2422
inputs: Sequence of Tensors representing inputs to the module
25-
debug: Whether to print out verbose debugging information
26-
workspace_size: Maximum workspace TRT is allowed to use for the module
27-
precision: Model Layer precision
23+
settings: Compilation settings
2824
Returns:
2925
TRTModule or TRTModuleNext
3026
"""
3127
interp = TRTInterpreter(
3228
module,
3329
InputTensorSpec.from_tensors(inputs),
3430
explicit_batch_dimension=True,
35-
logger_level=(trt.Logger.VERBOSE if debug else trt.Logger.WARNING),
31+
logger_level=(trt.Logger.VERBOSE if settings.debug else trt.Logger.WARNING),
32+
truncate_long_and_double=settings.truncate_long_and_double,
3633
)
3734

3835
r = interp.run(
39-
max_workspace_size=workspace_size,
40-
lower_precision=precision,
36+
max_workspace_size=settings.workspace_size,
37+
lower_precision=settings.precision,
4138
profiling_verbosity=(
4239
trt.ProfilingVerbosity.VERBOSE
43-
if debug
40+
if settings.debug
4441
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
4542
),
4643
)

0 commit comments

Comments
 (0)