Skip to content

Commit 22cf701

Browse files
peri044gs-olive
andauthored
chore: Switch converter tests to generate standalone ops using fx.symbolic_trace (#2361)
Signed-off-by: Dheeraj Peri <[email protected]> Co-authored-by: gs-olive <[email protected]>
1 parent 4cffd6e commit 22cf701

File tree

80 files changed

+421
-817
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

80 files changed

+421
-817
lines changed

tests/py/dynamo/conversion/harness.py

+23-82
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
from typing import Callable, List, Optional, Set, Tuple
55

66
import torch
7-
import torch_tensorrt.fx.tracer.dispatch_tracer.aten_tracer as aten_tracer
8-
from torch.fx.passes.infra.pass_base import PassResult
97
from torch.testing._internal.common_utils import TestCase
108
from torch_tensorrt import Input
119
from torch_tensorrt.dynamo._settings import CompilationSettings
@@ -14,19 +12,6 @@
1412
from torch_tensorrt.dynamo.conversion import TRTInterpreter
1513
from torch_tensorrt.dynamo.lowering import apply_lowering_passes
1614
from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
17-
from torch_tensorrt.fx.passes.lower_basic_pass_aten import (
18-
compose_bmm,
19-
compose_chunk,
20-
compose_getitem_slice,
21-
remove_ops,
22-
replace_aten_op_with_indices,
23-
replace_aten_reshape_alias_with_replace,
24-
replace_builtin_ops,
25-
replace_native_layernorm_with_layernorm,
26-
replace_transpose_mm_op_with_linear,
27-
run_const_fold,
28-
)
29-
from torch_tensorrt.fx.passes.pass_utils import chain_passes
3015

3116
_LOGGER: logging.Logger = logging.getLogger(__name__)
3217

@@ -62,8 +47,6 @@ def run_test(
6247
self,
6348
mod,
6449
inputs,
65-
expected_ops,
66-
unexpected_ops,
6750
interpreter,
6851
rtol,
6952
atol,
@@ -76,10 +59,6 @@ def run_test(
7659
cuda_inputs.append(i.cuda())
7760

7861
mod.eval()
79-
if len(expected_ops):
80-
self.assert_has_op(mod, expected_ops)
81-
if unexpected_ops:
82-
self.assert_unexpected_op(mod, unexpected_ops)
8362
start = time.perf_counter()
8463
interpreter_result = interpreter.run(precision=precision)
8564
sec = time.perf_counter() - start
@@ -215,75 +194,44 @@ def generate_graph(
215194
self,
216195
mod: torch.nn.Module,
217196
original_inputs: List[torch.Tensor],
218-
expected_ops: Set[Callable],
219-
unexpected_ops: Optional[Set[Callable]] = None,
220-
customized_passes: List[Callable] = None,
221-
disable_passes: bool = False,
197+
use_dynamo_tracer: bool,
198+
enable_passes: bool,
222199
):
223-
# Torchdynamo+aot proxytensor tracer
224-
# Below are common passes
225-
passes_list = [
226-
compose_bmm,
227-
compose_chunk,
228-
compose_getitem_slice,
229-
replace_aten_reshape_alias_with_replace,
230-
replace_aten_op_with_indices,
231-
replace_transpose_mm_op_with_linear, # after compose_bmm
232-
replace_native_layernorm_with_layernorm,
233-
remove_ops,
234-
replace_builtin_ops, # after replace_native_layernorm_with_layernorm
235-
]
236-
# Combine with customized passes specific to any model
237-
if customized_passes:
238-
passes_list.extend(customized_passes)
239-
240-
if disable_passes:
241-
passes_list = []
242-
243-
fx_module, _ = aten_tracer.trace(mod, original_inputs)
244-
for passes in passes_list:
245-
pr: PassResult = passes(fx_module)
246-
fx_module = pr.graph_module
247-
fx_module(*original_inputs)
248-
249-
fx_module = run_const_fold(fx_module)
200+
if use_dynamo_tracer:
201+
fx_module = torch._dynamo.export(
202+
mod,
203+
*original_inputs,
204+
aten_graph=True,
205+
assume_static_by_default=True,
206+
tracing_mode="real",
207+
).graph_module
208+
else:
209+
fx_module = torch.fx.symbolic_trace(mod)
210+
if enable_passes:
211+
fx_module = apply_lowering_passes(fx_module, original_inputs)
250212
_LOGGER.info(f"FX graph= {fx_module.graph}")
251-
252-
if len(expected_ops):
253-
self.assert_has_op(fx_module, expected_ops)
254-
if unexpected_ops:
255-
self.assert_unexpected_op(fx_module, unexpected_ops)
256-
257213
return fx_module
258214

259215
def run_test(
260216
self,
261217
mod,
262218
inputs,
263-
expected_ops,
264-
unexpected_ops=None,
265-
apply_passes=None,
266219
rtol=1e-03,
267220
atol=1e-03,
268221
precision=torch.float,
269222
check_dtype=True,
270-
disable_passes=False,
271223
output_dtypes=None,
224+
use_dynamo_tracer=False,
225+
enable_passes=False,
272226
):
273227
mod.eval()
274228
mod = self.generate_graph(
275229
mod,
276230
inputs,
277-
expected_ops,
278-
unexpected_ops,
279-
None,
280-
disable_passes=disable_passes,
231+
use_dynamo_tracer=use_dynamo_tracer,
232+
enable_passes=enable_passes,
281233
)
282234

283-
if apply_passes is not None:
284-
pass_tracer = chain_passes(*apply_passes)
285-
mod = pass_tracer(mod, inputs)
286-
287235
# Previous instance of the interpreter auto-casted 64-bit inputs
288236
# We replicate this behavior here
289237
compilation_settings = CompilationSettings(truncate_long_and_double=True)
@@ -297,8 +245,6 @@ def run_test(
297245
super().run_test(
298246
mod,
299247
inputs,
300-
expected_ops,
301-
unexpected_ops,
302248
interp,
303249
rtol,
304250
atol,
@@ -310,22 +256,19 @@ def run_test_with_dynamic_shape(
310256
self,
311257
mod,
312258
input_specs,
313-
expected_ops,
314-
unexpected_ops=None,
315259
rtol=1e-03,
316260
atol=1e-03,
317-
disable_passes=False,
318261
output_dtypes=None,
262+
use_dynamo_tracer=False,
263+
enable_passes=False,
319264
):
320265
mod.eval()
321266
inputs = [spec.example_tensor("opt_shape") for spec in input_specs]
322267
mod = self.generate_graph(
323268
mod,
324269
inputs,
325-
expected_ops,
326-
unexpected_ops,
327-
None,
328-
disable_passes=disable_passes,
270+
use_dynamo_tracer=use_dynamo_tracer,
271+
enable_passes=enable_passes,
329272
)
330273

331274
# Previous instance of the interpreter auto-casted 64-bit inputs
@@ -341,6 +284,4 @@ def run_test_with_dynamic_shape(
341284
# Since the lowering is based on optimal shape. We need to test with
342285
# different shape(for ex. max shape) for testing dynamic shape
343286
inputs_max = [spec.example_tensor("max_shape") for spec in input_specs]
344-
super().run_test(
345-
mod, inputs_max, expected_ops, unexpected_ops, interp, rtol, atol
346-
)
287+
super().run_test(mod, inputs_max, interp, rtol, atol)

tests/py/dynamo/conversion/test_abs_aten.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@ class TestAbsConverter(DispatchTestCase):
1818
def test_abs_float(self, input_shape, dtype):
1919
class abs(nn.Module):
2020
def forward(self, input):
21-
return torch.abs(input)
21+
return torch.ops.aten.abs.default(input)
2222

2323
inputs = [torch.randn(input_shape, dtype=dtype)]
2424
self.run_test(
2525
abs(),
2626
inputs,
27-
expected_ops={torch.ops.aten.abs.default},
2827
)
2928

3029
@parameterized.expand(
@@ -37,13 +36,12 @@ def forward(self, input):
3736
def test_abs_int(self, input_shape, dtype, low, high):
3837
class abs(nn.Module):
3938
def forward(self, input):
40-
return torch.abs(input)
39+
return torch.ops.aten.abs.default(input)
4140

4241
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
4342
self.run_test(
4443
abs(),
4544
inputs,
46-
expected_ops={torch.ops.aten.abs.default},
4745
output_dtypes=[torch.int],
4846
)
4947

tests/py/dynamo/conversion/test_acos_aten.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@ class TestAcosConverter(DispatchTestCase):
1818
def test_acos_float(self, input_shape, dtype):
1919
class acos(nn.Module):
2020
def forward(self, input):
21-
return torch.acos(input)
21+
return torch.ops.aten.acos.default(input)
2222

2323
inputs = [torch.randn(input_shape, dtype=dtype)]
2424
self.run_test(
2525
acos(),
2626
inputs,
27-
expected_ops={torch.ops.aten.acos.default},
2827
)
2928

3029
@parameterized.expand(
@@ -37,13 +36,12 @@ def forward(self, input):
3736
def test_acos_int(self, input_shape, dtype, low, high):
3837
class acos(nn.Module):
3938
def forward(self, input):
40-
return torch.acos(input)
39+
return torch.ops.aten.acos.default(input)
4140

4241
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
4342
self.run_test(
4443
acos(),
4544
inputs,
46-
expected_ops={torch.ops.aten.acos.default},
4745
)
4846

4947

tests/py/dynamo/conversion/test_acosh_aten.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@ class TestAcoshConverter(DispatchTestCase):
1818
def test_acosh_float(self, input_shape, dtype):
1919
class acosh(nn.Module):
2020
def forward(self, input):
21-
return torch.acosh(input)
21+
return torch.ops.aten.acosh.default(input)
2222

2323
inputs = [torch.randn(input_shape, dtype=dtype)]
2424
self.run_test(
2525
acosh(),
2626
inputs,
27-
expected_ops={torch.ops.aten.acosh.default},
2827
)
2928

3029
@parameterized.expand(
@@ -37,13 +36,12 @@ def forward(self, input):
3736
def test_acosh_int(self, input_shape, dtype, low, high):
3837
class acosh(nn.Module):
3938
def forward(self, input):
40-
return torch.acosh(input)
39+
return torch.ops.aten.acosh.default(input)
4140

4241
inputs = [torch.randint(low, high, input_shape, dtype=dtype)]
4342
self.run_test(
4443
acosh(),
4544
inputs,
46-
expected_ops={torch.ops.aten.acosh.default},
4745
)
4846

4947

tests/py/dynamo/conversion/test_adaptive_avgpool_aten.py

+5-23
Original file line numberDiff line numberDiff line change
@@ -7,27 +7,11 @@
77

88

99
class TestAdaptiveAvgPoolConverter(DispatchTestCase):
10-
def test_adaptive_avgpool_mean(self):
11-
class TestModule(torch.nn.Module):
12-
def __init__(self):
13-
super().__init__()
14-
self.pool = torch.nn.AdaptiveAvgPool2d((1, 1))
15-
16-
def forward(self, x):
17-
return self.pool(x)
18-
19-
inputs = [torch.randn(1, 3, 256, 256)]
20-
self.run_test(
21-
TestModule(),
22-
inputs,
23-
expected_ops={torch.ops.aten.mean.dim},
24-
)
25-
2610
@parameterized.expand(
2711
[
2812
((64, 64),),
2913
((128, 64),),
30-
(64,),
14+
# (64,), This case has been there in previous code but it isn't a valid pytorch code.
3115
]
3216
)
3317
def test_adaptive_avgpool(
@@ -46,7 +30,7 @@ def forward(self, x):
4630
self.run_test(
4731
TestModule(),
4832
inputs,
49-
expected_ops={torch.ops.aten._adaptive_avg_pool2d.default},
33+
use_dynamo_tracer=True,
5034
)
5135

5236
def test_adaptive_avgpool_with_dynamic_shape(self):
@@ -66,9 +50,7 @@ def forward(self, x):
6650
),
6751
]
6852
self.run_test_with_dynamic_shape(
69-
TestModule(),
70-
input_specs,
71-
expected_ops={torch.ops.aten._adaptive_avg_pool2d.default},
53+
TestModule(), input_specs, use_dynamo_tracer=True
7254
)
7355

7456
@parameterized.expand(
@@ -94,7 +76,7 @@ def forward(self, x):
9476
self.run_test(
9577
TestModule(),
9678
inputs,
97-
expected_ops={torch.ops.aten._adaptive_avg_pool3d.default},
79+
use_dynamo_tracer=True,
9880
)
9981

10082
def test_adaptive_avgpool3d_with_dynamic_shape(self):
@@ -118,7 +100,7 @@ def forward(self, x):
118100
self.run_test_with_dynamic_shape(
119101
TestModule(),
120102
input_specs,
121-
expected_ops={torch.ops.aten._adaptive_avg_pool3d.default},
103+
use_dynamo_tracer=True,
122104
)
123105

124106
# Testing with shape(-1, -1, -1, -1) results into error: "AdaptiveAvgPool2d and AdaptiveAvgPool3d currently doesn't support dynamic shapes for last two dims."

0 commit comments

Comments
 (0)