Skip to content

1/X Fix check_untyped_defs = True mypy errors #5673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Aug 20, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions src/_pytest/_code/code.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,15 @@
from inspect import CO_VARARGS
from inspect import CO_VARKEYWORDS
from traceback import format_exception_only
from types import CodeType
from types import TracebackType
from typing import Any
from typing import Dict
from typing import Generic
from typing import List
from typing import Optional
from typing import Pattern
from typing import Set
from typing import Tuple
from typing import TypeVar
from typing import Union
Expand All @@ -29,7 +34,7 @@
class Code:
""" wrapper around Python code objects """

def __init__(self, rawcode):
def __init__(self, rawcode) -> None:
if not hasattr(rawcode, "co_filename"):
rawcode = getrawcode(rawcode)
try:
Expand All @@ -38,7 +43,7 @@ def __init__(self, rawcode):
self.name = rawcode.co_name
except AttributeError:
raise TypeError("not a code object: {!r}".format(rawcode))
self.raw = rawcode
self.raw = rawcode # type: CodeType

def __eq__(self, other):
return self.raw == other.raw
Expand Down Expand Up @@ -351,7 +356,7 @@ def recursionindex(self):
""" return the index of the frame/TracebackEntry where recursion
originates if appropriate, None if no recursion occurred
"""
cache = {}
cache = {} # type: Dict[Tuple[Any, int, int], List[Dict[str, Any]]]
for i, entry in enumerate(self):
# id for the code.raw is needed to work around
# the strange metaprogramming in the decorator lib from pypi
Expand Down Expand Up @@ -650,7 +655,7 @@ def repr_args(self, entry):
args.append((argname, saferepr(argvalue)))
return ReprFuncArgs(args)

def get_source(self, source, line_index=-1, excinfo=None, short=False):
def get_source(self, source, line_index=-1, excinfo=None, short=False) -> List[str]:
""" return formatted and marked up source lines. """
import _pytest._code

Expand Down Expand Up @@ -722,7 +727,7 @@ def repr_traceback_entry(self, entry, excinfo=None):
else:
line_index = entry.lineno - entry.getfirstlinesource()

lines = []
lines = [] # type: List[str]
style = entry._repr_style
if style is None:
style = self.style
Expand Down Expand Up @@ -799,7 +804,7 @@ def _truncate_recursive_traceback(self, traceback):
exc_msg=str(e),
max_frames=max_frames,
total=len(traceback),
)
) # type: Optional[str]
traceback = traceback[:max_frames] + traceback[-max_frames:]
else:
if recursionindex is not None:
Expand All @@ -812,10 +817,12 @@ def _truncate_recursive_traceback(self, traceback):

def repr_excinfo(self, excinfo):

repr_chain = []
repr_chain = (
[]
) # type: List[Tuple[ReprTraceback, Optional[ReprFileLocation], Optional[str]]]
e = excinfo.value
descr = None
seen = set()
seen = set() # type: Set[int]
while e is not None and id(e) not in seen:
seen.add(id(e))
if excinfo:
Expand Down Expand Up @@ -868,8 +875,8 @@ def __repr__(self):


class ExceptionRepr(TerminalRepr):
def __init__(self):
self.sections = []
def __init__(self) -> None:
self.sections = [] # type: List[Tuple[str, str, str]]

def addsection(self, name, content, sep="-"):
self.sections.append((name, content, sep))
Expand Down
13 changes: 7 additions & 6 deletions src/_pytest/_code/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import warnings
from ast import PyCF_ONLY_AST as _AST_FLAG
from bisect import bisect_right
from typing import List

import py

Expand All @@ -19,11 +20,11 @@ class Source:
_compilecounter = 0

def __init__(self, *parts, **kwargs):
self.lines = lines = []
self.lines = lines = [] # type: List[str]
de = kwargs.get("deindent", True)
for part in parts:
if not part:
partlines = []
partlines = [] # type: List[str]
elif isinstance(part, Source):
partlines = part.lines
elif isinstance(part, (tuple, list)):
Expand Down Expand Up @@ -157,8 +158,7 @@ def compile(
source = "\n".join(self.lines) + "\n"
try:
co = compile(source, filename, mode, flag)
except SyntaxError:
ex = sys.exc_info()[1]
except SyntaxError as ex:
# re-represent syntax errors from parsing python strings
msglines = self.lines[: ex.lineno]
if ex.offset:
Expand All @@ -173,7 +173,8 @@ def compile(
if flag & _AST_FLAG:
return co
lines = [(x + "\n") for x in self.lines]
linecache.cache[filename] = (1, None, lines, filename)
# Type ignored because linecache.cache is private.
linecache.cache[filename] = (1, None, lines, filename) # type: ignore
return co


Expand Down Expand Up @@ -282,7 +283,7 @@ def get_statement_startend2(lineno, node):
return start, end


def getstatementrange_ast(lineno, source, assertion=False, astnode=None):
def getstatementrange_ast(lineno, source: Source, assertion=False, astnode=None):
if astnode is None:
content = str(source)
# See #4260:
Expand Down
11 changes: 8 additions & 3 deletions src/_pytest/assertion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
support for presenting detailed information in failing assertions.
"""
import sys
from typing import Optional

from _pytest.assertion import rewrite
from _pytest.assertion import truncate
Expand Down Expand Up @@ -52,7 +53,9 @@ def register_assert_rewrite(*names):
importhook = hook
break
else:
importhook = DummyRewriteHook()
# TODO(typing): Add a protocol for mark_rewrite() and use it
# for importhook and for PytestPluginManager.rewrite_hook.
importhook = DummyRewriteHook() # type: ignore
importhook.mark_rewrite(*names)


Expand All @@ -69,7 +72,7 @@ class AssertionState:
def __init__(self, config, mode):
self.mode = mode
self.trace = config.trace.root.get("assertion")
self.hook = None
self.hook = None # type: Optional[rewrite.AssertionRewritingHook]


def install_importhook(config):
Expand Down Expand Up @@ -108,6 +111,7 @@ def pytest_runtest_setup(item):
"""

def callbinrepr(op, left, right):
# type: (str, object, object) -> Optional[str]
"""Call the pytest_assertrepr_compare hook and prepare the result

This uses the first result from the hook and then ensures the
Expand All @@ -133,12 +137,13 @@ def callbinrepr(op, left, right):
if item.config.getvalue("assertmode") == "rewrite":
res = res.replace("%", "%%")
return res
return None

util._reprcompare = callbinrepr

if item.ihook.pytest_assertion_pass.get_hookimpls():

def call_assertion_pass_hook(lineno, expl, orig):
def call_assertion_pass_hook(lineno, orig, expl):
item.ihook.pytest_assertion_pass(
item=item, lineno=lineno, orig=orig, expl=expl
)
Expand Down
50 changes: 28 additions & 22 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import ast
import errno
import functools
import importlib.abc
import importlib.machinery
import importlib.util
import io
Expand All @@ -16,6 +17,7 @@
from typing import List
from typing import Optional
from typing import Set
from typing import Tuple

import atomicwrites

Expand All @@ -37,7 +39,7 @@
AST_NONE = ast.NameConstant(None)


class AssertionRewritingHook:
class AssertionRewritingHook(importlib.abc.MetaPathFinder):
"""PEP302/PEP451 import hook which rewrites asserts."""

def __init__(self, config):
Expand All @@ -47,13 +49,13 @@ def __init__(self, config):
except ValueError:
self.fnpats = ["test_*.py", "*_test.py"]
self.session = None
self._rewritten_names = set()
self._must_rewrite = set()
self._rewritten_names = set() # type: Set[str]
self._must_rewrite = set() # type: Set[str]
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
# which might result in infinite recursion (#3506)
self._writing_pyc = False
self._basenames_to_check_rewrite = {"conftest"}
self._marked_for_rewrite_cache = {}
self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
self._session_paths_checked = False

def set_session(self, session):
Expand Down Expand Up @@ -202,7 +204,7 @@ def _should_rewrite(self, name, fn, state):

return self._is_marked_for_rewrite(name, state)

def _is_marked_for_rewrite(self, name, state):
def _is_marked_for_rewrite(self, name: str, state):
try:
return self._marked_for_rewrite_cache[name]
except KeyError:
Expand All @@ -217,7 +219,7 @@ def _is_marked_for_rewrite(self, name, state):
self._marked_for_rewrite_cache[name] = False
return False

def mark_rewrite(self, *names):
def mark_rewrite(self, *names: str) -> None:
"""Mark import names as needing to be rewritten.

The named module or package as well as any nested modules will
Expand Down Expand Up @@ -384,6 +386,7 @@ def _format_boolop(explanations, is_or):


def _call_reprcompare(ops, results, expls, each_obj):
# type: (Tuple[str, ...], Tuple[bool, ...], Tuple[str, ...], Tuple[object, ...]) -> str
for i, res, expl in zip(range(len(ops)), results, expls):
try:
done = not res
Expand All @@ -399,11 +402,13 @@ def _call_reprcompare(ops, results, expls, each_obj):


def _call_assertion_pass(lineno, orig, expl):
# type: (int, str, str) -> None
if util._assertion_pass is not None:
util._assertion_pass(lineno=lineno, orig=orig, expl=expl)
util._assertion_pass(lineno, orig, expl)


def _check_if_assertion_pass_impl():
# type: () -> bool
"""Checks if any plugins implement the pytest_assertion_pass hook
in order not to generate explanation unecessarily (might be expensive)"""
return True if util._assertion_pass else False
Expand Down Expand Up @@ -577,7 +582,7 @@ def __init__(self, module_path, config, source):
def _assert_expr_to_lineno(self):
return _get_assertion_exprs(self.source)

def run(self, mod):
def run(self, mod: ast.Module) -> None:
"""Find all assert statements in *mod* and rewrite them."""
if not mod.body:
# Nothing to do.
Expand Down Expand Up @@ -619,12 +624,12 @@ def run(self, mod):
]
mod.body[pos:pos] = imports
# Collect asserts.
nodes = [mod]
nodes = [mod] # type: List[ast.AST]
while nodes:
node = nodes.pop()
for name, field in ast.iter_fields(node):
if isinstance(field, list):
new = []
new = [] # type: List
for i, child in enumerate(field):
if isinstance(child, ast.Assert):
# Transform assert.
Expand Down Expand Up @@ -698,7 +703,7 @@ def push_format_context(self):
.explanation_param().

"""
self.explanation_specifiers = {}
self.explanation_specifiers = {} # type: Dict[str, ast.expr]
self.stack.append(self.explanation_specifiers)

def pop_format_context(self, expl_expr):
Expand Down Expand Up @@ -741,7 +746,8 @@ def visit_Assert(self, assert_):
from _pytest.warning_types import PytestAssertRewriteWarning
import warnings

warnings.warn_explicit(
# Ignore type: typeshed bug https://github.com/python/typeshed/pull/3121
warnings.warn_explicit( # type: ignore
PytestAssertRewriteWarning(
"assertion is always true, perhaps remove parentheses?"
),
Expand All @@ -750,15 +756,15 @@ def visit_Assert(self, assert_):
lineno=assert_.lineno,
)

self.statements = []
self.variables = []
self.statements = [] # type: List[ast.stmt]
self.variables = [] # type: List[str]
self.variable_counter = itertools.count()

if self.enable_assertion_pass_hook:
self.format_variables = []
self.format_variables = [] # type: List[str]

self.stack = []
self.expl_stmts = []
self.stack = [] # type: List[Dict[str, ast.expr]]
self.expl_stmts = [] # type: List[ast.stmt]
self.push_format_context()
# Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test)
Expand Down Expand Up @@ -896,7 +902,7 @@ def visit_BoolOp(self, boolop):
# Process each operand, short-circuiting if needed.
for i, v in enumerate(boolop.values):
if i:
fail_inner = []
fail_inner = [] # type: List[ast.stmt]
# cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.expl_stmts = fail_inner
Expand All @@ -907,10 +913,10 @@ def visit_BoolOp(self, boolop):
call = ast.Call(app, [expl_format], [])
self.expl_stmts.append(ast.Expr(call))
if i < levels:
cond = res
cond = res # type: ast.expr
if is_or:
cond = ast.UnaryOp(ast.Not(), cond)
inner = []
inner = [] # type: List[ast.stmt]
self.statements.append(ast.If(cond, inner, []))
self.statements = body = inner
self.statements = save
Expand Down Expand Up @@ -976,7 +982,7 @@ def visit_Attribute(self, attr):
expl = pat % (res_expl, res_expl, value_expl, attr.attr)
return res, expl

def visit_Compare(self, comp):
def visit_Compare(self, comp: ast.Compare):
self.push_format_context()
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
Expand Down Expand Up @@ -1009,7 +1015,7 @@ def visit_Compare(self, comp):
ast.Tuple(results, ast.Load()),
)
if len(comp.ops) > 1:
res = ast.BoolOp(ast.And(), load_names)
res = ast.BoolOp(ast.And(), load_names) # type: ast.expr
else:
res = load_names[0]
return res, self.explanation_param(self.pop_format_context(expl_call))
Expand Down
Loading