diff --git a/temporalio/worker/_interceptor.py b/temporalio/worker/_interceptor.py index 1b412cb7f..358d34090 100644 --- a/temporalio/worker/_interceptor.py +++ b/temporalio/worker/_interceptor.py @@ -298,6 +298,7 @@ class StartNexusOperationInput(Generic[InputT, OutputT]): operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]] input: InputT schedule_to_close_timeout: Optional[timedelta] + cancellation_type: temporalio.workflow.NexusOperationCancellationType headers: Optional[Mapping[str, str]] output_type: Optional[Type[OutputT]] = None diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index 81f21588f..692306ddc 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -54,13 +54,13 @@ import temporalio.bridge.proto.activity_result import temporalio.bridge.proto.child_workflow import temporalio.bridge.proto.common +import temporalio.bridge.proto.nexus import temporalio.bridge.proto.workflow_activation import temporalio.bridge.proto.workflow_commands import temporalio.bridge.proto.workflow_completion import temporalio.common import temporalio.converter import temporalio.exceptions -import temporalio.nexus import temporalio.workflow from temporalio.service import __version__ @@ -1502,9 +1502,10 @@ async def workflow_start_nexus_operation( service: str, operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], input: Any, - output_type: Optional[Type[OutputT]] = None, - schedule_to_close_timeout: Optional[timedelta] = None, - headers: Optional[Mapping[str, str]] = None, + output_type: Optional[Type[OutputT]], + schedule_to_close_timeout: Optional[timedelta], + cancellation_type: temporalio.workflow.NexusOperationCancellationType, + headers: Optional[Mapping[str, str]], ) -> temporalio.workflow.NexusOperationHandle[OutputT]: # start_nexus_operation return await self._outbound.start_nexus_operation( @@ -1515,6 +1516,7 @@ async def workflow_start_nexus_operation( input=input, output_type=output_type, schedule_to_close_timeout=schedule_to_close_timeout, + cancellation_type=cancellation_type, headers=headers, ) ) @@ -2757,7 +2759,7 @@ def _apply_schedule_command( if self._input.retry_policy: self._input.retry_policy.apply_to_proto(v.retry_policy) v.cancellation_type = cast( - "temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType", + temporalio.bridge.proto.workflow_commands.ActivityCancellationType.ValueType, int(self._input.cancellation_type), ) @@ -2893,7 +2895,7 @@ def _apply_start_command(self) -> None: if self._input.task_timeout: v.workflow_task_timeout.FromTimedelta(self._input.task_timeout) v.parent_close_policy = cast( - "temporalio.bridge.proto.child_workflow.ParentClosePolicy.ValueType", + temporalio.bridge.proto.child_workflow.ParentClosePolicy.ValueType, int(self._input.parent_close_policy), ) v.workflow_id_reuse_policy = cast( @@ -2915,7 +2917,7 @@ def _apply_start_command(self) -> None: self._input.search_attributes, v.search_attributes ) v.cancellation_type = cast( - "temporalio.bridge.proto.child_workflow.ChildWorkflowCancellationType.ValueType", + temporalio.bridge.proto.child_workflow.ChildWorkflowCancellationType.ValueType, int(self._input.cancellation_type), ) if self._input.versioning_intent: @@ -3011,11 +3013,6 @@ def __init__( @property def operation_token(self) -> Optional[str]: - # TODO(nexus-preview): How should this behave? - # Java has a separate class that only exists if the operation token exists: - # https://github.com/temporalio/sdk-java/blob/master/temporal-sdk/src/main/java/io/temporal/internal/sync/NexusOperationExecutionImpl.java#L26 - # And Go similar: - # https://github.com/temporalio/sdk-go/blob/master/internal/workflow.go#L2770-L2771 try: return self._start_fut.result() except BaseException: @@ -3064,6 +3061,11 @@ def _apply_schedule_command(self) -> None: v.schedule_to_close_timeout.FromTimedelta( self._input.schedule_to_close_timeout ) + v.cancellation_type = cast( + temporalio.bridge.proto.nexus.NexusOperationCancellationType.ValueType, + int(self._input.cancellation_type), + ) + if self._input.headers: for key, val in self._input.headers.items(): v.nexus_header[key] = val diff --git a/temporalio/workflow.py b/temporalio/workflow.py index b72097d5a..6a5b1f776 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -858,9 +858,10 @@ async def workflow_start_nexus_operation( service: str, operation: Union[nexusrpc.Operation[InputT, OutputT], str, Callable[..., Any]], input: Any, - output_type: Optional[Type[OutputT]] = None, - schedule_to_close_timeout: Optional[timedelta] = None, - headers: Optional[Mapping[str, str]] = None, + output_type: Optional[Type[OutputT]], + schedule_to_close_timeout: Optional[timedelta], + cancellation_type: temporalio.workflow.NexusOperationCancellationType, + headers: Optional[Mapping[str, str]], ) -> NexusOperationHandle[OutputT]: ... @abstractmethod @@ -5137,6 +5138,46 @@ def _to_proto(self) -> temporalio.bridge.proto.common.VersioningIntent.ValueType ServiceT = TypeVar("ServiceT") +class NexusOperationCancellationType(IntEnum): + """Defines behavior of a Nexus operation when the caller workflow initiates cancellation. + + Pass one of these values to :py:meth:`NexusClient.start_operation` to define cancellation + behavior. + + To initiate cancellation, use :py:meth:`NexusOperationHandle.cancel` and then `await` the + operation handle. This will result in a :py:class:`exceptions.NexusOperationError`. The values + of this enum define what is guaranteed to have happened by that point. + """ + + ABANDON = int(temporalio.bridge.proto.nexus.NexusOperationCancellationType.ABANDON) + """Do not send any cancellation request to the operation handler; just report cancellation to the caller""" + + TRY_CANCEL = int( + temporalio.bridge.proto.nexus.NexusOperationCancellationType.TRY_CANCEL + ) + """Send a cancellation request but immediately report cancellation to the caller. Note that this + does not guarantee that cancellation is delivered to the operation handler if the caller exits + before the delivery is done. + """ + + # TODO(nexus-preview): core needs to be updated to handle + # NexusOperationCancelRequestCompleted and NexusOperationCancelRequestFailed + # see https://github.com/temporalio/sdk-core/issues/911 + # WAIT_REQUESTED = int( + # temporalio.bridge.proto.nexus.NexusOperationCancellationType.WAIT_CANCELLATION_REQUESTED + # ) + # """Send a cancellation request and wait for confirmation that the request was received. + # Does not wait for the operation to complete. + # """ + + WAIT_COMPLETED = int( + temporalio.bridge.proto.nexus.NexusOperationCancellationType.WAIT_CANCELLATION_COMPLETED + ) + """Send a cancellation request and wait for the operation to complete. + Note that the operation may not complete as cancelled (for example, if it catches the + :py:exc:`asyncio.CancelledError` resulting from the cancellation request).""" + + class NexusClient(ABC, Generic[ServiceT]): """A client for invoking Nexus operations. @@ -5167,6 +5208,7 @@ async def start_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... @@ -5180,6 +5222,7 @@ async def start_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... @@ -5196,6 +5239,7 @@ async def start_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... @@ -5212,6 +5256,7 @@ async def start_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... @@ -5228,6 +5273,7 @@ async def start_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> NexusOperationHandle[OutputT]: ... @@ -5239,6 +5285,7 @@ async def start_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> Any: """Start a Nexus operation and return its handle. @@ -5268,6 +5315,7 @@ async def execute_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... @@ -5281,6 +5329,7 @@ async def execute_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... @@ -5297,6 +5346,7 @@ async def execute_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... @@ -5316,6 +5366,7 @@ async def execute_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... @@ -5332,6 +5383,7 @@ async def execute_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> OutputT: ... @@ -5343,6 +5395,7 @@ async def execute_operation( *, output_type: Optional[Type[OutputT]] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> Any: """Execute a Nexus operation and return its result. @@ -5394,6 +5447,7 @@ async def start_operation( *, output_type: Optional[Type] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> Any: return ( @@ -5404,6 +5458,7 @@ async def start_operation( input=input, output_type=output_type, schedule_to_close_timeout=schedule_to_close_timeout, + cancellation_type=cancellation_type, headers=headers, ) ) @@ -5415,6 +5470,7 @@ async def execute_operation( *, output_type: Optional[Type] = None, schedule_to_close_timeout: Optional[timedelta] = None, + cancellation_type: NexusOperationCancellationType = NexusOperationCancellationType.WAIT_COMPLETED, headers: Optional[Mapping[str, str]] = None, ) -> Any: handle = await self.start_operation( @@ -5422,6 +5478,7 @@ async def execute_operation( input, output_type=output_type, schedule_to_close_timeout=schedule_to_close_timeout, + cancellation_type=cancellation_type, headers=headers, ) return await handle diff --git a/tests/helpers/__init__.py b/tests/helpers/__init__.py index a352877d5..a141cfbc8 100644 --- a/tests/helpers/__init__.py +++ b/tests/helpers/__init__.py @@ -7,7 +7,9 @@ from typing import Any, Awaitable, Callable, Optional, Sequence, Type, TypeVar from temporalio.api.common.v1 import WorkflowExecution +from temporalio.api.enums.v1 import EventType as EventType from temporalio.api.enums.v1 import IndexedValueType +from temporalio.api.history.v1 import HistoryEvent from temporalio.api.operatorservice.v1 import ( AddSearchAttributesRequest, ListSearchAttributesRequest, @@ -287,3 +289,40 @@ async def check_unpaused() -> bool: return not info.paused await assert_eventually(check_unpaused) + + +async def print_history(handle: WorkflowHandle): + i = 1 + async for evt in handle.fetch_history_events(): + event = EventType.Name(evt.event_type).removeprefix("EVENT_TYPE_") + print(f"{i:2}: {event}") + i += 1 + + +async def print_interleaved_histories(*handles: WorkflowHandle) -> None: + """ + Print the interleaved history events from multiple workflow handles in columns. + """ + all_events: list[tuple[WorkflowHandle, HistoryEvent, int]] = [] + for handle in handles: + event_num = 1 + async for event in handle.fetch_history_events(): + all_events.append((handle, event, event_num)) + event_num += 1 + all_events.sort(key=lambda item: item[1].event_time.ToDatetime()) + col_width = 40 + + def _format_row(items: list[str], truncate: bool = False) -> str: + if truncate: + items = [item[: col_width - 3] for item in items] + return " | ".join(f"{item:<{col_width - 3}}" for item in items) + + headers = [handle.id for handle in handles] + print("\n" + _format_row(headers, truncate=True)) + print("-" * (col_width * len(handles) + len(handles) - 1)) + for handle, event, event_num in all_events: + event_type = EventType.Name(event.event_type).removeprefix("EVENT_TYPE_") + row = [""] * len(handles) + col_idx = handles.index(handle) + row[col_idx] = f"{event_num:2}: {event_type[: col_width - 5]}" + print(_format_row(row)) diff --git a/tests/nexus/test_workflow_caller_cancellation_types.py b/tests/nexus/test_workflow_caller_cancellation_types.py new file mode 100644 index 000000000..af1e8397b --- /dev/null +++ b/tests/nexus/test_workflow_caller_cancellation_types.py @@ -0,0 +1,294 @@ +import asyncio +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from typing import Any, Optional + +import nexusrpc +import pytest + +from temporalio import exceptions, nexus, workflow +from temporalio.api.enums.v1 import EventType +from temporalio.api.history.v1 import HistoryEvent +from temporalio.client import ( + WorkflowExecutionStatus, + WorkflowHandle, +) +from temporalio.testing import WorkflowEnvironment +from temporalio.worker import Worker +from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name + + +@workflow.defn +class HandlerWorkflow: + @workflow.run + async def run(self) -> None: + await asyncio.Future() + + +@nexusrpc.service +class Service: + workflow_op: nexusrpc.Operation[None, None] + + +@nexusrpc.handler.service_handler(service=Service) +class ServiceHandler: + @nexus.workflow_run_operation + async def workflow_op( + self, ctx: nexus.WorkflowRunOperationContext, _input: None + ) -> nexus.WorkflowHandle[None]: + return await ctx.start_workflow( + HandlerWorkflow.run, + id="handler-wf-" + str(uuid.uuid4()), + ) + + +@dataclass +class Input: + endpoint: str + cancellation_type: Optional[workflow.NexusOperationCancellationType] + + +@dataclass +class CancellationResult: + operation_token: str + caller_unblock_time: datetime + + +@workflow.defn +class CallerWorkflow: + @workflow.init + def __init__(self, input: Input): + self.nexus_client = workflow.create_nexus_client( + service=Service, + endpoint=input.endpoint, + ) + + @workflow.run + async def run(self, input: Input) -> CancellationResult: + op_handle = await ( + self.nexus_client.start_operation( + Service.workflow_op, + input=None, + cancellation_type=input.cancellation_type, + ) + if input.cancellation_type is not None + else self.nexus_client.start_operation(Service.workflow_op, input=None) + ) + op_handle.cancel() + try: + await op_handle + except exceptions.NexusOperationError: + caller_unblock_time = workflow.now() + # Give time for handler wf cancellation event to be written if a request was sent + await asyncio.sleep(0.5) + assert op_handle.operation_token + return CancellationResult( + operation_token=op_handle.operation_token, + caller_unblock_time=caller_unblock_time, + ) + else: + pytest.fail("Expected NexusOperationError") + + +async def check_behavior_for_abandon( + caller_wf: WorkflowHandle, + handler_wf: WorkflowHandle, +) -> None: + """ + Check that a cancellation request is not sent. + """ + handler_status = (await handler_wf.describe()).status + assert handler_status == WorkflowExecutionStatus.RUNNING + await _assert_event_subsequence( + [ + (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED), + (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED), + ] + ) + assert not await _has_event( + caller_wf, + EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED, + ) + + +async def check_behavior_for_try_cancel( + caller_wf: WorkflowHandle, + handler_wf: WorkflowHandle, +) -> None: + """ + Check that a cancellation request is sent and the caller workflow exits before the operation is + canceled. + """ + handler_status = (await handler_wf.describe()).status + assert handler_status == WorkflowExecutionStatus.CANCELED + await _assert_event_subsequence( + [ + (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED), + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED), + ( + handler_wf, + EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCEL_REQUESTED, + ), + (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED), + ] + ) + + +async def check_behavior_for_wait_cancellation_completed( + caller_wf: WorkflowHandle[Any, CancellationResult], + handler_wf: WorkflowHandle, +) -> None: + """ + Check that a cancellation request is sent and the caller workflow nexus operation future is + unblocked after the operation is canceled. + """ + handler_status = (await handler_wf.describe()).status + assert handler_status == WorkflowExecutionStatus.CANCELED + await _assert_event_subsequence( + [ + (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED), + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED), + ( + handler_wf, + EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCEL_REQUESTED, + ), + (handler_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED), + (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCELED), + (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED), + ] + ) + result = await caller_wf.result() + handler_wf_canceled_event_time = await _get_event_time( + handler_wf, + EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED, + ) + assert ( + result.caller_unblock_time > handler_wf_canceled_event_time + ), "For WAIT_COMPLETED, the future should be unblocked after handler workflow cancellation. " + + +@pytest.mark.parametrize( + "cancellation_type", + [ + None, + workflow.NexusOperationCancellationType.ABANDON.name, + workflow.NexusOperationCancellationType.TRY_CANCEL.name, + workflow.NexusOperationCancellationType.WAIT_COMPLETED.name, + ], +) +async def test_cancellation_type( + env: WorkflowEnvironment, + cancellation_type: Optional[str], +): + if env.supports_time_skipping: + pytest.skip("Nexus tests don't work with time-skipping server") + + client = env.client + + async with Worker( + client, + task_queue=str(uuid.uuid4()), + workflows=[CallerWorkflow, HandlerWorkflow], + nexus_service_handlers=[ServiceHandler()], + ) as worker: + await create_nexus_endpoint(worker.task_queue, client) + + input = Input( + endpoint=make_nexus_endpoint_name(worker.task_queue), + cancellation_type=( + workflow.NexusOperationCancellationType[cancellation_type] + if cancellation_type + else None + ), + ) + caller_wf = await client.start_workflow( + CallerWorkflow.run, + input, + id="caller-wf-" + str(uuid.uuid4()), + task_queue=worker.task_queue, + ) + operation_token = (await caller_wf.result()).operation_token + handler_wf = ( + nexus.WorkflowHandle[None] + .from_token(operation_token) + ._to_client_workflow_handle(client) + ) + + if input.cancellation_type == workflow.NexusOperationCancellationType.ABANDON: + await check_behavior_for_abandon(caller_wf, handler_wf) + elif ( + input.cancellation_type + == workflow.NexusOperationCancellationType.TRY_CANCEL + ): + await check_behavior_for_try_cancel(caller_wf, handler_wf) + elif input.cancellation_type in [ + None, + workflow.NexusOperationCancellationType.WAIT_COMPLETED, + ]: + await check_behavior_for_wait_cancellation_completed(caller_wf, handler_wf) + else: + pytest.fail(f"Invalid cancellation type: {input.cancellation_type}") + + +async def _has_event(wf_handle: WorkflowHandle, event_type: EventType.ValueType): + async for e in wf_handle.fetch_history_events(): + if e.event_type == event_type: + return True + return False + + +async def _get_event_time( + wf_handle: WorkflowHandle, + event_type: EventType.ValueType, +) -> datetime: + async for event in wf_handle.fetch_history_events(): + if event.event_type == event_type: + return event.event_time.ToDatetime().replace(tzinfo=timezone.utc) + assert False, f"Event {event_type} not found in {wf_handle.id}" + + +async def _assert_event_subsequence( + expected_events: list[tuple[WorkflowHandle, EventType.ValueType]], +) -> None: + """ + Given a sequence of (WorkflowHandle, EventType) pairs, assert that the sorted sequence of events + from both workflows contains that subsequence. + """ + + def _event_time( + item: tuple[WorkflowHandle, HistoryEvent], + ) -> datetime: + return item[1].event_time.ToDatetime() + + all_events = [] + handles = {h for h, _ in expected_events} + for h in handles: + async for e in h.fetch_history_events(): + all_events.append((h, e)) + _all_events = iter(sorted(all_events, key=_event_time)) + _expected_events = iter(expected_events) + + previous_expected_handle, previous_expected_event_type_name = None, None + for expected_handle, expected_event_type in _expected_events: + expected_event_type_name = EventType.Name(expected_event_type).removeprefix( + "EVENT_TYPE_" + ) + has_expected = next( + ( + (h, e) + for h, e in _all_events + if h == expected_handle and e.event_type == expected_event_type + ), + None, + ) + if not has_expected: + if previous_expected_handle is not None: + prefix = f"After {previous_expected_event_type_name} in {previous_expected_handle.id}, " + else: + prefix = "" + pytest.fail( + f"{prefix}expected {expected_event_type_name} in {expected_handle.id}" + ) + previous_expected_event_type_name = expected_event_type_name + previous_expected_handle = expected_handle