Skip to content

Commit 0995ae0

Browse files
authored
Workflow init (#645)
* Introduce @workflow.init decorator
1 parent 09ac120 commit 0995ae0

File tree

5 files changed

+286
-29
lines changed

5 files changed

+286
-29
lines changed

README.md

+5
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,11 @@ Here are the decorators that can be applied:
575575
* The method's arguments are the workflow's arguments
576576
* The first parameter must be `self`, followed by positional arguments. Best practice is to only take a single
577577
argument that is an object/dataclass of fields that can be added to as needed.
578+
* `@workflow.init` - Specifies that the `__init__` method accepts the workflow's arguments.
579+
* If present, may only be applied to the `__init__` method, the parameters of which must then be identical to those of
580+
the `@workflow.run` method.
581+
* The purpose of this decorator is to allow operations involving workflow arguments to be performed in the `__init__`
582+
method, before any signal or update handler has a chance to execute.
578583
* `@workflow.signal` - Defines a method as a signal
579584
* Can be defined on an `async` or non-`async` function at any hierarchy depth, but if decorated method is overridden,
580585
the override must also be decorated

temporalio/worker/_workflow_instance.py

+40-17
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def __init__(self, det: WorkflowInstanceDetails) -> None:
201201
self._payload_converter = det.payload_converter_class()
202202
self._failure_converter = det.failure_converter_class()
203203
self._defn = det.defn
204+
self._workflow_input: Optional[ExecuteWorkflowInput] = None
204205
self._info = det.info
205206
self._extern_functions = det.extern_functions
206207
self._disable_eager_activity_execution = det.disable_eager_activity_execution
@@ -318,8 +319,9 @@ def get_thread_id(self) -> Optional[int]:
318319
return self._current_thread_id
319320

320321
#### Activation functions ####
321-
# These are in alphabetical order and besides "activate", all other calls
322-
# are "_apply_" + the job field name.
322+
# These are in alphabetical order and besides "activate", and
323+
# "_make_workflow_input", all other calls are "_apply_" + the job field
324+
# name.
323325

324326
def activate(
325327
self, act: temporalio.bridge.proto.workflow_activation.WorkflowActivation
@@ -342,6 +344,7 @@ def activate(
342344
try:
343345
# Split into job sets with patches, then signals + updates, then
344346
# non-queries, then queries
347+
start_job = None
345348
job_sets: List[
346349
List[temporalio.bridge.proto.workflow_activation.WorkflowActivationJob]
347350
] = [[], [], [], []]
@@ -351,10 +354,15 @@ def activate(
351354
elif job.HasField("signal_workflow") or job.HasField("do_update"):
352355
job_sets[1].append(job)
353356
elif not job.HasField("query_workflow"):
357+
if job.HasField("start_workflow"):
358+
start_job = job.start_workflow
354359
job_sets[2].append(job)
355360
else:
356361
job_sets[3].append(job)
357362

363+
if start_job:
364+
self._workflow_input = self._make_workflow_input(start_job)
365+
358366
# Apply every job set, running after each set
359367
for index, job_set in enumerate(job_sets):
360368
if not job_set:
@@ -863,34 +871,41 @@ async def run_workflow(input: ExecuteWorkflowInput) -> None:
863871
return
864872
raise
865873

874+
if not self._workflow_input:
875+
raise RuntimeError(
876+
"Expected workflow input to be set. This is an SDK Python bug."
877+
)
878+
self._primary_task = self.create_task(
879+
self._run_top_level_workflow_function(run_workflow(self._workflow_input)),
880+
name="run",
881+
)
882+
883+
def _apply_update_random_seed(
884+
self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed
885+
) -> None:
886+
self._random.seed(job.randomness_seed)
887+
888+
def _make_workflow_input(
889+
self, start_job: temporalio.bridge.proto.workflow_activation.StartWorkflow
890+
) -> ExecuteWorkflowInput:
866891
# Set arg types, using raw values for dynamic
867892
arg_types = self._defn.arg_types
868893
if not self._defn.name:
869894
# Dynamic is just the raw value for each input value
870-
arg_types = [temporalio.common.RawValue] * len(job.arguments)
871-
args = self._convert_payloads(job.arguments, arg_types)
895+
arg_types = [temporalio.common.RawValue] * len(start_job.arguments)
896+
args = self._convert_payloads(start_job.arguments, arg_types)
872897
# Put args in a list if dynamic
873898
if not self._defn.name:
874899
args = [args]
875900

876-
# Schedule it
877-
input = ExecuteWorkflowInput(
901+
return ExecuteWorkflowInput(
878902
type=self._defn.cls,
879903
# TODO(cretz): Remove cast when https://github.com/python/mypy/issues/5485 fixed
880904
run_fn=cast(Callable[..., Awaitable[Any]], self._defn.run_fn),
881905
args=args,
882-
headers=job.headers,
883-
)
884-
self._primary_task = self.create_task(
885-
self._run_top_level_workflow_function(run_workflow(input)),
886-
name="run",
906+
headers=start_job.headers,
887907
)
888908

889-
def _apply_update_random_seed(
890-
self, job: temporalio.bridge.proto.workflow_activation.UpdateRandomSeed
891-
) -> None:
892-
self._random.seed(job.randomness_seed)
893-
894909
#### _Runtime direct workflow call overrides ####
895910
# These are in alphabetical order and all start with "workflow_".
896911

@@ -1617,6 +1632,14 @@ def _convert_payloads(
16171632
except Exception as err:
16181633
raise RuntimeError("Failed decoding arguments") from err
16191634

1635+
def _instantiate_workflow_object(self) -> Any:
1636+
if not self._workflow_input:
1637+
raise RuntimeError("Expected workflow input. This is a Python SDK bug.")
1638+
if hasattr(self._defn.cls.__init__, "__temporal_workflow_init"):
1639+
return self._defn.cls(*self._workflow_input.args)
1640+
else:
1641+
return self._defn.cls()
1642+
16201643
def _is_workflow_failure_exception(self, err: BaseException) -> bool:
16211644
# An exception is a failure instead of a task fail if it's already a
16221645
# failure error or if it is an instance of any of the failure types in
@@ -1752,7 +1775,7 @@ def _run_once(self, *, check_conditions: bool) -> None:
17521775
# We instantiate the workflow class _inside_ here because __init__
17531776
# needs to run with this event loop set
17541777
if not self._object:
1755-
self._object = self._defn.cls()
1778+
self._object = self._instantiate_workflow_object()
17561779

17571780
# Run while there is anything ready
17581781
while self._ready:

temporalio/workflow.py

+57-9
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,38 @@ def decorator(cls: ClassType) -> ClassType:
143143
return decorator
144144

145145

146+
def init(
147+
init_fn: CallableType,
148+
) -> CallableType:
149+
"""Decorator for the workflow init method.
150+
151+
This may be used on the __init__ method of the workflow class to specify
152+
that it accepts the same workflow input arguments as the ``@workflow.run``
153+
method. It may not be used on any other method.
154+
155+
If used, the workflow will be instantiated as
156+
``MyWorkflow(**workflow_input_args)``. If not used, the workflow will be
157+
instantiated as ``MyWorkflow()``.
158+
159+
Note that the ``@workflow.run`` method is always called as
160+
``my_workflow.my_run_method(**workflow_input_args)``. If you use the
161+
``@workflow.init`` decorator, the parameter list of your __init__ and
162+
``@workflow.run`` methods must be identical.
163+
164+
Args:
165+
init_fn: The __init__function to decorate.
166+
"""
167+
if init_fn.__name__ != "__init__":
168+
raise ValueError("@workflow.init may only be used on the __init__ method")
169+
170+
setattr(init_fn, "__temporal_workflow_init", True)
171+
return init_fn
172+
173+
146174
def run(fn: CallableAsyncType) -> CallableAsyncType:
147175
"""Decorator for the workflow run method.
148176
149-
This must be set on one and only one async method defined on the same class
177+
This must be used on one and only one async method defined on the same class
150178
as ``@workflow.defn``. This can be defined on a base class method but must
151179
then be explicitly overridden and defined on the workflow class.
152180
@@ -238,7 +266,7 @@ def signal(
238266
):
239267
"""Decorator for a workflow signal method.
240268
241-
This is set on any async or non-async method that you wish to be called upon
269+
This is used on any async or non-async method that you wish to be called upon
242270
receiving a signal. If a function overrides one with this decorator, it too
243271
must be decorated.
244272
@@ -309,7 +337,7 @@ def query(
309337
):
310338
"""Decorator for a workflow query method.
311339
312-
This is set on any non-async method that expects to handle a query. If a
340+
This is used on any non-async method that expects to handle a query. If a
313341
function overrides one with this decorator, it too must be decorated.
314342
315343
Query methods can only have positional parameters. Best practice for
@@ -983,7 +1011,7 @@ def update(
9831011
):
9841012
"""Decorator for a workflow update handler method.
9851013
986-
This is set on any async or non-async method that you wish to be called upon
1014+
This is used on any async or non-async method that you wish to be called upon
9871015
receiving an update. If a function overrides one with this decorator, it too
9881016
must be decorated.
9891017
@@ -1307,13 +1335,13 @@ def _apply_to_class(
13071335
issues: List[str] = []
13081336

13091337
# Collect run fn and all signal/query/update fns
1310-
members = inspect.getmembers(cls)
1338+
init_fn: Optional[Callable[..., None]] = None
13111339
run_fn: Optional[Callable[..., Awaitable[Any]]] = None
13121340
seen_run_attr = False
13131341
signals: Dict[Optional[str], _SignalDefinition] = {}
13141342
queries: Dict[Optional[str], _QueryDefinition] = {}
13151343
updates: Dict[Optional[str], _UpdateDefinition] = {}
1316-
for name, member in members:
1344+
for name, member in inspect.getmembers(cls):
13171345
if hasattr(member, "__temporal_workflow_run"):
13181346
seen_run_attr = True
13191347
if not _is_unbound_method_on_cls(member, cls):
@@ -1354,6 +1382,8 @@ def _apply_to_class(
13541382
)
13551383
else:
13561384
queries[query_defn.name] = query_defn
1385+
elif name == "__init__" and hasattr(member, "__temporal_workflow_init"):
1386+
init_fn = member
13571387
elif isinstance(member, UpdateMethodMultiParam):
13581388
update_defn = member._defn
13591389
if update_defn.name in updates:
@@ -1406,9 +1436,14 @@ def _apply_to_class(
14061436

14071437
if not seen_run_attr:
14081438
issues.append("Missing @workflow.run method")
1409-
if len(issues) == 1:
1410-
raise ValueError(f"Invalid workflow class: {issues[0]}")
1411-
elif issues:
1439+
if init_fn and run_fn:
1440+
if not _parameters_identical_up_to_naming(init_fn, run_fn):
1441+
issues.append(
1442+
"@workflow.init and @workflow.run method parameters do not match"
1443+
)
1444+
if issues:
1445+
if len(issues) == 1:
1446+
raise ValueError(f"Invalid workflow class: {issues[0]}")
14121447
raise ValueError(
14131448
f"Invalid workflow class for {len(issues)} reasons: {', '.join(issues)}"
14141449
)
@@ -1444,6 +1479,19 @@ def __post_init__(self) -> None:
14441479
object.__setattr__(self, "ret_type", ret_type)
14451480

14461481

1482+
def _parameters_identical_up_to_naming(fn1: Callable, fn2: Callable) -> bool:
1483+
"""Return True if the functions have identical parameter lists, ignoring parameter names."""
1484+
1485+
def params(fn: Callable) -> List[inspect.Parameter]:
1486+
# Ignore name when comparing parameters (remaining fields are kind,
1487+
# default, and annotation).
1488+
return [p.replace(name="x") for p in inspect.signature(fn).parameters.values()]
1489+
1490+
# We require that any type annotations present match exactly; i.e. we do
1491+
# not support any notion of subtype compatibility.
1492+
return params(fn1) == params(fn2)
1493+
1494+
14471495
# Async safe version of partial
14481496
def _bind_method(obj: Any, fn: Callable[..., Any]) -> Callable[..., Any]:
14491497
# Curry instance on the definition function since that represents an

tests/test_workflow.py

+68-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any, Sequence
1+
import inspect
2+
import itertools
3+
from typing import Sequence
24

35
import pytest
46

@@ -342,3 +344,68 @@ def test_workflow_defn_dynamic_handler_warnings():
342344
# We want to make sure they are reporting the right stacklevel
343345
warnings[0].filename.endswith("test_workflow.py")
344346
warnings[1].filename.endswith("test_workflow.py")
347+
348+
349+
class _TestParametersIdenticalUpToNaming:
350+
def a1(self, a):
351+
pass
352+
353+
def a2(self, b):
354+
pass
355+
356+
def b1(self, a: int):
357+
pass
358+
359+
def b2(self, b: int) -> str:
360+
return ""
361+
362+
def c1(self, a1: int, a2: str) -> str:
363+
return ""
364+
365+
def c2(self, b1: int, b2: str) -> int:
366+
return 0
367+
368+
def d1(self, a1, a2: str) -> None:
369+
pass
370+
371+
def d2(self, b1, b2: str) -> str:
372+
return ""
373+
374+
def e1(self, a1, a2: str = "") -> None:
375+
return None
376+
377+
def e2(self, b1, b2: str = "") -> str:
378+
return ""
379+
380+
def f1(self, a1, a2: str = "a") -> None:
381+
return None
382+
383+
384+
def test_parameters_identical_up_to_naming():
385+
fns = [
386+
f
387+
for _, f in inspect.getmembers(_TestParametersIdenticalUpToNaming)
388+
if inspect.isfunction(f)
389+
]
390+
for f1, f2 in itertools.combinations(fns, 2):
391+
name1, name2 = f1.__name__, f2.__name__
392+
expect_equal = name1[0] == name2[0]
393+
assert (
394+
workflow._parameters_identical_up_to_naming(f1, f2) == (expect_equal)
395+
), f"expected {name1} and {name2} parameters{' ' if expect_equal else ' not '}to compare equal"
396+
397+
398+
@workflow.defn
399+
class BadWorkflowInit:
400+
def not__init__(self):
401+
pass
402+
403+
@workflow.run
404+
async def run(self):
405+
pass
406+
407+
408+
def test_workflow_init_not__init__():
409+
with pytest.raises(ValueError) as err:
410+
workflow.init(BadWorkflowInit.not__init__)
411+
assert "@workflow.init may only be used on the __init__ method" in str(err.value)

0 commit comments

Comments
 (0)