Skip to content

Add Client.count_workflows #510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 105 additions & 0 deletions temporalio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,30 @@ def list_workflows(
)
)

async def count_workflows(
self,
query: Optional[str] = None,
rpc_metadata: Mapping[str, str] = {},
rpc_timeout: Optional[timedelta] = None,
) -> WorkflowExecutionCount:
"""Count workflows.

Args:
query: A Temporal visibility filter. See Temporal documentation
concerning visibility list filters.
rpc_metadata: Headers used on each RPC call. Keys here override
client-level RPC metadata keys.
rpc_timeout: Optional RPC deadline to set for each RPC call.

Returns:
Count of workflows.
"""
return await self._impl.count_workflows(
CountWorkflowsInput(
query=query, rpc_metadata=rpc_metadata, rpc_timeout=rpc_timeout
)
)

@overload
def get_async_activity_handle(
self, *, workflow_id: str, run_id: Optional[str], activity_id: str
Expand Down Expand Up @@ -2310,6 +2334,57 @@ class WorkflowExecutionStatus(IntEnum):
)


@dataclass
class WorkflowExecutionCount:
"""Representation of a count from a count workflows call."""

count: int
"""Approximate number of workflows matching the original query.

If the query had a group-by clause, this is simply the sum of all the counts
in py:attr:`groups`.
"""

groups: Sequence[WorkflowExecutionCountAggregationGroup]
"""Groups if the query had a group-by clause, or empty if not."""

@staticmethod
def _from_raw(
raw: temporalio.api.workflowservice.v1.CountWorkflowExecutionsResponse,
) -> WorkflowExecutionCount:
return WorkflowExecutionCount(
count=raw.count,
groups=[
WorkflowExecutionCountAggregationGroup._from_raw(g) for g in raw.groups
],
)


@dataclass
class WorkflowExecutionCountAggregationGroup:
"""Aggregation group if the workflow count query had a group-by clause."""

count: int
"""Approximate number of workflows matching the original query for this
group.
"""

group_values: Sequence[temporalio.common.SearchAttributeValue]
"""Search attribute values for this group."""

@staticmethod
def _from_raw(
raw: temporalio.api.workflowservice.v1.CountWorkflowExecutionsResponse.AggregationGroup,
) -> WorkflowExecutionCountAggregationGroup:
return WorkflowExecutionCountAggregationGroup(
count=raw.count,
group_values=[
temporalio.converter._decode_search_attribute_value(v)
for v in raw.group_values
],
)


class WorkflowExecutionAsyncIterator:
"""Asynchronous iterator for :py:class:`WorkflowExecution` values.

Expand Down Expand Up @@ -4373,6 +4448,15 @@ class ListWorkflowsInput:
rpc_timeout: Optional[timedelta]


@dataclass
class CountWorkflowsInput:
"""Input for :py:meth:`OutboundInterceptor.count_workflows`."""

query: Optional[str]
rpc_metadata: Mapping[str, str]
rpc_timeout: Optional[timedelta]


@dataclass
class QueryWorkflowInput:
"""Input for :py:meth:`OutboundInterceptor.query_workflow`."""
Expand Down Expand Up @@ -4669,6 +4753,12 @@ def list_workflows(
"""Called for every :py:meth:`Client.list_workflows` call."""
return self.next.list_workflows(input)

async def count_workflows(
self, input: CountWorkflowsInput
) -> WorkflowExecutionCount:
"""Called for every :py:meth:`Client.count_workflows` call."""
return await self.next.count_workflows(input)

async def query_workflow(self, input: QueryWorkflowInput) -> Any:
"""Called for every :py:meth:`WorkflowHandle.query` call."""
return await self.next.query_workflow(input)
Expand Down Expand Up @@ -4928,6 +5018,21 @@ def list_workflows(
) -> WorkflowExecutionAsyncIterator:
return WorkflowExecutionAsyncIterator(self._client, input)

async def count_workflows(
self, input: CountWorkflowsInput
) -> WorkflowExecutionCount:
return WorkflowExecutionCount._from_raw(
await self._client.workflow_service.count_workflow_executions(
temporalio.api.workflowservice.v1.CountWorkflowExecutionsRequest(
namespace=self._client.namespace,
query=input.query or "",
),
retry=True,
metadata=input.rpc_metadata,
timeout=input.rpc_timeout,
)
)

async def query_workflow(self, input: QueryWorkflowInput) -> Any:
req = temporalio.api.workflowservice.v1.QueryWorkflowRequest(
namespace=self._client.namespace,
Expand Down
9 changes: 9 additions & 0 deletions temporalio/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1341,6 +1341,15 @@ def decode_typed_search_attributes(
return temporalio.common.TypedSearchAttributes(pairs)


def _decode_search_attribute_value(
payload: temporalio.api.common.v1.Payload,
) -> temporalio.common.SearchAttributeValue:
val = default().payload_converter.from_payload(payload)
if isinstance(val, str) and payload.metadata.get("type") == b"Datetime":
val = _get_iso_datetime_parser()(val)
return val # type: ignore


def value_to_type(
hint: Type,
value: Any,
Expand Down
57 changes: 56 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import os
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, cast

import pytest
from google.protobuf import json_format
Expand Down Expand Up @@ -60,6 +60,8 @@
TaskReachabilityType,
TerminateWorkflowInput,
WorkflowContinuedAsNewError,
WorkflowExecutionCount,
WorkflowExecutionCountAggregationGroup,
WorkflowExecutionStatus,
WorkflowFailureError,
WorkflowHandle,
Expand Down Expand Up @@ -569,6 +571,59 @@ async def test_list_workflows_and_fetch_history(
assert actual_id_and_input == expected_id_and_input


@workflow.defn
class CountableWorkflow:
@workflow.run
async def run(self, wait_forever: bool) -> None:
await workflow.wait_condition(lambda: not wait_forever)


async def test_count_workflows(client: Client, env: WorkflowEnvironment):
if env.supports_time_skipping:
pytest.skip("Java test server doesn't support newer workflow listing")

# 3 workflows that complete, 2 that don't
async with new_worker(client, CountableWorkflow) as worker:
for _ in range(3):
await client.execute_workflow(
CountableWorkflow.run,
False,
id=f"id-{uuid.uuid4()}",
task_queue=worker.task_queue,
)
for _ in range(2):
await client.start_workflow(
CountableWorkflow.run,
True,
id=f"id-{uuid.uuid4()}",
task_queue=worker.task_queue,
)

async def fetch_count() -> WorkflowExecutionCount:
resp = await client.count_workflows(
f"TaskQueue = '{worker.task_queue}' GROUP BY ExecutionStatus"
)
cast(List[WorkflowExecutionCountAggregationGroup], resp.groups).sort(
key=lambda g: g.count
)
return resp

await assert_eq_eventually(
WorkflowExecutionCount(
count=5,
groups=[
WorkflowExecutionCountAggregationGroup(
count=2, group_values=["Running"]
),
WorkflowExecutionCountAggregationGroup(
count=3, group_values=["Completed"]
),
],
),
fetch_count,
)


def test_history_from_json():
# Take proto, make JSON, convert to dict, alter some enums, confirm that it
# alters the enums back and matches original history
Expand Down
Loading