Skip to content

Commit a2f35e1

Browse files
authored
Fix DeprecationWarning on closing driver multiple times (#964)
Closing the driver multiple times will still be allowed in future version. However, it will become a noop.
1 parent 79747bb commit a2f35e1

File tree

4 files changed

+54
-6
lines changed

4 files changed

+54
-6
lines changed

src/neo4j/_async/driver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -563,7 +563,9 @@ def _prepare_session_config(cls, preview_check, config_kwargs):
563563
async def close(self) -> None:
564564
""" Shut down, closing any open connections in the pool.
565565
"""
566-
self._check_state()
566+
# TODO: 6.0 - NOOP if already closed
567+
# if self._closed:
568+
# return
567569
try:
568570
await self._pool.close()
569571
except asyncio.CancelledError:

src/neo4j/_sync/driver.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,9 @@ def _prepare_session_config(cls, preview_check, config_kwargs):
562562
def close(self) -> None:
563563
""" Shut down, closing any open connections in the pool.
564564
"""
565-
self._check_state()
565+
# TODO: 6.0 - NOOP if already closed
566+
# if self._closed:
567+
# return
566568
try:
567569
self._pool.close()
568570
except asyncio.CancelledError:

tests/unit/async_/test_driver.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import inspect
2222
import ssl
2323
import typing as t
24+
import warnings
2425

2526
import pytest
2627
import typing_extensions as te
@@ -963,12 +964,11 @@ async def test_supports_session_auth(session_cls_mock) -> None:
963964
("get_server_info", (), {}),
964965
("supports_multi_db", (), {}),
965966
("supports_session_auth", (), {}),
966-
("close", (), {}),
967967
968968
)
969969
)
970970
@mark_async_test
971-
async def test_using_closed_driver_is_deprecated(
971+
async def test_using_closed_driver_where_deprecated(
972972
method_name, args, kwargs, session_cls_mock
973973
) -> None:
974974
driver = AsyncGraphDatabase.driver("bolt://localhost")
@@ -983,3 +983,25 @@ async def test_using_closed_driver_is_deprecated(
983983
await method(*args, **kwargs)
984984
else:
985985
method(*args, **kwargs)
986+
987+
988+
@pytest.mark.parametrize(
989+
("method_name", "args", "kwargs"),
990+
(
991+
("close", (), {}),
992+
)
993+
)
994+
@mark_async_test
995+
async def test_using_closed_driver_where_not_deprecated(
996+
method_name, args, kwargs, session_cls_mock
997+
) -> None:
998+
driver = AsyncGraphDatabase.driver("bolt://localhost")
999+
await driver.close()
1000+
1001+
method = getattr(driver, method_name)
1002+
with warnings.catch_warnings():
1003+
warnings.simplefilter("error")
1004+
if inspect.iscoroutinefunction(method):
1005+
await method(*args, **kwargs)
1006+
else:
1007+
method(*args, **kwargs)

tests/unit/sync/test_driver.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import inspect
2222
import ssl
2323
import typing as t
24+
import warnings
2425

2526
import pytest
2627
import typing_extensions as te
@@ -962,12 +963,11 @@ def test_supports_session_auth(session_cls_mock) -> None:
962963
("get_server_info", (), {}),
963964
("supports_multi_db", (), {}),
964965
("supports_session_auth", (), {}),
965-
("close", (), {}),
966966
967967
)
968968
)
969969
@mark_sync_test
970-
def test_using_closed_driver_is_deprecated(
970+
def test_using_closed_driver_where_deprecated(
971971
method_name, args, kwargs, session_cls_mock
972972
) -> None:
973973
driver = GraphDatabase.driver("bolt://localhost")
@@ -982,3 +982,25 @@ def test_using_closed_driver_is_deprecated(
982982
method(*args, **kwargs)
983983
else:
984984
method(*args, **kwargs)
985+
986+
987+
@pytest.mark.parametrize(
988+
("method_name", "args", "kwargs"),
989+
(
990+
("close", (), {}),
991+
)
992+
)
993+
@mark_sync_test
994+
def test_using_closed_driver_where_not_deprecated(
995+
method_name, args, kwargs, session_cls_mock
996+
) -> None:
997+
driver = GraphDatabase.driver("bolt://localhost")
998+
driver.close()
999+
1000+
method = getattr(driver, method_name)
1001+
with warnings.catch_warnings():
1002+
warnings.simplefilter("error")
1003+
if inspect.iscoroutinefunction(method):
1004+
method(*args, **kwargs)
1005+
else:
1006+
method(*args, **kwargs)

0 commit comments

Comments
 (0)