Skip to content

Commit e57aef2

Browse files
committed
Merge pull request #29 from neo4j/1.0-session-pool
1.0 session pool
2 parents 02321e9 + 309826f commit e57aef2

File tree

9 files changed

+308
-85
lines changed

9 files changed

+308
-85
lines changed

neo4j/__main__.py

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,53 +26,10 @@
2626
from json import loads as json_loads
2727
from sys import stdout, stderr
2828

29+
from .util import Watcher
2930
from .v1.session import GraphDatabase, CypherError
3031

3132

32-
class ColourFormatter(logging.Formatter):
33-
""" Colour formatter for pretty log output.
34-
"""
35-
36-
def format(self, record):
37-
s = super(ColourFormatter, self).format(record)
38-
if record.levelno == logging.CRITICAL:
39-
return "\x1b[31;1m%s\x1b[0m" % s # bright red
40-
elif record.levelno == logging.ERROR:
41-
return "\x1b[33;1m%s\x1b[0m" % s # bright yellow
42-
elif record.levelno == logging.WARNING:
43-
return "\x1b[33m%s\x1b[0m" % s # yellow
44-
elif record.levelno == logging.INFO:
45-
return "\x1b[36m%s\x1b[0m" % s # cyan
46-
elif record.levelno == logging.DEBUG:
47-
return "\x1b[34m%s\x1b[0m" % s # blue
48-
else:
49-
return s
50-
51-
52-
class Watcher(object):
53-
""" Log watcher for debug output.
54-
"""
55-
56-
handlers = {}
57-
58-
def __init__(self, logger_name):
59-
super(Watcher, self).__init__()
60-
self.logger_name = logger_name
61-
self.logger = logging.getLogger(self.logger_name)
62-
self.formatter = ColourFormatter("%(asctime)s %(message)s")
63-
64-
def watch(self, level=logging.INFO, out=stdout):
65-
try:
66-
self.logger.removeHandler(self.handlers[self.logger_name])
67-
except KeyError:
68-
pass
69-
handler = logging.StreamHandler(out)
70-
handler.setFormatter(self.formatter)
71-
self.handlers[self.logger_name] = handler
72-
self.logger.addHandler(handler)
73-
self.logger.setLevel(level)
74-
75-
7633
def main():
7734
parser = ArgumentParser(description="Execute one or more Cypher statements using Bolt.")
7835
parser.add_argument("statement", nargs="+")

neo4j/util.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
#!/usr/bin/env python
2+
# -*- encoding: utf-8 -*-
3+
4+
# Copyright (c) 2002-2016 "Neo Technology,"
5+
# Network Engine for Objects in Lund AB [http://neotechnology.com]
6+
#
7+
# This file is part of Neo4j.
8+
#
9+
# Licensed under the Apache License, Version 2.0 (the "License");
10+
# you may not use this file except in compliance with the License.
11+
# You may obtain a copy of the License at
12+
#
13+
# http://www.apache.org/licenses/LICENSE-2.0
14+
#
15+
# Unless required by applicable law or agreed to in writing, software
16+
# distributed under the License is distributed on an "AS IS" BASIS,
17+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18+
# See the License for the specific language governing permissions and
19+
# limitations under the License.
20+
21+
22+
from __future__ import unicode_literals
23+
24+
import logging
25+
from argparse import ArgumentParser
26+
from json import loads as json_loads
27+
from sys import stdout, stderr
28+
29+
from .v1.session import GraphDatabase, CypherError
30+
31+
32+
class ColourFormatter(logging.Formatter):
33+
""" Colour formatter for pretty log output.
34+
"""
35+
36+
def format(self, record):
37+
s = super(ColourFormatter, self).format(record)
38+
if record.levelno == logging.CRITICAL:
39+
return "\x1b[31;1m%s\x1b[0m" % s # bright red
40+
elif record.levelno == logging.ERROR:
41+
return "\x1b[33;1m%s\x1b[0m" % s # bright yellow
42+
elif record.levelno == logging.WARNING:
43+
return "\x1b[33m%s\x1b[0m" % s # yellow
44+
elif record.levelno == logging.INFO:
45+
return "\x1b[36m%s\x1b[0m" % s # cyan
46+
elif record.levelno == logging.DEBUG:
47+
return "\x1b[34m%s\x1b[0m" % s # blue
48+
else:
49+
return s
50+
51+
52+
class Watcher(object):
53+
""" Log watcher for debug output.
54+
"""
55+
56+
handlers = {}
57+
58+
def __init__(self, logger_name):
59+
super(Watcher, self).__init__()
60+
self.logger_name = logger_name
61+
self.logger = logging.getLogger(self.logger_name)
62+
self.formatter = ColourFormatter("%(asctime)s %(message)s")
63+
64+
def watch(self, level=logging.INFO, out=stdout):
65+
self.stop()
66+
handler = logging.StreamHandler(out)
67+
handler.setFormatter(self.formatter)
68+
self.handlers[self.logger_name] = handler
69+
self.logger.addHandler(handler)
70+
self.logger.setLevel(level)
71+
72+
def stop(self):
73+
try:
74+
self.logger.removeHandler(self.handlers[self.logger_name])
75+
except KeyError:
76+
pass

neo4j/v1/connection.py

Lines changed: 53 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242

4343
# Signature bytes for each message type
4444
INIT = b"\x01" # 0000 0001 // INIT <user_agent>
45-
ACK_FAILURE = b"\x0F" # 0000 1111 // ACK_FAILURE
45+
RESET = b"\x0F" # 0000 1111 // RESET
4646
RUN = b"\x10" # 0001 0000 // RUN <statement> <parameters>
4747
DISCARD_ALL = b"\x2F" # 0010 1111 // DISCARD *
4848
PULL_ALL = b"\x3F" # 0011 1111 // PULL *
@@ -56,7 +56,7 @@
5656

5757
message_names = {
5858
INIT: "INIT",
59-
ACK_FAILURE: "ACK_FAILURE",
59+
RESET: "RESET",
6060
RUN: "RUN",
6161
DISCARD_ALL: "DISCARD_ALL",
6262
PULL_ALL: "PULL_ALL",
@@ -169,14 +169,6 @@ def chunk_reader(self):
169169
data = self._recv(chunk_size)
170170
yield data
171171

172-
def close(self):
173-
""" Shut down and close the connection.
174-
"""
175-
if __debug__: log_info("~~ [CLOSE]")
176-
socket = self.socket
177-
socket.shutdown(SHUT_RDWR)
178-
socket.close()
179-
180172

181173
class Response(object):
182174
""" Subscriber object for a full response (zero or
@@ -200,12 +192,6 @@ def on_ignored(self, metadata=None):
200192
pass
201193

202194

203-
class AckFailureResponse(Response):
204-
205-
def on_failure(self, metadata):
206-
raise ProtocolError("Could not acknowledge failure")
207-
208-
209195
class Connection(object):
210196
""" Server connection through which all protocol messages
211197
are sent and received. This class is designed for protocol
@@ -215,9 +201,11 @@ class Connection(object):
215201
"""
216202

217203
def __init__(self, sock, **config):
204+
self.defunct = False
218205
self.channel = ChunkChannel(sock)
219206
self.packer = Packer(self.channel)
220207
self.responses = deque()
208+
self.closed = False
221209

222210
# Determine the user agent and ensure it is a Unicode value
223211
user_agent = config.get("user_agent", DEFAULT_USER_AGENT)
@@ -235,8 +223,15 @@ def on_failure(metadata):
235223
while not response.complete:
236224
self.fetch_next()
237225

226+
def __del__(self):
227+
self.close()
228+
238229
def append(self, signature, fields=(), response=None):
239230
""" Add a message to the outgoing queue.
231+
232+
:arg signature: the signature of the message
233+
:arg fields: the fields of the message as a tuple
234+
:arg response: a response object to handle callbacks
240235
"""
241236
if __debug__:
242237
log_info("C: %s %s", message_names[signature], " ".join(map(repr, fields)))
@@ -247,42 +242,75 @@ def append(self, signature, fields=(), response=None):
247242
self.channel.flush(end_of_message=True)
248243
self.responses.append(response)
249244

245+
def reset(self):
246+
""" Add a RESET message to the outgoing queue, send
247+
it and consume all remaining messages.
248+
"""
249+
response = Response(self)
250+
251+
def on_failure(metadata):
252+
raise ProtocolError("Reset failed")
253+
254+
response.on_failure = on_failure
255+
256+
self.append(RESET, response=response)
257+
self.send()
258+
fetch_next = self.fetch_next
259+
while not response.complete:
260+
fetch_next()
261+
250262
def send(self):
251263
""" Send all queued messages to the server.
252264
"""
265+
if self.closed:
266+
raise ProtocolError("Cannot write to a closed connection")
267+
if self.defunct:
268+
raise ProtocolError("Cannot write to a defunct connection")
253269
self.channel.send()
254270

255271
def fetch_next(self):
256272
""" Receive exactly one message from the server.
257273
"""
274+
if self.closed:
275+
raise ProtocolError("Cannot read from a closed connection")
276+
if self.defunct:
277+
raise ProtocolError("Cannot read from a defunct connection")
258278
raw = BytesIO()
259279
unpack = Unpacker(raw).unpack
260-
raw.writelines(self.channel.chunk_reader())
261-
280+
try:
281+
raw.writelines(self.channel.chunk_reader())
282+
except ProtocolError:
283+
self.defunct = True
284+
self.close()
285+
raise
262286
# Unpack from the raw byte stream and call the relevant message handler(s)
263287
raw.seek(0)
264288
response = self.responses[0]
265289
for signature, fields in unpack():
266290
if __debug__:
267291
log_info("S: %s %s", message_names[signature], " ".join(map(repr, fields)))
292+
if signature in SUMMARY:
293+
response.complete = True
294+
self.responses.popleft()
295+
if signature == FAILURE:
296+
self.reset()
268297
handler_name = "on_%s" % message_names[signature].lower()
269298
try:
270299
handler = getattr(response, handler_name)
271300
except AttributeError:
272301
pass
273302
else:
274303
handler(*fields)
275-
if signature in SUMMARY:
276-
response.complete = True
277-
self.responses.popleft()
278-
if signature == FAILURE:
279-
self.append(ACK_FAILURE, response=AckFailureResponse(self))
280304
raw.close()
281305

282306
def close(self):
283-
""" Shut down and close the connection.
307+
""" Close the connection.
284308
"""
285-
self.channel.close()
309+
if not self.closed:
310+
if __debug__:
311+
log_info("~~ [CLOSE]")
312+
self.channel.socket.close()
313+
self.closed = True
286314

287315

288316
def connect(host, port=None, **config):

0 commit comments

Comments
 (0)