diff --git a/stdlib/unittest/mock.pyi b/stdlib/unittest/mock.pyi index 0ed0701cc450..db1cc7d9bfc9 100644 --- a/stdlib/unittest/mock.pyi +++ b/stdlib/unittest/mock.pyi @@ -234,6 +234,8 @@ class _patch(Generic[_T]): def copy(self) -> _patch[_T]: ... @overload def __call__(self, func: _TT) -> _TT: ... + # If new==DEFAULT, this should add a MagicMock parameter to the function + # arguments. See the _patch_default_new class below for this functionality. @overload def __call__(self, func: Callable[_P, _R]) -> Callable[_P, _R]: ... if sys.version_info >= (3, 8): @@ -257,6 +259,22 @@ class _patch(Generic[_T]): def start(self) -> _T: ... def stop(self) -> None: ... +if sys.version_info >= (3, 8): + _Mock: TypeAlias = MagicMock | AsyncMock +else: + _Mock: TypeAlias = MagicMock + +# This class does not exist at runtime, it's a hack to make this work: +# @patch("foo") +# def bar(..., mock: MagicMock) -> None: ... +class _patch_default_new(_patch[_Mock]): + @overload + def __call__(self, func: _TT) -> _TT: ... + # Can't use the following as ParamSpec is only allowed as last parameter: + # def __call__(self, func: Callable[_P, _R]) -> Callable[Concatenate[_P, MagicMock], _R]: ... + @overload + def __call__(self, func: Callable[..., _R]) -> Callable[..., _R]: ... + class _patch_dict: in_dict: Any values: Any @@ -273,11 +291,8 @@ class _patch_dict: start: Any stop: Any -if sys.version_info >= (3, 8): - _Mock: TypeAlias = MagicMock | AsyncMock -else: - _Mock: TypeAlias = MagicMock - +# This class does not exist at runtime, it's a hack to add methods to the +# patch() function. class _patcher: TEST_PREFIX: str dict: type[_patch_dict] @@ -307,7 +322,7 @@ class _patcher: autospec: Any | None = ..., new_callable: Any | None = ..., **kwargs: Any, - ) -> _patch[_Mock]: ... + ) -> _patch_default_new: ... @overload @staticmethod def object( # type: ignore[misc] diff --git a/test_cases/stdlib/check_unittest.py b/test_cases/stdlib/check_unittest.py index 438250670559..3a9edafef051 100644 --- a/test_cases/stdlib/check_unittest.py +++ b/test_cases/stdlib/check_unittest.py @@ -5,7 +5,7 @@ from decimal import Decimal from fractions import Fraction from typing_extensions import assert_type -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch case = unittest.TestCase() @@ -94,13 +94,20 @@ def __gt__(self, other: Bacon) -> bool: ### -@patch("sys.exit", new=Mock()) -def f(i: int) -> str: +@patch("sys.exit") +def f_default_new(i: int, mock: MagicMock) -> str: + return "asdf" + + +@patch("sys.exit", new=42) +def f_explicit_new(i: int) -> str: return "asdf" -assert_type(f(1), str) -f("a") # type: ignore +assert_type(f_default_new(1), str) +f_default_new("a") # Not an error due to ParamSpec limitations +assert_type(f_explicit_new(1), str) +f_explicit_new("a") # type: ignore[arg-type] @patch("sys.exit", new=Mock())