@@ -40,9 +40,14 @@ class Scope(TypedDict):
40
40
state : NotRequired [Dict [str , Any ]]
41
41
extensions : Optional [Dict [str , Dict [object , object ]]]
42
42
43
+ class RestateEvent (TypedDict ):
44
+ """An event that represents a run completion"""
45
+ type : Literal ["restate.run_completed" ]
46
+ data : Optional [Dict [str , Any ]]
47
+
43
48
class HTTPRequestEvent (TypedDict ):
44
49
"""ASGI Request event"""
45
- type : Literal ["http.request" , "restate.run_completed" ]
50
+ type : Literal ["http.request" ]
46
51
body : bytes
47
52
more_body : bool
48
53
@@ -91,38 +96,35 @@ def binary_to_header(headers: Iterable[Tuple[bytes, bytes]]) -> List[Tuple[str,
91
96
class ReceiveChannel :
92
97
"""ASGI receive channel."""
93
98
94
- def __init__ (self , receive : Receive ):
95
- self .queue = asyncio .Queue [ASGIReceiveEvent ]()
99
+ def __init__ (self , receive : Receive ) -> None :
100
+ self ._queue = asyncio .Queue [Union [ ASGIReceiveEvent , RestateEvent ] ]()
96
101
97
102
async def loop ():
98
103
"""Receive loop."""
99
104
while True :
100
105
event = await receive ()
101
- await self .queue .put (event )
106
+ await self ._queue .put (event )
102
107
if event .get ('type' ) == 'http.disconnect' :
103
108
break
104
109
105
- self .task = asyncio .create_task (loop ())
110
+ self ._task = asyncio .create_task (loop ())
106
111
107
- async def rx (self ) -> ASGIReceiveEvent :
112
+ async def __call__ (self ) -> ASGIReceiveEvent | RestateEvent :
108
113
"""Get the next message."""
109
- what = await self .queue .get ()
110
- self .queue .task_done ()
114
+ what = await self ._queue .get ()
115
+ self ._queue .task_done ()
111
116
return what
112
117
113
- async def __call__ (self ):
114
- """Get the next message."""
115
- return await self .rx ()
116
-
117
- async def tx (self , what : ASGIReceiveEvent ):
118
+ async def enqueue_restate_event (self , what : RestateEvent ):
118
119
"""Add a message."""
119
- await self .queue .put (what )
120
+ await self ._queue .put (what )
120
121
121
122
async def close (self ):
122
123
"""Close the channel."""
123
- if self .task and not self .task .done ():
124
- self .task .cancel ()
125
- try :
126
- await self .task
127
- except asyncio .CancelledError :
128
- pass
124
+ if self ._task .done ():
125
+ return
126
+ self ._task .cancel ()
127
+ try :
128
+ await self ._task
129
+ except asyncio .CancelledError :
130
+ pass
0 commit comments