13
13
import sys
14
14
import tokenize
15
15
import types
16
+ from collections import defaultdict
16
17
from pathlib import Path
17
18
from pathlib import PurePath
18
19
from typing import Callable
45
46
from _pytest .assertion import AssertionState
46
47
47
48
49
+ class Sentinel :
50
+ pass
51
+
52
+
48
53
assertstate_key = StashKey ["AssertionState" ]()
49
54
50
55
# pytest caches rewritten pycs in pycache dirs
51
56
PYTEST_TAG = f"{ sys .implementation .cache_tag } -pytest-{ version } "
52
57
PYC_EXT = ".py" + (__debug__ and "c" or "o" )
53
58
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
54
59
60
+ # Special marker that denotes we have just left a scope definition
61
+ _SCOPE_END_MARKER = Sentinel ()
62
+
55
63
56
64
class AssertionRewritingHook (importlib .abc .MetaPathFinder , importlib .abc .Loader ):
57
65
"""PEP302/PEP451 import hook which rewrites asserts."""
@@ -634,6 +642,8 @@ class AssertionRewriter(ast.NodeVisitor):
634
642
.push_format_context() and .pop_format_context() which allows
635
643
to build another %-formatted string while already building one.
636
644
645
+ :scope: A tuple containing the current scope used for variables_overwrite.
646
+
637
647
:variables_overwrite: A dict filled with references to variables
638
648
that change value within an assert. This happens when a variable is
639
649
reassigned with the walrus operator
@@ -655,7 +665,10 @@ def __init__(
655
665
else :
656
666
self .enable_assertion_pass_hook = False
657
667
self .source = source
658
- self .variables_overwrite : Dict [str , str ] = {}
668
+ self .scope : tuple [ast .AST , ...] = ()
669
+ self .variables_overwrite : defaultdict [
670
+ tuple [ast .AST , ...], Dict [str , str ]
671
+ ] = defaultdict (dict )
659
672
660
673
def run (self , mod : ast .Module ) -> None :
661
674
"""Find all assert statements in *mod* and rewrite them."""
@@ -719,9 +732,17 @@ def run(self, mod: ast.Module) -> None:
719
732
mod .body [pos :pos ] = imports
720
733
721
734
# Collect asserts.
722
- nodes : List [ast .AST ] = [mod ]
735
+ self .scope = (mod ,)
736
+ nodes : List [Union [ast .AST , Sentinel ]] = [mod ]
723
737
while nodes :
724
738
node = nodes .pop ()
739
+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef )):
740
+ self .scope = tuple ((* self .scope , node ))
741
+ nodes .append (_SCOPE_END_MARKER )
742
+ if node == _SCOPE_END_MARKER :
743
+ self .scope = self .scope [:- 1 ]
744
+ continue
745
+ assert isinstance (node , ast .AST )
725
746
for name , field in ast .iter_fields (node ):
726
747
if isinstance (field , list ):
727
748
new : List [ast .AST ] = []
@@ -992,7 +1013,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
992
1013
]
993
1014
):
994
1015
pytest_temp = self .variable ()
995
- self .variables_overwrite [
1016
+ self .variables_overwrite [self . scope ][
996
1017
v .left .target .id
997
1018
] = v .left # type:ignore[assignment]
998
1019
v .left .target .id = pytest_temp
@@ -1035,17 +1056,20 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
1035
1056
new_args = []
1036
1057
new_kwargs = []
1037
1058
for arg in call .args :
1038
- if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite :
1039
- arg = self .variables_overwrite [arg .id ] # type:ignore[assignment]
1059
+ if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite .get (
1060
+ self .scope , {}
1061
+ ):
1062
+ arg = self .variables_overwrite [self .scope ][
1063
+ arg .id
1064
+ ] # type:ignore[assignment]
1040
1065
res , expl = self .visit (arg )
1041
1066
arg_expls .append (expl )
1042
1067
new_args .append (res )
1043
1068
for keyword in call .keywords :
1044
- if (
1045
- isinstance (keyword .value , ast .Name )
1046
- and keyword .value .id in self .variables_overwrite
1047
- ):
1048
- keyword .value = self .variables_overwrite [
1069
+ if isinstance (
1070
+ keyword .value , ast .Name
1071
+ ) and keyword .value .id in self .variables_overwrite .get (self .scope , {}):
1072
+ keyword .value = self .variables_overwrite [self .scope ][
1049
1073
keyword .value .id
1050
1074
] # type:ignore[assignment]
1051
1075
res , expl = self .visit (keyword .value )
@@ -1081,12 +1105,14 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
1081
1105
def visit_Compare (self , comp : ast .Compare ) -> Tuple [ast .expr , str ]:
1082
1106
self .push_format_context ()
1083
1107
# We first check if we have overwritten a variable in the previous assert
1084
- if isinstance (comp .left , ast .Name ) and comp .left .id in self .variables_overwrite :
1085
- comp .left = self .variables_overwrite [
1108
+ if isinstance (
1109
+ comp .left , ast .Name
1110
+ ) and comp .left .id in self .variables_overwrite .get (self .scope , {}):
1111
+ comp .left = self .variables_overwrite [self .scope ][
1086
1112
comp .left .id
1087
1113
] # type:ignore[assignment]
1088
1114
if isinstance (comp .left , ast .NamedExpr ):
1089
- self .variables_overwrite [
1115
+ self .variables_overwrite [self . scope ][
1090
1116
comp .left .target .id
1091
1117
] = comp .left # type:ignore[assignment]
1092
1118
left_res , left_expl = self .visit (comp .left )
@@ -1106,7 +1132,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
1106
1132
and next_operand .target .id == left_res .id
1107
1133
):
1108
1134
next_operand .target .id = self .variable ()
1109
- self .variables_overwrite [
1135
+ self .variables_overwrite [self . scope ][
1110
1136
left_res .id
1111
1137
] = next_operand # type:ignore[assignment]
1112
1138
next_res , next_expl = self .visit (next_operand )
0 commit comments