4
4
from typing import Callable , List , Optional , Set , Tuple
5
5
6
6
import torch
7
- import torch_tensorrt .fx .tracer .dispatch_tracer .aten_tracer as aten_tracer
8
- from torch .fx .passes .infra .pass_base import PassResult
9
7
from torch .testing ._internal .common_utils import TestCase
10
8
from torch_tensorrt import Input
11
9
from torch_tensorrt .dynamo ._settings import CompilationSettings
14
12
from torch_tensorrt .dynamo .conversion import TRTInterpreter
15
13
from torch_tensorrt .dynamo .lowering import apply_lowering_passes
16
14
from torch_tensorrt .dynamo .runtime import PythonTorchTensorRTModule
17
- from torch_tensorrt .fx .passes .lower_basic_pass_aten import (
18
- compose_bmm ,
19
- compose_chunk ,
20
- compose_getitem_slice ,
21
- remove_ops ,
22
- replace_aten_op_with_indices ,
23
- replace_aten_reshape_alias_with_replace ,
24
- replace_builtin_ops ,
25
- replace_native_layernorm_with_layernorm ,
26
- replace_transpose_mm_op_with_linear ,
27
- run_const_fold ,
28
- )
29
- from torch_tensorrt .fx .passes .pass_utils import chain_passes
30
15
31
16
_LOGGER : logging .Logger = logging .getLogger (__name__ )
32
17
@@ -62,8 +47,6 @@ def run_test(
62
47
self ,
63
48
mod ,
64
49
inputs ,
65
- expected_ops ,
66
- unexpected_ops ,
67
50
interpreter ,
68
51
rtol ,
69
52
atol ,
@@ -76,10 +59,6 @@ def run_test(
76
59
cuda_inputs .append (i .cuda ())
77
60
78
61
mod .eval ()
79
- if len (expected_ops ):
80
- self .assert_has_op (mod , expected_ops )
81
- if unexpected_ops :
82
- self .assert_unexpected_op (mod , unexpected_ops )
83
62
start = time .perf_counter ()
84
63
interpreter_result = interpreter .run (precision = precision )
85
64
sec = time .perf_counter () - start
@@ -215,75 +194,44 @@ def generate_graph(
215
194
self ,
216
195
mod : torch .nn .Module ,
217
196
original_inputs : List [torch .Tensor ],
218
- expected_ops : Set [Callable ],
219
- unexpected_ops : Optional [Set [Callable ]] = None ,
220
- customized_passes : List [Callable ] = None ,
221
- disable_passes : bool = False ,
197
+ use_dynamo_tracer : bool ,
198
+ enable_passes : bool ,
222
199
):
223
- # Torchdynamo+aot proxytensor tracer
224
- # Below are common passes
225
- passes_list = [
226
- compose_bmm ,
227
- compose_chunk ,
228
- compose_getitem_slice ,
229
- replace_aten_reshape_alias_with_replace ,
230
- replace_aten_op_with_indices ,
231
- replace_transpose_mm_op_with_linear , # after compose_bmm
232
- replace_native_layernorm_with_layernorm ,
233
- remove_ops ,
234
- replace_builtin_ops , # after replace_native_layernorm_with_layernorm
235
- ]
236
- # Combine with customized passes specific to any model
237
- if customized_passes :
238
- passes_list .extend (customized_passes )
239
-
240
- if disable_passes :
241
- passes_list = []
242
-
243
- fx_module , _ = aten_tracer .trace (mod , original_inputs )
244
- for passes in passes_list :
245
- pr : PassResult = passes (fx_module )
246
- fx_module = pr .graph_module
247
- fx_module (* original_inputs )
248
-
249
- fx_module = run_const_fold (fx_module )
200
+ if use_dynamo_tracer :
201
+ fx_module = torch ._dynamo .export (
202
+ mod ,
203
+ * original_inputs ,
204
+ aten_graph = True ,
205
+ assume_static_by_default = True ,
206
+ tracing_mode = "real" ,
207
+ ).graph_module
208
+ else :
209
+ fx_module = torch .fx .symbolic_trace (mod )
210
+ if enable_passes :
211
+ fx_module = apply_lowering_passes (fx_module , original_inputs )
250
212
_LOGGER .info (f"FX graph= { fx_module .graph } " )
251
-
252
- if len (expected_ops ):
253
- self .assert_has_op (fx_module , expected_ops )
254
- if unexpected_ops :
255
- self .assert_unexpected_op (fx_module , unexpected_ops )
256
-
257
213
return fx_module
258
214
259
215
def run_test (
260
216
self ,
261
217
mod ,
262
218
inputs ,
263
- expected_ops ,
264
- unexpected_ops = None ,
265
- apply_passes = None ,
266
219
rtol = 1e-03 ,
267
220
atol = 1e-03 ,
268
221
precision = torch .float ,
269
222
check_dtype = True ,
270
- disable_passes = False ,
271
223
output_dtypes = None ,
224
+ use_dynamo_tracer = False ,
225
+ enable_passes = False ,
272
226
):
273
227
mod .eval ()
274
228
mod = self .generate_graph (
275
229
mod ,
276
230
inputs ,
277
- expected_ops ,
278
- unexpected_ops ,
279
- None ,
280
- disable_passes = disable_passes ,
231
+ use_dynamo_tracer = use_dynamo_tracer ,
232
+ enable_passes = enable_passes ,
281
233
)
282
234
283
- if apply_passes is not None :
284
- pass_tracer = chain_passes (* apply_passes )
285
- mod = pass_tracer (mod , inputs )
286
-
287
235
# Previous instance of the interpreter auto-casted 64-bit inputs
288
236
# We replicate this behavior here
289
237
compilation_settings = CompilationSettings (truncate_long_and_double = True )
@@ -297,8 +245,6 @@ def run_test(
297
245
super ().run_test (
298
246
mod ,
299
247
inputs ,
300
- expected_ops ,
301
- unexpected_ops ,
302
248
interp ,
303
249
rtol ,
304
250
atol ,
@@ -310,22 +256,19 @@ def run_test_with_dynamic_shape(
310
256
self ,
311
257
mod ,
312
258
input_specs ,
313
- expected_ops ,
314
- unexpected_ops = None ,
315
259
rtol = 1e-03 ,
316
260
atol = 1e-03 ,
317
- disable_passes = False ,
318
261
output_dtypes = None ,
262
+ use_dynamo_tracer = False ,
263
+ enable_passes = False ,
319
264
):
320
265
mod .eval ()
321
266
inputs = [spec .example_tensor ("opt_shape" ) for spec in input_specs ]
322
267
mod = self .generate_graph (
323
268
mod ,
324
269
inputs ,
325
- expected_ops ,
326
- unexpected_ops ,
327
- None ,
328
- disable_passes = disable_passes ,
270
+ use_dynamo_tracer = use_dynamo_tracer ,
271
+ enable_passes = enable_passes ,
329
272
)
330
273
331
274
# Previous instance of the interpreter auto-casted 64-bit inputs
@@ -341,6 +284,4 @@ def run_test_with_dynamic_shape(
341
284
# Since the lowering is based on optimal shape. We need to test with
342
285
# different shape(for ex. max shape) for testing dynamic shape
343
286
inputs_max = [spec .example_tensor ("max_shape" ) for spec in input_specs ]
344
- super ().run_test (
345
- mod , inputs_max , expected_ops , unexpected_ops , interp , rtol , atol
346
- )
287
+ super ().run_test (mod , inputs_max , interp , rtol , atol )
0 commit comments