diff --git a/docs/source/reference-testing.rst b/docs/source/reference-testing.rst index 76ecd4a2d4..89573bab2f 100644 --- a/docs/source/reference-testing.rst +++ b/docs/source/reference-testing.rst @@ -219,3 +219,16 @@ Testing checkpoints .. autofunction:: assert_no_checkpoints :with: + + +ExceptionGroup helpers +---------------------- + +.. autoclass:: RaisesGroup + :members: + +.. autoclass:: Matcher + :members: + +.. autoclass:: trio.testing._raises_group._ExceptionInfo + :members: diff --git a/newsfragments/2785.feature.rst b/newsfragments/2785.feature.rst new file mode 100644 index 0000000000..8dff767e4b --- /dev/null +++ b/newsfragments/2785.feature.rst @@ -0,0 +1,4 @@ +New helper classes: :class:`~.testing.RaisesGroup` and :class:`~.testing.Matcher`. + +In preparation for changing the default of ``strict_exception_groups`` to `True`, we're introducing a set of helper classes that can be used in place of `pytest.raises `_ in tests, to check for an expected `ExceptionGroup`. +These are provisional, and only planned to be supplied until there's a good solution in ``pytest``. See https://github.com/pytest-dev/pytest/issues/11538 diff --git a/src/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py index 71b95c8dcd..8c22528989 100644 --- a/src/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -19,7 +19,13 @@ from ... import _core from ..._threads import to_thread_run_sync from ..._timeouts import fail_after, sleep -from ...testing import Sequencer, assert_checkpoints, wait_all_tasks_blocked +from ...testing import ( + Matcher, + RaisesGroup, + Sequencer, + assert_checkpoints, + wait_all_tasks_blocked, +) from .._run import DEADLINE_HEAP_MIN_PRUNE_THRESHOLD from .tutil import ( check_sequence_matches, @@ -192,13 +198,8 @@ async def main() -> NoReturn: nursery.start_soon(crasher) raise KeyError - with pytest.raises(ExceptionGroup) as excinfo: + with RaisesGroup(ValueError, KeyError): _core.run(main) - print(excinfo.value) - assert {type(exc) for exc in excinfo.value.exceptions} == { - ValueError, - KeyError, - } def test_two_child_crashes() -> None: @@ -210,12 +211,8 @@ async def main() -> None: nursery.start_soon(crasher, KeyError) nursery.start_soon(crasher, ValueError) - with pytest.raises(ExceptionGroup) as excinfo: + with RaisesGroup(ValueError, KeyError): _core.run(main) - assert {type(exc) for exc in excinfo.value.exceptions} == { - ValueError, - KeyError, - } async def test_child_crash_wakes_parent() -> None: @@ -429,16 +426,18 @@ async def test_cancel_scope_exceptiongroup_filtering() -> None: async def crasher() -> NoReturn: raise KeyError - # check that the inner except is properly executed. - # alternative would be to have a `except BaseException` and an `else` - exception_group_caught_inner = False - # This is outside the outer scope, so all the Cancelled # exceptions should have been absorbed, leaving just a regular # KeyError from crasher() with pytest.raises(KeyError): # noqa: PT012 with _core.CancelScope() as outer: - try: + # Since the outer scope became cancelled before the + # nursery block exited, all cancellations inside the + # nursery block continue propagating to reach the + # outer scope. + with RaisesGroup( + _core.Cancelled, _core.Cancelled, _core.Cancelled, KeyError + ) as excinfo: async with _core.open_nursery() as nursery: # Two children that get cancelled by the nursery scope nursery.start_soon(sleep_forever) # t1 @@ -452,22 +451,9 @@ async def crasher() -> NoReturn: # And one that raises a different error nursery.start_soon(crasher) # t4 # and then our __aexit__ also receives an outer Cancelled - except BaseExceptionGroup as multi_exc: - exception_group_caught_inner = True - # Since the outer scope became cancelled before the - # nursery block exited, all cancellations inside the - # nursery block continue propagating to reach the - # outer scope. - # the noqa is for "Found assertion on exception `multi_exc` in `except` block" - assert len(multi_exc.exceptions) == 4 # noqa: PT017 - summary: dict[type, int] = {} - for exc in multi_exc.exceptions: - summary.setdefault(type(exc), 0) - summary[type(exc)] += 1 - assert summary == {_core.Cancelled: 3, KeyError: 1} - raise - - assert exception_group_caught_inner + # reraise the exception caught by RaisesGroup for the + # CancelScope to handle + raise excinfo.value async def test_precancelled_task() -> None: @@ -788,14 +774,22 @@ async def task2() -> None: RuntimeError, match="which had already been exited" ) as exc_info: await nursery_mgr.__aexit__(*sys.exc_info()) - assert type(exc_info.value.__context__) is ExceptionGroup - assert len(exc_info.value.__context__.exceptions) == 3 - cancelled_in_context = False - for exc in exc_info.value.__context__.exceptions: - assert isinstance(exc, RuntimeError) - assert "closed before the task exited" in str(exc) - cancelled_in_context |= isinstance(exc.__context__, _core.Cancelled) - assert cancelled_in_context # for the sleep_forever + + def no_context(exc: RuntimeError) -> bool: + return exc.__context__ is None + + msg = "closed before the task exited" + group = RaisesGroup( + Matcher(RuntimeError, match=msg, check=no_context), + Matcher(RuntimeError, match=msg, check=no_context), + # sleep_forever + Matcher( + RuntimeError, + match=msg, + check=lambda x: isinstance(x.__context__, _core.Cancelled), + ), + ) + assert group.matches(exc_info.value.__context__) # Trying to exit a cancel scope from an unrelated task raises an error # without affecting any state @@ -949,11 +943,7 @@ async def main() -> None: with pytest.raises(_core.TrioInternalError) as excinfo: _core.run(main) - me = excinfo.value.__cause__ - assert isinstance(me, ExceptionGroup) - assert len(me.exceptions) == 2 - for exc in me.exceptions: - assert isinstance(exc, (KeyError, ValueError)) + assert RaisesGroup(KeyError, ValueError).matches(excinfo.value.__cause__) def test_system_task_crash_plus_Cancelled() -> None: @@ -1210,12 +1200,11 @@ async def test_nursery_exception_chaining_doesnt_make_context_loops() -> None: async def crasher() -> NoReturn: raise KeyError - with pytest.raises(ExceptionGroup) as excinfo: # noqa: PT012 + # the ExceptionGroup should not have the KeyError or ValueError as context + with RaisesGroup(ValueError, KeyError, check=lambda x: x.__context__ is None): async with _core.open_nursery() as nursery: nursery.start_soon(crasher) raise ValueError - # the ExceptionGroup should not have the KeyError or ValueError as context - assert excinfo.value.__context__ is None def test_TrioToken_identity() -> None: @@ -1980,11 +1969,10 @@ async def test_nursery_stop_iteration() -> None: async def fail() -> NoReturn: raise ValueError - with pytest.raises(ExceptionGroup) as excinfo: # noqa: PT012 + with RaisesGroup(StopIteration, ValueError): async with _core.open_nursery() as nursery: nursery.start_soon(fail) raise StopIteration - assert tuple(map(type, excinfo.value.exceptions)) == (StopIteration, ValueError) async def test_nursery_stop_async_iteration() -> None: @@ -2033,7 +2021,18 @@ async def test_traceback_frame_removal() -> None: async def my_child_task() -> NoReturn: raise KeyError() - with pytest.raises(ExceptionGroup) as excinfo: # noqa: PT012 + def check_traceback(exc: KeyError) -> bool: + # The top frame in the exception traceback should be inside the child + # task, not trio/contextvars internals. And there's only one frame + # inside the child task, so this will also detect if our frame-removal + # is too eager. + tb = exc.__traceback__ + assert tb is not None + return tb.tb_frame.f_code is my_child_task.__code__ + + expected_exception = Matcher(KeyError, check=check_traceback) + + with RaisesGroup(expected_exception, expected_exception): # Trick: For now cancel/nursery scopes still leave a bunch of tb gunk # behind. But if there's an ExceptionGroup, they leave it on the group, # which lets us get a clean look at the KeyError itself. Someday I @@ -2042,15 +2041,6 @@ async def my_child_task() -> NoReturn: async with _core.open_nursery() as nursery: nursery.start_soon(my_child_task) nursery.start_soon(my_child_task) - first_exc = excinfo.value.exceptions[0] - assert isinstance(first_exc, KeyError) - # The top frame in the exception traceback should be inside the child - # task, not trio/contextvars internals. And there's only one frame - # inside the child task, so this will also detect if our frame-removal - # is too eager. - tb = first_exc.__traceback__ - assert tb is not None - assert tb.tb_frame.f_code is my_child_task.__code__ def test_contextvar_support() -> None: @@ -2529,15 +2519,12 @@ async def main() -> NoReturn: async with _core.open_nursery(): raise Exception("foo") - with pytest.raises( - ExceptionGroup, match="^Exceptions from Trio nursery \\(1 sub-exception\\)$" - ) as exc: + with RaisesGroup( + Matcher(Exception, match="^foo$"), + match="^Exceptions from Trio nursery \\(1 sub-exception\\)$", + ): _core.run(main, strict_exception_groups=True) - assert len(exc.value.exceptions) == 1 - assert type(exc.value.exceptions[0]) is Exception - assert exc.value.exceptions[0].args == ("foo",) - def test_run_strict_exception_groups_nursery_override() -> None: """ @@ -2555,14 +2542,10 @@ async def main() -> NoReturn: async def test_nursery_strict_exception_groups() -> None: """Test that strict exception groups can be enabled on a per-nursery basis.""" - with pytest.raises(ExceptionGroup) as exc: + with RaisesGroup(Matcher(Exception, match="^foo$")): async with _core.open_nursery(strict_exception_groups=True): raise Exception("foo") - assert len(exc.value.exceptions) == 1 - assert type(exc.value.exceptions[0]) is Exception - assert exc.value.exceptions[0].args == ("foo",) - async def test_nursery_loose_exception_groups() -> None: """Test that loose exception groups can be enabled on a per-nursery basis.""" @@ -2573,20 +2556,18 @@ async def raise_error() -> NoReturn: with pytest.raises(RuntimeError, match="^test error$"): async with _core.open_nursery(strict_exception_groups=False) as nursery: nursery.start_soon(raise_error) - - with pytest.raises( # noqa: PT012 # multiple statements - ExceptionGroup, match="^Exceptions from Trio nursery \\(2 sub-exceptions\\)$" - ) as exc: + m = Matcher(RuntimeError, match="^test error$") + + with RaisesGroup( + m, + m, + match="Exceptions from Trio nursery \\(2 sub-exceptions\\)", + check=lambda x: x.__notes__ == [_core._run.NONSTRICT_EXCEPTIONGROUP_NOTE], + ): async with _core.open_nursery(strict_exception_groups=False) as nursery: nursery.start_soon(raise_error) nursery.start_soon(raise_error) - assert exc.value.__notes__ == [_core._run.NONSTRICT_EXCEPTIONGROUP_NOTE] - assert len(exc.value.exceptions) == 2 - for subexc in exc.value.exceptions: - assert type(subexc) is RuntimeError - assert subexc.args == ("test error",) - async def test_nursery_collapse_strict() -> None: """ @@ -2597,7 +2578,7 @@ async def test_nursery_collapse_strict() -> None: async def raise_error() -> NoReturn: raise RuntimeError("test error") - with pytest.raises(ExceptionGroup) as exc: # noqa: PT012 + with RaisesGroup(RuntimeError, RaisesGroup(RuntimeError)): async with _core.open_nursery() as nursery: nursery.start_soon(sleep_forever) nursery.start_soon(raise_error) @@ -2606,13 +2587,6 @@ async def raise_error() -> NoReturn: nursery2.start_soon(raise_error) nursery.cancel_scope.cancel() - exceptions = exc.value.exceptions - assert len(exceptions) == 2 - assert isinstance(exceptions[0], RuntimeError) - assert isinstance(exceptions[1], ExceptionGroup) - assert len(exceptions[1].exceptions) == 1 - assert isinstance(exceptions[1].exceptions[0], RuntimeError) - async def test_nursery_collapse_loose() -> None: """ @@ -2623,7 +2597,7 @@ async def test_nursery_collapse_loose() -> None: async def raise_error() -> NoReturn: raise RuntimeError("test error") - with pytest.raises(ExceptionGroup) as exc: # noqa: PT012 + with RaisesGroup(RuntimeError, RuntimeError): async with _core.open_nursery() as nursery: nursery.start_soon(sleep_forever) nursery.start_soon(raise_error) @@ -2632,11 +2606,6 @@ async def raise_error() -> NoReturn: nursery2.start_soon(raise_error) nursery.cancel_scope.cancel() - exceptions = exc.value.exceptions - assert len(exceptions) == 2 - assert isinstance(exceptions[0], RuntimeError) - assert isinstance(exceptions[1], RuntimeError) - async def test_cancel_scope_no_cancellederror() -> None: """ @@ -2644,7 +2613,7 @@ async def test_cancel_scope_no_cancellederror() -> None: a Cancelled exception, it will NOT set the ``cancelled_caught`` flag. """ - with pytest.raises(ExceptionGroup): # noqa: PT012 + with RaisesGroup(RuntimeError, RuntimeError, match="test"): with _core.CancelScope() as scope: scope.cancel() raise ExceptionGroup("test", [RuntimeError(), RuntimeError()]) diff --git a/src/trio/_tests/test_exports.py b/src/trio/_tests/test_exports.py index 4138df0e5e..7418f11da8 100644 --- a/src/trio/_tests/test_exports.py +++ b/src/trio/_tests/test_exports.py @@ -317,6 +317,10 @@ def lookup_symbol(symbol: str) -> dict[str, str]: if module_name == "trio.socket" and class_name in dir(stdlib_socket): continue + # ignore class that does dirty tricks + if class_ is trio.testing.RaisesGroup: + continue + # dir() and inspect.getmembers doesn't display properties from the metaclass # also ignore some dunder methods that tend to differ but are of no consequence ignore_names = set(dir(type(class_))) | { @@ -429,7 +433,9 @@ def lookup_symbol(symbol: str) -> dict[str, str]: if tool == "mypy" and class_ == trio.Nursery: extra.remove("cancel_scope") - # TODO: I'm not so sure about these, but should still be looked at. + # These are (mostly? solely?) *runtime* attributes, often set in + # __init__, which doesn't show up with dir() or inspect.getmembers, + # but we get them in the way we query mypy & jedi EXTRAS = { trio.DTLSChannel: {"peer_address", "endpoint"}, trio.DTLSEndpoint: {"socket", "incoming_packets_buffer"}, @@ -444,6 +450,11 @@ def lookup_symbol(symbol: str) -> dict[str, str]: "send_all_hook", "wait_send_all_might_not_block_hook", }, + trio.testing.Matcher: { + "exception_type", + "match", + "check", + }, } if tool == "mypy" and class_ in EXTRAS: before = len(extra) diff --git a/src/trio/_tests/test_highlevel_open_tcp_stream.py b/src/trio/_tests/test_highlevel_open_tcp_stream.py index b62ecf032f..5f738ba4cc 100644 --- a/src/trio/_tests/test_highlevel_open_tcp_stream.py +++ b/src/trio/_tests/test_highlevel_open_tcp_stream.py @@ -16,6 +16,7 @@ reorder_for_rfc_6555_section_5_4, ) from trio.socket import AF_INET, AF_INET6, IPPROTO_TCP, SOCK_STREAM, SocketType +from trio.testing import Matcher, RaisesGroup if TYPE_CHECKING: from trio.testing import MockClock @@ -530,8 +531,12 @@ async def test_all_fail(autojump_clock: MockClock) -> None: expect_error=OSError, ) assert isinstance(exc, OSError) - assert isinstance(exc.__cause__, BaseExceptionGroup) - assert len(exc.__cause__.exceptions) == 4 + + subexceptions = (Matcher(OSError, match="^sorry$"),) * 4 + assert RaisesGroup( + *subexceptions, match="all attempts to connect to test.example.com:80 failed" + ).matches(exc.__cause__) + assert trio.current_time() == (0.1 + 0.2 + 10) assert scenario.connect_times == { "1.1.1.1": 0, diff --git a/src/trio/_tests/test_testing_raisesgroup.py b/src/trio/_tests/test_testing_raisesgroup.py new file mode 100644 index 0000000000..9b6b2a6fb6 --- /dev/null +++ b/src/trio/_tests/test_testing_raisesgroup.py @@ -0,0 +1,262 @@ +from __future__ import annotations + +import re +import sys +from types import TracebackType +from typing import Any + +import pytest + +import trio +from trio.testing import Matcher, RaisesGroup + +if sys.version_info < (3, 11): + from exceptiongroup import ExceptionGroup + + +def wrap_escape(s: str) -> str: + return "^" + re.escape(s) + "$" + + +def test_raises_group() -> None: + with pytest.raises( + ValueError, + match=wrap_escape( + f'Invalid argument "{TypeError()!r}" must be exception type, Matcher, or RaisesGroup.' + ), + ): + RaisesGroup(TypeError()) + + with RaisesGroup(ValueError): + raise ExceptionGroup("foo", (ValueError(),)) + + with RaisesGroup(SyntaxError): + with RaisesGroup(ValueError): + raise ExceptionGroup("foo", (SyntaxError(),)) + + # multiple exceptions + with RaisesGroup(ValueError, SyntaxError): + raise ExceptionGroup("foo", (ValueError(), SyntaxError())) + + # order doesn't matter + with RaisesGroup(SyntaxError, ValueError): + raise ExceptionGroup("foo", (ValueError(), SyntaxError())) + + # nested exceptions + with RaisesGroup(RaisesGroup(ValueError)): + raise ExceptionGroup("foo", (ExceptionGroup("bar", (ValueError(),)),)) + + with RaisesGroup( + SyntaxError, + RaisesGroup(ValueError), + RaisesGroup(RuntimeError), + ): + raise ExceptionGroup( + "foo", + ( + SyntaxError(), + ExceptionGroup("bar", (ValueError(),)), + ExceptionGroup("", (RuntimeError(),)), + ), + ) + + # will error if there's excess exceptions + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError): + raise ExceptionGroup("", (ValueError(), ValueError())) + + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError): + raise ExceptionGroup("", (RuntimeError(), ValueError())) + + # will error if there's missing exceptions + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError, ValueError): + raise ExceptionGroup("", (ValueError(),)) + + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError, SyntaxError): + raise ExceptionGroup("", (ValueError(),)) + + # loose semantics, as with expect* + with RaisesGroup(ValueError, strict=False): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + + # mixed loose is possible if you want it to be at least N deep + with RaisesGroup(RaisesGroup(ValueError, strict=False)): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + with RaisesGroup(RaisesGroup(ValueError, strict=False)): + raise ExceptionGroup( + "", (ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)),) + ) + with pytest.raises(ExceptionGroup): + with RaisesGroup(RaisesGroup(ValueError, strict=False)): + raise ExceptionGroup("", (ValueError(),)) + + # but not the other way around + with pytest.raises( + ValueError, + match="^You cannot specify a nested structure inside a RaisesGroup with strict=False$", + ): + RaisesGroup(RaisesGroup(ValueError), strict=False) + + # currently not fully identical in behaviour to expect*, which would also catch an unwrapped exception + with pytest.raises(ValueError, match="^value error text$"): + with RaisesGroup(ValueError, strict=False): + raise ValueError("value error text") + + +def test_match() -> None: + # supports match string + with RaisesGroup(ValueError, match="bar"): + raise ExceptionGroup("bar", (ValueError(),)) + + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError, match="foo"): + raise ExceptionGroup("bar", (ValueError(),)) + + +def test_check() -> None: + exc = ExceptionGroup("", (ValueError(),)) + with RaisesGroup(ValueError, check=lambda x: x is exc): + raise exc + with pytest.raises(ExceptionGroup): + with RaisesGroup(ValueError, check=lambda x: x is exc): + raise ExceptionGroup("", (ValueError(),)) + + +def test_RaisesGroup_matches() -> None: + rg = RaisesGroup(ValueError) + assert not rg.matches(None) + assert not rg.matches(ValueError()) + assert rg.matches(ExceptionGroup("", (ValueError(),))) + + +def test_message() -> None: + def check_message(message: str, body: RaisesGroup[Any]) -> None: + with pytest.raises( + AssertionError, + match=f"^DID NOT RAISE any exception, expected {re.escape(message)}$", + ): + with body: + ... + + # basic + check_message("ExceptionGroup(ValueError)", RaisesGroup(ValueError)) + # multiple exceptions + check_message( + "ExceptionGroup(ValueError, ValueError)", RaisesGroup(ValueError, ValueError) + ) + # nested + check_message( + "ExceptionGroup(ExceptionGroup(ValueError))", + RaisesGroup(RaisesGroup(ValueError)), + ) + + # Matcher + check_message( + "ExceptionGroup(Matcher(ValueError, match='my_str'))", + RaisesGroup(Matcher(ValueError, "my_str")), + ) + check_message( + "ExceptionGroup(Matcher(match='my_str'))", + RaisesGroup(Matcher(match="my_str")), + ) + + # BaseExceptionGroup + check_message( + "BaseExceptionGroup(KeyboardInterrupt)", RaisesGroup(KeyboardInterrupt) + ) + # BaseExceptionGroup with type inside Matcher + check_message( + "BaseExceptionGroup(Matcher(KeyboardInterrupt))", + RaisesGroup(Matcher(KeyboardInterrupt)), + ) + # Base-ness transfers to parent containers + check_message( + "BaseExceptionGroup(BaseExceptionGroup(KeyboardInterrupt))", + RaisesGroup(RaisesGroup(KeyboardInterrupt)), + ) + # but not to child containers + check_message( + "BaseExceptionGroup(BaseExceptionGroup(KeyboardInterrupt), ExceptionGroup(ValueError))", + RaisesGroup(RaisesGroup(KeyboardInterrupt), RaisesGroup(ValueError)), + ) + + +def test_matcher() -> None: + with pytest.raises( + ValueError, match="^You must specify at least one parameter to match on.$" + ): + Matcher() # type: ignore[call-overload] + with pytest.raises( + ValueError, + match=f"^exception_type {re.escape(repr(object))} must be a subclass of BaseException$", + ): + Matcher(object) # type: ignore[type-var] + + with RaisesGroup(Matcher(ValueError)): + raise ExceptionGroup("", (ValueError(),)) + with pytest.raises(ExceptionGroup): + with RaisesGroup(Matcher(TypeError)): + raise ExceptionGroup("", (ValueError(),)) + + +def test_matcher_match() -> None: + with RaisesGroup(Matcher(ValueError, "foo")): + raise ExceptionGroup("", (ValueError("foo"),)) + with pytest.raises(ExceptionGroup): + with RaisesGroup(Matcher(ValueError, "foo")): + raise ExceptionGroup("", (ValueError("bar"),)) + + # Can be used without specifying the type + with RaisesGroup(Matcher(match="foo")): + raise ExceptionGroup("", (ValueError("foo"),)) + with pytest.raises(ExceptionGroup): + with RaisesGroup(Matcher(match="foo")): + raise ExceptionGroup("", (ValueError("bar"),)) + + +def test_Matcher_check() -> None: + def check_oserror_and_errno_is_5(e: BaseException) -> bool: + return isinstance(e, OSError) and e.errno == 5 + + with RaisesGroup(Matcher(check=check_oserror_and_errno_is_5)): + raise ExceptionGroup("", (OSError(5, ""),)) + + # specifying exception_type narrows the parameter type to the callable + def check_errno_is_5(e: OSError) -> bool: + return e.errno == 5 + + with RaisesGroup(Matcher(OSError, check=check_errno_is_5)): + raise ExceptionGroup("", (OSError(5, ""),)) + + with pytest.raises(ExceptionGroup): + with RaisesGroup(Matcher(OSError, check=check_errno_is_5)): + raise ExceptionGroup("", (OSError(6, ""),)) + + +def test_matcher_tostring() -> None: + assert str(Matcher(ValueError)) == "Matcher(ValueError)" + assert str(Matcher(match="[a-z]")) == "Matcher(match='[a-z]')" + pattern_no_flags = re.compile("noflag", 0) + assert str(Matcher(match=pattern_no_flags)) == "Matcher(match='noflag')" + pattern_flags = re.compile("noflag", re.IGNORECASE) + assert str(Matcher(match=pattern_flags)) == f"Matcher(match={pattern_flags!r})" + assert ( + str(Matcher(ValueError, match="re", check=bool)) + == f"Matcher(ValueError, match='re', check={bool!r})" + ) + + +def test__ExceptionInfo(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr( + trio.testing._raises_group, + "ExceptionInfo", + trio.testing._raises_group._ExceptionInfo, + ) + with trio.testing.RaisesGroup(ValueError) as excinfo: + raise ExceptionGroup("", (ValueError("hello"),)) + assert excinfo.type is ExceptionGroup + assert excinfo.value.exceptions[0].args == ("hello",) + assert isinstance(excinfo.tb, TracebackType) diff --git a/src/trio/_tests/type_tests/raisesgroup.py b/src/trio/_tests/type_tests/raisesgroup.py new file mode 100644 index 0000000000..e00c20d1ba --- /dev/null +++ b/src/trio/_tests/type_tests/raisesgroup.py @@ -0,0 +1,135 @@ +"""The typing of RaisesGroup involves a lot of deception and lies, since AFAIK what we +actually want to achieve is ~impossible. This is because we specify what we expect with +instances of RaisesGroup and exception classes, but excinfo.value will be instances of +[Base]ExceptionGroup and instances of exceptions. So we need to "translate" from +RaisesGroup to ExceptionGroup. + +The way it currently works is that RaisesGroup[E] corresponds to +ExceptionInfo[BaseExceptionGroup[E]], so the top-level group will be correct. But +RaisesGroup[RaisesGroup[ValueError]] will become +ExceptionInfo[BaseExceptionGroup[RaisesGroup[ValueError]]]. To get around that we specify +RaisesGroup as a subclass of BaseExceptionGroup during type checking - which should mean +that most static type checking for end users should be mostly correct. +""" +from __future__ import annotations + +import sys +from typing import Union + +from trio.testing import Matcher, RaisesGroup +from typing_extensions import assert_type + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup, ExceptionGroup + +# split into functions to isolate the different scopes + + +def check_inheritance_and_assignments() -> None: + # Check inheritance + _: BaseExceptionGroup[ValueError] = RaisesGroup(ValueError) + _ = RaisesGroup(RaisesGroup(ValueError)) # type: ignore + + a: BaseExceptionGroup[BaseExceptionGroup[ValueError]] + a = RaisesGroup(RaisesGroup(ValueError)) + # pyright-ignore due to bug in exceptiongroup + # https://github.com/agronholm/exceptiongroup/pull/101 + # once fixed we'll get errors for unnecessary-pyright-ignore and can clean up + a = BaseExceptionGroup( + "", (BaseExceptionGroup("", (ValueError(),)),) # pyright: ignore + ) + assert a + + +def check_basic_contextmanager() -> None: + # One level of Group is correctly translated - except it's a BaseExceptionGroup + # instead of an ExceptionGroup. + with RaisesGroup(ValueError) as e: + raise ExceptionGroup("foo", (ValueError(),)) + assert_type(e.value, BaseExceptionGroup[ValueError]) + + +def check_basic_matches() -> None: + # check that matches gets rid of the naked ValueError in the union + exc: ExceptionGroup[ValueError] | ValueError = ExceptionGroup("", (ValueError(),)) + if RaisesGroup(ValueError).matches(exc): + assert_type(exc, BaseExceptionGroup[ValueError]) + + +def check_matches_with_different_exception_type() -> None: + # This should probably raise some type error somewhere, since + # ValueError != KeyboardInterrupt + e: BaseExceptionGroup[KeyboardInterrupt] = BaseExceptionGroup( + "", (KeyboardInterrupt(),) + ) + if RaisesGroup(ValueError).matches(e): + assert_type(e, BaseExceptionGroup[ValueError]) + + +def check_matcher_init() -> None: + def check_exc(exc: BaseException) -> bool: + return isinstance(exc, ValueError) + + def check_filenotfound(exc: FileNotFoundError) -> bool: + return not exc.filename.endswith(".tmp") + + # Check various combinations of constructor signatures. + # At least 1 arg must be provided. If exception_type is provided, that narrows + # check's argument. + Matcher() # type: ignore + Matcher(ValueError) + Matcher(ValueError, "regex") + Matcher(ValueError, "regex", check_exc) + Matcher(exception_type=ValueError) + Matcher(match="regex") + Matcher(check=check_exc) + Matcher(check=check_filenotfound) # type: ignore + Matcher(ValueError, match="regex") + Matcher(FileNotFoundError, check=check_filenotfound) + Matcher(match="regex", check=check_exc) + Matcher(FileNotFoundError, match="regex", check=check_filenotfound) + + +def check_matcher_transparent() -> None: + with RaisesGroup(Matcher(ValueError)) as e: + ... + _: BaseExceptionGroup[ValueError] = e.value + assert_type(e.value, BaseExceptionGroup[ValueError]) + + +def check_nested_raisesgroups_contextmanager() -> None: + with RaisesGroup(RaisesGroup(ValueError)) as excinfo: + raise ExceptionGroup("foo", (ValueError(),)) + + # thanks to inheritance this assignment works + _: BaseExceptionGroup[BaseExceptionGroup[ValueError]] = excinfo.value + # and it can mostly be treated like an exceptiongroup + print(excinfo.value.exceptions[0].exceptions[0]) + + # but assert_type reveals the lies + print(type(excinfo.value)) # would print "ExceptionGroup" + # typing says it's a BaseExceptionGroup + assert_type( + excinfo.value, + BaseExceptionGroup[RaisesGroup[ValueError]], + ) + + print(type(excinfo.value.exceptions[0])) # would print "ExceptionGroup" + # but type checkers are utterly confused + assert_type( + excinfo.value.exceptions[0], + Union[RaisesGroup[ValueError], BaseExceptionGroup[RaisesGroup[ValueError]]], + ) + + +def check_nested_raisesgroups_matches() -> None: + """Check nested RaisesGroups with .matches""" + # pyright-ignore due to bug in exceptiongroup + # https://github.com/agronholm/exceptiongroup/pull/101 + # once fixed we'll get errors for unnecessary-pyright-ignore and can clean up + exc: ExceptionGroup[ExceptionGroup[ValueError]] = ExceptionGroup( + "", (ExceptionGroup("", (ValueError(),)),) # pyright: ignore + ) + # has the same problems as check_nested_raisesgroups_contextmanager + if RaisesGroup(RaisesGroup(ValueError)).matches(exc): + assert_type(exc, BaseExceptionGroup[RaisesGroup[ValueError]]) diff --git a/src/trio/testing/__init__.py b/src/trio/testing/__init__.py index fa683e1145..f5dc97f0cd 100644 --- a/src/trio/testing/__init__.py +++ b/src/trio/testing/__init__.py @@ -24,6 +24,7 @@ memory_stream_pump as memory_stream_pump, ) from ._network import open_stream_to_socket_listener as open_stream_to_socket_listener +from ._raises_group import Matcher as Matcher, RaisesGroup as RaisesGroup from ._sequencer import Sequencer as Sequencer from ._trio_test import trio_test as trio_test diff --git a/src/trio/testing/_raises_group.py b/src/trio/testing/_raises_group.py new file mode 100644 index 0000000000..516f71f375 --- /dev/null +++ b/src/trio/testing/_raises_group.py @@ -0,0 +1,461 @@ +from __future__ import annotations + +import re +import sys +from typing import ( + TYPE_CHECKING, + Callable, + ContextManager, + Generic, + Iterable, + Pattern, + TypeVar, + cast, + overload, +) + +from trio._util import final + +if TYPE_CHECKING: + import builtins + + # sphinx will *only* work if we use types.TracebackType, and import + # *inside* TYPE_CHECKING. No other combination works..... + import types + + from _pytest._code.code import ExceptionChainRepr, ReprExceptionInfo, Traceback + from typing_extensions import TypeGuard + +if sys.version_info < (3, 11): + from exceptiongroup import BaseExceptionGroup + +E = TypeVar("E", bound=BaseException) + + +@final +class _ExceptionInfo(Generic[E]): + """Minimal re-implementation of pytest.ExceptionInfo, only used if pytest is not available. Supports a subset of its features necessary for functionality of :class:`trio.testing.RaisesGroup` and :class:`trio.testing.Matcher`.""" + + _excinfo: tuple[type[E], E, types.TracebackType] | None + + def __init__(self, excinfo: tuple[type[E], E, types.TracebackType] | None): + self._excinfo = excinfo + + def fill_unfilled(self, exc_info: tuple[type[E], E, types.TracebackType]) -> None: + """Fill an unfilled ExceptionInfo created with ``for_later()``.""" + assert self._excinfo is None, "ExceptionInfo was already filled" + self._excinfo = exc_info + + @classmethod + def for_later(cls) -> _ExceptionInfo[E]: + """Return an unfilled ExceptionInfo.""" + return cls(None) + + @property + def type(self) -> type[E]: + """The exception class.""" + assert ( + self._excinfo is not None + ), ".type can only be used after the context manager exits" + return self._excinfo[0] + + @property + def value(self) -> E: + """The exception value.""" + assert ( + self._excinfo is not None + ), ".value can only be used after the context manager exits" + return self._excinfo[1] + + @property + def tb(self) -> types.TracebackType: + """The exception raw traceback.""" + assert ( + self._excinfo is not None + ), ".tb can only be used after the context manager exits" + return self._excinfo[2] + + def exconly(self, tryshort: bool = False) -> str: + raise NotImplementedError( + "This is a helper method only available if you use RaisesGroup with the pytest package installed" + ) + + def errisinstance( + self, + exc: builtins.type[BaseException] | tuple[builtins.type[BaseException], ...], + ) -> bool: + raise NotImplementedError( + "This is a helper method only available if you use RaisesGroup with the pytest package installed" + ) + + def getrepr( + self, + showlocals: bool = False, + style: str = "long", + abspath: bool = False, + tbfilter: bool | Callable[[_ExceptionInfo[BaseException]], Traceback] = True, + funcargs: bool = False, + truncate_locals: bool = True, + chain: bool = True, + ) -> ReprExceptionInfo | ExceptionChainRepr: + raise NotImplementedError( + "This is a helper method only available if you use RaisesGroup with the pytest package installed" + ) + + +# Type checkers are not able to do conditional types depending on installed packages, so +# we've added signatures for all helpers to _ExceptionInfo, and then always use that. +# If this ends up leading to problems, we can resort to always using _ExceptionInfo and +# users that want to use getrepr/errisinstance/exconly can write helpers on their own, or +# we reimplement them ourselves...or get this merged in upstream pytest. +if TYPE_CHECKING: + ExceptionInfo = _ExceptionInfo + +else: + try: + from pytest import ExceptionInfo # noqa: PT013 + except ImportError: # pragma: no cover + ExceptionInfo = _ExceptionInfo + + +# copied from pytest.ExceptionInfo +def _stringify_exception(exc: BaseException) -> str: + return "\n".join( + [ + str(exc), + *getattr(exc, "__notes__", []), + ] + ) + + +# String patterns default to including the unicode flag. +_regex_no_flags = re.compile("").flags + + +@final +class Matcher(Generic[E]): + """Helper class to be used together with RaisesGroups when you want to specify requirements on sub-exceptions. Only specifying the type is redundant, and it's also unnecessary when the type is a nested `RaisesGroup` since it supports the same arguments. + The type is checked with `isinstance`, and does not need to be an exact match. If that is wanted you can use the ``check`` parameter. + :meth:`trio.testing.Matcher.matches` can also be used standalone to check individual exceptions. + + Examples:: + + with RaisesGroups(Matcher(ValueError, match="string")) + ... + with RaisesGroups(Matcher(check=lambda x: x.args == (3, "hello"))): + ... + with RaisesGroups(Matcher(check=lambda x: type(x) is ValueError)): + ... + + """ + + # At least one of the three parameters must be passed. + @overload + def __init__( + self: Matcher[E], + exception_type: type[E], + match: str | Pattern[str] = ..., + check: Callable[[E], bool] = ..., + ): + ... + + @overload + def __init__( + self: Matcher[BaseException], # Give E a value. + *, + match: str | Pattern[str], + # If exception_type is not provided, check() must do any typechecks itself. + check: Callable[[BaseException], bool] = ..., + ): + ... + + @overload + def __init__(self, *, check: Callable[[BaseException], bool]): + ... + + def __init__( + self, + exception_type: type[E] | None = None, + match: str | Pattern[str] | None = None, + check: Callable[[E], bool] | None = None, + ): + if exception_type is None and match is None and check is None: + raise ValueError("You must specify at least one parameter to match on.") + if exception_type is not None and not issubclass(exception_type, BaseException): + raise ValueError( + f"exception_type {exception_type} must be a subclass of BaseException" + ) + self.exception_type = exception_type + self.match: Pattern[str] | None + if isinstance(match, str): + self.match = re.compile(match) + else: + self.match = match + self.check = check + + def matches(self, exception: BaseException) -> TypeGuard[E]: + """Check if an exception matches the requirements of this Matcher. + + Examples:: + + assert Matcher(ValueError).matches(my_exception): + # is equivalent to + assert isinstance(my_exception, ValueError) + + # this can be useful when checking e.g. the ``__cause__`` of an exception. + with pytest.raises(ValueError) as excinfo: + ... + assert Matcher(SyntaxError, match="foo").matches(excinfo.value.__cause__) + # above line is equivalent to + assert isinstance(excinfo.value.__cause__, SyntaxError) + assert re.search("foo", str(excinfo.value.__cause__) + + """ + if self.exception_type is not None and not isinstance( + exception, self.exception_type + ): + return False + if self.match is not None and not re.search( + self.match, _stringify_exception(exception) + ): + return False + # If exception_type is None check() accepts BaseException. + # If non-none, we have done an isinstance check above. + if self.check is not None and not self.check(cast(E, exception)): + return False + return True + + def __str__(self) -> str: + reqs = [] + if self.exception_type is not None: + reqs.append(self.exception_type.__name__) + if (match := self.match) is not None: + # If no flags were specified, discard the redundant re.compile() here. + reqs.append( + f"match={match.pattern if match.flags == _regex_no_flags else match!r}" + ) + if self.check is not None: + reqs.append(f"check={self.check!r}") + return f'Matcher({", ".join(reqs)})' + + +# typing this has been somewhat of a nightmare, with the primary difficulty making +# the return type of __enter__ correct. Ideally it would function like this +# with RaisesGroup(RaisesGroup(ValueError)) as excinfo: +# ... +# assert_type(excinfo.value, ExceptionGroup[ExceptionGroup[ValueError]]) +# in addition to all the simple cases, but getting all the way to the above seems maybe +# impossible. The type being RaisesGroup[RaisesGroup[ValueError]] is probably also fine, +# as long as I add fake properties corresponding to the properties of exceptiongroup. But +# I had trouble with it handling recursive cases properly. + +# Current solution settles on the above giving BaseExceptionGroup[RaisesGroup[ValueError]], and it not +# being a type error to do `with RaisesGroup(ValueError()): ...` - but that will error on runtime. + +# We lie to type checkers that we inherit, so excinfo.value and sub-exceptiongroups can be treated as ExceptionGroups +if TYPE_CHECKING: + SuperClass = BaseExceptionGroup +# Inheriting at runtime leads to a series of TypeErrors, so we do not want to do that. +else: + SuperClass = Generic + + +@final +class RaisesGroup(ContextManager[ExceptionInfo[BaseExceptionGroup[E]]], SuperClass[E]): + """Contextmanager for checking for an expected `ExceptionGroup`. + This works similar to ``pytest.raises``, and a version of it will hopefully be added upstream, after which this can be deprecated and removed. See https://github.com/pytest-dev/pytest/issues/11538 + + + This differs from :ref:`except* ` in that all specified exceptions must be present, *and no others*. It will similarly not catch exceptions *not* wrapped in an exceptiongroup. + If you don't care for the nesting level of the exceptions you can pass ``strict=False``. + It currently does not care about the order of the exceptions, so ``RaisesGroups(ValueError, TypeError)`` is equivalent to ``RaisesGroups(TypeError, ValueError)``. + + This class is not as polished as ``pytest.raises``, and is currently not as helpful in e.g. printing diffs when strings don't match, suggesting you use ``re.escape``, etc. + + Examples:: + + with RaisesGroups(ValueError): + raise ExceptionGroup("", (ValueError(),)) + with RaisesGroups(ValueError, ValueError, Matcher(TypeError, match="expected int")): + ... + with RaisesGroups(KeyboardInterrupt, match="hello", check=lambda x: type(x) is BaseExceptionGroup): + ... + with RaisesGroups(RaisesGroups(ValueError)): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + + with RaisesGroups(ValueError, strict=False): + raise ExceptionGroup("", (ExceptionGroup("", (ValueError(),)),)) + + + `RaisesGroup.matches` can also be used directly to check a standalone exception group. + + + This class is also not perfectly smart, e.g. this will likely fail currently:: + + with RaisesGroups(ValueError, Matcher(ValueError, match="hello")): + raise ExceptionGroup("", (ValueError("hello"), ValueError("goodbye"))) + + even though it generally does not care about the order of the exceptions in the group. + To avoid the above you should specify the first ValueError with a Matcher as well. + + It is also not typechecked perfectly, and that's likely not possible with the current approach. Most common usage should work without issue though. + """ + + # needed for pyright, since BaseExceptionGroup.__new__ takes two arguments + if TYPE_CHECKING: + + def __new__(cls, *args: object, **kwargs: object) -> RaisesGroup[E]: + ... + + def __init__( + self, + exception: type[E] | Matcher[E] | E, + *other_exceptions: type[E] | Matcher[E] | E, + strict: bool = True, + match: str | Pattern[str] | None = None, + check: Callable[[BaseExceptionGroup[E]], bool] | None = None, + ): + self.expected_exceptions: tuple[type[E] | Matcher[E] | E, ...] = ( + exception, + *other_exceptions, + ) + self.strict = strict + self.match_expr = match + self.check = check + self.is_baseexceptiongroup = False + + for exc in self.expected_exceptions: + if isinstance(exc, RaisesGroup): + if not strict: + raise ValueError( + "You cannot specify a nested structure inside a RaisesGroup with" + " strict=False" + ) + self.is_baseexceptiongroup |= exc.is_baseexceptiongroup + elif isinstance(exc, Matcher): + if exc.exception_type is None: + continue + # Matcher __init__ assures it's a subclass of BaseException + self.is_baseexceptiongroup |= not issubclass( + exc.exception_type, Exception + ) + elif isinstance(exc, type) and issubclass(exc, BaseException): + self.is_baseexceptiongroup |= not issubclass(exc, Exception) + else: + raise ValueError( + f'Invalid argument "{exc!r}" must be exception type, Matcher, or' + " RaisesGroup." + ) + + def __enter__(self) -> ExceptionInfo[BaseExceptionGroup[E]]: + self.excinfo: ExceptionInfo[BaseExceptionGroup[E]] = ExceptionInfo.for_later() + return self.excinfo + + def _unroll_exceptions( + self, exceptions: Iterable[BaseException] + ) -> Iterable[BaseException]: + """Used in non-strict mode.""" + res: list[BaseException] = [] + for exc in exceptions: + if isinstance(exc, BaseExceptionGroup): + res.extend(self._unroll_exceptions(exc.exceptions)) + + else: + res.append(exc) + return res + + def matches( + self, + exc_val: BaseException | None, + ) -> TypeGuard[BaseExceptionGroup[E]]: + """Check if an exception matches the requirements of this RaisesGroup. + + Example:: + + with pytest.raises(TypeError) as excinfo: + ... + assert RaisesGroups(ValueError).matches(excinfo.value.__cause__) + # the above line is equivalent to + myexc = excinfo.value.__cause + assert isinstance(myexc, BaseExceptionGroup) + assert len(myexc.exceptions) == 1 + assert isinstance(myexc.exceptions[0], ValueError) + """ + if exc_val is None: + return False + # TODO: print/raise why a match fails, in a way that works properly in nested cases + # maybe have a list of strings logging failed matches, that __exit__ can + # recursively step through and print on a failing match. + if not isinstance(exc_val, BaseExceptionGroup): + return False + if len(exc_val.exceptions) != len(self.expected_exceptions): + return False + if self.match_expr is not None and not re.search( + self.match_expr, _stringify_exception(exc_val) + ): + return False + if self.check is not None and not self.check(exc_val): + return False + remaining_exceptions = list(self.expected_exceptions) + actual_exceptions: Iterable[BaseException] = exc_val.exceptions + if not self.strict: + actual_exceptions = self._unroll_exceptions(actual_exceptions) + + # it should be possible to get RaisesGroup.matches typed so as not to + # need these type: ignores, but I'm not sure that's possible while also having it + # transparent for the end user. + for e in actual_exceptions: + for rem_e in remaining_exceptions: + if ( + (isinstance(rem_e, type) and isinstance(e, rem_e)) + or ( + isinstance(e, BaseExceptionGroup) + and isinstance(rem_e, RaisesGroup) + and rem_e.matches(e) + ) + or (isinstance(rem_e, Matcher) and rem_e.matches(e)) + ): + remaining_exceptions.remove(rem_e) # type: ignore[arg-type] + break + else: + return False + return True + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: types.TracebackType | None, + ) -> bool: + __tracebackhide__ = True + assert ( + exc_type is not None + ), f"DID NOT RAISE any exception, expected {self.expected_type()}" + assert ( + self.excinfo is not None + ), "Internal error - should have been constructed in __enter__" + + if not self.matches(exc_val): + return False + + # Cast to narrow the exception type now that it's verified. + exc_info = cast( + "tuple[type[BaseExceptionGroup[E]], BaseExceptionGroup[E], types.TracebackType]", + (exc_type, exc_val, exc_tb), + ) + self.excinfo.fill_unfilled(exc_info) + return True + + def expected_type(self) -> str: + subexcs = [] + for e in self.expected_exceptions: + if isinstance(e, Matcher): + subexcs.append(str(e)) + elif isinstance(e, RaisesGroup): + subexcs.append(e.expected_type()) + elif isinstance(e, type): + subexcs.append(e.__name__) + else: # pragma: no cover + raise AssertionError("unknown type") + group_type = "Base" if self.is_baseexceptiongroup else "" + return f"{group_type}ExceptionGroup({', '.join(subexcs)})"