Skip to content

Commit 73d918d

Browse files
authored
Remove astor and reproduce the original assertion expression (#5512)
Remove astor and reproduce the original assertion expression
2 parents 3c9b46f + 7ee2444 commit 73d918d

File tree

4 files changed

+197
-55
lines changed

4 files changed

+197
-55
lines changed

changelog/3457.trivial.rst

Lines changed: 0 additions & 1 deletion
This file was deleted.

setup.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
"pluggy>=0.12,<1.0",
1414
"importlib-metadata>=0.12",
1515
"wcwidth",
16-
"astor",
1716
]
1817

1918

src/_pytest/assertion/rewrite.py

Lines changed: 66 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
"""Rewrite assertion AST to produce nice error messages"""
22
import ast
33
import errno
4+
import functools
45
import importlib.machinery
56
import importlib.util
7+
import io
68
import itertools
79
import marshal
810
import os
911
import struct
1012
import sys
13+
import tokenize
1114
import types
1215

13-
import astor
1416
import atomicwrites
1517

1618
from _pytest._io.saferepr import saferepr
@@ -285,7 +287,7 @@ def _rewrite_test(fn, config):
285287
with open(fn, "rb") as f:
286288
source = f.read()
287289
tree = ast.parse(source, filename=fn)
288-
rewrite_asserts(tree, fn, config)
290+
rewrite_asserts(tree, source, fn, config)
289291
co = compile(tree, fn, "exec", dont_inherit=True)
290292
return stat, co
291293

@@ -327,9 +329,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
327329
return co
328330

329331

330-
def rewrite_asserts(mod, module_path=None, config=None):
332+
def rewrite_asserts(mod, source, module_path=None, config=None):
331333
"""Rewrite the assert statements in mod."""
332-
AssertionRewriter(module_path, config).run(mod)
334+
AssertionRewriter(module_path, config, source).run(mod)
333335

334336

335337
def _saferepr(obj):
@@ -457,6 +459,59 @@ def _fix(node, lineno, col_offset):
457459
return node
458460

459461

462+
def _get_assertion_exprs(src: bytes): # -> Dict[int, str]
463+
"""Returns a mapping from {lineno: "assertion test expression"}"""
464+
ret = {}
465+
466+
depth = 0
467+
lines = []
468+
assert_lineno = None
469+
seen_lines = set()
470+
471+
def _write_and_reset() -> None:
472+
nonlocal depth, lines, assert_lineno, seen_lines
473+
ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
474+
depth = 0
475+
lines = []
476+
assert_lineno = None
477+
seen_lines = set()
478+
479+
tokens = tokenize.tokenize(io.BytesIO(src).readline)
480+
for tp, src, (lineno, offset), _, line in tokens:
481+
if tp == tokenize.NAME and src == "assert":
482+
assert_lineno = lineno
483+
elif assert_lineno is not None:
484+
# keep track of depth for the assert-message `,` lookup
485+
if tp == tokenize.OP and src in "([{":
486+
depth += 1
487+
elif tp == tokenize.OP and src in ")]}":
488+
depth -= 1
489+
490+
if not lines:
491+
lines.append(line[offset:])
492+
seen_lines.add(lineno)
493+
# a non-nested comma separates the expression from the message
494+
elif depth == 0 and tp == tokenize.OP and src == ",":
495+
# one line assert with message
496+
if lineno in seen_lines and len(lines) == 1:
497+
offset_in_trimmed = offset + len(lines[-1]) - len(line)
498+
lines[-1] = lines[-1][:offset_in_trimmed]
499+
# multi-line assert with message
500+
elif lineno in seen_lines:
501+
lines[-1] = lines[-1][:offset]
502+
# multi line assert with escapd newline before message
503+
else:
504+
lines.append(line[:offset])
505+
_write_and_reset()
506+
elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
507+
_write_and_reset()
508+
elif lines and lineno not in seen_lines:
509+
lines.append(line)
510+
seen_lines.add(lineno)
511+
512+
return ret
513+
514+
460515
class AssertionRewriter(ast.NodeVisitor):
461516
"""Assertion rewriting implementation.
462517
@@ -511,7 +566,7 @@ class AssertionRewriter(ast.NodeVisitor):
511566
512567
"""
513568

514-
def __init__(self, module_path, config):
569+
def __init__(self, module_path, config, source):
515570
super().__init__()
516571
self.module_path = module_path
517572
self.config = config
@@ -521,6 +576,11 @@ def __init__(self, module_path, config):
521576
)
522577
else:
523578
self.enable_assertion_pass_hook = False
579+
self.source = source
580+
581+
@functools.lru_cache(maxsize=1)
582+
def _assert_expr_to_lineno(self):
583+
return _get_assertion_exprs(self.source)
524584

525585
def run(self, mod):
526586
"""Find all assert statements in *mod* and rewrite them."""
@@ -738,7 +798,7 @@ def visit_Assert(self, assert_):
738798

739799
# Passed
740800
fmt_pass = self.helper("_format_explanation", msg)
741-
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
801+
orig = self._assert_expr_to_lineno()[assert_.lineno]
742802
hook_call_pass = ast.Expr(
743803
self.helper(
744804
"_call_assertion_pass",

testing/test_assertrewrite.py

Lines changed: 131 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import _pytest._code
1414
import pytest
1515
from _pytest.assertion import util
16+
from _pytest.assertion.rewrite import _get_assertion_exprs
1617
from _pytest.assertion.rewrite import AssertionRewritingHook
1718
from _pytest.assertion.rewrite import PYTEST_TAG
1819
from _pytest.assertion.rewrite import rewrite_asserts
@@ -31,7 +32,7 @@ def teardown_module(mod):
3132

3233
def rewrite(src):
3334
tree = ast.parse(src)
34-
rewrite_asserts(tree)
35+
rewrite_asserts(tree, src.encode())
3536
return tree
3637

3738

@@ -1292,10 +1293,10 @@ def test_pattern_contains_subdirectories(self, testdir, hook):
12921293
"""
12931294
p = testdir.makepyfile(
12941295
**{
1295-
"tests/file.py": """
1296-
def test_simple_failure():
1297-
assert 1 + 1 == 3
1298-
"""
1296+
"tests/file.py": """\
1297+
def test_simple_failure():
1298+
assert 1 + 1 == 3
1299+
"""
12991300
}
13001301
)
13011302
testdir.syspathinsert(p.dirpath())
@@ -1315,19 +1316,19 @@ def test_cwd_changed(self, testdir, monkeypatch):
13151316

13161317
testdir.makepyfile(
13171318
**{
1318-
"test_setup_nonexisting_cwd.py": """
1319-
import os
1320-
import shutil
1321-
import tempfile
1322-
1323-
d = tempfile.mkdtemp()
1324-
os.chdir(d)
1325-
shutil.rmtree(d)
1326-
""",
1327-
"test_test.py": """
1328-
def test():
1329-
pass
1330-
""",
1319+
"test_setup_nonexisting_cwd.py": """\
1320+
import os
1321+
import shutil
1322+
import tempfile
1323+
1324+
d = tempfile.mkdtemp()
1325+
os.chdir(d)
1326+
shutil.rmtree(d)
1327+
""",
1328+
"test_test.py": """\
1329+
def test():
1330+
pass
1331+
""",
13311332
}
13321333
)
13331334
result = testdir.runpytest()
@@ -1339,23 +1340,22 @@ def test_option_default(self, testdir):
13391340
config = testdir.parseconfig()
13401341
assert config.getini("enable_assertion_pass_hook") is False
13411342

1342-
def test_hook_call(self, testdir):
1343+
@pytest.fixture
1344+
def flag_on(self, testdir):
1345+
testdir.makeini("[pytest]\nenable_assertion_pass_hook = True\n")
1346+
1347+
@pytest.fixture
1348+
def hook_on(self, testdir):
13431349
testdir.makeconftest(
1344-
"""
1350+
"""\
13451351
def pytest_assertion_pass(item, lineno, orig, expl):
13461352
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
13471353
"""
13481354
)
13491355

1350-
testdir.makeini(
1351-
"""
1352-
[pytest]
1353-
enable_assertion_pass_hook = True
1354-
"""
1355-
)
1356-
1356+
def test_hook_call(self, testdir, flag_on, hook_on):
13571357
testdir.makepyfile(
1358-
"""
1358+
"""\
13591359
def test_simple():
13601360
a=1
13611361
b=2
@@ -1371,10 +1371,21 @@ def test_fails():
13711371
)
13721372
result = testdir.runpytest()
13731373
result.stdout.fnmatch_lines(
1374-
"*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*"
1374+
"*Assertion Passed: a+b == c+d (1 + 2) == (3 + 0) at line 7*"
1375+
)
1376+
1377+
def test_hook_call_with_parens(self, testdir, flag_on, hook_on):
1378+
testdir.makepyfile(
1379+
"""\
1380+
def f(): return 1
1381+
def test():
1382+
assert f()
1383+
"""
13751384
)
1385+
result = testdir.runpytest()
1386+
result.stdout.fnmatch_lines("*Assertion Passed: f() 1")
13761387

1377-
def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch):
1388+
def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch, flag_on):
13781389
"""Assertion pass should not be called (and hence formatting should
13791390
not occur) if there is no hook declared for pytest_assertion_pass"""
13801391

@@ -1385,15 +1396,8 @@ def raise_on_assertionpass(*_, **__):
13851396
_pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
13861397
)
13871398

1388-
testdir.makeini(
1389-
"""
1390-
[pytest]
1391-
enable_assertion_pass_hook = True
1392-
"""
1393-
)
1394-
13951399
testdir.makepyfile(
1396-
"""
1400+
"""\
13971401
def test_simple():
13981402
a=1
13991403
b=2
@@ -1418,21 +1422,14 @@ def raise_on_assertionpass(*_, **__):
14181422
)
14191423

14201424
testdir.makeconftest(
1421-
"""
1425+
"""\
14221426
def pytest_assertion_pass(item, lineno, orig, expl):
14231427
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
14241428
"""
14251429
)
14261430

1427-
testdir.makeini(
1428-
"""
1429-
[pytest]
1430-
enable_assertion_pass_hook = False
1431-
"""
1432-
)
1433-
14341431
testdir.makepyfile(
1435-
"""
1432+
"""\
14361433
def test_simple():
14371434
a=1
14381435
b=2
@@ -1444,3 +1441,90 @@ def test_simple():
14441441
)
14451442
result = testdir.runpytest()
14461443
result.assert_outcomes(passed=1)
1444+
1445+
1446+
@pytest.mark.parametrize(
1447+
("src", "expected"),
1448+
(
1449+
# fmt: off
1450+
pytest.param(b"", {}, id="trivial"),
1451+
pytest.param(
1452+
b"def x(): assert 1\n",
1453+
{1: "1"},
1454+
id="assert statement not on own line",
1455+
),
1456+
pytest.param(
1457+
b"def x():\n"
1458+
b" assert 1\n"
1459+
b" assert 1+2\n",
1460+
{2: "1", 3: "1+2"},
1461+
id="multiple assertions",
1462+
),
1463+
pytest.param(
1464+
# changes in encoding cause the byte offsets to be different
1465+
"# -*- coding: latin1\n"
1466+
"def ÀÀÀÀÀ(): assert 1\n".encode("latin1"),
1467+
{2: "1"},
1468+
id="latin1 encoded on first line\n",
1469+
),
1470+
pytest.param(
1471+
# using the default utf-8 encoding
1472+
"def ÀÀÀÀÀ(): assert 1\n".encode(),
1473+
{1: "1"},
1474+
id="utf-8 encoded on first line",
1475+
),
1476+
pytest.param(
1477+
b"def x():\n"
1478+
b" assert (\n"
1479+
b" 1 + 2 # comment\n"
1480+
b" )\n",
1481+
{2: "(\n 1 + 2 # comment\n )"},
1482+
id="multi-line assertion",
1483+
),
1484+
pytest.param(
1485+
b"def x():\n"
1486+
b" assert y == [\n"
1487+
b" 1, 2, 3\n"
1488+
b" ]\n",
1489+
{2: "y == [\n 1, 2, 3\n ]"},
1490+
id="multi line assert with list continuation",
1491+
),
1492+
pytest.param(
1493+
b"def x():\n"
1494+
b" assert 1 + \\\n"
1495+
b" 2\n",
1496+
{2: "1 + \\\n 2"},
1497+
id="backslash continuation",
1498+
),
1499+
pytest.param(
1500+
b"def x():\n"
1501+
b" assert x, y\n",
1502+
{2: "x"},
1503+
id="assertion with message",
1504+
),
1505+
pytest.param(
1506+
b"def x():\n"
1507+
b" assert (\n"
1508+
b" f(1, 2, 3)\n"
1509+
b" ), 'f did not work!'\n",
1510+
{2: "(\n f(1, 2, 3)\n )"},
1511+
id="assertion with message, test spanning multiple lines",
1512+
),
1513+
pytest.param(
1514+
b"def x():\n"
1515+
b" assert \\\n"
1516+
b" x\\\n"
1517+
b" , 'failure message'\n",
1518+
{2: "x"},
1519+
id="escaped newlines plus message",
1520+
),
1521+
pytest.param(
1522+
b"def x(): assert 5",
1523+
{1: "5"},
1524+
id="no newline at end of file",
1525+
),
1526+
# fmt: on
1527+
),
1528+
)
1529+
def test_get_assertion_exprs(src, expected):
1530+
assert _get_assertion_exprs(src) == expected

0 commit comments

Comments
 (0)