Skip to content

Commit 6fc55b9

Browse files
committed
Type annotate CallSpec2
1 parent 56a5dbe commit 6fc55b9

File tree

1 file changed

+43
-16
lines changed

1 file changed

+43
-16
lines changed

src/_pytest/python.py

+43-16
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
import inspect
55
import os
66
import sys
7+
import typing
78
import warnings
89
from collections import Counter
910
from collections import defaultdict
1011
from collections.abc import Sequence
1112
from functools import partial
1213
from typing import Dict
14+
from typing import Iterable
1315
from typing import List
16+
from typing import Mapping
1417
from typing import Optional
1518
from typing import Tuple
1619
from typing import Union
@@ -35,19 +38,24 @@
3538
from _pytest.compat import safe_getattr
3639
from _pytest.compat import safe_isclass
3740
from _pytest.compat import STRING_TYPES
41+
from _pytest.compat import TYPE_CHECKING
3842
from _pytest.config import hookimpl
3943
from _pytest.deprecated import FUNCARGNAMES
4044
from _pytest.mark import MARK_GEN
4145
from _pytest.mark import ParameterSet
4246
from _pytest.mark.structures import get_unpacked_marks
4347
from _pytest.mark.structures import Mark
48+
from _pytest.mark.structures import MarkDecorator
4449
from _pytest.mark.structures import normalize_mark_list
4550
from _pytest.outcomes import fail
4651
from _pytest.outcomes import skip
4752
from _pytest.pathlib import parts
4853
from _pytest.warning_types import PytestCollectionWarning
4954
from _pytest.warning_types import PytestUnhandledCoroutineWarning
5055

56+
if TYPE_CHECKING:
57+
from typing_extensions import Literal
58+
5159

5260
def pyobj_property(name):
5361
def get(self):
@@ -784,16 +792,17 @@ def hasnew(obj):
784792

785793

786794
class CallSpec2:
787-
def __init__(self, metafunc):
795+
def __init__(self, metafunc: "Metafunc") -> None:
788796
self.metafunc = metafunc
789-
self.funcargs = {}
790-
self._idlist = []
791-
self.params = {}
792-
self._arg2scopenum = {} # used for sorting parametrized resources
793-
self.marks = []
794-
self.indices = {}
795-
796-
def copy(self):
797+
self.funcargs = {} # type: Dict[str, object]
798+
self._idlist = [] # type: List[str]
799+
self.params = {} # type: Dict[str, object]
800+
# Used for sorting parametrized resources.
801+
self._arg2scopenum = {} # type: Dict[str, int]
802+
self.marks = [] # type: List[Mark]
803+
self.indices = {} # type: Dict[str, int]
804+
805+
def copy(self) -> "CallSpec2":
797806
cs = CallSpec2(self.metafunc)
798807
cs.funcargs.update(self.funcargs)
799808
cs.params.update(self.params)
@@ -803,25 +812,39 @@ def copy(self):
803812
cs._idlist = list(self._idlist)
804813
return cs
805814

806-
def _checkargnotcontained(self, arg):
815+
def _checkargnotcontained(self, arg: str) -> None:
807816
if arg in self.params or arg in self.funcargs:
808817
raise ValueError("duplicate {!r}".format(arg))
809818

810-
def getparam(self, name):
819+
def getparam(self, name: str) -> object:
811820
try:
812821
return self.params[name]
813822
except KeyError:
814823
raise ValueError(name)
815824

816825
@property
817-
def id(self):
826+
def id(self) -> str:
818827
return "-".join(map(str, self._idlist))
819828

820-
def setmulti2(self, valtypes, argnames, valset, id, marks, scopenum, param_index):
829+
def setmulti2(
830+
self,
831+
valtypes: "Mapping[str, Literal['params', 'funcargs']]",
832+
argnames: typing.Sequence[str],
833+
valset: Iterable[object],
834+
id: str,
835+
marks: Iterable[Union[Mark, MarkDecorator]],
836+
scopenum: int,
837+
param_index: int,
838+
) -> None:
821839
for arg, val in zip(argnames, valset):
822840
self._checkargnotcontained(arg)
823841
valtype_for_arg = valtypes[arg]
824-
getattr(self, valtype_for_arg)[arg] = val
842+
if valtype_for_arg == "params":
843+
self.params[arg] = val
844+
elif valtype_for_arg == "funcargs":
845+
self.funcargs[arg] = val
846+
else:
847+
assert False, "Unhandled valtype for arg: {}".format(valtype_for_arg)
825848
self.indices[arg] = param_index
826849
self._arg2scopenum[arg] = scopenum
827850
self._idlist.append(id)
@@ -1042,7 +1065,9 @@ def _validate_ids(self, ids, parameters, func_name):
10421065
)
10431066
return new_ids
10441067

1045-
def _resolve_arg_value_types(self, argnames: List[str], indirect) -> Dict[str, str]:
1068+
def _resolve_arg_value_types(
1069+
self, argnames: List[str], indirect
1070+
) -> Dict[str, "Literal['params', 'funcargs']"]:
10461071
"""Resolves if each parametrized argument must be considered a parameter to a fixture or a "funcarg"
10471072
to the function, based on the ``indirect`` parameter of the parametrized() call.
10481073
@@ -1054,7 +1079,9 @@ def _resolve_arg_value_types(self, argnames: List[str], indirect) -> Dict[str, s
10541079
* "funcargs" if the argname should be a parameter to the parametrized test function.
10551080
"""
10561081
if isinstance(indirect, bool):
1057-
valtypes = dict.fromkeys(argnames, "params" if indirect else "funcargs")
1082+
valtypes = dict.fromkeys(
1083+
argnames, "params" if indirect else "funcargs"
1084+
) # type: Dict[str, Literal["params", "funcargs"]]
10581085
elif isinstance(indirect, Sequence):
10591086
valtypes = dict.fromkeys(argnames, "funcargs")
10601087
for arg in indirect:

0 commit comments

Comments
 (0)