19
19
from fx2trt_oss .fx .utils import LowerPrecision
20
20
from fx2trt_oss .tracer .acc_tracer import acc_ops
21
21
from torch .ao .quantization import default_qconfig
22
- from torch .ao .quantization ._quantize_fx_do_not_use import (
23
- _convert_fx_do_not_use ,
22
+ from torch .ao .quantization .quantize_fx import (
23
+ convert_fx ,
24
24
)
25
25
from torch .ao .quantization .fx .backend_config .observation_type import ObservationType
26
26
from torch .ao .quantization .fx .match_utils import (
@@ -95,7 +95,7 @@ def forward(self, x):
95
95
prepare_custom_config_dict = prepare_custom_config_dict )
96
96
self .checkGraphModuleNodes (mp , expected_node_occurrence = prepare_count_check )
97
97
mp (torch .randn (1 , 1 , 4 , 4 ))
98
- mq = _convert_fx_do_not_use (
98
+ mq = convert_fx (
99
99
mp , is_reference = True , backend_config_dict = self .trt_backend_config_dict )
100
100
self .checkGraphModuleNodes (mq , expected_node_occurrence = convert_count_check )
101
101
@@ -229,15 +229,15 @@ def forward(self, x):
229
229
self .checkGraphModuleNodes (m .standalone , expected_node_occurrence = standalone_prepare_count_check )
230
230
231
231
# check converted/quantized model
232
- m = _convert_fx_do_not_use (m , is_reference = True , backend_config_dict = backend_config_dict )
232
+ m = convert_fx (m , is_reference = True , backend_config_dict = backend_config_dict )
233
233
self .checkGraphModuleNodes (m , expected_node_occurrence = convert_count_check )
234
234
self .checkGraphModuleNodes (m .standalone , expected_node_occurrence = standalone_convert_count_check )
235
235
res = m (data )
236
236
237
237
# quantize the reference model
238
238
ref_m = prepare_fx (original_ref_m_copy , qconfig_dict , backend_config_dict = backend_config_dict )
239
239
ref_m (data )
240
- ref_m = _convert_fx_do_not_use (ref_m , is_reference = True , backend_config_dict = backend_config_dict )
240
+ ref_m = convert_fx (ref_m , is_reference = True , backend_config_dict = backend_config_dict )
241
241
ref_res = ref_m (data )
242
242
self .assertEqual (res , ref_res )
243
243
@@ -395,7 +395,7 @@ def _test_module(
395
395
self .checkGraphModuleNodes (prepared , expected_node_occurrence = no_prepare )
396
396
# calibration
397
397
prepared (* inputs )
398
- quantized = _convert_fx_do_not_use (
398
+ quantized = convert_fx (
399
399
prepared , is_reference = True , backend_config_dict = self .trt_backend_config_dict )
400
400
self .checkGraphModuleNodes (quantized , expected_node_occurrence = no_convert )
401
401
# lower to trt
@@ -500,7 +500,7 @@ def forward(self, x):
500
500
501
501
m = M ().eval ()
502
502
m = prepare_fx (m , {"" : default_qconfig })
503
- m = _convert_fx_do_not_use (
503
+ m = convert_fx (
504
504
m , is_reference = True , backend_config_dict = self .trt_backend_config_dict )
505
505
expected_occurrence = {
506
506
ns .call_function (torch .quantize_per_tensor ): 5 ,
@@ -530,7 +530,7 @@ def forward(self, x):
530
530
prepared = prepare_fx (m , {"" : trt_unsupported_qconfig }, backend_config_dict = self .trt_backend_config_dict )
531
531
# calibration
532
532
prepared (linear_module_input )
533
- quantized = _convert_fx_do_not_use (
533
+ quantized = convert_fx (
534
534
prepared , is_reference = True , backend_config_dict = self .trt_backend_config_dict )
535
535
node_occurrence = {
536
536
ns .call_function (torch .quantize_per_tensor ): 0 ,
@@ -553,7 +553,7 @@ def forward(self, x):
553
553
prepared = prepare_fx (
554
554
m , {"" : self .qconfig }, backend_config_dict = self .trt_backend_config_dict )
555
555
self .assertTrue (len (dict (prepared .named_children ())) == 1 )
556
- quantized = _convert_fx_do_not_use (
556
+ quantized = convert_fx (
557
557
prepared , is_reference = True , backend_config_dict = self .trt_backend_config_dict )
558
558
node_occurrence = {
559
559
ns .call_function (torch .quantize_per_tensor ): 2 ,
@@ -582,7 +582,7 @@ def forward(self, x):
582
582
ns .call_module (torch .ao .quantization .HistogramObserver ): 2 ,
583
583
}
584
584
self .checkGraphModuleNodes (prepared , expected_node_occurrence = node_occurrence )
585
- quantized = _convert_fx_do_not_use (
585
+ quantized = convert_fx (
586
586
prepared , is_reference = True , backend_config_dict = self .trt_backend_config_dict )
587
587
node_occurrence = {
588
588
# input activation, output activation and weight
@@ -630,7 +630,7 @@ def forward(self, x, y):
630
630
ns .call_module (torch .ao .quantization .HistogramObserver ): 3 ,
631
631
}
632
632
self .checkGraphModuleNodes (m , expected_node_occurrence = node_occurrence )
633
- m = _convert_fx_do_not_use (m , is_reference = True , backend_config_dict = modified_backend_config_dict )
633
+ m = convert_fx (m , is_reference = True , backend_config_dict = modified_backend_config_dict )
634
634
node_occurrence = {
635
635
ns .call_function (torch .quantize_per_tensor ): 3 ,
636
636
ns .call_method ("dequantize" ): 3 ,
@@ -725,7 +725,7 @@ def forward(self, x, y):
725
725
ns .call_module (torch .ao .quantization .HistogramObserver ): 1 ,
726
726
}
727
727
self .checkGraphModuleNodes (m .standalone , expected_node_occurrence = standalone_node_occurrence )
728
- m = _convert_fx_do_not_use (m , is_reference = True , backend_config_dict = backend_config_dict )
728
+ m = convert_fx (m , is_reference = True , backend_config_dict = backend_config_dict )
729
729
node_occurrence = {
730
730
# two inputs for standalone module
731
731
ns .call_function (torch .quantize_per_tensor ): 2 ,
@@ -757,7 +757,7 @@ def forward(self, x):
757
757
inputs = torch .rand (8 , 5 )
758
758
759
759
prepared = prepare_fx (model , {"" : self .qconfig }, backend_config_dict = self .trt_backend_config_dict )
760
- quantized = _convert_fx_do_not_use (
760
+ quantized = convert_fx (
761
761
prepared , is_reference = True , backend_config_dict = self .trt_backend_config_dict )
762
762
763
763
model = acc_tracer .trace (quantized , inputs )
0 commit comments