Skip to content

Commit e3768c5

Browse files
committed
remove requirement for source_stream.tell() on encrypt
1 parent 5b6c04b commit e3768c5

File tree

3 files changed

+24
-5
lines changed

3 files changed

+24
-5
lines changed

src/aws_encryption_sdk/streaming_client.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,8 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-
399399
):
400400
raise SerializationError("Source too large for non-framed message")
401401

402+
self.__unframed_plaintext_cache = io.BytesIO()
403+
402404
def ciphertext_length(self):
403405
"""Returns the length of the resulting ciphertext message in bytes.
404406
@@ -486,14 +488,25 @@ def _write_header(self):
486488

487489
def _prep_non_framed(self):
488490
"""Prepare the opening data for a non-framed message."""
491+
try:
492+
plaintext_length = self.stream_length
493+
self.__unframed_plaintext_cache = self.source_stream
494+
except NotSupportedError:
495+
# We need to know the plaintext length before we can start processing the data.
496+
# If we cannot seek on the source then we need to read the entire source into memory.
497+
self.__unframed_plaintext_cache = io.BytesIO()
498+
self.__unframed_plaintext_cache.write(self.source_stream.read())
499+
plaintext_length = self.__unframed_plaintext_cache.tell()
500+
self.__unframed_plaintext_cache.seek(0)
501+
489502
aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string(
490503
content_type=self.content_type, is_final_frame=True
491504
)
492505
associated_data = assemble_content_aad(
493506
message_id=self._header.message_id,
494507
aad_content_string=aad_content_string,
495508
seq_num=1,
496-
length=self.stream_length,
509+
length=plaintext_length,
497510
)
498511
self.encryptor = Encryptor(
499512
algorithm=self._encryption_materials.algorithm,
@@ -504,7 +517,7 @@ def _prep_non_framed(self):
504517
self.output_buffer += serialize_non_framed_open(
505518
algorithm=self._encryption_materials.algorithm,
506519
iv=self.encryptor.iv,
507-
plaintext_length=self.stream_length,
520+
plaintext_length=plaintext_length,
508521
signer=self.signer,
509522
)
510523

@@ -516,7 +529,7 @@ def _read_bytes_to_non_framed_body(self, b):
516529
:rtype: bytes
517530
"""
518531
_LOGGER.debug("Reading %d bytes", b)
519-
plaintext = self.source_stream.read(b)
532+
plaintext = self.__unframed_plaintext_cache.read(b)
520533
plaintext_length = len(plaintext)
521534
if self.tell() + len(plaintext) > MAX_NON_FRAMED_SIZE:
522535
raise SerializationError("Source too large for non-framed message")
@@ -529,6 +542,7 @@ def _read_bytes_to_non_framed_body(self, b):
529542
if len(plaintext) < b:
530543
_LOGGER.debug("Closing encryptor after receiving only %d bytes of %d bytes requested", plaintext_length, b)
531544
self.source_stream.close()
545+
self.__unframed_plaintext_cache.close()
532546
closing = self.encryptor.finalize()
533547

534548
if self.signer is not None:

test/functional/test_f_aws_encryption_sdk_client.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -771,7 +771,7 @@ def close(self):
771771
@pytest.mark.parametrize(
772772
"wrapping_class",
773773
(
774-
pytest.param(NoTell, marks=pytest.mark.xfail),
774+
NoTell,
775775
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
776776
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
777777
),
@@ -792,7 +792,7 @@ def test_cycle_minimal_source_stream_api(frame_length, wrapping_class):
792792
@pytest.mark.parametrize(
793793
"wrapping_class",
794794
(
795-
pytest.param(NoTell, marks=pytest.mark.xfail),
795+
NoTell,
796796
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
797797
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
798798
),

test/unit/test_streaming_client_stream_encryptor.py

+5
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,10 @@ def test_read_bytes_to_non_framed_body(self):
382382
test_encryptor.encryptor = MagicMock()
383383
test_encryptor._encryption_materials = self.mock_encryption_materials
384384
test_encryptor.encryptor.update.return_value = sentinel.ciphertext
385+
test_encryptor._StreamEncryptor__unframed_plaintext_cache = pt_stream
386+
385387
test = test_encryptor._read_bytes_to_non_framed_body(5)
388+
386389
test_encryptor.encryptor.update.assert_called_once_with(self.plaintext[:5])
387390
test_encryptor.signer.update.assert_called_once_with(sentinel.ciphertext)
388391
assert not test_encryptor.source_stream.closed
@@ -392,6 +395,8 @@ def test_read_bytes_to_non_framed_body_too_large(self):
392395
pt_stream = io.BytesIO(self.plaintext)
393396
test_encryptor = StreamEncryptor(source=pt_stream, key_provider=self.mock_key_provider)
394397
test_encryptor.bytes_read = aws_encryption_sdk.internal.defaults.MAX_NON_FRAMED_SIZE
398+
test_encryptor._StreamEncryptor__unframed_plaintext_cache = pt_stream
399+
395400
with six.assertRaisesRegex(self, SerializationError, "Source too large for non-framed message"):
396401
test_encryptor._read_bytes_to_non_framed_body(5)
397402

0 commit comments

Comments
 (0)