Skip to content

Optimisations from profiling #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 30, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 67 additions & 59 deletions neo4j/v1/bolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from struct import pack as struct_pack, unpack as struct_unpack, unpack_from as struct_unpack_from

from .constants import DEFAULT_USER_AGENT, KNOWN_HOSTS, MAGIC_PREAMBLE, TRUST_DEFAULT, TRUST_ON_FIRST_USE
from .compat import hex2
from .exceptions import ProtocolError, Unauthorized
from .packstream import Packer, Unpacker
from .ssl_compat import SSL_AVAILABLE, HAS_SNI, SSLError
Expand Down Expand Up @@ -81,6 +80,42 @@
log_error = log.error


class BufferingSocket(object):

def __init__(self, socket):
self.socket = socket
self.buffer = bytearray()

def fill(self):
ready_to_read, _, _ = select((self.socket,), (), (), 0)
received = self.socket.recv(65539)
if received:
if __debug__:
log_debug("S: b%r", received)
self.buffer[len(self.buffer):] = received
else:
if ready_to_read is not None:
raise ProtocolError("Server closed connection")

def read_message(self):
message_data = bytearray()
p = 0
size = -1
while size != 0:
while len(self.buffer) - p < 2:
self.fill()
size = 0x100 * self.buffer[p] + self.buffer[p + 1]
p += 2
if size > 0:
while len(self.buffer) - p < size:
self.fill()
end = p + size
message_data[len(message_data):] = self.buffer[p:end]
p = end
self.buffer = self.buffer[p:]
return message_data


class ChunkChannel(object):
""" Reader/writer for chunked data.

Expand Down Expand Up @@ -137,45 +172,11 @@ def send(self):
"""
data = self.raw.getvalue()
if __debug__:
log_debug("C: %s", ":".join(map(hex2, data)))
log_debug("C: b%r", data)
self.socket.sendall(data)

self.raw.seek(self.raw.truncate(0))

def _recv(self, size):
# If data is needed, keep reading until all bytes have been received
remaining = size - len(self._recv_buffer)
ready_to_read = None
while remaining > 0:
# Read up to the required amount remaining
b = self.socket.recv(8192)
if b:
if __debug__: log_debug("S: %s", ":".join(map(hex2, b)))
else:
if ready_to_read is not None:
raise ProtocolError("Server closed connection")
remaining -= len(b)
self._recv_buffer += b

# If more is required, wait for available network data
if remaining > 0:
ready_to_read, _, _ = select((self.socket,), (), (), 0)
while not ready_to_read:
ready_to_read, _, _ = select((self.socket,), (), (), 0)

# Split off the amount of data required and keep the rest in the buffer
data, self._recv_buffer = self._recv_buffer[:size], self._recv_buffer[size:]
return data

def chunk_reader(self):
chunk_size = -1
while chunk_size != 0:
chunk_header = self._recv(2)
chunk_size, = struct_unpack_from(">H", chunk_header)
if chunk_size > 0:
data = self._recv(chunk_size)
yield data


class Response(object):
""" Subscriber object for a full response (zero or
Expand Down Expand Up @@ -208,9 +209,12 @@ class Connection(object):
"""

def __init__(self, sock, **config):
self.socket = sock
self.buffering_socket = BufferingSocket(sock)
self.defunct = False
self.channel = ChunkChannel(sock)
self.packer = Packer(self.channel)
self.unpacker = Unpacker()
self.responses = deque()
self.closed = False

Expand Down Expand Up @@ -318,33 +322,37 @@ def fetch(self):
raise ProtocolError("Cannot read from a closed connection")
if self.defunct:
raise ProtocolError("Cannot read from a defunct connection")
raw = BytesIO()
unpack = Unpacker(raw).unpack
try:
raw.writelines(self.channel.chunk_reader())
message_data = self.buffering_socket.read_message()
except ProtocolError:
self.defunct = True
self.close()
raise
# Unpack from the raw byte stream and call the relevant message handler(s)
raw.seek(0)
response = self.responses[0]
for signature, fields in unpack():
if __debug__:
log_info("S: %s %s", message_names[signature], " ".join(map(repr, fields)))
if signature in SUMMARY:
response.complete = True
self.responses.popleft()
if signature == FAILURE:
self.acknowledge_failure()
handler_name = "on_%s" % message_names[signature].lower()
try:
handler = getattr(response, handler_name)
except AttributeError:
pass
else:
handler(*fields)
raw.close()
self.unpacker.load(message_data)
size, signature = self.unpacker.unpack_structure_header()
fields = [self.unpacker.unpack() for _ in range(size)]

if __debug__:
log_info("S: %s %r", message_names[signature], fields)

if signature == SUCCESS:
response = self.responses.popleft()
response.complete = True
response.on_success(*fields)
elif signature == RECORD:
response = self.responses[0]
response.on_record(*fields)
elif signature == IGNORED:
response = self.responses.popleft()
response.complete = True
response.on_ignored(*fields)
elif signature == FAILURE:
response = self.responses.popleft()
response.complete = True
response.on_failure(*fields)
else:
raise ProtocolError("Unexpected response message with signature %02X" % signature)

def fetch_all(self):
while self.responses:
Expand Down Expand Up @@ -454,7 +462,7 @@ def connect(host_port, ssl_context=None, **config):
handshake = [MAGIC_PREAMBLE] + supported_versions
if __debug__: log_info("C: [HANDSHAKE] 0x%X %r", MAGIC_PREAMBLE, supported_versions)
data = b"".join(struct_pack(">I", num) for num in handshake)
if __debug__: log_debug("C: %s", ":".join(map(hex2, data)))
if __debug__: log_debug("C: b%r", data)
s.sendall(data)

# Handle the handshake response
Expand All @@ -469,7 +477,7 @@ def connect(host_port, ssl_context=None, **config):
log_error("S: [CLOSE]")
raise ProtocolError("Server closed connection without responding to handshake")
if data_size == 4:
if __debug__: log_debug("S: %s", ":".join(map(hex2, data)))
if __debug__: log_debug("S: b%r", data)
else:
# Some other garbled data has been received
log_error("S: @*#!")
Expand Down
13 changes: 0 additions & 13 deletions neo4j/v1/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,6 @@ def ustr(x):
else:
return str(x)

def hex2(x):
if x < 0x10:
return "0" + hex(x)[2:].upper()
else:
return hex(x)[2:].upper()

else:
# Python 2

Expand All @@ -65,13 +59,6 @@ def ustr(x):
else:
return unicode(x)

def hex2(x):
x = ord(x)
if x < 0x10:
return "0" + hex(x)[2:].upper()
else:
return hex(x)[2:].upper()


try:
from multiprocessing import Array, Process
Expand Down
Loading