Skip to content

Commit 7259e8d

Browse files
authored
Fix assert rewriting with assignment expressions (#11414)
Fixes #11239
1 parent dd7beb3 commit 7259e8d

File tree

4 files changed

+63
-14
lines changed

4 files changed

+63
-14
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ Maho
235235
Maik Figura
236236
Mandeep Bhutani
237237
Manuel Krebber
238+
Marc Mueller
238239
Marc Schlaich
239240
Marcelo Duarte Trevisani
240241
Marcin Bachry

changelog/11239.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed ``:=`` in asserts impacting unrelated test cases.

src/_pytest/assertion/rewrite.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sys
1414
import tokenize
1515
import types
16+
from collections import defaultdict
1617
from pathlib import Path
1718
from pathlib import PurePath
1819
from typing import Callable
@@ -45,13 +46,20 @@
4546
from _pytest.assertion import AssertionState
4647

4748

49+
class Sentinel:
50+
pass
51+
52+
4853
assertstate_key = StashKey["AssertionState"]()
4954

5055
# pytest caches rewritten pycs in pycache dirs
5156
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
5257
PYC_EXT = ".py" + (__debug__ and "c" or "o")
5358
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
5459

60+
# Special marker that denotes we have just left a scope definition
61+
_SCOPE_END_MARKER = Sentinel()
62+
5563

5664
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
5765
"""PEP302/PEP451 import hook which rewrites asserts."""
@@ -634,6 +642,8 @@ class AssertionRewriter(ast.NodeVisitor):
634642
.push_format_context() and .pop_format_context() which allows
635643
to build another %-formatted string while already building one.
636644
645+
:scope: A tuple containing the current scope used for variables_overwrite.
646+
637647
:variables_overwrite: A dict filled with references to variables
638648
that change value within an assert. This happens when a variable is
639649
reassigned with the walrus operator
@@ -655,7 +665,10 @@ def __init__(
655665
else:
656666
self.enable_assertion_pass_hook = False
657667
self.source = source
658-
self.variables_overwrite: Dict[str, str] = {}
668+
self.scope: tuple[ast.AST, ...] = ()
669+
self.variables_overwrite: defaultdict[
670+
tuple[ast.AST, ...], Dict[str, str]
671+
] = defaultdict(dict)
659672

660673
def run(self, mod: ast.Module) -> None:
661674
"""Find all assert statements in *mod* and rewrite them."""
@@ -719,9 +732,17 @@ def run(self, mod: ast.Module) -> None:
719732
mod.body[pos:pos] = imports
720733

721734
# Collect asserts.
722-
nodes: List[ast.AST] = [mod]
735+
self.scope = (mod,)
736+
nodes: List[Union[ast.AST, Sentinel]] = [mod]
723737
while nodes:
724738
node = nodes.pop()
739+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
740+
self.scope = tuple((*self.scope, node))
741+
nodes.append(_SCOPE_END_MARKER)
742+
if node == _SCOPE_END_MARKER:
743+
self.scope = self.scope[:-1]
744+
continue
745+
assert isinstance(node, ast.AST)
725746
for name, field in ast.iter_fields(node):
726747
if isinstance(field, list):
727748
new: List[ast.AST] = []
@@ -992,7 +1013,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
9921013
]
9931014
):
9941015
pytest_temp = self.variable()
995-
self.variables_overwrite[
1016+
self.variables_overwrite[self.scope][
9961017
v.left.target.id
9971018
] = v.left # type:ignore[assignment]
9981019
v.left.target.id = pytest_temp
@@ -1035,17 +1056,20 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
10351056
new_args = []
10361057
new_kwargs = []
10371058
for arg in call.args:
1038-
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite:
1039-
arg = self.variables_overwrite[arg.id] # type:ignore[assignment]
1059+
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(
1060+
self.scope, {}
1061+
):
1062+
arg = self.variables_overwrite[self.scope][
1063+
arg.id
1064+
] # type:ignore[assignment]
10401065
res, expl = self.visit(arg)
10411066
arg_expls.append(expl)
10421067
new_args.append(res)
10431068
for keyword in call.keywords:
1044-
if (
1045-
isinstance(keyword.value, ast.Name)
1046-
and keyword.value.id in self.variables_overwrite
1047-
):
1048-
keyword.value = self.variables_overwrite[
1069+
if isinstance(
1070+
keyword.value, ast.Name
1071+
) and keyword.value.id in self.variables_overwrite.get(self.scope, {}):
1072+
keyword.value = self.variables_overwrite[self.scope][
10491073
keyword.value.id
10501074
] # type:ignore[assignment]
10511075
res, expl = self.visit(keyword.value)
@@ -1081,12 +1105,14 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
10811105
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
10821106
self.push_format_context()
10831107
# We first check if we have overwritten a variable in the previous assert
1084-
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
1085-
comp.left = self.variables_overwrite[
1108+
if isinstance(
1109+
comp.left, ast.Name
1110+
) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
1111+
comp.left = self.variables_overwrite[self.scope][
10861112
comp.left.id
10871113
] # type:ignore[assignment]
10881114
if isinstance(comp.left, ast.NamedExpr):
1089-
self.variables_overwrite[
1115+
self.variables_overwrite[self.scope][
10901116
comp.left.target.id
10911117
] = comp.left # type:ignore[assignment]
10921118
left_res, left_expl = self.visit(comp.left)
@@ -1106,7 +1132,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
11061132
and next_operand.target.id == left_res.id
11071133
):
11081134
next_operand.target.id = self.variable()
1109-
self.variables_overwrite[
1135+
self.variables_overwrite[self.scope][
11101136
left_res.id
11111137
] = next_operand # type:ignore[assignment]
11121138
next_res, next_expl = self.visit(next_operand)

testing/test_assertrewrite.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,6 +1543,27 @@ def test_gt():
15431543
result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"])
15441544

15451545

1546+
class TestIssue11239:
1547+
def test_assertion_walrus_different_test_cases(self, pytester: Pytester) -> None:
1548+
"""Regression for (#11239)
1549+
1550+
Walrus operator rewriting would leak to separate test cases if they used the same variables.
1551+
"""
1552+
pytester.makepyfile(
1553+
"""
1554+
def test_1():
1555+
state = {"x": 2}.get("x")
1556+
assert state is not None
1557+
1558+
def test_2():
1559+
db = {"x": 2}
1560+
assert (state := db.get("x")) is not None
1561+
"""
1562+
)
1563+
result = pytester.runpytest()
1564+
assert result.ret == 0
1565+
1566+
15461567
@pytest.mark.skipif(
15471568
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
15481569
)

0 commit comments

Comments
 (0)