From 92848462d136794af6d2d4a8db20a23a788091fa Mon Sep 17 00:00:00 2001 From: Pavel Kirilin Date: Sat, 1 Apr 2023 04:57:16 +0400 Subject: [PATCH] Updated inmemory broker. Signed-off-by: Pavel Kirilin --- taskiq/brokers/inmemory_broker.py | 9 +++- tests/brokers/test_inmemory.py | 87 +++++++++++++++++++++++++++++++ 2 files changed, 94 insertions(+), 2 deletions(-) create mode 100644 tests/brokers/test_inmemory.py diff --git a/taskiq/brokers/inmemory_broker.py b/taskiq/brokers/inmemory_broker.py index 823169d2..809728f0 100644 --- a/taskiq/brokers/inmemory_broker.py +++ b/taskiq/brokers/inmemory_broker.py @@ -1,6 +1,7 @@ +import asyncio import inspect from collections import OrderedDict -from typing import Any, AsyncGenerator, Callable, Optional, TypeVar, get_type_hints +from typing import Any, AsyncGenerator, Callable, Optional, Set, TypeVar, get_type_hints from taskiq_dependencies import DependencyGraph @@ -114,6 +115,7 @@ def __init__( # noqa: WPS211 log_collector_format=logs_format or WorkerArgs.log_collector_format, ), ) + self._running_tasks: "Set[asyncio.Task[Any]]" = set() async def kick(self, message: BrokerMessage) -> None: """ @@ -128,6 +130,7 @@ async def kick(self, message: BrokerMessage) -> None: target_task = self.available_tasks.get(message.task_name) if target_task is None: raise TaskiqError("Unknown task.") + if not self.receiver.dependency_graphs.get(target_task.task_name): self.receiver.dependency_graphs[target_task.task_name] = DependencyGraph( target_task.original_func, @@ -141,7 +144,9 @@ async def kick(self, message: BrokerMessage) -> None: target_task.original_func, ) - await self.receiver.callback(message=message) + task = asyncio.create_task(self.receiver.callback(message=message)) + self._running_tasks.add(task) + task.add_done_callback(self._running_tasks.discard) def listen(self) -> AsyncGenerator[BrokerMessage, None]: """ diff --git a/tests/brokers/test_inmemory.py b/tests/brokers/test_inmemory.py new file mode 100644 index 00000000..46a65a30 --- /dev/null +++ b/tests/brokers/test_inmemory.py @@ -0,0 +1,87 @@ +import asyncio +import uuid + +import pytest + +from taskiq import InMemoryBroker +from taskiq.events import TaskiqEvents +from taskiq.state import TaskiqState + + +@pytest.mark.anyio +async def test_inmemory_success() -> None: + broker = InMemoryBroker() + test_val = uuid.uuid4().hex + + @broker.task + async def task() -> str: + return test_val + + kicked = await task.kiq() + result = await kicked.wait_result() + assert result.return_value == test_val + assert not broker._running_tasks + + +@pytest.mark.anyio +async def test_cannot_listen() -> None: + broker = InMemoryBroker() + + with pytest.raises(RuntimeError): + async for _ in broker.listen(): + pass + + +@pytest.mark.anyio +async def test_startup() -> None: + broker = InMemoryBroker() + test_value = uuid.uuid4().hex + + @broker.on_event(TaskiqEvents.WORKER_STARTUP) + async def _w_startup(state: TaskiqState) -> None: + state.from_worker = test_value + + @broker.on_event(TaskiqEvents.CLIENT_STARTUP) + async def _c_startup(state: TaskiqState) -> None: + state.from_client = test_value + + await broker.startup() + + assert broker.state.from_worker == test_value + assert broker.state.from_client == test_value + + +@pytest.mark.anyio +async def test_shutdown() -> None: + broker = InMemoryBroker() + test_value = uuid.uuid4().hex + + @broker.on_event(TaskiqEvents.WORKER_SHUTDOWN) + async def _w_startup(state: TaskiqState) -> None: + state.from_worker = test_value + + @broker.on_event(TaskiqEvents.CLIENT_SHUTDOWN) + async def _c_startup(state: TaskiqState) -> None: + state.from_client = test_value + + await broker.shutdown() + + assert broker.state.from_worker == test_value + assert broker.state.from_client == test_value + + +@pytest.mark.anyio +async def test_execution() -> None: + broker = InMemoryBroker() + test_value = uuid.uuid4().hex + + @broker.task + async def test_task() -> str: + await asyncio.sleep(0.5) + return test_value + + task = await test_task.kiq() + assert not await task.is_ready() + + result = await task.wait_result() + assert result.return_value == test_value