Skip to content

Commit b44581f

Browse files
authored
Minor type hint improvments (#95)
* ctx.get should return T instead of Any * Add additional type hints * more type hints around durable promise
1 parent 7c1c1d1 commit b44581f

File tree

4 files changed

+25
-14
lines changed

4 files changed

+25
-14
lines changed

examples/virtual_object.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
@counter.handler()
2020
async def increment(ctx: ObjectContext, value: int) -> int:
21-
n = await ctx.get("counter") or 0
21+
n = await ctx.get("counter", type_hint=int) or 0
2222
n += value
2323
ctx.set("counter", n)
2424
return n

examples/workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,5 @@ def payment_gateway():
5656

5757
@payment.handler()
5858
async def payment_verified(ctx: WorkflowSharedContext, result: str):
59-
promise = ctx.promise("verify.payment")
59+
promise = ctx.promise("verify.payment", type_hint=str)
6060
await promise.resolve(result)

python/restate/context.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -132,15 +132,15 @@ def get(self,
132132
name: str,
133133
serde: Serde[T] = DefaultSerde(),
134134
type_hint: Optional[typing.Type[T]] = None
135-
) -> Awaitable[Optional[Any]]:
135+
) -> Awaitable[Optional[T]]:
136136
"""
137137
Retrieves the value associated with the given name.
138138
139139
Args:
140140
name: The state name
141-
serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type.
141+
serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type.
142142
See also 'type_hint'.
143-
type_hint: The type hint of the return value. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer.
143+
type_hint: The type hint of the return value. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer.
144144
"""
145145

146146
@abc.abstractmethod
@@ -213,15 +213,15 @@ def run(self,
213213
Args:
214214
name: The name of the action.
215215
action: The action to run.
216-
serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type.
216+
serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type.
217217
See also 'type_hint'.
218218
max_attempts: The maximum number of retry attempts to complete the action.
219219
If None, the action will be retried indefinitely, until it succeeds.
220220
Otherwise, the action will be retried until the maximum number of attempts is reached and then it will raise a TerminalError.
221221
max_retry_duration: The maximum duration for retrying. If None, the action will be retried indefinitely, until it succeeds.
222222
Otherwise, the action will be retried until the maximum duration is reached and then it will raise a TerminalError.
223223
type_hint: The type hint of the return value of the action.
224-
This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer.
224+
This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer.
225225
226226
"""
227227

@@ -327,7 +327,7 @@ def generic_send(self,
327327
def awakeable(self,
328328
serde: Serde[T] = DefaultSerde(),
329329
type_hint: Optional[typing.Type[T]] = None
330-
) -> typing.Tuple[str, RestateDurableFuture[Any]]:
330+
) -> typing.Tuple[str, RestateDurableFuture[T]]:
331331
"""
332332
Returns the name of the awakeable and the future to be awaited.
333333
"""
@@ -388,15 +388,15 @@ def get(self,
388388
name: str,
389389
serde: Serde[T] = DefaultSerde(),
390390
type_hint: Optional[typing.Type[T]] = None
391-
) -> RestateDurableFuture[Optional[Any]]:
391+
) -> RestateDurableFuture[Optional[T]]:
392392
"""
393393
Retrieves the value associated with the given name.
394394
395395
Args:
396396
name: The state name
397-
serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type.
397+
serde: The serialization/deserialization mechanism. - if the default serde is used, a default serializer will be used based on the type.
398398
See also 'type_hint'.
399-
type_hint: The type hint of the return value. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer.
399+
type_hint: The type hint of the return value. This is used to pick the serializer. If None, the type hint will be inferred from the action's return type, or the provided serializer.
400400
"""
401401

402402
@abc.abstractmethod
@@ -438,13 +438,19 @@ def value(self) -> RestateDurableFuture[T]:
438438
Returns the value of the promise if it is resolved, None otherwise.
439439
"""
440440

441+
@abc.abstractmethod
442+
def __await__(self) -> typing.Generator[Any, Any, T]:
443+
"""
444+
Returns the value of the promise. This is a shortcut for calling value() and awaiting it.
445+
"""
446+
441447
class WorkflowContext(ObjectContext):
442448
"""
443449
Represents the context of the current workflow invocation.
444450
"""
445451

446452
@abc.abstractmethod
447-
def promise(self, name: str, serde: Serde[T] = DefaultSerde()) -> DurablePromise[Any]:
453+
def promise(self, name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[typing.Type[T]] = None) -> DurablePromise[T]:
448454
"""
449455
Returns a durable promise with the given name.
450456
"""
@@ -455,7 +461,7 @@ class WorkflowSharedContext(ObjectSharedContext):
455461
"""
456462

457463
@abc.abstractmethod
458-
def promise(self, name: str, serde: Serde[T] = DefaultSerde()) -> DurablePromise[Any]:
464+
def promise(self, name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[typing.Type[T]] = None) -> DurablePromise[T]:
459465
"""
460466
Returns a durable promise with the given name.
461467
"""

python/restate/server_context.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,9 @@ def peek(self) -> Awaitable[Any | None]:
216216
assert serde is not None
217217
return self.server_context.create_future(handle, serde)
218218

219+
def __await__(self):
220+
return self.value().__await__()
221+
219222

220223
# disable too many public method
221224
# pylint: disable=R0904
@@ -684,8 +687,10 @@ def resolve_awakeable(self,
684687
def reject_awakeable(self, name: str, failure_message: str, failure_code: int = 500) -> None:
685688
return self.vm.sys_reject_awakeable(name, Failure(code=failure_code, message=failure_message))
686689

687-
def promise(self, name: str, serde: typing.Optional[Serde[T]] = JsonSerde()) -> DurablePromise[Any]:
690+
def promise(self, name: str, serde: typing.Optional[Serde[T]] = JsonSerde(), type_hint: Optional[typing.Type[T]] = None) -> DurablePromise[T]:
688691
"""Create a durable promise."""
692+
if isinstance(serde, DefaultSerde):
693+
serde = serde.with_maybe_type(type_hint)
689694
return ServerDurablePromise(self, name, serde)
690695

691696
def key(self) -> str:

0 commit comments

Comments
 (0)