Skip to content

Commit 56bd610

Browse files
authored
Separate the ASGI types from the restate types (#90)
1 parent 5137b6b commit 56bd610

File tree

3 files changed

+29
-25
lines changed

3 files changed

+29
-25
lines changed

python/restate/server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,9 @@ async def process_invocation_to_completion(vm: VMWrapper,
107107
# everything ends here really ...
108108
return
109109
if message.get('type') == 'http.request':
110-
assert isinstance(message['body'], bytes)
111-
vm.notify_input(message['body'])
110+
body = message.get('body', None)
111+
assert isinstance(body, bytes)
112+
vm.notify_input(body)
112113
if not message.get('more_body', False):
113114
vm.notify_input_closed()
114115
break

python/restate/server_context.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ async def create_poll_or_cancel_coroutine(self, handles: typing.List[int]) -> No
386386
async def wrapper(f):
387387
await f()
388388
await self.take_and_send_output()
389-
await self.receive.tx({ 'type' : 'restate.run_completed', 'body' : bytes(), 'more_body' : True})
389+
await self.receive.enqueue_restate_event({ 'type' : 'restate.run_completed', 'data': None})
390390

391391
task = asyncio.create_task(wrapper(fn))
392392
self.tasks.add(task)
@@ -398,8 +398,9 @@ async def wrapper(f):
398398
if chunk.get('type') == 'http.disconnect':
399399
raise DisconnectedException()
400400
if chunk.get('body', None) is not None:
401-
assert isinstance(chunk['body'], bytes)
402-
self.vm.notify_input(chunk['body'])
401+
body = chunk.get('body')
402+
assert isinstance(body, bytes)
403+
self.vm.notify_input(body)
403404
if not chunk.get('more_body', False):
404405
self.vm.notify_input_closed()
405406

python/restate/server_types.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,14 @@ class Scope(TypedDict):
4040
state: NotRequired[Dict[str, Any]]
4141
extensions: Optional[Dict[str, Dict[object, object]]]
4242

43+
class RestateEvent(TypedDict):
44+
"""An event that represents a run completion"""
45+
type: Literal["restate.run_completed"]
46+
data: Optional[Dict[str, Any]]
47+
4348
class HTTPRequestEvent(TypedDict):
4449
"""ASGI Request event"""
45-
type: Literal["http.request", "restate.run_completed"]
50+
type: Literal["http.request"]
4651
body: bytes
4752
more_body: bool
4853

@@ -91,38 +96,35 @@ def binary_to_header(headers: Iterable[Tuple[bytes, bytes]]) -> List[Tuple[str,
9196
class ReceiveChannel:
9297
"""ASGI receive channel."""
9398

94-
def __init__(self, receive: Receive):
95-
self.queue = asyncio.Queue[ASGIReceiveEvent]()
99+
def __init__(self, receive: Receive) -> None:
100+
self._queue = asyncio.Queue[Union[ASGIReceiveEvent, RestateEvent]]()
96101

97102
async def loop():
98103
"""Receive loop."""
99104
while True:
100105
event = await receive()
101-
await self.queue.put(event)
106+
await self._queue.put(event)
102107
if event.get('type') == 'http.disconnect':
103108
break
104109

105-
self.task = asyncio.create_task(loop())
110+
self._task = asyncio.create_task(loop())
106111

107-
async def rx(self) -> ASGIReceiveEvent:
112+
async def __call__(self) -> ASGIReceiveEvent | RestateEvent:
108113
"""Get the next message."""
109-
what = await self.queue.get()
110-
self.queue.task_done()
114+
what = await self._queue.get()
115+
self._queue.task_done()
111116
return what
112117

113-
async def __call__(self):
114-
"""Get the next message."""
115-
return await self.rx()
116-
117-
async def tx(self, what: ASGIReceiveEvent):
118+
async def enqueue_restate_event(self, what: RestateEvent):
118119
"""Add a message."""
119-
await self.queue.put(what)
120+
await self._queue.put(what)
120121

121122
async def close(self):
122123
"""Close the channel."""
123-
if self.task and not self.task.done():
124-
self.task.cancel()
125-
try:
126-
await self.task
127-
except asyncio.CancelledError:
128-
pass
124+
if self._task.done():
125+
return
126+
self._task.cancel()
127+
try:
128+
await self._task
129+
except asyncio.CancelledError:
130+
pass

0 commit comments

Comments
 (0)