diff --git a/temporalio/client.py b/temporalio/client.py index ce16feb59..98197b64c 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -819,6 +819,30 @@ def list_workflows( ) ) + async def count_workflows( + self, + query: Optional[str] = None, + rpc_metadata: Mapping[str, str] = {}, + rpc_timeout: Optional[timedelta] = None, + ) -> WorkflowExecutionCount: + """Count workflows. + + Args: + query: A Temporal visibility filter. See Temporal documentation + concerning visibility list filters. + rpc_metadata: Headers used on each RPC call. Keys here override + client-level RPC metadata keys. + rpc_timeout: Optional RPC deadline to set for each RPC call. + + Returns: + Count of workflows. + """ + return await self._impl.count_workflows( + CountWorkflowsInput( + query=query, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout + ) + ) + @overload def get_async_activity_handle( self, *, workflow_id: str, run_id: Optional[str], activity_id: str @@ -2310,6 +2334,57 @@ class WorkflowExecutionStatus(IntEnum): ) +@dataclass +class WorkflowExecutionCount: + """Representation of a count from a count workflows call.""" + + count: int + """Approximate number of workflows matching the original query. + + If the query had a group-by clause, this is simply the sum of all the counts + in py:attr:`groups`. + """ + + groups: Sequence[WorkflowExecutionCountAggregationGroup] + """Groups if the query had a group-by clause, or empty if not.""" + + @staticmethod + def _from_raw( + raw: temporalio.api.workflowservice.v1.CountWorkflowExecutionsResponse, + ) -> WorkflowExecutionCount: + return WorkflowExecutionCount( + count=raw.count, + groups=[ + WorkflowExecutionCountAggregationGroup._from_raw(g) for g in raw.groups + ], + ) + + +@dataclass +class WorkflowExecutionCountAggregationGroup: + """Aggregation group if the workflow count query had a group-by clause.""" + + count: int + """Approximate number of workflows matching the original query for this + group. + """ + + group_values: Sequence[temporalio.common.SearchAttributeValue] + """Search attribute values for this group.""" + + @staticmethod + def _from_raw( + raw: temporalio.api.workflowservice.v1.CountWorkflowExecutionsResponse.AggregationGroup, + ) -> WorkflowExecutionCountAggregationGroup: + return WorkflowExecutionCountAggregationGroup( + count=raw.count, + group_values=[ + temporalio.converter._decode_search_attribute_value(v) + for v in raw.group_values + ], + ) + + class WorkflowExecutionAsyncIterator: """Asynchronous iterator for :py:class:`WorkflowExecution` values. @@ -4373,6 +4448,15 @@ class ListWorkflowsInput: rpc_timeout: Optional[timedelta] +@dataclass +class CountWorkflowsInput: + """Input for :py:meth:`OutboundInterceptor.count_workflows`.""" + + query: Optional[str] + rpc_metadata: Mapping[str, str] + rpc_timeout: Optional[timedelta] + + @dataclass class QueryWorkflowInput: """Input for :py:meth:`OutboundInterceptor.query_workflow`.""" @@ -4669,6 +4753,12 @@ def list_workflows( """Called for every :py:meth:`Client.list_workflows` call.""" return self.next.list_workflows(input) + async def count_workflows( + self, input: CountWorkflowsInput + ) -> WorkflowExecutionCount: + """Called for every :py:meth:`Client.count_workflows` call.""" + return await self.next.count_workflows(input) + async def query_workflow(self, input: QueryWorkflowInput) -> Any: """Called for every :py:meth:`WorkflowHandle.query` call.""" return await self.next.query_workflow(input) @@ -4928,6 +5018,21 @@ def list_workflows( ) -> WorkflowExecutionAsyncIterator: return WorkflowExecutionAsyncIterator(self._client, input) + async def count_workflows( + self, input: CountWorkflowsInput + ) -> WorkflowExecutionCount: + return WorkflowExecutionCount._from_raw( + await self._client.workflow_service.count_workflow_executions( + temporalio.api.workflowservice.v1.CountWorkflowExecutionsRequest( + namespace=self._client.namespace, + query=input.query or "", + ), + retry=True, + metadata=input.rpc_metadata, + timeout=input.rpc_timeout, + ) + ) + async def query_workflow(self, input: QueryWorkflowInput) -> Any: req = temporalio.api.workflowservice.v1.QueryWorkflowRequest( namespace=self._client.namespace, diff --git a/temporalio/converter.py b/temporalio/converter.py index 1bb847a09..155a6ee85 100644 --- a/temporalio/converter.py +++ b/temporalio/converter.py @@ -1341,6 +1341,15 @@ def decode_typed_search_attributes( return temporalio.common.TypedSearchAttributes(pairs) +def _decode_search_attribute_value( + payload: temporalio.api.common.v1.Payload, +) -> temporalio.common.SearchAttributeValue: + val = default().payload_converter.from_payload(payload) + if isinstance(val, str) and payload.metadata.get("type") == b"Datetime": + val = _get_iso_datetime_parser()(val) + return val # type: ignore + + def value_to_type( hint: Type, value: Any, diff --git a/tests/test_client.py b/tests/test_client.py index 57db98737..2a26a8615 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -3,7 +3,7 @@ import os import uuid from datetime import datetime, timedelta, timezone -from typing import Any, List, Optional, Tuple +from typing import Any, List, Optional, Tuple, cast import pytest from google.protobuf import json_format @@ -60,6 +60,8 @@ TaskReachabilityType, TerminateWorkflowInput, WorkflowContinuedAsNewError, + WorkflowExecutionCount, + WorkflowExecutionCountAggregationGroup, WorkflowExecutionStatus, WorkflowFailureError, WorkflowHandle, @@ -569,6 +571,59 @@ async def test_list_workflows_and_fetch_history( assert actual_id_and_input == expected_id_and_input +@workflow.defn +class CountableWorkflow: + @workflow.run + async def run(self, wait_forever: bool) -> None: + await workflow.wait_condition(lambda: not wait_forever) + + +async def test_count_workflows(client: Client, env: WorkflowEnvironment): + if env.supports_time_skipping: + pytest.skip("Java test server doesn't support newer workflow listing") + + # 3 workflows that complete, 2 that don't + async with new_worker(client, CountableWorkflow) as worker: + for _ in range(3): + await client.execute_workflow( + CountableWorkflow.run, + False, + id=f"id-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + for _ in range(2): + await client.start_workflow( + CountableWorkflow.run, + True, + id=f"id-{uuid.uuid4()}", + task_queue=worker.task_queue, + ) + + async def fetch_count() -> WorkflowExecutionCount: + resp = await client.count_workflows( + f"TaskQueue = '{worker.task_queue}' GROUP BY ExecutionStatus" + ) + cast(List[WorkflowExecutionCountAggregationGroup], resp.groups).sort( + key=lambda g: g.count + ) + return resp + + await assert_eq_eventually( + WorkflowExecutionCount( + count=5, + groups=[ + WorkflowExecutionCountAggregationGroup( + count=2, group_values=["Running"] + ), + WorkflowExecutionCountAggregationGroup( + count=3, group_values=["Completed"] + ), + ], + ), + fetch_count, + ) + + def test_history_from_json(): # Take proto, make JSON, convert to dict, alter some enums, confirm that it # alters the enums back and matches original history