Skip to content

Features assertion pass hook #3479

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 25 commits into from
Jun 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9a89783
Assertion passed hook
Jun 24, 2019
52e695b
Removed debug code.
Jun 24, 2019
98b212c
Added "experimental" note.
Jun 25, 2019
f8c9a7b
Formatting and removed py2 support.
Jun 25, 2019
2280f28
Black formatting.
Jun 25, 2019
81e3f3c
Black formatting
Jun 25, 2019
db50a97
Reverted leak fixture test.
Jun 25, 2019
cfbfa53
Using pytester subprocess to avoid keeping references in the HookReco…
Jun 25, 2019
4db5488
Now dependent on command line option.
Jun 25, 2019
80ac910
Added msg to docstring for cleaning pyc.
Jun 25, 2019
7efdd50
Update src/_pytest/assertion/rewrite.py
Sup3rGeo Jun 26, 2019
0fb5241
Reverted changes.
Jun 26, 2019
d638da5
Using ini-file option instead of cmd option.
Jun 26, 2019
f755ff6
Black formatting.
Jun 26, 2019
d91a5d3
Further reverting changes.
Jun 26, 2019
9a34d88
Explanation variables only defined if failed or passed with plugins i…
Jun 26, 2019
6f851e6
Merge remote-tracking branch 'upstream/master' into features-assertio…
Jun 26, 2019
53234bf
Added config back to AssertionWriter and fixed typo in check_if_asser…
Jun 26, 2019
6854ff2
Fixed import order pep8.
Jun 26, 2019
eb90f3d
Fix default value of 'enable_assertion_pass_hook'
nicoddemus Jun 26, 2019
fcbe66f
Restore proper handling of '%' in assertion messages
nicoddemus Jun 26, 2019
3afee36
Improve docs and reference
nicoddemus Jun 26, 2019
8edf68f
Add a trivial note about astor
nicoddemus Jun 26, 2019
629eb3e
Move formatting variables under the "has impls" if
nicoddemus Jun 26, 2019
2ea2221
Cover assertions with messages when enable_assertion_pass_hook is ena…
nicoddemus Jun 26, 2019
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions changelog/3457.feature.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
New `pytest_assertion_pass <https://docs.pytest.org/en/latest/reference.html#_pytest.hookspec.pytest_assertion_pass>`__
hook, called with context information when an assertion *passes*.

This hook is still **experimental** so use it with caution.
1 change: 1 addition & 0 deletions changelog/3457.trivial.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pytest now also depends on the `astor <https://pypi.org/project/astor/>`__ package.
7 changes: 3 additions & 4 deletions doc/en/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -665,15 +665,14 @@ Session related reporting hooks:
.. autofunction:: pytest_fixture_post_finalizer
.. autofunction:: pytest_warning_captured

And here is the central hook for reporting about
test execution:
Central hook for reporting about test execution:

.. autofunction:: pytest_runtest_logreport

You can also use this hook to customize assertion representation for some
types:
Assertion related hooks:

.. autofunction:: pytest_assertrepr_compare
.. autofunction:: pytest_assertion_pass


Debugging/Interaction hooks
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"pluggy>=0.12,<1.0",
"importlib-metadata>=0.12",
"wcwidth",
"astor",
]


Expand Down
19 changes: 18 additions & 1 deletion src/_pytest/assertion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ def pytest_addoption(parser):
test modules on import to provide assert
expression information.""",
)
parser.addini(
"enable_assertion_pass_hook",
type="bool",
default=False,
help="Enables the pytest_assertion_pass hook."
"Make sure to delete any previously generated pyc cache files.",
)


def register_assert_rewrite(*names):
Expand Down Expand Up @@ -92,7 +99,7 @@ def pytest_collection(session):


def pytest_runtest_setup(item):
"""Setup the pytest_assertrepr_compare hook
"""Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks

The newinterpret and rewrite modules will use util._reprcompare if
it exists to use custom reporting via the
Expand Down Expand Up @@ -129,9 +136,19 @@ def callbinrepr(op, left, right):

util._reprcompare = callbinrepr

if item.ihook.pytest_assertion_pass.get_hookimpls():

def call_assertion_pass_hook(lineno, expl, orig):
item.ihook.pytest_assertion_pass(
item=item, lineno=lineno, orig=orig, expl=expl
)

util._assertion_pass = call_assertion_pass_hook


def pytest_runtest_teardown(item):
util._reprcompare = None
util._assertion_pass = None


def pytest_sessionfinish(session):
Expand Down
152 changes: 116 additions & 36 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import sys
import types

import astor
import atomicwrites

from _pytest._io.saferepr import saferepr
Expand Down Expand Up @@ -134,7 +135,7 @@ def exec_module(self, module):
co = _read_pyc(fn, pyc, state.trace)
if co is None:
state.trace("rewriting {!r}".format(fn))
source_stat, co = _rewrite_test(fn)
source_stat, co = _rewrite_test(fn, self.config)
if write:
self._writing_pyc = True
try:
Expand Down Expand Up @@ -278,13 +279,13 @@ def _write_pyc(state, co, source_stat, pyc):
return True


def _rewrite_test(fn):
def _rewrite_test(fn, config):
"""read and rewrite *fn* and return the code object."""
stat = os.stat(fn)
with open(fn, "rb") as f:
source = f.read()
tree = ast.parse(source, filename=fn)
rewrite_asserts(tree, fn)
rewrite_asserts(tree, fn, config)
co = compile(tree, fn, "exec", dont_inherit=True)
return stat, co

Expand Down Expand Up @@ -326,9 +327,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
return co


def rewrite_asserts(mod, module_path=None):
def rewrite_asserts(mod, module_path=None, config=None):
"""Rewrite the assert statements in mod."""
AssertionRewriter(module_path).run(mod)
AssertionRewriter(module_path, config).run(mod)


def _saferepr(obj):
Expand Down Expand Up @@ -401,6 +402,17 @@ def _call_reprcompare(ops, results, expls, each_obj):
return expl


def _call_assertion_pass(lineno, orig, expl):
if util._assertion_pass is not None:
util._assertion_pass(lineno=lineno, orig=orig, expl=expl)


def _check_if_assertion_pass_impl():
"""Checks if any plugins implement the pytest_assertion_pass hook
in order not to generate explanation unecessarily (might be expensive)"""
return True if util._assertion_pass else False


unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}

binop_map = {
Expand Down Expand Up @@ -473,7 +485,8 @@ class AssertionRewriter(ast.NodeVisitor):
original assert statement: it rewrites the test of an assertion
to provide intermediate values and replace it with an if statement
which raises an assertion error with a detailed explanation in
case the expression is false.
case the expression is false and calls pytest_assertion_pass hook
if expression is true.

For this .visit_Assert() uses the visitor pattern to visit all the
AST nodes of the ast.Assert.test field, each visit call returning
Expand All @@ -491,9 +504,10 @@ class AssertionRewriter(ast.NodeVisitor):
by statements. Variables are created using .variable() and
have the form of "@py_assert0".

:on_failure: The AST statements which will be executed if the
assertion test fails. This is the code which will construct
the failure message and raises the AssertionError.
:expl_stmts: The AST statements which will be executed to get
data from the assertion. This is the code which will construct
the detailed assertion message that is used in the AssertionError
or for the pytest_assertion_pass hook.

:explanation_specifiers: A dict filled by .explanation_param()
with %-formatting placeholders and their corresponding
Expand All @@ -509,9 +523,16 @@ class AssertionRewriter(ast.NodeVisitor):

"""

def __init__(self, module_path):
def __init__(self, module_path, config):
super().__init__()
self.module_path = module_path
self.config = config
if config is not None:
self.enable_assertion_pass_hook = config.getini(
"enable_assertion_pass_hook"
)
else:
self.enable_assertion_pass_hook = False

def run(self, mod):
"""Find all assert statements in *mod* and rewrite them."""
Expand Down Expand Up @@ -642,7 +663,7 @@ def pop_format_context(self, expl_expr):

The expl_expr should be an ast.Str instance constructed from
the %-placeholders created by .explanation_param(). This will
add the required code to format said string to .on_failure and
add the required code to format said string to .expl_stmts and
return the ast.Name instance of the formatted string.

"""
Expand All @@ -653,7 +674,9 @@ def pop_format_context(self, expl_expr):
format_dict = ast.Dict(keys, list(current.values()))
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
name = "@py_format" + str(next(self.variable_counter))
self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
if self.enable_assertion_pass_hook:
self.format_variables.append(name)
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
return ast.Name(name, ast.Load())

def generic_visit(self, node):
Expand Down Expand Up @@ -687,8 +710,12 @@ def visit_Assert(self, assert_):
self.statements = []
self.variables = []
self.variable_counter = itertools.count()

if self.enable_assertion_pass_hook:
self.format_variables = []

self.stack = []
self.on_failure = []
self.expl_stmts = []
self.push_format_context()
# Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test)
Expand All @@ -699,24 +726,77 @@ def visit_Assert(self, assert_):
top_condition, module_path=self.module_path, lineno=assert_.lineno
)
)
# Create failure message.
body = self.on_failure
negation = ast.UnaryOp(ast.Not(), top_condition)
self.statements.append(ast.If(negation, body, []))
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
explanation = "\n>assert " + explanation
else:
assertmsg = ast.Str("")
explanation = "assert " + explanation
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
msg = self.pop_format_context(template)
fmt = self.helper("_format_explanation", msg)
err_name = ast.Name("AssertionError", ast.Load())
exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)

body.append(raise_)

if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
negation = ast.UnaryOp(ast.Not(), top_condition)
msg = self.pop_format_context(ast.Str(explanation))

# Failed
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
gluestr = "\n>assert "
else:
assertmsg = ast.Str("")
gluestr = "assert "
err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
err_name = ast.Name("AssertionError", ast.Load())
fmt = self.helper("_format_explanation", err_msg)
exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)
statements_fail = []
statements_fail.extend(self.expl_stmts)
statements_fail.append(raise_)

# Passed
fmt_pass = self.helper("_format_explanation", msg)
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
hook_call_pass = ast.Expr(
self.helper(
"_call_assertion_pass",
ast.Num(assert_.lineno),
ast.Str(orig),
fmt_pass,
)
)
# If any hooks implement assert_pass hook
hook_impl_test = ast.If(
self.helper("_check_if_assertion_pass_impl"),
self.expl_stmts + [hook_call_pass],
[],
)
statements_pass = [hook_impl_test]

# Test for assertion condition
main_test = ast.If(negation, statements_fail, statements_pass)
self.statements.append(main_test)
if self.format_variables:
variables = [
ast.Name(name, ast.Store()) for name in self.format_variables
]
clear_format = ast.Assign(variables, _NameConstant(None))
self.statements.append(clear_format)

else: # Original assertion rewriting
# Create failure message.
body = self.expl_stmts
negation = ast.UnaryOp(ast.Not(), top_condition)
self.statements.append(ast.If(negation, body, []))
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
explanation = "\n>assert " + explanation
else:
assertmsg = ast.Str("")
explanation = "assert " + explanation
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
msg = self.pop_format_context(template)
fmt = self.helper("_format_explanation", msg)
err_name = ast.Name("AssertionError", ast.Load())
exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)

body.append(raise_)

# Clear temporary variables by setting them to None.
if self.variables:
variables = [ast.Name(name, ast.Store()) for name in self.variables]
Expand Down Expand Up @@ -770,22 +850,22 @@ def visit_BoolOp(self, boolop):
app = ast.Attribute(expl_list, "append", ast.Load())
is_or = int(isinstance(boolop.op, ast.Or))
body = save = self.statements
fail_save = self.on_failure
fail_save = self.expl_stmts
levels = len(boolop.values) - 1
self.push_format_context()
# Process each operand, short-circuiting if needed.
for i, v in enumerate(boolop.values):
if i:
fail_inner = []
# cond is set in a prior loop iteration below
self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa
self.on_failure = fail_inner
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.expl_stmts = fail_inner
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
expl_format = self.pop_format_context(ast.Str(expl))
call = ast.Call(app, [expl_format], [])
self.on_failure.append(ast.Expr(call))
self.expl_stmts.append(ast.Expr(call))
if i < levels:
cond = res
if is_or:
Expand All @@ -794,7 +874,7 @@ def visit_BoolOp(self, boolop):
self.statements.append(ast.If(cond, inner, []))
self.statements = body = inner
self.statements = save
self.on_failure = fail_save
self.expl_stmts = fail_save
expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
expl = self.pop_format_context(expl_template)
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
Expand Down
4 changes: 4 additions & 0 deletions src/_pytest/assertion/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
# DebugInterpreter.
_reprcompare = None

# Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called.
_assertion_pass = None


def format_explanation(explanation):
"""This formats an explanation
Expand Down
Loading