Skip to content

Commit c5b570d

Browse files
committed
Fixing RTT in executemany
1 parent 9f80b0c commit c5b570d

File tree

3 files changed

+52
-40
lines changed

3 files changed

+52
-40
lines changed

asyncpg/protocol/coreproto.pxd

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,6 @@ cdef class CoreProtocol:
7575
bint _skip_discard
7676
bint _discard_data
7777

78-
# executemany support data
79-
object _execute_iter
80-
str _execute_portal_name
81-
str _execute_stmt_name
82-
8378
ConnectionStatus con_status
8479
ProtocolState state
8580
TransactionStatus xact_status

asyncpg/protocol/coreproto.pyx

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,6 @@ cdef class CoreProtocol:
2525

2626
self._skip_discard = False
2727

28-
# executemany support data
29-
self._execute_iter = None
30-
self._execute_portal_name = None
31-
self._execute_stmt_name = None
32-
3328
self._reset_result()
3429

3530
cdef _write(self, buf):
@@ -256,22 +251,7 @@ cdef class CoreProtocol:
256251
elif mtype == b'Z':
257252
# ReadyForQuery
258253
self._parse_msg_ready_for_query()
259-
if self.result_type == RESULT_FAILED:
260-
self._push_result()
261-
else:
262-
try:
263-
buf = <WriteBuffer>next(self._execute_iter)
264-
except StopIteration:
265-
self._push_result()
266-
except Exception as e:
267-
self.result_type = RESULT_FAILED
268-
self.result = e
269-
self._push_result()
270-
else:
271-
# Next iteration over the executemany() arg sequence
272-
self._send_bind_message(
273-
self._execute_portal_name, self._execute_stmt_name,
274-
buf, 0)
254+
self._push_result()
275255

276256
elif mtype == b'I':
277257
# EmptyQueryResponse
@@ -799,27 +779,42 @@ cdef class CoreProtocol:
799779
cdef _bind_execute_many(self, str portal_name, str stmt_name,
800780
object bind_data):
801781

802-
cdef WriteBuffer buf
782+
cdef:
783+
WriteBuffer packet
784+
WriteBuffer buf
803785

804786
self._ensure_connected()
805787
self._set_state(PROTOCOL_BIND_EXECUTE_MANY)
806788

789+
packet = WriteBuffer.new()
790+
807791
self.result = None
808792
self._discard_data = True
809-
self._execute_iter = bind_data
810-
self._execute_portal_name = portal_name
811-
self._execute_stmt_name = stmt_name
812793

813-
try:
814-
buf = <WriteBuffer>next(bind_data)
815-
except StopIteration:
816-
self._push_result()
817-
except Exception as e:
818-
self.result_type = RESULT_FAILED
819-
self.result = e
820-
self._push_result()
821-
else:
822-
self._send_bind_message(portal_name, stmt_name, buf, 0)
794+
while True:
795+
try:
796+
buf = <WriteBuffer>next(bind_data)
797+
except StopIteration:
798+
if packet.len() > 0:
799+
packet.write_bytes(SYNC_MESSAGE)
800+
self.transport.write(memoryview(packet))
801+
else:
802+
self._push_result()
803+
break
804+
except Exception as e:
805+
self.result_type = RESULT_FAILED
806+
self.result = e
807+
self._push_result()
808+
break
809+
else:
810+
buf = self._build_bind_message(portal_name, stmt_name, buf)
811+
packet.write_buffer(buf)
812+
813+
buf = WriteBuffer.new_message(b'E')
814+
buf.write_str(portal_name, self.encoding) # name of the portal
815+
buf.write_int32(0) # number of rows to return; 0 - all
816+
buf.end_message()
817+
packet.write_buffer(buf)
823818

824819
cdef _execute(self, str portal_name, int32_t limit):
825820
cdef WriteBuffer buf

tests/test_execute.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,3 +151,25 @@ async def test_execute_many_2(self):
151151
''', good_data)
152152
finally:
153153
await self.con.execute('DROP TABLE exmany')
154+
155+
async def test_execute_many_atomic(self):
156+
from asyncpg.exceptions import UniqueViolationError
157+
158+
await self.con.execute('CREATE TEMP TABLE exmany '
159+
'(a text, b int PRIMARY KEY)')
160+
161+
try:
162+
with self.assertRaises(UniqueViolationError):
163+
await self.con.executemany('''
164+
INSERT INTO exmany VALUES($1, $2)
165+
''', [
166+
('a', 1), ('b', 2), ('c', 2), ('d', 4)
167+
])
168+
169+
result = await self.con.fetch('''
170+
SELECT * FROM exmany
171+
''')
172+
173+
self.assertEqual(result, [])
174+
finally:
175+
await self.con.execute('DROP TABLE exmany')

0 commit comments

Comments
 (0)