40
40
class WriterAsyncIO :
41
41
_loop : asyncio .AbstractEventLoop
42
42
_reconnector : "WriterAsyncIOReconnector"
43
- _lock : asyncio .Lock
44
43
_closed : bool
45
44
46
45
@property
47
46
def last_seqno (self ) -> int :
48
47
raise NotImplementedError ()
49
48
50
49
def __init__ (self , driver : SupportedDriverType , settings : PublicWriterSettings ):
51
- self ._lock = asyncio .Lock ()
52
50
self ._loop = asyncio .get_running_loop ()
53
51
self ._closed = False
54
52
self ._reconnector = WriterAsyncIOReconnector (
@@ -68,10 +66,10 @@ def __del__(self):
68
66
self ._loop .call_soon (self .close )
69
67
70
68
async def close (self ):
71
- async with self ._lock :
72
- if self . _closed :
73
- return
74
- self ._closed = True
69
+ if self ._closed :
70
+ return
71
+
72
+ self ._closed = True
75
73
76
74
await self ._reconnector .close ()
77
75
@@ -164,65 +162,81 @@ class WriterAsyncIOReconnector:
164
162
_update_token_interval : int
165
163
_token_get_function : TokenGetterFuncType
166
164
_init_message : StreamWriteMessage .InitRequest
167
- _new_messages : asyncio .Queue
168
165
_init_info : asyncio .Future
169
166
_stream_connected : asyncio .Event
170
167
_settings : WriterSettings
171
168
172
- _lock : asyncio .Lock
173
169
_last_known_seq_no : int
174
170
_messages : Deque [InternalMessage ]
175
171
_messages_future : Deque [asyncio .Future ]
176
- _stop_reason : Optional [Exception ]
172
+ _new_messages : asyncio .Queue
173
+ _stop_reason : asyncio .Future
177
174
_background_tasks : List [asyncio .Task ]
178
175
179
176
def __init__ (self , driver : SupportedDriverType , settings : WriterSettings ):
180
177
self ._driver = driver
181
178
self ._credentials = driver ._credentials
182
179
self ._init_message = settings .create_init_request ()
183
- self ._new_messages = asyncio .Queue ()
184
180
self ._init_info = asyncio .Future ()
185
181
self ._stream_connected = asyncio .Event ()
186
182
self ._settings = settings
187
183
188
- self ._lock = asyncio .Lock ()
189
184
self ._last_known_seq_no = 0
190
185
self ._messages = deque ()
191
186
self ._messages_future = deque ()
192
- self ._stop_reason = None
187
+ self ._new_messages = asyncio .Queue ()
188
+ self ._stop_reason = asyncio .Future ()
193
189
self ._background_tasks = [
194
190
asyncio .create_task (self ._connection_loop (), name = "connection_loop" )
195
191
]
196
192
197
193
async def close (self ):
198
- await self ._check_stop ()
199
- await self ._stop (TopicWriterStopped ())
194
+ self ._check_stop ()
195
+ self ._stop (TopicWriterStopped ())
196
+
197
+ background_tasks = self ._background_tasks
198
+
199
+ for task in background_tasks :
200
+ task .cancel ()
201
+
202
+ await asyncio .wait (self ._background_tasks )
200
203
201
204
async def wait_init (self ) -> PublicWriterInitInfo :
202
- return await self ._init_info
205
+ done , _ = await asyncio .wait (
206
+ [self ._init_info , self ._stop_reason ], return_when = asyncio .FIRST_COMPLETED
207
+ )
208
+ res = done .pop () # type: asyncio.Future
209
+ res_val = res .result ()
210
+
211
+ if isinstance (res_val , Exception ):
212
+ raise res_val
213
+
214
+ return res_val
215
+
216
+ async def wait_stop (self ) -> Exception :
217
+ return await self ._stop_reason
203
218
204
219
async def write_with_ack (
205
220
self , messages : List [PublicMessage ]
206
221
) -> List [asyncio .Future ]:
207
222
# todo check internal buffer limit
208
- await self ._check_stop ()
223
+ self ._check_stop ()
209
224
210
225
if self ._settings .auto_seqno :
211
226
await self .wait_init ()
212
227
213
- async with self ._lock :
214
- internal_messages = self ._prepare_internal_messages_locked (messages )
215
- messages_future = [asyncio .Future () for _ in internal_messages ]
228
+ internal_messages = self ._prepare_internal_messages (messages )
229
+ messages_future = [asyncio .Future () for _ in internal_messages ]
216
230
217
- self ._messages .extend (internal_messages )
218
- self ._messages_future .extend (messages_future )
231
+ self ._messages .extend (internal_messages )
232
+ self ._messages_future .extend (messages_future )
219
233
220
234
for m in internal_messages :
221
235
self ._new_messages .put_nowait (m )
222
236
223
237
return messages_future
224
238
225
- def _prepare_internal_messages_locked (self , messages : List [PublicMessage ]):
239
+ def _prepare_internal_messages (self , messages : List [PublicMessage ]):
226
240
if self ._settings .auto_created_at :
227
241
now = datetime .datetime .now ()
228
242
else :
@@ -263,10 +277,9 @@ def _prepare_internal_messages_locked(self, messages: List[PublicMessage]):
263
277
264
278
return res
265
279
266
- async def _check_stop (self ):
267
- async with self ._lock :
268
- if self ._stop_reason is not None :
269
- raise self ._stop_reason
280
+ def _check_stop (self ):
281
+ if self ._stop_reason .done ():
282
+ raise self ._stop_reason .result ()
270
283
271
284
async def _connection_loop (self ):
272
285
retry_settings = RetrySettings () # todo
@@ -275,23 +288,16 @@ async def _connection_loop(self):
275
288
attempt = 0 # todo calc and reset
276
289
pending = []
277
290
278
- async def on_stop (e ):
279
- for t in pending :
280
- self ._background_tasks .append (t )
281
- pending .clear ()
282
- await self ._stop (e )
283
-
284
291
# noinspection PyBroadException
285
292
try :
286
293
stream_writer = await WriterAsyncIOStream .create (
287
294
self ._driver , self ._init_message , self ._get_token
288
295
)
289
296
try :
290
- async with self ._lock :
291
- self ._last_known_seq_no = stream_writer .last_seqno
292
- self ._init_info .set_result (
293
- PublicWriterInitInfo (last_seqno = stream_writer .last_seqno )
294
- )
297
+ self ._last_known_seq_no = stream_writer .last_seqno
298
+ self ._init_info .set_result (
299
+ PublicWriterInitInfo (last_seqno = stream_writer .last_seqno )
300
+ )
295
301
except asyncio .InvalidStateError :
296
302
pass
297
303
@@ -316,13 +322,13 @@ async def on_stop(e):
316
322
317
323
err_info = check_retriable_error (err , retry_settings , attempt )
318
324
if not err_info .is_retriable :
319
- await on_stop (err )
325
+ self . _stop (err )
320
326
return
321
327
322
328
await asyncio .sleep (err_info .sleep_timeout_seconds )
323
329
324
- except Exception as e :
325
- await on_stop ( e )
330
+ except ( asyncio . CancelledError , Exception ) as err :
331
+ self . _stop ( err )
326
332
return
327
333
finally :
328
334
if len (pending ) > 0 :
@@ -333,11 +339,11 @@ async def on_stop(e):
333
339
async def _read_loop (self , writer : "WriterAsyncIOStream" ):
334
340
while True :
335
341
resp = await writer .receive ()
336
- async with self ._lock :
337
- for ack in resp .acks :
338
- self ._handle_receive_ack_need_lock (ack )
339
342
340
- def _handle_receive_ack_need_lock (self , ack ):
343
+ for ack in resp .acks :
344
+ self ._handle_receive_ack (ack )
345
+
346
+ def _handle_receive_ack (self , ack ):
341
347
current_message = self ._messages .popleft ()
342
348
message_future = self ._messages_future .popleft ()
343
349
if current_message .seq_no != ack .seq_no :
@@ -351,8 +357,7 @@ def _handle_receive_ack_need_lock(self, ack):
351
357
352
358
async def _send_loop (self , writer : "WriterAsyncIOStream" ):
353
359
try :
354
- async with self ._lock :
355
- messages = list (self ._messages )
360
+ messages = list (self ._messages )
356
361
357
362
last_seq_no = 0
358
363
for m in messages :
@@ -364,24 +369,18 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"):
364
369
if m .seq_no > last_seq_no :
365
370
writer .write ([m ])
366
371
except Exception as e :
367
- await self ._stop (e )
372
+ self ._stop (e )
368
373
finally :
369
374
pass
370
375
371
- async def _stop (self , reason : Exception ):
376
+ def _stop (self , reason : Exception ):
372
377
if reason is None :
373
378
raise Exception ("writer stop reason can not be None" )
374
379
375
- async with self ._lock :
376
- if self ._stop_reason is not None :
377
- return
378
- self ._stop_reason = reason
379
- background_tasks = self ._background_tasks
380
-
381
- for task in background_tasks :
382
- task .cancel ()
380
+ if self ._stop_reason .done ():
381
+ return
383
382
384
- await asyncio . wait ( self ._background_tasks )
383
+ self ._stop_reason . set_result ( reason )
385
384
386
385
def _get_token (self ) -> str :
387
386
raise NotImplementedError ()
0 commit comments