13
13
import sys
14
14
import tokenize
15
15
import types
16
+ from collections import defaultdict
16
17
from pathlib import Path
17
18
from pathlib import PurePath
18
19
from typing import Callable
52
53
PYC_EXT = ".py" + (__debug__ and "c" or "o" )
53
54
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT
54
55
56
+ _SCOPE_END_MARKER = object ()
57
+
55
58
56
59
class AssertionRewritingHook (importlib .abc .MetaPathFinder , importlib .abc .Loader ):
57
60
"""PEP302/PEP451 import hook which rewrites asserts."""
@@ -634,6 +637,8 @@ class AssertionRewriter(ast.NodeVisitor):
634
637
.push_format_context() and .pop_format_context() which allows
635
638
to build another %-formatted string while already building one.
636
639
640
+ :scope: A tuple containing the current scope used for variables_overwrite.
641
+
637
642
:variables_overwrite: A dict filled with references to variables
638
643
that change value within an assert. This happens when a variable is
639
644
reassigned with the walrus operator
@@ -655,7 +660,10 @@ def __init__(
655
660
else :
656
661
self .enable_assertion_pass_hook = False
657
662
self .source = source
658
- self .variables_overwrite : Dict [str , str ] = {}
663
+ self .scope : tuple [ast .AST , ...] = ()
664
+ self .variables_overwrite : defaultdict [
665
+ tuple [ast .AST , ...], Dict [str , str ]
666
+ ] = defaultdict (dict )
659
667
660
668
def run (self , mod : ast .Module ) -> None :
661
669
"""Find all assert statements in *mod* and rewrite them."""
@@ -719,9 +727,17 @@ def run(self, mod: ast.Module) -> None:
719
727
mod .body [pos :pos ] = imports
720
728
721
729
# Collect asserts.
722
- nodes : List [ast .AST ] = [mod ]
730
+ self .scope = (mod ,)
731
+ nodes : List [Union [ast .AST , object ]] = [mod ]
723
732
while nodes :
724
733
node = nodes .pop ()
734
+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef , ast .ClassDef )):
735
+ self .scope = tuple ((* self .scope , node ))
736
+ nodes .append (_SCOPE_END_MARKER )
737
+ if node == _SCOPE_END_MARKER :
738
+ self .scope = self .scope [:- 1 ]
739
+ continue
740
+ assert isinstance (node , ast .AST )
725
741
for name , field in ast .iter_fields (node ):
726
742
if isinstance (field , list ):
727
743
new : List [ast .AST ] = []
@@ -992,7 +1008,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
992
1008
]
993
1009
):
994
1010
pytest_temp = self .variable ()
995
- self .variables_overwrite [
1011
+ self .variables_overwrite [self . scope ][
996
1012
v .left .target .id
997
1013
] = v .left # type:ignore[assignment]
998
1014
v .left .target .id = pytest_temp
@@ -1035,17 +1051,20 @@ def visit_Call(self, call: ast.Call) -> Tuple[ast.Name, str]:
1035
1051
new_args = []
1036
1052
new_kwargs = []
1037
1053
for arg in call .args :
1038
- if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite :
1039
- arg = self .variables_overwrite [arg .id ] # type:ignore[assignment]
1054
+ if isinstance (arg , ast .Name ) and arg .id in self .variables_overwrite .get (
1055
+ self .scope , {}
1056
+ ):
1057
+ arg = self .variables_overwrite [self .scope ][
1058
+ arg .id
1059
+ ] # type:ignore[assignment]
1040
1060
res , expl = self .visit (arg )
1041
1061
arg_expls .append (expl )
1042
1062
new_args .append (res )
1043
1063
for keyword in call .keywords :
1044
- if (
1045
- isinstance (keyword .value , ast .Name )
1046
- and keyword .value .id in self .variables_overwrite
1047
- ):
1048
- keyword .value = self .variables_overwrite [
1064
+ if isinstance (
1065
+ keyword .value , ast .Name
1066
+ ) and keyword .value .id in self .variables_overwrite .get (self .scope , {}):
1067
+ keyword .value = self .variables_overwrite [self .scope ][
1049
1068
keyword .value .id
1050
1069
] # type:ignore[assignment]
1051
1070
res , expl = self .visit (keyword .value )
@@ -1081,12 +1100,14 @@ def visit_Attribute(self, attr: ast.Attribute) -> Tuple[ast.Name, str]:
1081
1100
def visit_Compare (self , comp : ast .Compare ) -> Tuple [ast .expr , str ]:
1082
1101
self .push_format_context ()
1083
1102
# We first check if we have overwritten a variable in the previous assert
1084
- if isinstance (comp .left , ast .Name ) and comp .left .id in self .variables_overwrite :
1085
- comp .left = self .variables_overwrite [
1103
+ if isinstance (
1104
+ comp .left , ast .Name
1105
+ ) and comp .left .id in self .variables_overwrite .get (self .scope , {}):
1106
+ comp .left = self .variables_overwrite [self .scope ][
1086
1107
comp .left .id
1087
1108
] # type:ignore[assignment]
1088
1109
if isinstance (comp .left , ast .NamedExpr ):
1089
- self .variables_overwrite [
1110
+ self .variables_overwrite [self . scope ][
1090
1111
comp .left .target .id
1091
1112
] = comp .left # type:ignore[assignment]
1092
1113
left_res , left_expl = self .visit (comp .left )
@@ -1106,7 +1127,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
1106
1127
and next_operand .target .id == left_res .id
1107
1128
):
1108
1129
next_operand .target .id = self .variable ()
1109
- self .variables_overwrite [
1130
+ self .variables_overwrite [self . scope ][
1110
1131
left_res .id
1111
1132
] = next_operand # type:ignore[assignment]
1112
1133
next_res , next_expl = self .visit (next_operand )
0 commit comments