21
21
22
22
from __future__ import division
23
23
24
+ from base64 import b64encode
24
25
from collections import deque
25
26
from io import BytesIO
26
27
import logging
27
- from os import environ
28
+ from os import makedirs , open as os_open , write as os_write , close as os_close , O_CREAT , O_APPEND , O_WRONLY
29
+ from os .path import dirname , isfile
28
30
from select import select
29
31
from socket import create_connection , SHUT_RDWR
32
+ from ssl import HAS_SNI , SSLError
30
33
from struct import pack as struct_pack , unpack as struct_unpack , unpack_from as struct_unpack_from
31
34
32
- from ..meta import version
33
- from .compat import hex2 , secure_socket
35
+ from .constants import DEFAULT_PORT , DEFAULT_USER_AGENT , KNOWN_HOSTS , MAGIC_PREAMBLE , \
36
+ SECURITY_DEFAULT , SECURITY_TRUST_ON_FIRST_USE
37
+ from .compat import hex2
34
38
from .exceptions import ProtocolError
35
39
from .packstream import Packer , Unpacker
36
40
37
41
38
- DEFAULT_PORT = 7687
39
- DEFAULT_USER_AGENT = "neo4j-python/%s" % version
40
-
41
- MAGIC_PREAMBLE = 0x6060B017
42
-
43
42
# Signature bytes for each message type
44
43
INIT = b"\x01 " # 0000 0001 // INIT <user_agent>
45
44
RESET = b"\x0F " # 0000 1111 // RESET
@@ -211,14 +210,18 @@ def __init__(self, sock, **config):
211
210
user_agent = config .get ("user_agent" , DEFAULT_USER_AGENT )
212
211
if isinstance (user_agent , bytes ):
213
212
user_agent = user_agent .decode ("UTF-8" )
213
+ self .user_agent = user_agent
214
+
215
+ # Pick up the server certificate, if any
216
+ self .der_encoded_server_certificate = config .get ("der_encoded_server_certificate" )
214
217
215
218
def on_failure (metadata ):
216
219
raise ProtocolError ("Initialisation failed" )
217
220
218
221
response = Response (self )
219
222
response .on_failure = on_failure
220
223
221
- self .append (INIT , (user_agent ,), response = response )
224
+ self .append (INIT , (self . user_agent ,), response = response )
222
225
self .send ()
223
226
while not response .complete :
224
227
self .fetch ()
@@ -313,7 +316,53 @@ def close(self):
313
316
self .closed = True
314
317
315
318
316
- def connect (host , port = None , ** config ):
319
+ class CertificateStore (object ):
320
+
321
+ def match_or_trust (self , host , der_encoded_certificate ):
322
+ """ Check whether the supplied certificate matches that stored for the
323
+ specified host. If it does, return ``True``, if it doesn't, return
324
+ ``False``. If no entry for that host is found, add it to the store
325
+ and return ``True``.
326
+
327
+ :arg host:
328
+ :arg der_encoded_certificate:
329
+ :return:
330
+ """
331
+ raise NotImplementedError ()
332
+
333
+
334
+ class PersonalCertificateStore (CertificateStore ):
335
+
336
+ def __init__ (self , path = None ):
337
+ self .path = path or KNOWN_HOSTS
338
+
339
+ def match_or_trust (self , host , der_encoded_certificate ):
340
+ base64_encoded_certificate = b64encode (der_encoded_certificate )
341
+ if isfile (self .path ):
342
+ with open (self .path ) as f_in :
343
+ for line in f_in :
344
+ known_host , _ , known_cert = line .strip ().partition (":" )
345
+ known_cert = known_cert .encode ("utf-8" )
346
+ if host == known_host :
347
+ return base64_encoded_certificate == known_cert
348
+ # First use (no hosts match)
349
+ try :
350
+ makedirs (dirname (self .path ))
351
+ except OSError :
352
+ pass
353
+ f_out = os_open (self .path , O_CREAT | O_APPEND | O_WRONLY , 0o600 ) # TODO: Windows
354
+ if isinstance (host , bytes ):
355
+ os_write (f_out , host )
356
+ else :
357
+ os_write (f_out , host .encode ("utf-8" ))
358
+ os_write (f_out , b":" )
359
+ os_write (f_out , base64_encoded_certificate )
360
+ os_write (f_out , b"\n " )
361
+ os_close (f_out )
362
+ return True
363
+
364
+
365
+ def connect (host , port = None , ssl_context = None , ** config ):
317
366
""" Connect and perform a handshake and return a valid Connection object, assuming
318
367
a protocol version can be agreed.
319
368
"""
@@ -323,14 +372,28 @@ def connect(host, port=None, **config):
323
372
if __debug__ : log_info ("~~ [CONNECT] %s %d" , host , port )
324
373
s = create_connection ((host , port ))
325
374
326
- # Secure the connection if so requested
327
- try :
328
- secure = environ ["NEO4J_SECURE" ]
329
- except KeyError :
330
- secure = config .get ("secure" , False )
331
- if secure :
375
+ # Secure the connection if an SSL context has been provided
376
+ if ssl_context :
332
377
if __debug__ : log_info ("~~ [SECURE] %s" , host )
333
- s = secure_socket (s , host )
378
+ try :
379
+ s = ssl_context .wrap_socket (s , server_hostname = host if HAS_SNI else None )
380
+ except SSLError as cause :
381
+ error = ProtocolError ("Cannot establish secure connection; %s" % cause .args [1 ])
382
+ error .__cause__ = cause
383
+ raise error
384
+ else :
385
+ # Check that the server provides a certificate
386
+ der_encoded_server_certificate = s .getpeercert (binary_form = True )
387
+ if der_encoded_server_certificate is None :
388
+ raise ProtocolError ("When using a secure socket, the server should always provide a certificate" )
389
+ security = config .get ("security" , SECURITY_DEFAULT )
390
+ if security == SECURITY_TRUST_ON_FIRST_USE :
391
+ store = PersonalCertificateStore ()
392
+ if not store .match_or_trust (host , der_encoded_server_certificate ):
393
+ raise ProtocolError ("Server certificate does not match known certificate for %r; check "
394
+ "details in file %r" % (host , KNOWN_HOSTS ))
395
+ else :
396
+ der_encoded_server_certificate = None
334
397
335
398
# Send details of the protocol versions supported
336
399
supported_versions = [1 , 0 , 0 , 0 ]
@@ -364,4 +427,4 @@ def connect(host, port=None, **config):
364
427
s .shutdown (SHUT_RDWR )
365
428
s .close ()
366
429
else :
367
- return Connection (s , ** config )
430
+ return Connection (s , der_encoded_server_certificate = der_encoded_server_certificate , ** config )
0 commit comments