10
10
import sys
11
11
import types
12
12
13
+ import astor
13
14
import atomicwrites
14
15
15
16
from _pytest ._io .saferepr import saferepr
@@ -134,7 +135,7 @@ def exec_module(self, module):
134
135
co = _read_pyc (fn , pyc , state .trace )
135
136
if co is None :
136
137
state .trace ("rewriting {!r}" .format (fn ))
137
- source_stat , co = _rewrite_test (fn )
138
+ source_stat , co = _rewrite_test (fn , self . config )
138
139
if write :
139
140
self ._writing_pyc = True
140
141
try :
@@ -278,13 +279,13 @@ def _write_pyc(state, co, source_stat, pyc):
278
279
return True
279
280
280
281
281
- def _rewrite_test (fn ):
282
+ def _rewrite_test (fn , config ):
282
283
"""read and rewrite *fn* and return the code object."""
283
284
stat = os .stat (fn )
284
285
with open (fn , "rb" ) as f :
285
286
source = f .read ()
286
287
tree = ast .parse (source , filename = fn )
287
- rewrite_asserts (tree , fn )
288
+ rewrite_asserts (tree , fn , config )
288
289
co = compile (tree , fn , "exec" , dont_inherit = True )
289
290
return stat , co
290
291
@@ -326,9 +327,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
326
327
return co
327
328
328
329
329
- def rewrite_asserts (mod , module_path = None ):
330
+ def rewrite_asserts (mod , module_path = None , config = None ):
330
331
"""Rewrite the assert statements in mod."""
331
- AssertionRewriter (module_path ).run (mod )
332
+ AssertionRewriter (module_path , config ).run (mod )
332
333
333
334
334
335
def _saferepr (obj ):
@@ -401,6 +402,17 @@ def _call_reprcompare(ops, results, expls, each_obj):
401
402
return expl
402
403
403
404
405
+ def _call_assertion_pass (lineno , orig , expl ):
406
+ if util ._assertion_pass is not None :
407
+ util ._assertion_pass (lineno = lineno , orig = orig , expl = expl )
408
+
409
+
410
+ def _check_if_assertion_pass_impl ():
411
+ """Checks if any plugins implement the pytest_assertion_pass hook
412
+ in order not to generate explanation unecessarily (might be expensive)"""
413
+ return True if util ._assertion_pass else False
414
+
415
+
404
416
unary_map = {ast .Not : "not %s" , ast .Invert : "~%s" , ast .USub : "-%s" , ast .UAdd : "+%s" }
405
417
406
418
binop_map = {
@@ -473,7 +485,8 @@ class AssertionRewriter(ast.NodeVisitor):
473
485
original assert statement: it rewrites the test of an assertion
474
486
to provide intermediate values and replace it with an if statement
475
487
which raises an assertion error with a detailed explanation in
476
- case the expression is false.
488
+ case the expression is false and calls pytest_assertion_pass hook
489
+ if expression is true.
477
490
478
491
For this .visit_Assert() uses the visitor pattern to visit all the
479
492
AST nodes of the ast.Assert.test field, each visit call returning
@@ -491,9 +504,10 @@ class AssertionRewriter(ast.NodeVisitor):
491
504
by statements. Variables are created using .variable() and
492
505
have the form of "@py_assert0".
493
506
494
- :on_failure: The AST statements which will be executed if the
495
- assertion test fails. This is the code which will construct
496
- the failure message and raises the AssertionError.
507
+ :expl_stmts: The AST statements which will be executed to get
508
+ data from the assertion. This is the code which will construct
509
+ the detailed assertion message that is used in the AssertionError
510
+ or for the pytest_assertion_pass hook.
497
511
498
512
:explanation_specifiers: A dict filled by .explanation_param()
499
513
with %-formatting placeholders and their corresponding
@@ -509,9 +523,16 @@ class AssertionRewriter(ast.NodeVisitor):
509
523
510
524
"""
511
525
512
- def __init__ (self , module_path ):
526
+ def __init__ (self , module_path , config ):
513
527
super ().__init__ ()
514
528
self .module_path = module_path
529
+ self .config = config
530
+ if config is not None :
531
+ self .enable_assertion_pass_hook = config .getini (
532
+ "enable_assertion_pass_hook"
533
+ )
534
+ else :
535
+ self .enable_assertion_pass_hook = False
515
536
516
537
def run (self , mod ):
517
538
"""Find all assert statements in *mod* and rewrite them."""
@@ -642,7 +663,7 @@ def pop_format_context(self, expl_expr):
642
663
643
664
The expl_expr should be an ast.Str instance constructed from
644
665
the %-placeholders created by .explanation_param(). This will
645
- add the required code to format said string to .on_failure and
666
+ add the required code to format said string to .expl_stmts and
646
667
return the ast.Name instance of the formatted string.
647
668
648
669
"""
@@ -653,7 +674,9 @@ def pop_format_context(self, expl_expr):
653
674
format_dict = ast .Dict (keys , list (current .values ()))
654
675
form = ast .BinOp (expl_expr , ast .Mod (), format_dict )
655
676
name = "@py_format" + str (next (self .variable_counter ))
656
- self .on_failure .append (ast .Assign ([ast .Name (name , ast .Store ())], form ))
677
+ if self .enable_assertion_pass_hook :
678
+ self .format_variables .append (name )
679
+ self .expl_stmts .append (ast .Assign ([ast .Name (name , ast .Store ())], form ))
657
680
return ast .Name (name , ast .Load ())
658
681
659
682
def generic_visit (self , node ):
@@ -687,8 +710,12 @@ def visit_Assert(self, assert_):
687
710
self .statements = []
688
711
self .variables = []
689
712
self .variable_counter = itertools .count ()
713
+
714
+ if self .enable_assertion_pass_hook :
715
+ self .format_variables = []
716
+
690
717
self .stack = []
691
- self .on_failure = []
718
+ self .expl_stmts = []
692
719
self .push_format_context ()
693
720
# Rewrite assert into a bunch of statements.
694
721
top_condition , explanation = self .visit (assert_ .test )
@@ -699,24 +726,77 @@ def visit_Assert(self, assert_):
699
726
top_condition , module_path = self .module_path , lineno = assert_ .lineno
700
727
)
701
728
)
702
- # Create failure message.
703
- body = self .on_failure
704
- negation = ast .UnaryOp (ast .Not (), top_condition )
705
- self .statements .append (ast .If (negation , body , []))
706
- if assert_ .msg :
707
- assertmsg = self .helper ("_format_assertmsg" , assert_ .msg )
708
- explanation = "\n >assert " + explanation
709
- else :
710
- assertmsg = ast .Str ("" )
711
- explanation = "assert " + explanation
712
- template = ast .BinOp (assertmsg , ast .Add (), ast .Str (explanation ))
713
- msg = self .pop_format_context (template )
714
- fmt = self .helper ("_format_explanation" , msg )
715
- err_name = ast .Name ("AssertionError" , ast .Load ())
716
- exc = ast .Call (err_name , [fmt ], [])
717
- raise_ = ast .Raise (exc , None )
718
-
719
- body .append (raise_ )
729
+
730
+ if self .enable_assertion_pass_hook : # Experimental pytest_assertion_pass hook
731
+ negation = ast .UnaryOp (ast .Not (), top_condition )
732
+ msg = self .pop_format_context (ast .Str (explanation ))
733
+
734
+ # Failed
735
+ if assert_ .msg :
736
+ assertmsg = self .helper ("_format_assertmsg" , assert_ .msg )
737
+ gluestr = "\n >assert "
738
+ else :
739
+ assertmsg = ast .Str ("" )
740
+ gluestr = "assert "
741
+ err_explanation = ast .BinOp (ast .Str (gluestr ), ast .Add (), msg )
742
+ err_msg = ast .BinOp (assertmsg , ast .Add (), err_explanation )
743
+ err_name = ast .Name ("AssertionError" , ast .Load ())
744
+ fmt = self .helper ("_format_explanation" , err_msg )
745
+ exc = ast .Call (err_name , [fmt ], [])
746
+ raise_ = ast .Raise (exc , None )
747
+ statements_fail = []
748
+ statements_fail .extend (self .expl_stmts )
749
+ statements_fail .append (raise_ )
750
+
751
+ # Passed
752
+ fmt_pass = self .helper ("_format_explanation" , msg )
753
+ orig = astor .to_source (assert_ .test ).rstrip ("\n " ).lstrip ("(" ).rstrip (")" )
754
+ hook_call_pass = ast .Expr (
755
+ self .helper (
756
+ "_call_assertion_pass" ,
757
+ ast .Num (assert_ .lineno ),
758
+ ast .Str (orig ),
759
+ fmt_pass ,
760
+ )
761
+ )
762
+ # If any hooks implement assert_pass hook
763
+ hook_impl_test = ast .If (
764
+ self .helper ("_check_if_assertion_pass_impl" ),
765
+ self .expl_stmts + [hook_call_pass ],
766
+ [],
767
+ )
768
+ statements_pass = [hook_impl_test ]
769
+
770
+ # Test for assertion condition
771
+ main_test = ast .If (negation , statements_fail , statements_pass )
772
+ self .statements .append (main_test )
773
+ if self .format_variables :
774
+ variables = [
775
+ ast .Name (name , ast .Store ()) for name in self .format_variables
776
+ ]
777
+ clear_format = ast .Assign (variables , _NameConstant (None ))
778
+ self .statements .append (clear_format )
779
+
780
+ else : # Original assertion rewriting
781
+ # Create failure message.
782
+ body = self .expl_stmts
783
+ negation = ast .UnaryOp (ast .Not (), top_condition )
784
+ self .statements .append (ast .If (negation , body , []))
785
+ if assert_ .msg :
786
+ assertmsg = self .helper ("_format_assertmsg" , assert_ .msg )
787
+ explanation = "\n >assert " + explanation
788
+ else :
789
+ assertmsg = ast .Str ("" )
790
+ explanation = "assert " + explanation
791
+ template = ast .BinOp (assertmsg , ast .Add (), ast .Str (explanation ))
792
+ msg = self .pop_format_context (template )
793
+ fmt = self .helper ("_format_explanation" , msg )
794
+ err_name = ast .Name ("AssertionError" , ast .Load ())
795
+ exc = ast .Call (err_name , [fmt ], [])
796
+ raise_ = ast .Raise (exc , None )
797
+
798
+ body .append (raise_ )
799
+
720
800
# Clear temporary variables by setting them to None.
721
801
if self .variables :
722
802
variables = [ast .Name (name , ast .Store ()) for name in self .variables ]
@@ -770,22 +850,22 @@ def visit_BoolOp(self, boolop):
770
850
app = ast .Attribute (expl_list , "append" , ast .Load ())
771
851
is_or = int (isinstance (boolop .op , ast .Or ))
772
852
body = save = self .statements
773
- fail_save = self .on_failure
853
+ fail_save = self .expl_stmts
774
854
levels = len (boolop .values ) - 1
775
855
self .push_format_context ()
776
856
# Process each operand, short-circuiting if needed.
777
857
for i , v in enumerate (boolop .values ):
778
858
if i :
779
859
fail_inner = []
780
860
# cond is set in a prior loop iteration below
781
- self .on_failure .append (ast .If (cond , fail_inner , [])) # noqa
782
- self .on_failure = fail_inner
861
+ self .expl_stmts .append (ast .If (cond , fail_inner , [])) # noqa
862
+ self .expl_stmts = fail_inner
783
863
self .push_format_context ()
784
864
res , expl = self .visit (v )
785
865
body .append (ast .Assign ([ast .Name (res_var , ast .Store ())], res ))
786
866
expl_format = self .pop_format_context (ast .Str (expl ))
787
867
call = ast .Call (app , [expl_format ], [])
788
- self .on_failure .append (ast .Expr (call ))
868
+ self .expl_stmts .append (ast .Expr (call ))
789
869
if i < levels :
790
870
cond = res
791
871
if is_or :
@@ -794,7 +874,7 @@ def visit_BoolOp(self, boolop):
794
874
self .statements .append (ast .If (cond , inner , []))
795
875
self .statements = body = inner
796
876
self .statements = save
797
- self .on_failure = fail_save
877
+ self .expl_stmts = fail_save
798
878
expl_template = self .helper ("_format_boolop" , expl_list , ast .Num (is_or ))
799
879
expl = self .pop_format_context (expl_template )
800
880
return ast .Name (res_var , ast .Load ()), self .explanation_param (expl )
0 commit comments