Skip to content

Commit 1e74b11

Browse files
committed
simplify topic writer
1 parent 889d147 commit 1e74b11

File tree

1 file changed

+56
-57
lines changed

1 file changed

+56
-57
lines changed

ydb/_topic_writer/topic_writer_asyncio.py

Lines changed: 56 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,13 @@
4040
class WriterAsyncIO:
4141
_loop: asyncio.AbstractEventLoop
4242
_reconnector: "WriterAsyncIOReconnector"
43-
_lock: asyncio.Lock
4443
_closed: bool
4544

4645
@property
4746
def last_seqno(self) -> int:
4847
raise NotImplementedError()
4948

5049
def __init__(self, driver: SupportedDriverType, settings: PublicWriterSettings):
51-
self._lock = asyncio.Lock()
5250
self._loop = asyncio.get_running_loop()
5351
self._closed = False
5452
self._reconnector = WriterAsyncIOReconnector(
@@ -68,10 +66,10 @@ def __del__(self):
6866
self._loop.call_soon(self.close)
6967

7068
async def close(self):
71-
async with self._lock:
72-
if self._closed:
73-
return
74-
self._closed = True
69+
if self._closed:
70+
return
71+
72+
self._closed = True
7573

7674
await self._reconnector.close()
7775

@@ -164,65 +162,81 @@ class WriterAsyncIOReconnector:
164162
_update_token_interval: int
165163
_token_get_function: TokenGetterFuncType
166164
_init_message: StreamWriteMessage.InitRequest
167-
_new_messages: asyncio.Queue
168165
_init_info: asyncio.Future
169166
_stream_connected: asyncio.Event
170167
_settings: WriterSettings
171168

172-
_lock: asyncio.Lock
173169
_last_known_seq_no: int
174170
_messages: Deque[InternalMessage]
175171
_messages_future: Deque[asyncio.Future]
176-
_stop_reason: Optional[Exception]
172+
_new_messages: asyncio.Queue
173+
_stop_reason: asyncio.Future
177174
_background_tasks: List[asyncio.Task]
178175

179176
def __init__(self, driver: SupportedDriverType, settings: WriterSettings):
180177
self._driver = driver
181178
self._credentials = driver._credentials
182179
self._init_message = settings.create_init_request()
183-
self._new_messages = asyncio.Queue()
184180
self._init_info = asyncio.Future()
185181
self._stream_connected = asyncio.Event()
186182
self._settings = settings
187183

188-
self._lock = asyncio.Lock()
189184
self._last_known_seq_no = 0
190185
self._messages = deque()
191186
self._messages_future = deque()
192-
self._stop_reason = None
187+
self._new_messages = asyncio.Queue()
188+
self._stop_reason = asyncio.Future()
193189
self._background_tasks = [
194190
asyncio.create_task(self._connection_loop(), name="connection_loop")
195191
]
196192

197193
async def close(self):
198-
await self._check_stop()
199-
await self._stop(TopicWriterStopped())
194+
self._check_stop()
195+
self._stop(TopicWriterStopped())
196+
197+
background_tasks = self._background_tasks
198+
199+
for task in background_tasks:
200+
task.cancel()
201+
202+
await asyncio.wait(self._background_tasks)
200203

201204
async def wait_init(self) -> PublicWriterInitInfo:
202-
return await self._init_info
205+
done, _ = await asyncio.wait(
206+
[self._init_info, self._stop_reason], return_when=asyncio.FIRST_COMPLETED
207+
)
208+
res = done.pop() # type: asyncio.Future
209+
res_val = res.result()
210+
211+
if isinstance(res_val, Exception):
212+
raise res_val
213+
214+
return res_val
215+
216+
async def wait_stop(self) -> Exception:
217+
return await self._stop_reason
203218

204219
async def write_with_ack(
205220
self, messages: List[PublicMessage]
206221
) -> List[asyncio.Future]:
207222
# todo check internal buffer limit
208-
await self._check_stop()
223+
self._check_stop()
209224

210225
if self._settings.auto_seqno:
211226
await self.wait_init()
212227

213-
async with self._lock:
214-
internal_messages = self._prepare_internal_messages_locked(messages)
215-
messages_future = [asyncio.Future() for _ in internal_messages]
228+
internal_messages = self._prepare_internal_messages(messages)
229+
messages_future = [asyncio.Future() for _ in internal_messages]
216230

217-
self._messages.extend(internal_messages)
218-
self._messages_future.extend(messages_future)
231+
self._messages.extend(internal_messages)
232+
self._messages_future.extend(messages_future)
219233

220234
for m in internal_messages:
221235
self._new_messages.put_nowait(m)
222236

223237
return messages_future
224238

225-
def _prepare_internal_messages_locked(self, messages: List[PublicMessage]):
239+
def _prepare_internal_messages(self, messages: List[PublicMessage]):
226240
if self._settings.auto_created_at:
227241
now = datetime.datetime.now()
228242
else:
@@ -263,10 +277,9 @@ def _prepare_internal_messages_locked(self, messages: List[PublicMessage]):
263277

264278
return res
265279

266-
async def _check_stop(self):
267-
async with self._lock:
268-
if self._stop_reason is not None:
269-
raise self._stop_reason
280+
def _check_stop(self):
281+
if self._stop_reason.done():
282+
raise self._stop_reason.result()
270283

271284
async def _connection_loop(self):
272285
retry_settings = RetrySettings() # todo
@@ -275,23 +288,16 @@ async def _connection_loop(self):
275288
attempt = 0 # todo calc and reset
276289
pending = []
277290

278-
async def on_stop(e):
279-
for t in pending:
280-
self._background_tasks.append(t)
281-
pending.clear()
282-
await self._stop(e)
283-
284291
# noinspection PyBroadException
285292
try:
286293
stream_writer = await WriterAsyncIOStream.create(
287294
self._driver, self._init_message, self._get_token
288295
)
289296
try:
290-
async with self._lock:
291-
self._last_known_seq_no = stream_writer.last_seqno
292-
self._init_info.set_result(
293-
PublicWriterInitInfo(last_seqno=stream_writer.last_seqno)
294-
)
297+
self._last_known_seq_no = stream_writer.last_seqno
298+
self._init_info.set_result(
299+
PublicWriterInitInfo(last_seqno=stream_writer.last_seqno)
300+
)
295301
except asyncio.InvalidStateError:
296302
pass
297303

@@ -316,13 +322,13 @@ async def on_stop(e):
316322

317323
err_info = check_retriable_error(err, retry_settings, attempt)
318324
if not err_info.is_retriable:
319-
await on_stop(err)
325+
self._stop(err)
320326
return
321327

322328
await asyncio.sleep(err_info.sleep_timeout_seconds)
323329

324-
except Exception as e:
325-
await on_stop(e)
330+
except (asyncio.CancelledError, Exception) as err:
331+
self._stop(err)
326332
return
327333
finally:
328334
if len(pending) > 0:
@@ -333,11 +339,11 @@ async def on_stop(e):
333339
async def _read_loop(self, writer: "WriterAsyncIOStream"):
334340
while True:
335341
resp = await writer.receive()
336-
async with self._lock:
337-
for ack in resp.acks:
338-
self._handle_receive_ack_need_lock(ack)
339342

340-
def _handle_receive_ack_need_lock(self, ack):
343+
for ack in resp.acks:
344+
self._handle_receive_ack(ack)
345+
346+
def _handle_receive_ack(self, ack):
341347
current_message = self._messages.popleft()
342348
message_future = self._messages_future.popleft()
343349
if current_message.seq_no != ack.seq_no:
@@ -351,8 +357,7 @@ def _handle_receive_ack_need_lock(self, ack):
351357

352358
async def _send_loop(self, writer: "WriterAsyncIOStream"):
353359
try:
354-
async with self._lock:
355-
messages = list(self._messages)
360+
messages = list(self._messages)
356361

357362
last_seq_no = 0
358363
for m in messages:
@@ -364,24 +369,18 @@ async def _send_loop(self, writer: "WriterAsyncIOStream"):
364369
if m.seq_no > last_seq_no:
365370
writer.write([m])
366371
except Exception as e:
367-
await self._stop(e)
372+
self._stop(e)
368373
finally:
369374
pass
370375

371-
async def _stop(self, reason: Exception):
376+
def _stop(self, reason: Exception):
372377
if reason is None:
373378
raise Exception("writer stop reason can not be None")
374379

375-
async with self._lock:
376-
if self._stop_reason is not None:
377-
return
378-
self._stop_reason = reason
379-
background_tasks = self._background_tasks
380-
381-
for task in background_tasks:
382-
task.cancel()
380+
if self._stop_reason.done():
381+
return
383382

384-
await asyncio.wait(self._background_tasks)
383+
self._stop_reason.set_result(reason)
385384

386385
def _get_token(self) -> str:
387386
raise NotImplementedError()

0 commit comments

Comments
 (0)