@@ -143,10 +143,38 @@ def decorator(cls: ClassType) -> ClassType:
143
143
return decorator
144
144
145
145
146
+ def init (
147
+ init_fn : CallableType ,
148
+ ) -> CallableType :
149
+ """Decorator for the workflow init method.
150
+
151
+ This may be used on the __init__ method of the workflow class to specify
152
+ that it accepts the same workflow input arguments as the ``@workflow.run``
153
+ method. It may not be used on any other method.
154
+
155
+ If used, the workflow will be instantiated as
156
+ ``MyWorkflow(**workflow_input_args)``. If not used, the workflow will be
157
+ instantiated as ``MyWorkflow()``.
158
+
159
+ Note that the ``@workflow.run`` method is always called as
160
+ ``my_workflow.my_run_method(**workflow_input_args)``. If you use the
161
+ ``@workflow.init`` decorator, the parameter list of your __init__ and
162
+ ``@workflow.run`` methods must be identical.
163
+
164
+ Args:
165
+ init_fn: The __init__function to decorate.
166
+ """
167
+ if init_fn .__name__ != "__init__" :
168
+ raise ValueError ("@workflow.init may only be used on the __init__ method" )
169
+
170
+ setattr (init_fn , "__temporal_workflow_init" , True )
171
+ return init_fn
172
+
173
+
146
174
def run (fn : CallableAsyncType ) -> CallableAsyncType :
147
175
"""Decorator for the workflow run method.
148
176
149
- This must be set on one and only one async method defined on the same class
177
+ This must be used on one and only one async method defined on the same class
150
178
as ``@workflow.defn``. This can be defined on a base class method but must
151
179
then be explicitly overridden and defined on the workflow class.
152
180
@@ -238,7 +266,7 @@ def signal(
238
266
):
239
267
"""Decorator for a workflow signal method.
240
268
241
- This is set on any async or non-async method that you wish to be called upon
269
+ This is used on any async or non-async method that you wish to be called upon
242
270
receiving a signal. If a function overrides one with this decorator, it too
243
271
must be decorated.
244
272
@@ -309,7 +337,7 @@ def query(
309
337
):
310
338
"""Decorator for a workflow query method.
311
339
312
- This is set on any non-async method that expects to handle a query. If a
340
+ This is used on any non-async method that expects to handle a query. If a
313
341
function overrides one with this decorator, it too must be decorated.
314
342
315
343
Query methods can only have positional parameters. Best practice for
@@ -983,7 +1011,7 @@ def update(
983
1011
):
984
1012
"""Decorator for a workflow update handler method.
985
1013
986
- This is set on any async or non-async method that you wish to be called upon
1014
+ This is used on any async or non-async method that you wish to be called upon
987
1015
receiving an update. If a function overrides one with this decorator, it too
988
1016
must be decorated.
989
1017
@@ -1307,13 +1335,13 @@ def _apply_to_class(
1307
1335
issues : List [str ] = []
1308
1336
1309
1337
# Collect run fn and all signal/query/update fns
1310
- members = inspect . getmembers ( cls )
1338
+ init_fn : Optional [ Callable [..., None ]] = None
1311
1339
run_fn : Optional [Callable [..., Awaitable [Any ]]] = None
1312
1340
seen_run_attr = False
1313
1341
signals : Dict [Optional [str ], _SignalDefinition ] = {}
1314
1342
queries : Dict [Optional [str ], _QueryDefinition ] = {}
1315
1343
updates : Dict [Optional [str ], _UpdateDefinition ] = {}
1316
- for name , member in members :
1344
+ for name , member in inspect . getmembers ( cls ) :
1317
1345
if hasattr (member , "__temporal_workflow_run" ):
1318
1346
seen_run_attr = True
1319
1347
if not _is_unbound_method_on_cls (member , cls ):
@@ -1354,6 +1382,8 @@ def _apply_to_class(
1354
1382
)
1355
1383
else :
1356
1384
queries [query_defn .name ] = query_defn
1385
+ elif name == "__init__" and hasattr (member , "__temporal_workflow_init" ):
1386
+ init_fn = member
1357
1387
elif isinstance (member , UpdateMethodMultiParam ):
1358
1388
update_defn = member ._defn
1359
1389
if update_defn .name in updates :
@@ -1406,9 +1436,14 @@ def _apply_to_class(
1406
1436
1407
1437
if not seen_run_attr :
1408
1438
issues .append ("Missing @workflow.run method" )
1409
- if len (issues ) == 1 :
1410
- raise ValueError (f"Invalid workflow class: { issues [0 ]} " )
1411
- elif issues :
1439
+ if init_fn and run_fn :
1440
+ if not _parameters_identical_up_to_naming (init_fn , run_fn ):
1441
+ issues .append (
1442
+ "@workflow.init and @workflow.run method parameters do not match"
1443
+ )
1444
+ if issues :
1445
+ if len (issues ) == 1 :
1446
+ raise ValueError (f"Invalid workflow class: { issues [0 ]} " )
1412
1447
raise ValueError (
1413
1448
f"Invalid workflow class for { len (issues )} reasons: { ', ' .join (issues )} "
1414
1449
)
@@ -1444,6 +1479,19 @@ def __post_init__(self) -> None:
1444
1479
object .__setattr__ (self , "ret_type" , ret_type )
1445
1480
1446
1481
1482
+ def _parameters_identical_up_to_naming (fn1 : Callable , fn2 : Callable ) -> bool :
1483
+ """Return True if the functions have identical parameter lists, ignoring parameter names."""
1484
+
1485
+ def params (fn : Callable ) -> List [inspect .Parameter ]:
1486
+ # Ignore name when comparing parameters (remaining fields are kind,
1487
+ # default, and annotation).
1488
+ return [p .replace (name = "x" ) for p in inspect .signature (fn ).parameters .values ()]
1489
+
1490
+ # We require that any type annotations present match exactly; i.e. we do
1491
+ # not support any notion of subtype compatibility.
1492
+ return params (fn1 ) == params (fn2 )
1493
+
1494
+
1447
1495
# Async safe version of partial
1448
1496
def _bind_method (obj : Any , fn : Callable [..., Any ]) -> Callable [..., Any ]:
1449
1497
# Curry instance on the definition function since that represents an
0 commit comments