Skip to content

Commit 37fb50a

Browse files
authored
Features assertion pass hook (#3479)
Features assertion pass hook
2 parents 790806e + 2ea2221 commit 37fb50a

File tree

11 files changed

+302
-42
lines changed

11 files changed

+302
-42
lines changed

changelog/3457.feature.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
New `pytest_assertion_pass <https://docs.pytest.org/en/latest/reference.html#_pytest.hookspec.pytest_assertion_pass>`__
2+
hook, called with context information when an assertion *passes*.
3+
4+
This hook is still **experimental** so use it with caution.

changelog/3457.trivial.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pytest now also depends on the `astor <https://pypi.org/project/astor/>`__ package.

doc/en/reference.rst

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -665,15 +665,14 @@ Session related reporting hooks:
665665
.. autofunction:: pytest_fixture_post_finalizer
666666
.. autofunction:: pytest_warning_captured
667667

668-
And here is the central hook for reporting about
669-
test execution:
668+
Central hook for reporting about test execution:
670669

671670
.. autofunction:: pytest_runtest_logreport
672671

673-
You can also use this hook to customize assertion representation for some
674-
types:
672+
Assertion related hooks:
675673

676674
.. autofunction:: pytest_assertrepr_compare
675+
.. autofunction:: pytest_assertion_pass
677676

678677

679678
Debugging/Interaction hooks

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
"pluggy>=0.12,<1.0",
1414
"importlib-metadata>=0.12",
1515
"wcwidth",
16+
"astor",
1617
]
1718

1819

src/_pytest/assertion/__init__.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ def pytest_addoption(parser):
2323
test modules on import to provide assert
2424
expression information.""",
2525
)
26+
parser.addini(
27+
"enable_assertion_pass_hook",
28+
type="bool",
29+
default=False,
30+
help="Enables the pytest_assertion_pass hook."
31+
"Make sure to delete any previously generated pyc cache files.",
32+
)
2633

2734

2835
def register_assert_rewrite(*names):
@@ -92,7 +99,7 @@ def pytest_collection(session):
9299

93100

94101
def pytest_runtest_setup(item):
95-
"""Setup the pytest_assertrepr_compare hook
102+
"""Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks
96103
97104
The newinterpret and rewrite modules will use util._reprcompare if
98105
it exists to use custom reporting via the
@@ -129,9 +136,19 @@ def callbinrepr(op, left, right):
129136

130137
util._reprcompare = callbinrepr
131138

139+
if item.ihook.pytest_assertion_pass.get_hookimpls():
140+
141+
def call_assertion_pass_hook(lineno, expl, orig):
142+
item.ihook.pytest_assertion_pass(
143+
item=item, lineno=lineno, orig=orig, expl=expl
144+
)
145+
146+
util._assertion_pass = call_assertion_pass_hook
147+
132148

133149
def pytest_runtest_teardown(item):
134150
util._reprcompare = None
151+
util._assertion_pass = None
135152

136153

137154
def pytest_sessionfinish(session):

src/_pytest/assertion/rewrite.py

Lines changed: 116 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import sys
1111
import types
1212

13+
import astor
1314
import atomicwrites
1415

1516
from _pytest._io.saferepr import saferepr
@@ -134,7 +135,7 @@ def exec_module(self, module):
134135
co = _read_pyc(fn, pyc, state.trace)
135136
if co is None:
136137
state.trace("rewriting {!r}".format(fn))
137-
source_stat, co = _rewrite_test(fn)
138+
source_stat, co = _rewrite_test(fn, self.config)
138139
if write:
139140
self._writing_pyc = True
140141
try:
@@ -278,13 +279,13 @@ def _write_pyc(state, co, source_stat, pyc):
278279
return True
279280

280281

281-
def _rewrite_test(fn):
282+
def _rewrite_test(fn, config):
282283
"""read and rewrite *fn* and return the code object."""
283284
stat = os.stat(fn)
284285
with open(fn, "rb") as f:
285286
source = f.read()
286287
tree = ast.parse(source, filename=fn)
287-
rewrite_asserts(tree, fn)
288+
rewrite_asserts(tree, fn, config)
288289
co = compile(tree, fn, "exec", dont_inherit=True)
289290
return stat, co
290291

@@ -326,9 +327,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
326327
return co
327328

328329

329-
def rewrite_asserts(mod, module_path=None):
330+
def rewrite_asserts(mod, module_path=None, config=None):
330331
"""Rewrite the assert statements in mod."""
331-
AssertionRewriter(module_path).run(mod)
332+
AssertionRewriter(module_path, config).run(mod)
332333

333334

334335
def _saferepr(obj):
@@ -401,6 +402,17 @@ def _call_reprcompare(ops, results, expls, each_obj):
401402
return expl
402403

403404

405+
def _call_assertion_pass(lineno, orig, expl):
406+
if util._assertion_pass is not None:
407+
util._assertion_pass(lineno=lineno, orig=orig, expl=expl)
408+
409+
410+
def _check_if_assertion_pass_impl():
411+
"""Checks if any plugins implement the pytest_assertion_pass hook
412+
in order not to generate explanation unecessarily (might be expensive)"""
413+
return True if util._assertion_pass else False
414+
415+
404416
unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}
405417

406418
binop_map = {
@@ -473,7 +485,8 @@ class AssertionRewriter(ast.NodeVisitor):
473485
original assert statement: it rewrites the test of an assertion
474486
to provide intermediate values and replace it with an if statement
475487
which raises an assertion error with a detailed explanation in
476-
case the expression is false.
488+
case the expression is false and calls pytest_assertion_pass hook
489+
if expression is true.
477490
478491
For this .visit_Assert() uses the visitor pattern to visit all the
479492
AST nodes of the ast.Assert.test field, each visit call returning
@@ -491,9 +504,10 @@ class AssertionRewriter(ast.NodeVisitor):
491504
by statements. Variables are created using .variable() and
492505
have the form of "@py_assert0".
493506
494-
:on_failure: The AST statements which will be executed if the
495-
assertion test fails. This is the code which will construct
496-
the failure message and raises the AssertionError.
507+
:expl_stmts: The AST statements which will be executed to get
508+
data from the assertion. This is the code which will construct
509+
the detailed assertion message that is used in the AssertionError
510+
or for the pytest_assertion_pass hook.
497511
498512
:explanation_specifiers: A dict filled by .explanation_param()
499513
with %-formatting placeholders and their corresponding
@@ -509,9 +523,16 @@ class AssertionRewriter(ast.NodeVisitor):
509523
510524
"""
511525

512-
def __init__(self, module_path):
526+
def __init__(self, module_path, config):
513527
super().__init__()
514528
self.module_path = module_path
529+
self.config = config
530+
if config is not None:
531+
self.enable_assertion_pass_hook = config.getini(
532+
"enable_assertion_pass_hook"
533+
)
534+
else:
535+
self.enable_assertion_pass_hook = False
515536

516537
def run(self, mod):
517538
"""Find all assert statements in *mod* and rewrite them."""
@@ -642,7 +663,7 @@ def pop_format_context(self, expl_expr):
642663
643664
The expl_expr should be an ast.Str instance constructed from
644665
the %-placeholders created by .explanation_param(). This will
645-
add the required code to format said string to .on_failure and
666+
add the required code to format said string to .expl_stmts and
646667
return the ast.Name instance of the formatted string.
647668
648669
"""
@@ -653,7 +674,9 @@ def pop_format_context(self, expl_expr):
653674
format_dict = ast.Dict(keys, list(current.values()))
654675
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
655676
name = "@py_format" + str(next(self.variable_counter))
656-
self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
677+
if self.enable_assertion_pass_hook:
678+
self.format_variables.append(name)
679+
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
657680
return ast.Name(name, ast.Load())
658681

659682
def generic_visit(self, node):
@@ -687,8 +710,12 @@ def visit_Assert(self, assert_):
687710
self.statements = []
688711
self.variables = []
689712
self.variable_counter = itertools.count()
713+
714+
if self.enable_assertion_pass_hook:
715+
self.format_variables = []
716+
690717
self.stack = []
691-
self.on_failure = []
718+
self.expl_stmts = []
692719
self.push_format_context()
693720
# Rewrite assert into a bunch of statements.
694721
top_condition, explanation = self.visit(assert_.test)
@@ -699,24 +726,77 @@ def visit_Assert(self, assert_):
699726
top_condition, module_path=self.module_path, lineno=assert_.lineno
700727
)
701728
)
702-
# Create failure message.
703-
body = self.on_failure
704-
negation = ast.UnaryOp(ast.Not(), top_condition)
705-
self.statements.append(ast.If(negation, body, []))
706-
if assert_.msg:
707-
assertmsg = self.helper("_format_assertmsg", assert_.msg)
708-
explanation = "\n>assert " + explanation
709-
else:
710-
assertmsg = ast.Str("")
711-
explanation = "assert " + explanation
712-
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
713-
msg = self.pop_format_context(template)
714-
fmt = self.helper("_format_explanation", msg)
715-
err_name = ast.Name("AssertionError", ast.Load())
716-
exc = ast.Call(err_name, [fmt], [])
717-
raise_ = ast.Raise(exc, None)
718-
719-
body.append(raise_)
729+
730+
if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
731+
negation = ast.UnaryOp(ast.Not(), top_condition)
732+
msg = self.pop_format_context(ast.Str(explanation))
733+
734+
# Failed
735+
if assert_.msg:
736+
assertmsg = self.helper("_format_assertmsg", assert_.msg)
737+
gluestr = "\n>assert "
738+
else:
739+
assertmsg = ast.Str("")
740+
gluestr = "assert "
741+
err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
742+
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
743+
err_name = ast.Name("AssertionError", ast.Load())
744+
fmt = self.helper("_format_explanation", err_msg)
745+
exc = ast.Call(err_name, [fmt], [])
746+
raise_ = ast.Raise(exc, None)
747+
statements_fail = []
748+
statements_fail.extend(self.expl_stmts)
749+
statements_fail.append(raise_)
750+
751+
# Passed
752+
fmt_pass = self.helper("_format_explanation", msg)
753+
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
754+
hook_call_pass = ast.Expr(
755+
self.helper(
756+
"_call_assertion_pass",
757+
ast.Num(assert_.lineno),
758+
ast.Str(orig),
759+
fmt_pass,
760+
)
761+
)
762+
# If any hooks implement assert_pass hook
763+
hook_impl_test = ast.If(
764+
self.helper("_check_if_assertion_pass_impl"),
765+
self.expl_stmts + [hook_call_pass],
766+
[],
767+
)
768+
statements_pass = [hook_impl_test]
769+
770+
# Test for assertion condition
771+
main_test = ast.If(negation, statements_fail, statements_pass)
772+
self.statements.append(main_test)
773+
if self.format_variables:
774+
variables = [
775+
ast.Name(name, ast.Store()) for name in self.format_variables
776+
]
777+
clear_format = ast.Assign(variables, _NameConstant(None))
778+
self.statements.append(clear_format)
779+
780+
else: # Original assertion rewriting
781+
# Create failure message.
782+
body = self.expl_stmts
783+
negation = ast.UnaryOp(ast.Not(), top_condition)
784+
self.statements.append(ast.If(negation, body, []))
785+
if assert_.msg:
786+
assertmsg = self.helper("_format_assertmsg", assert_.msg)
787+
explanation = "\n>assert " + explanation
788+
else:
789+
assertmsg = ast.Str("")
790+
explanation = "assert " + explanation
791+
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
792+
msg = self.pop_format_context(template)
793+
fmt = self.helper("_format_explanation", msg)
794+
err_name = ast.Name("AssertionError", ast.Load())
795+
exc = ast.Call(err_name, [fmt], [])
796+
raise_ = ast.Raise(exc, None)
797+
798+
body.append(raise_)
799+
720800
# Clear temporary variables by setting them to None.
721801
if self.variables:
722802
variables = [ast.Name(name, ast.Store()) for name in self.variables]
@@ -770,22 +850,22 @@ def visit_BoolOp(self, boolop):
770850
app = ast.Attribute(expl_list, "append", ast.Load())
771851
is_or = int(isinstance(boolop.op, ast.Or))
772852
body = save = self.statements
773-
fail_save = self.on_failure
853+
fail_save = self.expl_stmts
774854
levels = len(boolop.values) - 1
775855
self.push_format_context()
776856
# Process each operand, short-circuiting if needed.
777857
for i, v in enumerate(boolop.values):
778858
if i:
779859
fail_inner = []
780860
# cond is set in a prior loop iteration below
781-
self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa
782-
self.on_failure = fail_inner
861+
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
862+
self.expl_stmts = fail_inner
783863
self.push_format_context()
784864
res, expl = self.visit(v)
785865
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
786866
expl_format = self.pop_format_context(ast.Str(expl))
787867
call = ast.Call(app, [expl_format], [])
788-
self.on_failure.append(ast.Expr(call))
868+
self.expl_stmts.append(ast.Expr(call))
789869
if i < levels:
790870
cond = res
791871
if is_or:
@@ -794,7 +874,7 @@ def visit_BoolOp(self, boolop):
794874
self.statements.append(ast.If(cond, inner, []))
795875
self.statements = body = inner
796876
self.statements = save
797-
self.on_failure = fail_save
877+
self.expl_stmts = fail_save
798878
expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
799879
expl = self.pop_format_context(expl_template)
800880
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)

src/_pytest/assertion/util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
# DebugInterpreter.
1313
_reprcompare = None
1414

15+
# Works similarly as _reprcompare attribute. Is populated with the hook call
16+
# when pytest_runtest_setup is called.
17+
_assertion_pass = None
18+
1519

1620
def format_explanation(explanation):
1721
"""This formats an explanation

0 commit comments

Comments
 (0)