diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index 223fd3d99..86d106c4e 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -195,6 +195,13 @@ async def drain_poll_queue(self) -> None: except temporalio.bridge.worker.PollShutdownError: return + # Only call this after run()/drain_poll_queue() have returned. This will not + # raise an exception. + async def wait_all_completed(self) -> None: + running_tasks = [v.task for v in self._running_activities.values() if v.task] + if running_tasks: + await asyncio.gather(*running_tasks, return_exceptions=False) + def _cancel( self, task_token: bytes, cancel: temporalio.bridge.proto.activity_task.Cancel ) -> None: diff --git a/temporalio/worker/_worker.py b/temporalio/worker/_worker.py index 15af12cee..67187deb7 100644 --- a/temporalio/worker/_worker.py +++ b/temporalio/worker/_worker.py @@ -467,6 +467,13 @@ async def raise_on_shutdown(): for task in tasks: task.cancel() + # If there's an activity worker, we have to let all activity completions + # finish. We cannot guarantee that because poll shutdown completed + # (which means activities completed) that they got flushed to the + # server. + if self._activity_worker: + await self._activity_worker.wait_all_completed() + # Do final shutdown try: await self._bridge_worker.finalize_shutdown()