Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions tests/topics/test_topic_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,26 @@

@pytest.mark.asyncio
class TestTopicReaderAsyncIO:
async def test_read_batch(
self, driver, topic_path, topic_with_messages, topic_consumer
):
reader = driver.topic_client.reader(topic_consumer, topic_path)
batch = await reader.receive_batch()

assert batch is not None
assert len(batch.messages) > 0

await reader.close()

async def test_read_message(
self, driver, topic_path, topic_with_messages, topic_consumer
):
reader = driver.topic_client.reader(topic_consumer, topic_path)
msg = await reader.receive_message()

assert msg is not None
assert msg.seqno

assert await reader.receive_batch() is not None
await reader.close()

async def test_read_and_commit_message(
Expand Down Expand Up @@ -59,12 +73,26 @@ def decode(b: bytes):


class TestTopicReaderSync:
def test_read_batch(
self, driver_sync, topic_path, topic_with_messages, topic_consumer
):
reader = driver_sync.topic_client.reader(topic_consumer, topic_path)
batch = reader.receive_batch()

assert batch is not None
assert len(batch.messages) > 0

reader.close()

def test_read_message(
self, driver_sync, topic_path, topic_with_messages, topic_consumer
):
reader = driver_sync.topic_client.reader(topic_consumer, topic_path)
msg = reader.receive_message()

assert msg is not None
assert msg.seqno

assert reader.receive_batch() is not None
reader.close()

def test_read_and_commit_message(
Expand Down
6 changes: 6 additions & 0 deletions ydb/_topic_reader/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ def _commit_get_offsets_range(self) -> OffsetsRange:
self.messages[-1]._commit_get_offsets_range().end,
)

def empty(self) -> bool:
return len(self.messages) == 0

# ISessionAlive implementation
@property
def is_alive(self) -> bool:
Expand All @@ -187,3 +190,6 @@ def is_alive(self) -> bool:
state == PartitionSession.State.Active
or state == PartitionSession.State.GracefulShutdown
)

def pop_message(self) -> PublicMessage:
return self.messages.pop(0)
34 changes: 25 additions & 9 deletions ydb/_topic_reader/topic_reader_asyncio.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,14 +95,6 @@ def messages(
"""
raise NotImplementedError()

async def receive_message(self) -> typing.Union[topic_reader.PublicMessage, None]:
"""
Block until receive new message

use asyncio.wait_for for wait with timeout.
"""
raise NotImplementedError()

def batches(
self,
*,
Expand Down Expand Up @@ -133,6 +125,15 @@ async def receive_batch(
await self._reconnector.wait_message()
return self._reconnector.receive_batch_nowait()

async def receive_message(self) -> typing.Optional[datatypes.PublicMessage]:
"""
Block until receive new message

use asyncio.wait_for for wait with timeout.
"""
await self._reconnector.wait_message()
return self._reconnector.receive_message_nowait()

async def commit_on_exit(
self, mess: datatypes.ICommittable
) -> typing.AsyncContextManager:
Expand Down Expand Up @@ -244,6 +245,9 @@ async def wait_message(self):
def receive_batch_nowait(self):
return self._stream_reader.receive_batch_nowait()

def receive_message_nowait(self):
return self._stream_reader.receive_message_nowait()

def commit(
self, batch: datatypes.ICommittable
) -> datatypes.PartitionSession.CommitAckWaiter:
Expand Down Expand Up @@ -397,12 +401,24 @@ def receive_batch_nowait(self):
raise self._get_first_error()

if not self._message_batches:
return
return None

batch = self._message_batches.popleft()
self._buffer_release_bytes(batch._bytes_size)
return batch

def receive_message_nowait(self):
try:
batch = self._message_batches[0]
message = batch.pop_message()
except IndexError:
return None

if batch.empty():
self._message_batches.popleft()

return message

def commit(
self, batch: datatypes.ICommittable
) -> datatypes.PartitionSession.CommitAckWaiter:
Expand Down
128 changes: 127 additions & 1 deletion ydb/_topic_reader/topic_reader_asyncio_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import datetime
import gzip
import typing
from collections import deque
from dataclasses import dataclass
from unittest import mock

Expand Down Expand Up @@ -53,6 +54,34 @@ def default_executor():
executor.shutdown()


def stub_partition_session():
return datatypes.PartitionSession(
id=0,
state=datatypes.PartitionSession.State.Active,
topic_path="asd",
partition_id=1,
committed_offset=0,
reader_reconnector_id=415,
reader_stream_id=513,
)


def stub_message(id: int):
return PublicMessage(
seqno=id,
created_at=datetime.datetime(2023, 3, 18, 14, 15),
message_group_id="",
session_metadata={},
offset=0,
written_at=datetime.datetime(2023, 3, 18, 14, 15),
producer_id="",
data=bytes(),
_partition_session=stub_partition_session(),
_commit_start_offset=0,
_commit_end_offset=1,
)


@pytest.fixture()
def default_reader_settings(default_executor):
return PublicReaderSettings(
Expand Down Expand Up @@ -179,7 +208,9 @@ async def stream_reader_finish_with_error(

@staticmethod
def create_message(
partition_session: datatypes.PartitionSession, seqno: int, offset_delta: int
partition_session: typing.Optional[datatypes.PartitionSession],
seqno: int,
offset_delta: int,
):
return PublicMessage(
seqno=seqno,
Expand Down Expand Up @@ -963,6 +994,101 @@ async def test_read_batches(
_codec=Codec.CODEC_RAW,
)

@pytest.mark.parametrize(
"batches_before,expected_message,batches_after",
[
([], None, []),
(
[
PublicBatch(
session_metadata={},
messages=[stub_message(1)],
_partition_session=stub_partition_session(),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
)
],
stub_message(1),
[],
),
(
[
PublicBatch(
session_metadata={},
messages=[stub_message(1), stub_message(2)],
_partition_session=stub_partition_session(),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
),
PublicBatch(
session_metadata={},
messages=[stub_message(3), stub_message(4)],
_partition_session=stub_partition_session(),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
),
],
stub_message(1),
[
PublicBatch(
session_metadata={},
messages=[stub_message(2)],
_partition_session=stub_partition_session(),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
),
PublicBatch(
session_metadata={},
messages=[stub_message(3), stub_message(4)],
_partition_session=stub_partition_session(),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
),
],
),
(
[
PublicBatch(
session_metadata={},
messages=[stub_message(1)],
_partition_session=stub_partition_session(),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
),
PublicBatch(
session_metadata={},
messages=[stub_message(2), stub_message(3)],
_partition_session=stub_partition_session(),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
),
],
stub_message(1),
[
PublicBatch(
session_metadata={},
messages=[stub_message(2), stub_message(3)],
_partition_session=stub_partition_session(),
_bytes_size=0,
_codec=Codec.CODEC_RAW,
)
],
),
],
)
async def test_read_message(
self,
stream_reader,
batches_before: typing.List[datatypes.PublicBatch],
expected_message: PublicMessage,
batches_after: typing.List[datatypes.PublicBatch],
):
stream_reader._message_batches = deque(batches_before)
mess = stream_reader.receive_message_nowait()

assert mess == expected_message
assert list(stream_reader._message_batches) == batches_after

async def test_receive_batch_nowait(self, stream, stream_reader, partition_session):
assert stream_reader.receive_batch_nowait() is None

Expand Down
27 changes: 20 additions & 7 deletions ydb/_topic_reader/topic_reader_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,19 +83,28 @@ def messages(
It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available.

if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration
if timeout <= 0 - it will fast non block method, get messages from internal buffer only.
if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses,
get messages from internal buffer only.
"""
raise NotImplementedError()

def receive_message(self, *, timeout: Union[float, None] = None) -> PublicMessage:
def receive_message(
self, *, timeout: TimeoutType = None
) -> datatypes.PublicMessage:
"""
Block until receive new message
It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available.
receive_message(timeout=0) may return None even right after async_wait_message() is ok - because lost of partition
or connection to server lost

if no new message in timeout seconds (default - infinite): raise TimeoutError()
if timeout <= 0 - it will fast non block method, get messages from internal buffer only.
if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only.
"""
raise NotImplementedError()
self._check_closed()

return self._caller.safe_call_with_result(
self._async_reader.receive_message(), timeout
)

def async_wait_message(self) -> concurrent.futures.Future:
"""
Expand All @@ -105,7 +114,11 @@ def async_wait_message(self) -> concurrent.futures.Future:
Possible situation when receive signal about message available, but no messages when try to receive a message.
If message expired between send event and try to retrieve message (for example connection broken).
"""
raise NotImplementedError()
self._check_closed()

return self._caller.unsafe_call_with_future(
self._async_reader._reconnector.wait_message()
)

def batches(
self,
Expand All @@ -119,7 +132,7 @@ def batches(
It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available.

if no new message in timeout seconds (default - infinite): stop iterations by raise StopIteration
if timeout <= 0 - it will fast non block method, get messages from internal buffer only.
if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only.
"""
raise NotImplementedError()

Expand All @@ -135,7 +148,7 @@ def receive_batch(
It has no async_ version for prevent lost messages, use async_wait_message as signal for new batches available.

if no new message in timeout seconds (default - infinite): raise TimeoutError()
if timeout <= 0 - it will fast non block method, get messages from internal buffer only.
if timeout <= 0 - it will fast wait only one event loop cycle - without wait any i/o operations or pauses, get messages from internal buffer only.
"""
self._check_closed()

Expand Down