@@ -399,6 +399,8 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-
399
399
):
400
400
raise SerializationError ("Source too large for non-framed message" )
401
401
402
+ self .__unframed_plaintext_cache = io .BytesIO ()
403
+
402
404
def ciphertext_length (self ):
403
405
"""Returns the length of the resulting ciphertext message in bytes.
404
406
@@ -486,14 +488,25 @@ def _write_header(self):
486
488
487
489
def _prep_non_framed (self ):
488
490
"""Prepare the opening data for a non-framed message."""
491
+ try :
492
+ plaintext_length = self .stream_length
493
+ self .__unframed_plaintext_cache = self .source_stream
494
+ except NotSupportedError :
495
+ # We need to know the plaintext length before we can start processing the data.
496
+ # If we cannot seek on the source then we need to read the entire source into memory.
497
+ self .__unframed_plaintext_cache = io .BytesIO ()
498
+ self .__unframed_plaintext_cache .write (self .source_stream .read ())
499
+ plaintext_length = self .__unframed_plaintext_cache .tell ()
500
+ self .__unframed_plaintext_cache .seek (0 )
501
+
489
502
aad_content_string = aws_encryption_sdk .internal .utils .get_aad_content_string (
490
503
content_type = self .content_type , is_final_frame = True
491
504
)
492
505
associated_data = assemble_content_aad (
493
506
message_id = self ._header .message_id ,
494
507
aad_content_string = aad_content_string ,
495
508
seq_num = 1 ,
496
- length = self . stream_length ,
509
+ length = plaintext_length ,
497
510
)
498
511
self .encryptor = Encryptor (
499
512
algorithm = self ._encryption_materials .algorithm ,
@@ -504,7 +517,7 @@ def _prep_non_framed(self):
504
517
self .output_buffer += serialize_non_framed_open (
505
518
algorithm = self ._encryption_materials .algorithm ,
506
519
iv = self .encryptor .iv ,
507
- plaintext_length = self . stream_length ,
520
+ plaintext_length = plaintext_length ,
508
521
signer = self .signer ,
509
522
)
510
523
@@ -516,7 +529,7 @@ def _read_bytes_to_non_framed_body(self, b):
516
529
:rtype: bytes
517
530
"""
518
531
_LOGGER .debug ("Reading %d bytes" , b )
519
- plaintext = self .source_stream .read (b )
532
+ plaintext = self .__unframed_plaintext_cache .read (b )
520
533
plaintext_length = len (plaintext )
521
534
if self .tell () + len (plaintext ) > MAX_NON_FRAMED_SIZE :
522
535
raise SerializationError ("Source too large for non-framed message" )
@@ -529,6 +542,7 @@ def _read_bytes_to_non_framed_body(self, b):
529
542
if len (plaintext ) < b :
530
543
_LOGGER .debug ("Closing encryptor after receiving only %d bytes of %d bytes requested" , plaintext_length , b )
531
544
self .source_stream .close ()
545
+ self .__unframed_plaintext_cache .close ()
532
546
closing = self .encryptor .finalize ()
533
547
534
548
if self .signer is not None :
0 commit comments