|
| 1 | +""" |
| 2 | +Find intermediate evalutation results in assert statements through builtin AST. |
| 3 | +This should replace _assertionold.py eventually. |
| 4 | +""" |
| 5 | + |
| 6 | +import ast |
| 7 | +import sys |
| 8 | + |
| 9 | +from .assertion import _format_explanation, BuiltinAssertionError |
| 10 | + |
| 11 | +if sys.platform.startswith("java") and sys.version_info < (2, 5, 2): |
| 12 | + # See http://bugs.jython.org/issue1497 |
| 13 | + _exprs = ("BoolOp", "BinOp", "UnaryOp", "Lambda", "IfExp", "Dict", |
| 14 | + "ListComp", "GeneratorExp", "Yield", "Compare", "Call", |
| 15 | + "Repr", "Num", "Str", "Attribute", "Subscript", "Name", |
| 16 | + "List", "Tuple") |
| 17 | + _stmts = ("FunctionDef", "ClassDef", "Return", "Delete", "Assign", |
| 18 | + "AugAssign", "Print", "For", "While", "If", "With", "Raise", |
| 19 | + "TryExcept", "TryFinally", "Assert", "Import", "ImportFrom", |
| 20 | + "Exec", "Global", "Expr", "Pass", "Break", "Continue") |
| 21 | + _expr_nodes = set(getattr(ast, name) for name in _exprs) |
| 22 | + _stmt_nodes = set(getattr(ast, name) for name in _stmts) |
| 23 | + def _is_ast_expr(node): |
| 24 | + return node.__class__ in _expr_nodes |
| 25 | + def _is_ast_stmt(node): |
| 26 | + return node.__class__ in _stmt_nodes |
| 27 | +else: |
| 28 | + def _is_ast_expr(node): |
| 29 | + return isinstance(node, ast.expr) |
| 30 | + def _is_ast_stmt(node): |
| 31 | + return isinstance(node, ast.stmt) |
| 32 | + |
| 33 | + |
| 34 | +class Failure(Exception): |
| 35 | + """Error found while interpreting AST.""" |
| 36 | + |
| 37 | + def __init__(self, explanation=""): |
| 38 | + self.cause = sys.exc_info() |
| 39 | + self.explanation = explanation |
| 40 | + |
| 41 | + |
| 42 | +def interpret(source, frame, should_fail=False): |
| 43 | + mod = ast.parse(source) |
| 44 | + visitor = DebugInterpreter(frame) |
| 45 | + try: |
| 46 | + visitor.visit(mod) |
| 47 | + except Failure: |
| 48 | + failure = sys.exc_info()[1] |
| 49 | + return getfailure(failure) |
| 50 | + if should_fail: |
| 51 | + return ("(assertion failed, but when it was re-run for " |
| 52 | + "printing intermediate values, it did not fail. Suggestions: " |
| 53 | + "compute assert expression before the assert or use --no-assert)") |
| 54 | + |
| 55 | +def run(offending_line, frame=None): |
| 56 | + from .code import Frame |
| 57 | + if frame is None: |
| 58 | + frame = Frame(sys._getframe(1)) |
| 59 | + return interpret(offending_line, frame) |
| 60 | + |
| 61 | +def getfailure(failure): |
| 62 | + explanation = _format_explanation(failure.explanation) |
| 63 | + value = failure.cause[1] |
| 64 | + if str(value): |
| 65 | + lines = explanation.splitlines() |
| 66 | + if not lines: |
| 67 | + lines.append("") |
| 68 | + lines[0] += " << %s" % (value,) |
| 69 | + explanation = "\n".join(lines) |
| 70 | + text = "%s: %s" % (failure.cause[0].__name__, explanation) |
| 71 | + if text.startswith("AssertionError: assert "): |
| 72 | + text = text[16:] |
| 73 | + return text |
| 74 | + |
| 75 | + |
| 76 | +operator_map = { |
| 77 | + ast.BitOr : "|", |
| 78 | + ast.BitXor : "^", |
| 79 | + ast.BitAnd : "&", |
| 80 | + ast.LShift : "<<", |
| 81 | + ast.RShift : ">>", |
| 82 | + ast.Add : "+", |
| 83 | + ast.Sub : "-", |
| 84 | + ast.Mult : "*", |
| 85 | + ast.Div : "/", |
| 86 | + ast.FloorDiv : "//", |
| 87 | + ast.Mod : "%", |
| 88 | + ast.Eq : "==", |
| 89 | + ast.NotEq : "!=", |
| 90 | + ast.Lt : "<", |
| 91 | + ast.LtE : "<=", |
| 92 | + ast.Gt : ">", |
| 93 | + ast.GtE : ">=", |
| 94 | + ast.Pow : "**", |
| 95 | + ast.Is : "is", |
| 96 | + ast.IsNot : "is not", |
| 97 | + ast.In : "in", |
| 98 | + ast.NotIn : "not in" |
| 99 | +} |
| 100 | + |
| 101 | +unary_map = { |
| 102 | + ast.Not : "not %s", |
| 103 | + ast.Invert : "~%s", |
| 104 | + ast.USub : "-%s", |
| 105 | + ast.UAdd : "+%s" |
| 106 | +} |
| 107 | + |
| 108 | + |
| 109 | +class DebugInterpreter(ast.NodeVisitor): |
| 110 | + """Interpret AST nodes to gleam useful debugging information. """ |
| 111 | + |
| 112 | + def __init__(self, frame): |
| 113 | + self.frame = frame |
| 114 | + |
| 115 | + def generic_visit(self, node): |
| 116 | + # Fallback when we don't have a special implementation. |
| 117 | + if _is_ast_expr(node): |
| 118 | + mod = ast.Expression(node) |
| 119 | + co = self._compile(mod) |
| 120 | + try: |
| 121 | + result = self.frame.eval(co) |
| 122 | + except Exception: |
| 123 | + raise Failure() |
| 124 | + explanation = self.frame.repr(result) |
| 125 | + return explanation, result |
| 126 | + elif _is_ast_stmt(node): |
| 127 | + mod = ast.Module([node]) |
| 128 | + co = self._compile(mod, "exec") |
| 129 | + try: |
| 130 | + self.frame.exec_(co) |
| 131 | + except Exception: |
| 132 | + raise Failure() |
| 133 | + return None, None |
| 134 | + else: |
| 135 | + raise AssertionError("can't handle %s" %(node,)) |
| 136 | + |
| 137 | + def _compile(self, source, mode="eval"): |
| 138 | + return compile(source, "<assertion interpretation>", mode) |
| 139 | + |
| 140 | + def visit_Expr(self, expr): |
| 141 | + return self.visit(expr.value) |
| 142 | + |
| 143 | + def visit_Module(self, mod): |
| 144 | + for stmt in mod.body: |
| 145 | + self.visit(stmt) |
| 146 | + |
| 147 | + def visit_Name(self, name): |
| 148 | + explanation, result = self.generic_visit(name) |
| 149 | + # See if the name is local. |
| 150 | + source = "%r in locals() is not globals()" % (name.id,) |
| 151 | + co = self._compile(source) |
| 152 | + try: |
| 153 | + local = self.frame.eval(co) |
| 154 | + except Exception: |
| 155 | + # have to assume it isn't |
| 156 | + local = False |
| 157 | + if not local: |
| 158 | + return name.id, result |
| 159 | + return explanation, result |
| 160 | + |
| 161 | + def visit_Compare(self, comp): |
| 162 | + left = comp.left |
| 163 | + left_explanation, left_result = self.visit(left) |
| 164 | + for op, next_op in zip(comp.ops, comp.comparators): |
| 165 | + next_explanation, next_result = self.visit(next_op) |
| 166 | + op_symbol = operator_map[op.__class__] |
| 167 | + explanation = "%s %s %s" % (left_explanation, op_symbol, |
| 168 | + next_explanation) |
| 169 | + source = "__exprinfo_left %s __exprinfo_right" % (op_symbol,) |
| 170 | + co = self._compile(source) |
| 171 | + try: |
| 172 | + result = self.frame.eval(co, __exprinfo_left=left_result, |
| 173 | + __exprinfo_right=next_result) |
| 174 | + except Exception: |
| 175 | + raise Failure(explanation) |
| 176 | + try: |
| 177 | + if not result: |
| 178 | + break |
| 179 | + except KeyboardInterrupt: |
| 180 | + raise |
| 181 | + except: |
| 182 | + break |
| 183 | + left_explanation, left_result = next_explanation, next_result |
| 184 | + |
| 185 | + import _pytest._code |
| 186 | + rcomp = _pytest._code._reprcompare |
| 187 | + if rcomp: |
| 188 | + res = rcomp(op_symbol, left_result, next_result) |
| 189 | + if res: |
| 190 | + explanation = res |
| 191 | + return explanation, result |
| 192 | + |
| 193 | + def visit_BoolOp(self, boolop): |
| 194 | + is_or = isinstance(boolop.op, ast.Or) |
| 195 | + explanations = [] |
| 196 | + for operand in boolop.values: |
| 197 | + explanation, result = self.visit(operand) |
| 198 | + explanations.append(explanation) |
| 199 | + if result == is_or: |
| 200 | + break |
| 201 | + name = is_or and " or " or " and " |
| 202 | + explanation = "(" + name.join(explanations) + ")" |
| 203 | + return explanation, result |
| 204 | + |
| 205 | + def visit_UnaryOp(self, unary): |
| 206 | + pattern = unary_map[unary.op.__class__] |
| 207 | + operand_explanation, operand_result = self.visit(unary.operand) |
| 208 | + explanation = pattern % (operand_explanation,) |
| 209 | + co = self._compile(pattern % ("__exprinfo_expr",)) |
| 210 | + try: |
| 211 | + result = self.frame.eval(co, __exprinfo_expr=operand_result) |
| 212 | + except Exception: |
| 213 | + raise Failure(explanation) |
| 214 | + return explanation, result |
| 215 | + |
| 216 | + def visit_BinOp(self, binop): |
| 217 | + left_explanation, left_result = self.visit(binop.left) |
| 218 | + right_explanation, right_result = self.visit(binop.right) |
| 219 | + symbol = operator_map[binop.op.__class__] |
| 220 | + explanation = "(%s %s %s)" % (left_explanation, symbol, |
| 221 | + right_explanation) |
| 222 | + source = "__exprinfo_left %s __exprinfo_right" % (symbol,) |
| 223 | + co = self._compile(source) |
| 224 | + try: |
| 225 | + result = self.frame.eval(co, __exprinfo_left=left_result, |
| 226 | + __exprinfo_right=right_result) |
| 227 | + except Exception: |
| 228 | + raise Failure(explanation) |
| 229 | + return explanation, result |
| 230 | + |
| 231 | + def visit_Call(self, call): |
| 232 | + func_explanation, func = self.visit(call.func) |
| 233 | + arg_explanations = [] |
| 234 | + ns = {"__exprinfo_func" : func} |
| 235 | + arguments = [] |
| 236 | + for arg in call.args: |
| 237 | + arg_explanation, arg_result = self.visit(arg) |
| 238 | + arg_name = "__exprinfo_%s" % (len(ns),) |
| 239 | + ns[arg_name] = arg_result |
| 240 | + arguments.append(arg_name) |
| 241 | + arg_explanations.append(arg_explanation) |
| 242 | + for keyword in call.keywords: |
| 243 | + arg_explanation, arg_result = self.visit(keyword.value) |
| 244 | + arg_name = "__exprinfo_%s" % (len(ns),) |
| 245 | + ns[arg_name] = arg_result |
| 246 | + keyword_source = "%s=%%s" % (keyword.arg) |
| 247 | + arguments.append(keyword_source % (arg_name,)) |
| 248 | + arg_explanations.append(keyword_source % (arg_explanation,)) |
| 249 | + if call.starargs: |
| 250 | + arg_explanation, arg_result = self.visit(call.starargs) |
| 251 | + arg_name = "__exprinfo_star" |
| 252 | + ns[arg_name] = arg_result |
| 253 | + arguments.append("*%s" % (arg_name,)) |
| 254 | + arg_explanations.append("*%s" % (arg_explanation,)) |
| 255 | + if call.kwargs: |
| 256 | + arg_explanation, arg_result = self.visit(call.kwargs) |
| 257 | + arg_name = "__exprinfo_kwds" |
| 258 | + ns[arg_name] = arg_result |
| 259 | + arguments.append("**%s" % (arg_name,)) |
| 260 | + arg_explanations.append("**%s" % (arg_explanation,)) |
| 261 | + args_explained = ", ".join(arg_explanations) |
| 262 | + explanation = "%s(%s)" % (func_explanation, args_explained) |
| 263 | + args = ", ".join(arguments) |
| 264 | + source = "__exprinfo_func(%s)" % (args,) |
| 265 | + co = self._compile(source) |
| 266 | + try: |
| 267 | + result = self.frame.eval(co, **ns) |
| 268 | + except Exception: |
| 269 | + raise Failure(explanation) |
| 270 | + pattern = "%s\n{%s = %s\n}" |
| 271 | + rep = self.frame.repr(result) |
| 272 | + explanation = pattern % (rep, rep, explanation) |
| 273 | + return explanation, result |
| 274 | + |
| 275 | + def _is_builtin_name(self, name): |
| 276 | + pattern = "%r not in globals() and %r not in locals()" |
| 277 | + source = pattern % (name.id, name.id) |
| 278 | + co = self._compile(source) |
| 279 | + try: |
| 280 | + return self.frame.eval(co) |
| 281 | + except Exception: |
| 282 | + return False |
| 283 | + |
| 284 | + def visit_Attribute(self, attr): |
| 285 | + if not isinstance(attr.ctx, ast.Load): |
| 286 | + return self.generic_visit(attr) |
| 287 | + source_explanation, source_result = self.visit(attr.value) |
| 288 | + explanation = "%s.%s" % (source_explanation, attr.attr) |
| 289 | + source = "__exprinfo_expr.%s" % (attr.attr,) |
| 290 | + co = self._compile(source) |
| 291 | + try: |
| 292 | + result = self.frame.eval(co, __exprinfo_expr=source_result) |
| 293 | + except Exception: |
| 294 | + raise Failure(explanation) |
| 295 | + explanation = "%s\n{%s = %s.%s\n}" % (self.frame.repr(result), |
| 296 | + self.frame.repr(result), |
| 297 | + source_explanation, attr.attr) |
| 298 | + # Check if the attr is from an instance. |
| 299 | + source = "%r in getattr(__exprinfo_expr, '__dict__', {})" |
| 300 | + source = source % (attr.attr,) |
| 301 | + co = self._compile(source) |
| 302 | + try: |
| 303 | + from_instance = self.frame.eval(co, __exprinfo_expr=source_result) |
| 304 | + except Exception: |
| 305 | + from_instance = True |
| 306 | + if from_instance: |
| 307 | + rep = self.frame.repr(result) |
| 308 | + pattern = "%s\n{%s = %s\n}" |
| 309 | + explanation = pattern % (rep, rep, explanation) |
| 310 | + return explanation, result |
| 311 | + |
| 312 | + def visit_Assert(self, assrt): |
| 313 | + test_explanation, test_result = self.visit(assrt.test) |
| 314 | + if test_explanation.startswith("False\n{False =") and \ |
| 315 | + test_explanation.endswith("\n"): |
| 316 | + test_explanation = test_explanation[15:-2] |
| 317 | + explanation = "assert %s" % (test_explanation,) |
| 318 | + if not test_result: |
| 319 | + try: |
| 320 | + raise BuiltinAssertionError |
| 321 | + except Exception: |
| 322 | + raise Failure(explanation) |
| 323 | + return explanation, test_result |
| 324 | + |
| 325 | + def visit_Assign(self, assign): |
| 326 | + value_explanation, value_result = self.visit(assign.value) |
| 327 | + explanation = "... = %s" % (value_explanation,) |
| 328 | + name = ast.Name("__exprinfo_expr", ast.Load(), |
| 329 | + lineno=assign.value.lineno, |
| 330 | + col_offset=assign.value.col_offset) |
| 331 | + new_assign = ast.Assign(assign.targets, name, lineno=assign.lineno, |
| 332 | + col_offset=assign.col_offset) |
| 333 | + mod = ast.Module([new_assign]) |
| 334 | + co = self._compile(mod, "exec") |
| 335 | + try: |
| 336 | + self.frame.exec_(co, __exprinfo_expr=value_result) |
| 337 | + except Exception: |
| 338 | + raise Failure(explanation) |
| 339 | + return explanation, value_result |
0 commit comments