Skip to content

Commit 433b17e

Browse files
Merge pull request #36 from sqlitecloud/fix-cursor-rowscount
fix(cursors): reset iterator counter between queries execution
2 parents eb4e061 + 1d638de commit 433b17e

File tree

6 files changed

+64
-7
lines changed

6 files changed

+64
-7
lines changed

.github/workflows/test.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ on:
99

1010
jobs:
1111
tests:
12-
runs-on: ubuntu-20.04
12+
runs-on: ubuntu-latest
1313
strategy:
1414
fail-fast: false
1515
matrix:
1616
# last supported for sqlitecloud, last security maintained, last release
17-
python-version: ["3.6", "3.8", "3.12"]
17+
python-version: ["3.9", "3.10", "3.12"]
1818

1919
steps:
2020
- uses: actions/checkout@v4

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
profile = "black"
77

88
[tool.pytest.ini_options]
9+
log_level = "INFO"
910
markers = [
1011
"slow: marks tests as slow (deselect with '-m \"not slow\"')",
1112
"serial",

src/sqlitecloud/dbapi2.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,7 @@ def execute(
584584
sql, parameters, self.connection.sqlitecloud_connection
585585
)
586586

587-
self._resultset = None
588-
self._result_operation = None
587+
self._reset()
589588

590589
if isinstance(result, SQLiteCloudResult):
591590
self._resultset = result
@@ -853,6 +852,12 @@ def _get_value(self, row: int, col: int) -> Optional[Any]:
853852

854853
return self._convert_value(value, colname, decltype)
855854

855+
def _reset(self) -> None:
856+
self._resultset = None
857+
self._result_operation = None
858+
859+
self._iter_row = 0
860+
856861
def __iter__(self) -> "Cursor":
857862
return self
858863

src/sqlitecloud/driver.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -207,10 +207,10 @@ def _internal_setup_pubsub(
207207
def _internal_pubsub_thread(self, connection: SQLiteCloudConnect) -> None:
208208
blen = 2048
209209
buffer: bytes = b""
210+
tread = 0
210211

211212
try:
212213
while True:
213-
tread = 0
214214

215215
try:
216216
if not connection.pubsub_socket:
@@ -240,7 +240,6 @@ def _internal_pubsub_thread(self, connection: SQLiteCloudConnect) -> None:
240240

241241
nread = len(data)
242242
tread += nread
243-
blen -= nread
244243
buffer += data
245244

246245
sqlitecloud_number = self._internal_parse_number(buffer)
@@ -262,11 +261,16 @@ def _internal_pubsub_thread(self, connection: SQLiteCloudConnect) -> None:
262261
connection.pubsub_callback(
263262
connection, SQLiteCloudResultSet(result), connection.pubsub_data
264263
)
264+
265+
# reset after having read the message
266+
tread = 0
267+
buffer: bytes = b""
265268
except Exception as e:
266269
logging.error(f"An error occurred while parsing data: {e}.")
267270

268271
finally:
269-
connection.pubsub_callback(connection, None, connection.pubsub_data)
272+
if connection and connection.pubsub_callback:
273+
connection.pubsub_callback(connection, None, connection.pubsub_data)
270274

271275
def upload_database(
272276
self,

src/tests/integration/test_dbapi2.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,3 +441,18 @@ def test_connection_is_connected(self, sqlitecloud_dbapi2_connection):
441441
connection.close()
442442

443443
assert not connection.is_connected()
444+
445+
def test_fetchall_returns_right_nrows_number(self, sqlitecloud_dbapi2_connection):
446+
connection = sqlitecloud_dbapi2_connection
447+
448+
cursor = connection.cursor()
449+
450+
cursor.execute("SELECT * FROM Genres LIMIT 3")
451+
452+
assert len(cursor.fetchall()) == 3
453+
assert cursor.rowcount == 3
454+
455+
cursor.execute("SELECT * FROM Albums LIMIT 4")
456+
457+
assert len(cursor.fetchall()) == 4
458+
assert cursor.rowcount == 4

src/tests/integration/test_pubsub.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,38 @@ def assert_callback(conn, result, data):
4747

4848
assert callback_called
4949

50+
def test_notify_multiple_messages(self, sqlitecloud_connection):
51+
connection, _ = sqlitecloud_connection
52+
53+
called_times = 3
54+
flag = threading.Event()
55+
56+
def assert_callback(conn, result, data):
57+
nonlocal called_times
58+
nonlocal flag
59+
60+
if isinstance(result, SQLiteCloudResultSet):
61+
assert data == ["somedataX"]
62+
called_times -= 1
63+
if called_times == 0:
64+
flag.set()
65+
66+
pubsub = SQLiteCloudPubSub()
67+
subject_type = SQLITECLOUD_PUBSUB_SUBJECT.CHANNEL
68+
channel = "channel" + str(uuid.uuid4())
69+
70+
pubsub.create_channel(connection, channel)
71+
pubsub.listen(connection, subject_type, channel, assert_callback, ["somedataX"])
72+
73+
pubsub.notify_channel(connection, channel, "somedataX")
74+
pubsub.notify_channel(connection, channel, "somedataX")
75+
pubsub.notify_channel(connection, channel, "somedataX")
76+
77+
# wait for callback to be called
78+
flag.wait(30)
79+
80+
assert called_times == 0
81+
5082
def test_unlisten_channel(self, sqlitecloud_connection):
5183
connection, _ = sqlitecloud_connection
5284

0 commit comments

Comments
 (0)