13
13
from torch_tensorrt .dynamo .lowering import (
14
14
apply_lowering_passes ,
15
15
get_decompositions ,
16
+ remove_sym_nodes ,
16
17
repair_input_aliasing ,
17
18
)
18
19
from torch_tensorrt .dynamo .utils import (
27
28
@td .register_backend (name = "tensorrt" ) # type: ignore[misc]
28
29
@td .register_backend (name = "torch_tensorrt" ) # type: ignore[misc]
29
30
def torch_tensorrt_backend (
30
- gm : torch .fx .GraphModule , sample_inputs : Sequence [torch . Tensor ], ** kwargs : Any
31
+ gm : torch .fx .GraphModule , sample_inputs : Sequence [Any ], ** kwargs : Any
31
32
) -> torch .nn .Module :
32
33
# Set log level at the top of compilation (torch_tensorrt.dynamo)
33
34
if (
@@ -44,15 +45,15 @@ def torch_tensorrt_backend(
44
45
45
46
@td .register_backend (name = "aot_torch_tensorrt_aten" ) # type: ignore[misc]
46
47
def aot_torch_tensorrt_aten_backend (
47
- gm : torch .fx .GraphModule , sample_inputs : Sequence [torch . Tensor ], ** kwargs : Any
48
+ gm : torch .fx .GraphModule , sample_inputs : Sequence [Any ], ** kwargs : Any
48
49
) -> torch .nn .Module :
49
50
settings = parse_dynamo_kwargs (kwargs )
50
51
return _pretraced_backend (gm , sample_inputs , settings )
51
52
52
53
53
54
def _pretraced_backend (
54
55
gm : torch .fx .GraphModule ,
55
- sample_inputs : Sequence [torch . Tensor ],
56
+ sample_inputs : Sequence [Any ],
56
57
settings : CompilationSettings = CompilationSettings (),
57
58
) -> torch .fx .GraphModule | Callable [..., Any ]:
58
59
"""Helper function to manage translation of traced FX module to TRT engines
@@ -74,10 +75,17 @@ def _pretraced_backend(
74
75
fake_mode , "allow_non_fake_inputs" , True
75
76
), fake_mode :
76
77
repair_input_aliasing (gm )
78
+
79
+ # Remove sym_int placeholders and inputs
80
+ remove_sym_nodes (gm )
81
+ torch_inputs = [
82
+ input for input in sample_inputs if isinstance (input , torch .Tensor )
83
+ ]
84
+
77
85
# Invoke AOTAutograd to translate operators to aten
78
86
gm = aot_export_joint_simple (
79
87
gm ,
80
- sample_inputs ,
88
+ torch_inputs ,
81
89
trace_joint = False ,
82
90
decompositions = get_decompositions (
83
91
settings .enable_experimental_decompositions
@@ -86,10 +94,10 @@ def _pretraced_backend(
86
94
87
95
logger .debug ("Post-AOT Autograd graph:\n " + str (gm .graph ))
88
96
89
- gm = apply_lowering_passes (gm , sample_inputs )
97
+ gm = apply_lowering_passes (gm , torch_inputs )
90
98
91
99
torchtrt_inputs = prepare_inputs (
92
- sample_inputs , disable_memory_format_check = True
100
+ torch_inputs , disable_memory_format_check = True
93
101
)
94
102
trt_compiled = compile_module (
95
103
gm ,
0 commit comments