@@ -326,7 +326,7 @@ def test_disallow_eval_train(self) -> None:
326
326
m .train ()
327
327
328
328
# After export: this is not OK
329
- m = export_for_training (m , example_inputs ).module ()
329
+ m = export_for_training (m , example_inputs , strict = True ).module ()
330
330
with self .assertRaises (NotImplementedError ):
331
331
m .eval ()
332
332
with self .assertRaises (NotImplementedError ):
@@ -380,7 +380,7 @@ def forward(self, x):
380
380
m = M ().train ()
381
381
example_inputs = (torch .randn (1 , 3 , 3 , 3 ),)
382
382
bn_train_op , bn_eval_op = self ._get_bn_train_eval_ops () # pyre-ignore[23]
383
- m = export_for_training (m , example_inputs ).module ()
383
+ m = export_for_training (m , example_inputs , strict = True ).module ()
384
384
385
385
def _assert_ops_are_correct (m : torch .fx .GraphModule , train : bool ) -> None :
386
386
bn_op = bn_train_op if train else bn_eval_op
@@ -449,10 +449,7 @@ def forward(self, x):
449
449
quantizer .set_global (operator_config )
450
450
example_inputs = (torch .randn (2 , 2 ),)
451
451
m = M ().eval ()
452
- m = export_for_training (
453
- m ,
454
- example_inputs ,
455
- ).module ()
452
+ m = export_for_training (m , example_inputs , strict = True ).module ()
456
453
weight_meta = None
457
454
for n in m .graph .nodes : # pyre-ignore[16]
458
455
if (
@@ -481,7 +478,7 @@ def test_reentrant(self) -> None:
481
478
get_symmetric_quantization_config (is_per_channel = True , is_qat = True )
482
479
)
483
480
m .conv_bn_relu = export_for_training ( # pyre-ignore[8]
484
- m .conv_bn_relu , example_inputs
481
+ m .conv_bn_relu , example_inputs , strict = True
485
482
).module ()
486
483
m .conv_bn_relu = prepare_qat_pt2e (m .conv_bn_relu , quantizer ) # pyre-ignore[6,8]
487
484
m (* example_inputs )
@@ -490,7 +487,7 @@ def test_reentrant(self) -> None:
490
487
quantizer = XNNPACKQuantizer ().set_module_type (
491
488
torch .nn .Linear , get_symmetric_quantization_config (is_per_channel = False )
492
489
)
493
- m = export_for_training (m , example_inputs ).module ()
490
+ m = export_for_training (m , example_inputs , strict = True ).module ()
494
491
m = prepare_pt2e (m , quantizer ) # pyre-ignore[6]
495
492
m = convert_pt2e (m )
496
493
@@ -553,7 +550,7 @@ def check_nn_module(node: torch.fx.Node) -> None:
553
550
)
554
551
555
552
m .conv_bn_relu = export_for_training ( # pyre-ignore[8]
556
- m .conv_bn_relu , example_inputs
553
+ m .conv_bn_relu , example_inputs , strict = True
557
554
).module ()
558
555
for node in m .conv_bn_relu .graph .nodes : # pyre-ignore[16]
559
556
if node .op not in ["placeholder" , "output" , "get_attr" ]:
@@ -568,7 +565,7 @@ def test_speed(self) -> None:
568
565
569
566
def dynamic_quantize_pt2e (model , example_inputs ) -> torch .fx .GraphModule :
570
567
torch ._dynamo .reset ()
571
- model = export_for_training (model , example_inputs ).module ()
568
+ model = export_for_training (model , example_inputs , strict = True ).module ()
572
569
# Per channel quantization for weight
573
570
# Dynamic quantization for activation
574
571
# Please read a detail: https://fburl.com/code/30zds51q
@@ -625,7 +622,7 @@ def forward(self, x):
625
622
626
623
example_inputs = (torch .randn (1 , 3 , 5 , 5 ),)
627
624
m = M ()
628
- m = export_for_training (m , example_inputs ).module ()
625
+ m = export_for_training (m , example_inputs , strict = True ).module ()
629
626
quantizer = XNNPACKQuantizer ().set_global (
630
627
get_symmetric_quantization_config (),
631
628
)
@@ -701,7 +698,6 @@ def test_save_load(self) -> None:
701
698
702
699
703
700
class TestNumericDebugger (TestCase ):
704
-
705
701
def _extract_debug_handles (self , model ) -> Dict [str , int ]:
706
702
debug_handle_map : Dict [str , int ] = {}
707
703
@@ -731,7 +727,7 @@ def _assert_node_has_debug_handle(node: torch.fx.Node) -> None:
731
727
def test_quantize_pt2e_preserve_handle (self ) -> None :
732
728
m = TestHelperModules .Conv2dThenConv1d ()
733
729
example_inputs = m .example_inputs ()
734
- ep = export_for_training (m , example_inputs )
730
+ ep = export_for_training (m , example_inputs , strict = True )
735
731
generate_numeric_debug_handle (ep )
736
732
m = ep .module ()
737
733
@@ -761,7 +757,7 @@ def test_quantize_pt2e_preserve_handle(self) -> None:
761
757
def test_extract_results_from_loggers (self ) -> None :
762
758
m = TestHelperModules .Conv2dThenConv1d ()
763
759
example_inputs = m .example_inputs ()
764
- ep = export_for_training (m , example_inputs )
760
+ ep = export_for_training (m , example_inputs , strict = True )
765
761
generate_numeric_debug_handle (ep )
766
762
m = ep .module ()
767
763
m_ref_logger = prepare_for_propagation_comparison (m ) # pyre-ignore[6]
@@ -779,18 +775,20 @@ def test_extract_results_from_loggers(self) -> None:
779
775
ref_results = extract_results_from_loggers (m_ref_logger )
780
776
quant_results = extract_results_from_loggers (m_quant_logger )
781
777
comparison_results = compare_results (
782
- ref_results , quant_results # pyre-ignore[6]
778
+ ref_results ,
779
+ quant_results , # pyre-ignore[6]
783
780
)
784
781
for node_summary in comparison_results .values ():
785
782
if len (node_summary .results ) > 0 :
786
783
self .assertGreaterEqual (
787
- node_summary .results [0 ].sqnr , 35 # pyre-ignore[6]
784
+ node_summary .results [0 ].sqnr ,
785
+ 35 , # pyre-ignore[6]
788
786
)
789
787
790
788
def test_extract_results_from_loggers_list_output (self ) -> None :
791
789
m = TestHelperModules .Conv2dWithSplit ()
792
790
example_inputs = m .example_inputs ()
793
- ep = export_for_training (m , example_inputs )
791
+ ep = export_for_training (m , example_inputs , strict = True )
794
792
generate_numeric_debug_handle (ep )
795
793
m = ep .module ()
796
794
m_ref_logger = prepare_for_propagation_comparison (m ) # pyre-ignore[6]
@@ -808,7 +806,8 @@ def test_extract_results_from_loggers_list_output(self) -> None:
808
806
ref_results = extract_results_from_loggers (m_ref_logger )
809
807
quant_results = extract_results_from_loggers (m_quant_logger )
810
808
comparison_results = compare_results (
811
- ref_results , quant_results # pyre-ignore[6]
809
+ ref_results ,
810
+ quant_results , # pyre-ignore[6]
812
811
)
813
812
for node_summary in comparison_results .values ():
814
813
if len (node_summary .results ) > 0 :
0 commit comments