Skip to content

Fix coroutine not awaited warning #1129

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions src/neo4j/_async_compat/network/_bolt_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, reader, protocol, writer):
self._timeout = None
self._deadline = None

async def _wait_for_io(self, io_fut):
async def _wait_for_io(self, io_async_fn, *args, **kwargs):
timeout = self._timeout
to_raise = SocketTimeout
if self._deadline is not None:
Expand All @@ -109,6 +109,7 @@ async def _wait_for_io(self, io_fut):
timeout = deadline_timeout
to_raise = SocketDeadlineExceededError

io_fut = io_async_fn(*args, **kwargs)
if timeout is not None and timeout <= 0:
# give the io-operation time for one loop cycle to do its thing
io_fut = asyncio.create_task(io_fut)
Expand Down Expand Up @@ -157,20 +158,17 @@ def settimeout(self, timeout):
self._timeout = timeout

async def recv(self, n):
io_fut = self._reader.read(n)
return await self._wait_for_io(io_fut)
return await self._wait_for_io(self._reader.read, n)

async def recv_into(self, buffer, nbytes):
# FIXME: not particularly memory or time efficient
io_fut = self._reader.read(nbytes)
res = await self._wait_for_io(io_fut)
res = await self._wait_for_io(self._reader.read, nbytes)
buffer[: len(res)] = res
return len(res)

async def sendall(self, data):
self._writer.write(data)
io_fut = self._writer.drain()
return await self._wait_for_io(io_fut)
return await self._wait_for_io(self._writer.drain)

async def close(self):
self._writer.close()
Expand Down
164 changes: 164 additions & 0 deletions tests/unit/mixed/async_compat/test_network.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

import asyncio
import socket
import typing as t

import freezegun
import pytest

from neo4j._async_compat.network import AsyncBoltSocket
from neo4j._exceptions import SocketDeadlineExceededError

from ...._async_compat.mark_decorator import mark_async_test


if t.TYPE_CHECKING:
import typing_extensions as te
from freezegun.api import (
FrozenDateTimeFactory,
StepTickTimeFactory,
TickingDateTimeFactory,
)

TFreezeTime: te.TypeAlias = (
StepTickTimeFactory | TickingDateTimeFactory | FrozenDateTimeFactory
)


@pytest.fixture
def reader_factory(mocker):
def factory():
return mocker.create_autospec(asyncio.StreamReader)

return factory


@pytest.fixture
def writer_factory(mocker):
def factory():
return mocker.create_autospec(asyncio.StreamWriter)

return factory


@pytest.fixture
def socket_factory(reader_factory, writer_factory):
def factory():
protocol = None
return AsyncBoltSocket(reader_factory(), protocol, writer_factory())

return factory


def reader(s: AsyncBoltSocket):
return s._reader


def writer(s: AsyncBoltSocket):
return s._writer


@pytest.mark.parametrize(
("timeout", "deadline", "pre_tick", "tick", "exception"),
(
(None, None, 60 * 60 * 10, 60 * 60 * 10, None),
# test timeout
(5, None, 0, 4, None),
# timeout is not affected by time passed before the call
(5, None, 7, 4, None),
(5, None, 0, 6, socket.timeout),
# test deadline
(None, 5, 0, 4, None),
(None, 5, 2, 2, None),
# deadline is affected by time passed before the call
(None, 5, 2, 4, SocketDeadlineExceededError),
(None, 5, 6, 0, SocketDeadlineExceededError),
(None, 5, 0, 6, SocketDeadlineExceededError),
# test combination
(5, 5, 0, 4, None),
(5, 5, 2, 2, None),
# deadline triggered by time passed before
(5, 5, 2, 4, SocketDeadlineExceededError),
# the shorter one determines the error
(4, 5, 0, 6, socket.timeout),
(5, 4, 0, 6, SocketDeadlineExceededError),
),
)
@pytest.mark.parametrize("method", ("recv", "recv_into", "sendall"))
@mark_async_test
async def test_async_bolt_socket_read_timeout(
socket_factory, timeout, deadline, pre_tick, tick, exception, method
):
def make_read_side_effect(freeze_time: TFreezeTime):
async def read_side_effect(n):
assert n == 1
freeze_time.tick(tick)
for _ in range(10):
await asyncio.sleep(0)
return b"y"

return read_side_effect

def make_drain_side_effect(freeze_time: TFreezeTime):
async def drain_side_effect():
freeze_time.tick(tick)
for _ in range(10):
await asyncio.sleep(0)

return drain_side_effect

async def call_method(s: AsyncBoltSocket):
if method == "recv":
res = await s.recv(1)
assert res == b"y"
elif method == "recv_into":
b = bytearray(1)
await s.recv_into(b, 1)
assert b == b"y"
elif method == "sendall":
await s.sendall(b"y")
else:
raise NotImplementedError(f"method: {method}")

with freezegun.freeze_time("1970-01-01T00:00:00") as frozen_time:
socket = socket_factory()
if timeout is not None:
socket.settimeout(timeout)
if deadline is not None:
socket.set_deadline(deadline)
if pre_tick:
frozen_time.tick(pre_tick)

if method in {"recv", "recv_into"}:
reader(socket).read.side_effect = make_read_side_effect(
frozen_time
)
elif method == "sendall":
writer(socket).drain.side_effect = make_drain_side_effect(
frozen_time
)
else:
raise NotImplementedError(f"method: {method}")

if exception:
with pytest.raises(exception):
await call_method(socket)
else:
await call_method(socket)