Skip to content

Commit 0180ee7

Browse files
Merge pull request #7 from davidbrochart/wraps
Fix mypy
2 parents a3ba4c6 + 7f5ab25 commit 0180ee7

File tree

2 files changed

+27
-20
lines changed

2 files changed

+27
-20
lines changed

sqlite_anyio/sqlite.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import sqlite3
66
from collections.abc import Callable, Sequence
7-
from functools import partial, wraps
7+
from functools import partial, update_wrapper
88
from logging import Logger, getLogger
99
from types import TracebackType
1010
from typing import Any
@@ -46,18 +46,21 @@ async def __aexit__(
4646
exception_handled = self._exception_handler(exc_type, exc_val, exc_tb, self._log)
4747
return exception_handled
4848

49-
@wraps(sqlite3.Connection.close)
5049
async def close(self):
5150
return await to_thread.run_sync(self._real_connection.close, limiter=self._limiter)
5251

53-
@wraps(sqlite3.Connection.commit)
52+
update_wrapper(close, sqlite3.Connection.close)
53+
5454
async def commit(self):
5555
return await to_thread.run_sync(self._real_connection.commit, limiter=self._limiter)
5656

57-
@wraps(sqlite3.Connection.rollback)
57+
update_wrapper(commit, sqlite3.Connection.commit)
58+
5859
async def rollback(self):
5960
return await to_thread.run_sync(self._real_connection.rollback, limiter=self._limiter)
6061

62+
update_wrapper(rollback, sqlite3.Connection.rollback)
63+
6164
async def cursor(self, factory: Callable[[sqlite3.Connection], sqlite3.Cursor] = sqlite3.Cursor) -> Cursor:
6265
real_cursor = await to_thread.run_sync(self._real_connection.cursor, factory, limiter=self._limiter)
6366
return Cursor(real_cursor, self._limiter)
@@ -80,37 +83,44 @@ def rowcount(self) -> int:
8083
def arraysize(self) -> int:
8184
return self._real_cursor.arraysize
8285

83-
@wraps(sqlite3.Cursor.close)
8486
async def close(self) -> None:
8587
await to_thread.run_sync(self._real_cursor.close, limiter=self._limiter)
8688

87-
@wraps(sqlite3.Cursor.execute)
89+
update_wrapper(close, sqlite3.Cursor.close)
90+
8891
async def execute(self, sql: str, parameters: Sequence[Any] = (), /) -> Cursor:
8992
real_cursor = await to_thread.run_sync(self._real_cursor.execute, sql, parameters, limiter=self._limiter)
9093
return Cursor(real_cursor, self._limiter)
9194

92-
@wraps(sqlite3.Cursor.executemany)
95+
update_wrapper(execute, sqlite3.Cursor.execute)
96+
9397
async def executemany(self, sql: str, parameters: Sequence[Any], /) -> Cursor:
9498
real_cursor = await to_thread.run_sync(self._real_cursor.executemany, sql, parameters, limiter=self._limiter)
9599
return Cursor(real_cursor, self._limiter)
96100

97-
@wraps(sqlite3.Cursor.executescript)
101+
update_wrapper(executemany, sqlite3.Cursor.executemany)
102+
98103
async def executescript(self, sql_script: str, /) -> Cursor:
99104
real_cursor = await to_thread.run_sync(self._real_cursor.executescript, sql_script, limiter=self._limiter)
100105
return Cursor(real_cursor, self._limiter)
101106

102-
@wraps(sqlite3.Cursor.fetchone)
107+
update_wrapper(executescript, sqlite3.Cursor.executescript)
108+
103109
async def fetchone(self) -> tuple[Any, ...] | None:
104110
return await to_thread.run_sync(self._real_cursor.fetchone, limiter=self._limiter)
105111

106-
@wraps(sqlite3.Cursor.fetchmany)
112+
update_wrapper(fetchone, sqlite3.Cursor.fetchone)
113+
107114
async def fetchmany(self, size: int) -> list[tuple[Any, ...]]:
108115
return await to_thread.run_sync(self._real_cursor.fetchmany, size, limiter=self._limiter)
109116

110-
@wraps(sqlite3.Cursor.fetchall)
117+
update_wrapper(fetchmany, sqlite3.Cursor.fetchmany)
118+
111119
async def fetchall(self) -> list[tuple[Any, ...]]:
112120
return await to_thread.run_sync(self._real_cursor.fetchall, limiter=self._limiter)
113121

122+
update_wrapper(fetchall, sqlite3.Cursor.fetchall)
123+
114124

115125
async def connect(
116126
database: str,

tests/test_context_manager.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,8 @@
55
import sqlite_anyio
66

77

8-
pytestmark = pytest.mark.anyio
9-
10-
11-
async def test_context_manager_commit():
12-
mem_uri = "file:mem0?mode=memory&cache=shared"
8+
async def test_context_manager_commit(anyio_backend):
9+
mem_uri = f"file:{anyio_backend}_mem0?mode=memory&cache=shared"
1310
acon0 = await sqlite_anyio.connect(mem_uri, uri=True)
1411
acur0 = await acon0.cursor()
1512
async with acon0:
@@ -22,8 +19,8 @@ async def test_context_manager_commit():
2219
assert await acur1.fetchone() == ("Python",)
2320

2421

25-
async def test_context_manager_rollback():
26-
mem_uri = "file:mem1?mode=memory&cache=shared"
22+
async def test_context_manager_rollback(anyio_backend):
23+
mem_uri = f"file:{anyio_backend}_mem1?mode=memory&cache=shared"
2724
acon0 = await sqlite_anyio.connect(mem_uri, uri=True)
2825
acur0 = await acon0.cursor()
2926
with pytest.raises(RuntimeError):
@@ -38,9 +35,9 @@ async def test_context_manager_rollback():
3835
assert await acur1.fetchone() is None
3936

4037

41-
async def test_exception_logger(caplog):
38+
async def test_exception_logger(anyio_backend, caplog):
4239
caplog.set_level(logging.INFO)
43-
mem_uri = "file:mem2?mode=memory&cache=shared"
40+
mem_uri = f"file:{anyio_backend}_mem2?mode=memory&cache=shared"
4441
log = logging.getLogger("logger")
4542
acon0 = await sqlite_anyio.connect(mem_uri, uri=True, exception_handler=sqlite_anyio.exception_logger, log=log)
4643
acur0 = await acon0.cursor()

0 commit comments

Comments
 (0)