Skip to content

Commit f54ec30

Browse files
authored
Merge pull request #7862 from asottile/comm2ann
py36+: com2ann
2 parents 703e891 + 33d119f commit f54ec30

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

62 files changed

+431
-443
lines changed

.pre-commit-config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ repos:
4444
- id: setup-cfg-fmt
4545
# TODO: when upgrading setup-cfg-fmt this can be removed
4646
args: [--max-py-version=3.9]
47+
- repo: https://github.com/pre-commit/pygrep-hooks
48+
rev: v1.6.0
49+
hooks:
50+
- id: python-use-type-annotations
4751
- repo: https://github.com/pre-commit/mirrors-mypy
4852
rev: v0.782 # NOTE: keep this in sync with setup.cfg.
4953
hooks:

src/_pytest/_argcomplete.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __call__(self, prefix: str, **kwargs: Any) -> List[str]:
103103
import argcomplete.completers
104104
except ImportError:
105105
sys.exit(-1)
106-
filescompleter = FastFilesCompleter() # type: Optional[FastFilesCompleter]
106+
filescompleter: Optional[FastFilesCompleter] = FastFilesCompleter()
107107

108108
def try_argcomplete(parser: argparse.ArgumentParser) -> None:
109109
argcomplete.autocomplete(parser, always_complete_options=False)

src/_pytest/_code/code.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def getargs(self, var: bool = False):
164164
class TracebackEntry:
165165
"""A single entry in a Traceback."""
166166

167-
_repr_style = None # type: Optional[Literal["short", "long"]]
167+
_repr_style: Optional['Literal["short", "long"]'] = None
168168
exprinfo = None
169169

170170
def __init__(
@@ -246,9 +246,9 @@ def ishidden(self) -> bool:
246246
247247
Mostly for internal use.
248248
"""
249-
tbh = (
249+
tbh: Union[bool, Callable[[Optional[ExceptionInfo[BaseException]]], bool]] = (
250250
False
251-
) # type: Union[bool, Callable[[Optional[ExceptionInfo[BaseException]]], bool]]
251+
)
252252
for maybe_ns_dct in (self.frame.f_locals, self.frame.f_globals):
253253
# in normal cases, f_locals and f_globals are dictionaries
254254
# however via `exec(...)` / `eval(...)` they can be other types
@@ -301,7 +301,7 @@ def __init__(
301301
if isinstance(tb, TracebackType):
302302

303303
def f(cur: TracebackType) -> Iterable[TracebackEntry]:
304-
cur_ = cur # type: Optional[TracebackType]
304+
cur_: Optional[TracebackType] = cur
305305
while cur_ is not None:
306306
yield TracebackEntry(cur_, excinfo=excinfo)
307307
cur_ = cur_.tb_next
@@ -381,7 +381,7 @@ def getcrashentry(self) -> TracebackEntry:
381381
def recursionindex(self) -> Optional[int]:
382382
"""Return the index of the frame/TracebackEntry where recursion originates if
383383
appropriate, None if no recursion occurred."""
384-
cache = {} # type: Dict[Tuple[Any, int, int], List[Dict[str, Any]]]
384+
cache: Dict[Tuple[Any, int, int], List[Dict[str, Any]]] = {}
385385
for i, entry in enumerate(self):
386386
# id for the code.raw is needed to work around
387387
# the strange metaprogramming in the decorator lib from pypi
@@ -760,7 +760,7 @@ def repr_traceback_entry(
760760
entry: TracebackEntry,
761761
excinfo: Optional[ExceptionInfo[BaseException]] = None,
762762
) -> "ReprEntry":
763-
lines = [] # type: List[str]
763+
lines: List[str] = []
764764
style = entry._repr_style if entry._repr_style is not None else self.style
765765
if style in ("short", "long"):
766766
source = self._getentrysource(entry)
@@ -842,7 +842,7 @@ def _truncate_recursive_traceback(
842842
recursionindex = traceback.recursionindex()
843843
except Exception as e:
844844
max_frames = 10
845-
extraline = (
845+
extraline: Optional[str] = (
846846
"!!! Recursion error detected, but an error occurred locating the origin of recursion.\n"
847847
" The following exception happened when comparing locals in the stack frame:\n"
848848
" {exc_type}: {exc_msg}\n"
@@ -852,7 +852,7 @@ def _truncate_recursive_traceback(
852852
exc_msg=str(e),
853853
max_frames=max_frames,
854854
total=len(traceback),
855-
) # type: Optional[str]
855+
)
856856
# Type ignored because adding two instaces of a List subtype
857857
# currently incorrectly has type List instead of the subtype.
858858
traceback = traceback[:max_frames] + traceback[-max_frames:] # type: ignore
@@ -868,20 +868,20 @@ def _truncate_recursive_traceback(
868868
def repr_excinfo(
869869
self, excinfo: ExceptionInfo[BaseException]
870870
) -> "ExceptionChainRepr":
871-
repr_chain = (
872-
[]
873-
) # type: List[Tuple[ReprTraceback, Optional[ReprFileLocation], Optional[str]]]
874-
e = excinfo.value # type: Optional[BaseException]
875-
excinfo_ = excinfo # type: Optional[ExceptionInfo[BaseException]]
871+
repr_chain: List[
872+
Tuple[ReprTraceback, Optional[ReprFileLocation], Optional[str]]
873+
] = []
874+
e: Optional[BaseException] = excinfo.value
875+
excinfo_: Optional[ExceptionInfo[BaseException]] = excinfo
876876
descr = None
877-
seen = set() # type: Set[int]
877+
seen: Set[int] = set()
878878
while e is not None and id(e) not in seen:
879879
seen.add(id(e))
880880
if excinfo_:
881881
reprtraceback = self.repr_traceback(excinfo_)
882-
reprcrash = (
882+
reprcrash: Optional[ReprFileLocation] = (
883883
excinfo_._getreprcrash() if self.style != "value" else None
884-
) # type: Optional[ReprFileLocation]
884+
)
885885
else:
886886
# Fallback to native repr if the exception doesn't have a traceback:
887887
# ExceptionInfo objects require a full traceback to work.
@@ -936,11 +936,11 @@ def toterminal(self, tw: TerminalWriter) -> None:
936936
@attr.s(eq=False)
937937
class ExceptionRepr(TerminalRepr):
938938
# Provided by subclasses.
939-
reprcrash = None # type: Optional[ReprFileLocation]
940-
reprtraceback = None # type: ReprTraceback
939+
reprcrash: Optional["ReprFileLocation"]
940+
reprtraceback: "ReprTraceback"
941941

942942
def __attrs_post_init__(self) -> None:
943-
self.sections = [] # type: List[Tuple[str, str, str]]
943+
self.sections: List[Tuple[str, str, str]] = []
944944

945945
def addsection(self, name: str, content: str, sep: str = "-") -> None:
946946
self.sections.append((name, content, sep))
@@ -1022,7 +1022,7 @@ def __init__(self, tblines: Sequence[str]) -> None:
10221022
@attr.s(eq=False)
10231023
class ReprEntryNative(TerminalRepr):
10241024
lines = attr.ib(type=Sequence[str])
1025-
style = "native" # type: _TracebackStyle
1025+
style: "_TracebackStyle" = "native"
10261026

10271027
def toterminal(self, tw: TerminalWriter) -> None:
10281028
tw.write("".join(self.lines))
@@ -1058,9 +1058,9 @@ def _write_entry_lines(self, tw: TerminalWriter) -> None:
10581058
# such as "> assert 0"
10591059
fail_marker = f"{FormattedExcinfo.fail_marker} "
10601060
indent_size = len(fail_marker)
1061-
indents = [] # type: List[str]
1062-
source_lines = [] # type: List[str]
1063-
failure_lines = [] # type: List[str]
1061+
indents: List[str] = []
1062+
source_lines: List[str] = []
1063+
failure_lines: List[str] = []
10641064
for index, line in enumerate(self.lines):
10651065
is_failure_line = line.startswith(fail_marker)
10661066
if is_failure_line:

src/_pytest/_code/source.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Source:
2121

2222
def __init__(self, obj: object = None) -> None:
2323
if not obj:
24-
self.lines = [] # type: List[str]
24+
self.lines: List[str] = []
2525
elif isinstance(obj, Source):
2626
self.lines = obj.lines
2727
elif isinstance(obj, (tuple, list)):
@@ -144,12 +144,12 @@ def deindent(lines: Iterable[str]) -> List[str]:
144144
def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[int]]:
145145
# Flatten all statements and except handlers into one lineno-list.
146146
# AST's line numbers start indexing at 1.
147-
values = [] # type: List[int]
147+
values: List[int] = []
148148
for x in ast.walk(node):
149149
if isinstance(x, (ast.stmt, ast.ExceptHandler)):
150150
values.append(x.lineno - 1)
151151
for name in ("finalbody", "orelse"):
152-
val = getattr(x, name, None) # type: Optional[List[ast.stmt]]
152+
val: Optional[List[ast.stmt]] = getattr(x, name, None)
153153
if val:
154154
# Treat the finally/orelse part as its own statement.
155155
values.append(val[0].lineno - 1 - 1)

src/_pytest/_io/terminalwriter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(self, file: Optional[TextIO] = None) -> None:
7676
self._file = file
7777
self.hasmarkup = should_do_markup(file)
7878
self._current_line = ""
79-
self._terminal_width = None # type: Optional[int]
79+
self._terminal_width: Optional[int] = None
8080
self.code_highlight = True
8181

8282
@property
@@ -204,7 +204,7 @@ def _highlight(self, source: str) -> str:
204204
except ImportError:
205205
return source
206206
else:
207-
highlighted = highlight(
207+
highlighted: str = highlight(
208208
source, PythonLexer(), TerminalFormatter(bg="dark")
209-
) # type: str
209+
)
210210
return highlighted

src/_pytest/assertion/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ class AssertionState:
8383
def __init__(self, config: Config, mode) -> None:
8484
self.mode = mode
8585
self.trace = config.trace.root.get("assertion")
86-
self.hook = None # type: Optional[rewrite.AssertionRewritingHook]
86+
self.hook: Optional[rewrite.AssertionRewritingHook] = None
8787

8888

8989
def install_importhook(config: Config) -> rewrite.AssertionRewritingHook:

src/_pytest/assertion/rewrite.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,14 @@ def __init__(self, config: Config) -> None:
6262
self.fnpats = config.getini("python_files")
6363
except ValueError:
6464
self.fnpats = ["test_*.py", "*_test.py"]
65-
self.session = None # type: Optional[Session]
66-
self._rewritten_names = set() # type: Set[str]
67-
self._must_rewrite = set() # type: Set[str]
65+
self.session: Optional[Session] = None
66+
self._rewritten_names: Set[str] = set()
67+
self._must_rewrite: Set[str] = set()
6868
# flag to guard against trying to rewrite a pyc file while we are already writing another pyc file,
6969
# which might result in infinite recursion (#3506)
7070
self._writing_pyc = False
7171
self._basenames_to_check_rewrite = {"conftest"}
72-
self._marked_for_rewrite_cache = {} # type: Dict[str, bool]
72+
self._marked_for_rewrite_cache: Dict[str, bool] = {}
7373
self._session_paths_checked = False
7474

7575
def set_session(self, session: Optional[Session]) -> None:
@@ -529,12 +529,12 @@ def _fix(node, lineno, col_offset):
529529

530530
def _get_assertion_exprs(src: bytes) -> Dict[int, str]:
531531
"""Return a mapping from {lineno: "assertion test expression"}."""
532-
ret = {} # type: Dict[int, str]
532+
ret: Dict[int, str] = {}
533533

534534
depth = 0
535-
lines = [] # type: List[str]
536-
assert_lineno = None # type: Optional[int]
537-
seen_lines = set() # type: Set[int]
535+
lines: List[str] = []
536+
assert_lineno: Optional[int] = None
537+
seen_lines: Set[int] = set()
538538

539539
def _write_and_reset() -> None:
540540
nonlocal depth, lines, assert_lineno, seen_lines
@@ -699,12 +699,12 @@ def run(self, mod: ast.Module) -> None:
699699
]
700700
mod.body[pos:pos] = imports
701701
# Collect asserts.
702-
nodes = [mod] # type: List[ast.AST]
702+
nodes: List[ast.AST] = [mod]
703703
while nodes:
704704
node = nodes.pop()
705705
for name, field in ast.iter_fields(node):
706706
if isinstance(field, list):
707-
new = [] # type: List[ast.AST]
707+
new: List[ast.AST] = []
708708
for i, child in enumerate(field):
709709
if isinstance(child, ast.Assert):
710710
# Transform assert.
@@ -776,7 +776,7 @@ def push_format_context(self) -> None:
776776
to format a string of %-formatted values as added by
777777
.explanation_param().
778778
"""
779-
self.explanation_specifiers = {} # type: Dict[str, ast.expr]
779+
self.explanation_specifiers: Dict[str, ast.expr] = {}
780780
self.stack.append(self.explanation_specifiers)
781781

782782
def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
@@ -828,15 +828,15 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
828828
lineno=assert_.lineno,
829829
)
830830

831-
self.statements = [] # type: List[ast.stmt]
832-
self.variables = [] # type: List[str]
831+
self.statements: List[ast.stmt] = []
832+
self.variables: List[str] = []
833833
self.variable_counter = itertools.count()
834834

835835
if self.enable_assertion_pass_hook:
836-
self.format_variables = [] # type: List[str]
836+
self.format_variables: List[str] = []
837837

838-
self.stack = [] # type: List[Dict[str, ast.expr]]
839-
self.expl_stmts = [] # type: List[ast.stmt]
838+
self.stack: List[Dict[str, ast.expr]] = []
839+
self.expl_stmts: List[ast.stmt] = []
840840
self.push_format_context()
841841
# Rewrite assert into a bunch of statements.
842842
top_condition, explanation = self.visit(assert_.test)
@@ -943,7 +943,7 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
943943
# Process each operand, short-circuiting if needed.
944944
for i, v in enumerate(boolop.values):
945945
if i:
946-
fail_inner = [] # type: List[ast.stmt]
946+
fail_inner: List[ast.stmt] = []
947947
# cond is set in a prior loop iteration below
948948
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
949949
self.expl_stmts = fail_inner
@@ -954,10 +954,10 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> Tuple[ast.Name, str]:
954954
call = ast.Call(app, [expl_format], [])
955955
self.expl_stmts.append(ast.Expr(call))
956956
if i < levels:
957-
cond = res # type: ast.expr
957+
cond: ast.expr = res
958958
if is_or:
959959
cond = ast.UnaryOp(ast.Not(), cond)
960-
inner = [] # type: List[ast.stmt]
960+
inner: List[ast.stmt] = []
961961
self.statements.append(ast.If(cond, inner, []))
962962
self.statements = body = inner
963963
self.statements = save
@@ -1053,7 +1053,7 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
10531053
ast.Tuple(results, ast.Load()),
10541054
)
10551055
if len(comp.ops) > 1:
1056-
res = ast.BoolOp(ast.And(), load_names) # type: ast.expr
1056+
res: ast.expr = ast.BoolOp(ast.And(), load_names)
10571057
else:
10581058
res = load_names[0]
10591059
return res, self.explanation_param(self.pop_format_context(expl_call))

src/_pytest/assertion/util.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
# interpretation code and assertion rewriter to detect this plugin was
2222
# loaded and in turn call the hooks defined here as part of the
2323
# DebugInterpreter.
24-
_reprcompare = None # type: Optional[Callable[[str, object, object], Optional[str]]]
24+
_reprcompare: Optional[Callable[[str, object, object], Optional[str]]] = None
2525

2626
# Works similarly as _reprcompare attribute. Is populated with the hook call
2727
# when pytest_runtest_setup is called.
28-
_assertion_pass = None # type: Optional[Callable[[int, str, str], None]]
28+
_assertion_pass: Optional[Callable[[int, str, str], None]] = None
2929

3030

3131
def format_explanation(explanation: str) -> str:
@@ -197,7 +197,7 @@ def _diff_text(left: str, right: str, verbose: int = 0) -> List[str]:
197197
"""
198198
from difflib import ndiff
199199

200-
explanation = [] # type: List[str]
200+
explanation: List[str] = []
201201

202202
if verbose < 1:
203203
i = 0 # just in case left or right has zero length
@@ -242,7 +242,7 @@ def _compare_eq_verbose(left: Any, right: Any) -> List[str]:
242242
left_lines = repr(left).splitlines(keepends)
243243
right_lines = repr(right).splitlines(keepends)
244244

245-
explanation = [] # type: List[str]
245+
explanation: List[str] = []
246246
explanation += ["+" + line for line in left_lines]
247247
explanation += ["-" + line for line in right_lines]
248248

@@ -296,7 +296,7 @@ def _compare_eq_sequence(
296296
left: Sequence[Any], right: Sequence[Any], verbose: int = 0
297297
) -> List[str]:
298298
comparing_bytes = isinstance(left, bytes) and isinstance(right, bytes)
299-
explanation = [] # type: List[str]
299+
explanation: List[str] = []
300300
len_left = len(left)
301301
len_right = len(right)
302302
for i in range(min(len_left, len_right)):
@@ -365,7 +365,7 @@ def _compare_eq_set(
365365
def _compare_eq_dict(
366366
left: Mapping[Any, Any], right: Mapping[Any, Any], verbose: int = 0
367367
) -> List[str]:
368-
explanation = [] # type: List[str]
368+
explanation: List[str] = []
369369
set_left = set(left)
370370
set_right = set(right)
371371
common = set_left.intersection(set_right)

0 commit comments

Comments
 (0)