Skip to content

Commit a9c6ac1

Browse files
committed
Fix assert rewriting with assignment expressions
1 parent 0a06db0 commit a9c6ac1

File tree

4 files changed

+54
-14
lines changed

4 files changed

+54
-14
lines changed

AUTHORS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ Maho
235235
Maik Figura
236236
Mandeep Bhutani
237237
Manuel Krebber
238+
Marc Mueller
238239
Marc Schlaich
239240
Marcelo Duarte Trevisani
240241
Marcin Bachry

changelog/11239.bugfix.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Fixed ``:=`` in asserts impacting unrelated test cases.

src/_pytest/assertion/rewrite.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import sys
1414
import tokenize
1515
import types
16+
from collections import defaultdict
1617
from pathlib import Path
1718
from pathlib import PurePath
1819
from typing import Callable
@@ -52,6 +53,8 @@
5253
PYC_EXT = ".py" + (__debug__ and "c" or "o")
5354
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
5455

56+
_SCOPE_END_MARKER = object()
57+
5558

5659
class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
5760
"""PEP302/PEP451 import hook which rewrites asserts."""
@@ -634,6 +637,8 @@ class AssertionRewriter(ast.NodeVisitor):
634637
.push_format_context() and .pop_format_context() which allows
635638
to build another %-formatted string while already building one.
636639
640+
:scope: A tuple containing the current scope used for variables_overwrite.
641+
637642
:variables_overwrite: A dict filled with references to variables
638643
that change value within an assert. This happens when a variable is
639644
reassigned with the walrus operator
@@ -655,7 +660,10 @@ def __init__(
655660
else:
656661
self.enable_assertion_pass_hook = False
657662
self.source = source
658-
self.variables_overwrite: Dict[str, str] = {}
663+
self.scope: tuple[ast.AST, ...] = ()
664+
self.variables_overwrite: defaultdict[
665+
tuple[ast.AST, ...], Dict[str, str]
666+
] = defaultdict(dict)
659667

660668
def run(self, mod: ast.Module) -> None:
661669
"""Find all assert statements in *mod* and rewrite them."""
@@ -719,9 +727,17 @@ def run(self, mod: ast.Module) -> None:
719727
mod.body[pos:pos] = imports
720728

721729
# Collect asserts.
722-
nodes: List[ast.AST] = [mod]
730+
self.scope = (mod,)
731+
nodes: List[Union[ast.AST, object]] = [mod]
723732
while nodes:
724733
node = nodes.pop()
734+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)):
735+
self.scope = tuple((*self.scope, node))
736+
nodes.append(_SCOPE_END_MARKER)
737+
if node == _SCOPE_END_MARKER:
738+
self.scope = self.scope[:-1]
739+
continue
740+
assert isinstance(node, ast.AST)
725741
for name, field in ast.iter_fields(node):
726742
if isinstance(field, list):
727743
new: List[ast.AST] = []
@@ -992,7 +1008,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
9921008
]
9931009
):
9941010
pytest_temp = self.variable()
995-
self.variables_overwrite[
1011+
self.variables_overwrite[self.scope][
9961012
v.left.target.id
9971013
] = v.left # type:ignore[assignment]
9981014
v.left.target.id = pytest_temp
@@ -1035,17 +1051,20 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
10351051
new_args = []
10361052
new_kwargs = []
10371053
for arg in call.args:
1038-
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite:
1039-
arg = self.variables_overwrite[arg.id] # type:ignore[assignment]
1054+
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(
1055+
self.scope, {}
1056+
):
1057+
arg = self.variables_overwrite[self.scope][
1058+
arg.id
1059+
] # type:ignore[assignment]
10401060
res, expl = self.visit(arg)
10411061
arg_expls.append(expl)
10421062
new_args.append(res)
10431063
for keyword in call.keywords:
1044-
if (
1045-
isinstance(keyword.value, ast.Name)
1046-
and keyword.value.id in self.variables_overwrite
1047-
):
1048-
keyword.value = self.variables_overwrite[
1064+
if isinstance(
1065+
keyword.value, ast.Name
1066+
) and keyword.value.id in self.variables_overwrite.get(self.scope, {}):
1067+
keyword.value = self.variables_overwrite[self.scope][
10491068
keyword.value.id
10501069
] # type:ignore[assignment]
10511070
res, expl = self.visit(keyword.value)
@@ -1081,12 +1100,14 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
10811100
def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
10821101
self.push_format_context()
10831102
# We first check if we have overwritten a variable in the previous assert
1084-
if isinstance(comp.left, ast.Name) and comp.left.id in self.variables_overwrite:
1085-
comp.left = self.variables_overwrite[
1103+
if isinstance(
1104+
comp.left, ast.Name
1105+
) and comp.left.id in self.variables_overwrite.get(self.scope, {}):
1106+
comp.left = self.variables_overwrite[self.scope][
10861107
comp.left.id
10871108
] # type:ignore[assignment]
10881109
if isinstance(comp.left, ast.NamedExpr):
1089-
self.variables_overwrite[
1110+
self.variables_overwrite[self.scope][
10901111
comp.left.target.id
10911112
] = comp.left # type:ignore[assignment]
10921113
left_res, left_expl = self.visit(comp.left)
@@ -1106,7 +1127,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
11061127
and next_operand.target.id == left_res.id
11071128
):
11081129
next_operand.target.id = self.variable()
1109-
self.variables_overwrite[
1130+
self.variables_overwrite[self.scope][
11101131
left_res.id
11111132
] = next_operand # type:ignore[assignment]
11121133
next_res, next_expl = self.visit(next_operand)

testing/test_assertrewrite.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1543,6 +1543,23 @@ def test_gt():
15431543
result.stdout.fnmatch_lines(["*assert 4 > 5", "*where 5 = add_one(4)"])
15441544

15451545

1546+
class TestIssue11239:
1547+
def test_assertion_walrus_different_test_cases(self, pytester: Pytester) -> None:
1548+
pytester.makepyfile(
1549+
"""
1550+
def test_1():
1551+
state = {"x": 2}.get("x")
1552+
assert state is not None
1553+
1554+
def test_2():
1555+
db = {"x": 2}
1556+
assert (state := db.get("x")) is not None
1557+
"""
1558+
)
1559+
result = pytester.runpytest()
1560+
assert result.ret == 0
1561+
1562+
15461563
@pytest.mark.skipif(
15471564
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
15481565
)

0 commit comments

Comments
 (0)