Skip to content

Commit f34c25c

Browse files
committed
fix: Add support for passing through build issues
- Add support for `pass_through_build_failures` keyword arg - Add resnet18 testing to validate feature - Add minor typo fixes
1 parent 960626c commit f34c25c

File tree

8 files changed

+61
-19
lines changed

8 files changed

+61
-19
lines changed

py/torch_tensorrt/dynamo/backend/__init__.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
DEBUG,
1717
MAX_WORKSPACE_SIZE,
1818
MIN_BLOCK_SIZE,
19+
PASS_THROUGH_BUILD_FAILURES,
1920
)
2021

2122

@@ -53,7 +54,8 @@ def compile(
5354
logger.warn(
5455
"The Dynamo backend is an experimental feature, for which only the "
5556
+ "following arguments are supported: "
56-
+ "{enabled_precisions, debug, workspace_size, min_block_size, torch_executed_ops}"
57+
+ "{enabled_precisions, debug, workspace_size, min_block_size, "
58+
+ "torch_executed_ops, pass_through_build_failures}"
5759
)
5860

5961
if not isinstance(inputs, collections.abc.Sequence):
@@ -107,6 +109,7 @@ def create_backend(
107109
workspace_size: int = MAX_WORKSPACE_SIZE,
108110
min_block_size: int = MIN_BLOCK_SIZE,
109111
torch_executed_ops: Sequence[str] = set(),
112+
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
110113
**kwargs,
111114
):
112115
"""Create torch.compile backend given specified arguments
@@ -125,6 +128,7 @@ def create_backend(
125128
workspace_size=workspace_size,
126129
min_block_size=min_block_size,
127130
torch_executed_ops=torch_executed_ops,
131+
pass_through_build_failures=pass_through_build_failures,
128132
)
129133

130134
return partial(

py/torch_tensorrt/dynamo/backend/_defaults.py

+1
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@
55
DEBUG = False
66
MAX_WORKSPACE_SIZE = 20 << 30
77
MIN_BLOCK_SIZE = 5
8+
PASS_THROUGH_BUILD_FAILURES = False

py/torch_tensorrt/dynamo/backend/_settings.py

+2
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
DEBUG,
88
MAX_WORKSPACE_SIZE,
99
MIN_BLOCK_SIZE,
10+
PASS_THROUGH_BUILD_FAILURES,
1011
)
1112

1213

@@ -17,3 +18,4 @@ class CompilationSettings:
1718
workspace_size: int = MAX_WORKSPACE_SIZE
1819
min_block_size: int = MIN_BLOCK_SIZE
1920
torch_executed_ops: Sequence[str] = field(default_factory=set)
21+
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES

py/torch_tensorrt/dynamo/backend/backends.py

+16-15
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import logging
22
from typing import Sequence
33
import torch
4-
import traceback
54
from functools import partial
65
import torch._dynamo as td
76

@@ -23,14 +22,10 @@
2322
from torch._functorch.aot_autograd import aot_module_simplified, make_boxed_compiler
2423

2524

26-
<<<<<<< HEAD:py/torch_tensorrt/dynamo/backend/backends.py
27-
@td.register_backend(name="torch_tensorrt")
28-
=======
2925
logger = logging.getLogger(__name__)
3026

3127

32-
@td.register_backend(name="tensorrt")
33-
>>>>>>> 7e0f4405... feat: Prototype Module-Acceleration in Dynamo:py/torch_tensorrt/dynamo/torch_compile/backends.py
28+
@td.register_backend(name="torch_tensorrt")
3429
@fake_tensor_unsupported
3530
def torch_tensorrt_backend(
3631
gm: torch.fx.GraphModule,
@@ -85,25 +80,31 @@ def _pretraced_backend(
8580
Compiled FX GraphModule
8681
"""
8782
try:
88-
<<<<<<< HEAD:py/torch_tensorrt/dynamo/backend/backends.py
89-
trt_compiled = _compile_module(
90-
=======
9183
logger.debug("Post-AOT Autograd graph:\n" + str(gm.graph))
9284

93-
trt_compiled = compile_module(
94-
>>>>>>> 7e0f4405... feat: Prototype Module-Acceleration in Dynamo:py/torch_tensorrt/dynamo/torch_compile/backends.py
85+
trt_compiled = _compile_module(
9586
gm,
9687
sample_inputs,
9788
settings=settings,
9889
)
9990
return trt_compiled
10091
except:
101-
traceback.print_exc()
102-
print(
92+
logger.error(
10393
"FX2TRT conversion failed on the subgraph. See trace above. "
104-
+ "Returning GraphModule forward instead."
94+
+ "Returning GraphModule forward instead.",
95+
exc_info=True,
10596
)
106-
return gm.forward
97+
98+
if not settings.pass_through_build_failures:
99+
return gm.forward
100+
else:
101+
raise AssertionError(
102+
"Halting compilation on build failure since "
103+
+ "pass_through_build_failures was specified as True. "
104+
+ "To return the default Torch implementation and avoid "
105+
+ "halting compilation on engine build failures, "
106+
+ "specify pass_through_build_failures=False."
107+
)
107108

108109

109110
def _compile_module(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from torch.testing._internal.common_utils import run_tests, TestCase
2+
import torch
3+
import torch_tensorrt
4+
import torchvision.models as models
5+
from torch_tensorrt.dynamo.common_utils.test_utils import (
6+
COSINE_THRESHOLD,
7+
cosine_similarity,
8+
)
9+
10+
11+
class TestResNet18(TestCase):
12+
def test_resnet18(ir):
13+
model = models.resnet18(pretrained=True).eval().to("cuda")
14+
input_ = torch.randn((1, 3, 224, 224)).to("cuda")
15+
16+
compile_spec = {
17+
"inputs": [input_],
18+
"enabled_precisions": {torch.float},
19+
"pass_through_build_failures": True,
20+
}
21+
22+
trt_mod = torch_tensorrt.dynamo.torch_compile(model, **compile_spec)
23+
cos_sim = cosine_similarity(model(input_), trt_mod(input_))
24+
assert (
25+
cos_sim > COSINE_THRESHOLD,
26+
f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
27+
)
28+
29+
30+
if __name__ == "__main__":
31+
run_tests()

py/torch_tensorrt/dynamo/common_utils/__init__.py

Whitespace-only changes.

py/torch_tensorrt/dynamo/test/test_dynamo_backend.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77

88
from transformers import BertModel
99

10-
from utils import COSINE_THRESHOLD, cosine_similarity
10+
from torch_tensorrt.dynamo.common_utils.test_utils import (
11+
COSINE_THRESHOLD,
12+
cosine_similarity,
13+
)
1114

1215

1316
@pytest.mark.unit
@@ -30,7 +33,7 @@ def test_resnet18(ir):
3033
cos_sim = cosine_similarity(model(input), trt_mod(input))
3134
assert (
3235
cos_sim > COSINE_THRESHOLD,
33-
f"Resnet50 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
36+
f"Resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
3437
)
3538

3639
# Clean up model env
@@ -163,7 +166,7 @@ def test_resnet18_half(ir):
163166
cos_sim = cosine_similarity(model(input), trt_mod(input))
164167
assert (
165168
cos_sim > COSINE_THRESHOLD,
166-
f"Resnet50 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
169+
f"Resnet18 Half TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
167170
)
168171

169172
# Clean up model env

0 commit comments

Comments
 (0)