39
39
from struct import pack as struct_pack , unpack as struct_unpack , unpack_from as struct_unpack_from
40
40
41
41
from .constants import DEFAULT_USER_AGENT , KNOWN_HOSTS , MAGIC_PREAMBLE , TRUST_DEFAULT , TRUST_ON_FIRST_USE
42
- from .compat import hex2
43
42
from .exceptions import ProtocolError , Unauthorized
44
43
from .packstream import Packer , Unpacker
45
44
from .ssl_compat import SSL_AVAILABLE , HAS_SNI , SSLError
81
80
log_error = log .error
82
81
83
82
83
+ class BufferingSocket (object ):
84
+
85
+ def __init__ (self , socket ):
86
+ self .socket = socket
87
+ self .buffer = bytearray ()
88
+
89
+ def fill (self ):
90
+ ready_to_read , _ , _ = select ((self .socket ,), (), (), 0 )
91
+ received = self .socket .recv (65539 )
92
+ if received :
93
+ if __debug__ :
94
+ log_debug ("S: b%r" , received )
95
+ self .buffer [len (self .buffer ):] = received
96
+ else :
97
+ if ready_to_read is not None :
98
+ raise ProtocolError ("Server closed connection" )
99
+
100
+ def read_message (self ):
101
+ message_data = bytearray ()
102
+ p = 0
103
+ size = - 1
104
+ while size != 0 :
105
+ while len (self .buffer ) - p < 2 :
106
+ self .fill ()
107
+ size = 0x100 * self .buffer [p ] + self .buffer [p + 1 ]
108
+ p += 2
109
+ if size > 0 :
110
+ while len (self .buffer ) - p < size :
111
+ self .fill ()
112
+ end = p + size
113
+ message_data [len (message_data ):] = self .buffer [p :end ]
114
+ p = end
115
+ self .buffer = self .buffer [p :]
116
+ return message_data
117
+
118
+
84
119
class ChunkChannel (object ):
85
120
""" Reader/writer for chunked data.
86
121
@@ -137,45 +172,11 @@ def send(self):
137
172
"""
138
173
data = self .raw .getvalue ()
139
174
if __debug__ :
140
- log_debug ("C: %s " , ":" . join ( map ( hex2 , data )) )
175
+ log_debug ("C: b%r " , data )
141
176
self .socket .sendall (data )
142
177
143
178
self .raw .seek (self .raw .truncate (0 ))
144
179
145
- def _recv (self , size ):
146
- # If data is needed, keep reading until all bytes have been received
147
- remaining = size - len (self ._recv_buffer )
148
- ready_to_read = None
149
- while remaining > 0 :
150
- # Read up to the required amount remaining
151
- b = self .socket .recv (8192 )
152
- if b :
153
- if __debug__ : log_debug ("S: %s" , ":" .join (map (hex2 , b )))
154
- else :
155
- if ready_to_read is not None :
156
- raise ProtocolError ("Server closed connection" )
157
- remaining -= len (b )
158
- self ._recv_buffer += b
159
-
160
- # If more is required, wait for available network data
161
- if remaining > 0 :
162
- ready_to_read , _ , _ = select ((self .socket ,), (), (), 0 )
163
- while not ready_to_read :
164
- ready_to_read , _ , _ = select ((self .socket ,), (), (), 0 )
165
-
166
- # Split off the amount of data required and keep the rest in the buffer
167
- data , self ._recv_buffer = self ._recv_buffer [:size ], self ._recv_buffer [size :]
168
- return data
169
-
170
- def chunk_reader (self ):
171
- chunk_size = - 1
172
- while chunk_size != 0 :
173
- chunk_header = self ._recv (2 )
174
- chunk_size , = struct_unpack_from (">H" , chunk_header )
175
- if chunk_size > 0 :
176
- data = self ._recv (chunk_size )
177
- yield data
178
-
179
180
180
181
class Response (object ):
181
182
""" Subscriber object for a full response (zero or
@@ -208,9 +209,12 @@ class Connection(object):
208
209
"""
209
210
210
211
def __init__ (self , sock , ** config ):
212
+ self .socket = sock
213
+ self .buffering_socket = BufferingSocket (sock )
211
214
self .defunct = False
212
215
self .channel = ChunkChannel (sock )
213
216
self .packer = Packer (self .channel )
217
+ self .unpacker = Unpacker ()
214
218
self .responses = deque ()
215
219
self .closed = False
216
220
@@ -318,33 +322,37 @@ def fetch(self):
318
322
raise ProtocolError ("Cannot read from a closed connection" )
319
323
if self .defunct :
320
324
raise ProtocolError ("Cannot read from a defunct connection" )
321
- raw = BytesIO ()
322
- unpack = Unpacker (raw ).unpack
323
325
try :
324
- raw . writelines ( self .channel . chunk_reader () )
326
+ message_data = self .buffering_socket . read_message ( )
325
327
except ProtocolError :
326
328
self .defunct = True
327
329
self .close ()
328
330
raise
329
331
# Unpack from the raw byte stream and call the relevant message handler(s)
330
- raw .seek (0 )
331
- response = self .responses [0 ]
332
- for signature , fields in unpack ():
333
- if __debug__ :
334
- log_info ("S: %s %s" , message_names [signature ], " " .join (map (repr , fields )))
335
- if signature in SUMMARY :
336
- response .complete = True
337
- self .responses .popleft ()
338
- if signature == FAILURE :
339
- self .acknowledge_failure ()
340
- handler_name = "on_%s" % message_names [signature ].lower ()
341
- try :
342
- handler = getattr (response , handler_name )
343
- except AttributeError :
344
- pass
345
- else :
346
- handler (* fields )
347
- raw .close ()
332
+ self .unpacker .load (message_data )
333
+ size , signature = self .unpacker .unpack_structure_header ()
334
+ fields = [self .unpacker .unpack () for _ in range (size )]
335
+
336
+ if __debug__ :
337
+ log_info ("S: %s %r" , message_names [signature ], fields )
338
+
339
+ if signature == SUCCESS :
340
+ response = self .responses .popleft ()
341
+ response .complete = True
342
+ response .on_success (* fields )
343
+ elif signature == RECORD :
344
+ response = self .responses [0 ]
345
+ response .on_record (* fields )
346
+ elif signature == IGNORED :
347
+ response = self .responses .popleft ()
348
+ response .complete = True
349
+ response .on_ignored (* fields )
350
+ elif signature == FAILURE :
351
+ response = self .responses .popleft ()
352
+ response .complete = True
353
+ response .on_failure (* fields )
354
+ else :
355
+ raise ProtocolError ("Unexpected response message with signature %02X" % signature )
348
356
349
357
def fetch_all (self ):
350
358
while self .responses :
@@ -454,7 +462,7 @@ def connect(host_port, ssl_context=None, **config):
454
462
handshake = [MAGIC_PREAMBLE ] + supported_versions
455
463
if __debug__ : log_info ("C: [HANDSHAKE] 0x%X %r" , MAGIC_PREAMBLE , supported_versions )
456
464
data = b"" .join (struct_pack (">I" , num ) for num in handshake )
457
- if __debug__ : log_debug ("C: %s " , ":" . join ( map ( hex2 , data )) )
465
+ if __debug__ : log_debug ("C: b%r " , data )
458
466
s .sendall (data )
459
467
460
468
# Handle the handshake response
@@ -469,7 +477,7 @@ def connect(host_port, ssl_context=None, **config):
469
477
log_error ("S: [CLOSE]" )
470
478
raise ProtocolError ("Server closed connection without responding to handshake" )
471
479
if data_size == 4 :
472
- if __debug__ : log_debug ("S: %s " , ":" . join ( map ( hex2 , data )) )
480
+ if __debug__ : log_debug ("S: b%r " , data )
473
481
else :
474
482
# Some other garbled data has been received
475
483
log_error ("S: @*#!" )
0 commit comments