Skip to content

Commit 1942f59

Browse files
authored
Merge pull request #110 from mattsb42-aws/dev-103b
Remove requirement for tell() on source_stream
2 parents 6c9aaed + e3768c5 commit 1942f59

File tree

3 files changed

+67
-14
lines changed

3 files changed

+67
-14
lines changed

src/aws_encryption_sdk/streaming_client.py

+17-3
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,8 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-
399399
):
400400
raise SerializationError("Source too large for non-framed message")
401401

402+
self.__unframed_plaintext_cache = io.BytesIO()
403+
402404
def ciphertext_length(self):
403405
"""Returns the length of the resulting ciphertext message in bytes.
404406
@@ -486,14 +488,25 @@ def _write_header(self):
486488

487489
def _prep_non_framed(self):
488490
"""Prepare the opening data for a non-framed message."""
491+
try:
492+
plaintext_length = self.stream_length
493+
self.__unframed_plaintext_cache = self.source_stream
494+
except NotSupportedError:
495+
# We need to know the plaintext length before we can start processing the data.
496+
# If we cannot seek on the source then we need to read the entire source into memory.
497+
self.__unframed_plaintext_cache = io.BytesIO()
498+
self.__unframed_plaintext_cache.write(self.source_stream.read())
499+
plaintext_length = self.__unframed_plaintext_cache.tell()
500+
self.__unframed_plaintext_cache.seek(0)
501+
489502
aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string(
490503
content_type=self.content_type, is_final_frame=True
491504
)
492505
associated_data = assemble_content_aad(
493506
message_id=self._header.message_id,
494507
aad_content_string=aad_content_string,
495508
seq_num=1,
496-
length=self.stream_length,
509+
length=plaintext_length,
497510
)
498511
self.encryptor = Encryptor(
499512
algorithm=self._encryption_materials.algorithm,
@@ -504,7 +517,7 @@ def _prep_non_framed(self):
504517
self.output_buffer += serialize_non_framed_open(
505518
algorithm=self._encryption_materials.algorithm,
506519
iv=self.encryptor.iv,
507-
plaintext_length=self.stream_length,
520+
plaintext_length=plaintext_length,
508521
signer=self.signer,
509522
)
510523

@@ -516,7 +529,7 @@ def _read_bytes_to_non_framed_body(self, b):
516529
:rtype: bytes
517530
"""
518531
_LOGGER.debug("Reading %d bytes", b)
519-
plaintext = self.source_stream.read(b)
532+
plaintext = self.__unframed_plaintext_cache.read(b)
520533
plaintext_length = len(plaintext)
521534
if self.tell() + len(plaintext) > MAX_NON_FRAMED_SIZE:
522535
raise SerializationError("Source too large for non-framed message")
@@ -529,6 +542,7 @@ def _read_bytes_to_non_framed_body(self, b):
529542
if len(plaintext) < b:
530543
_LOGGER.debug("Closing encryptor after receiving only %d bytes of %d bytes requested", plaintext_length, b)
531544
self.source_stream.close()
545+
self.__unframed_plaintext_cache.close()
532546
closing = self.encryptor.finalize()
533547

534548
if self.signer is not None:

test/functional/test_f_aws_encryption_sdk_client.py

+45-11
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from cryptography.hazmat.primitives import hashes, serialization
2525
from cryptography.hazmat.primitives.asymmetric import padding
2626
from mock import MagicMock
27+
from wrapt import ObjectProxy
2728

2829
import aws_encryption_sdk
2930
from aws_encryption_sdk import KMSMasterKeyProvider
@@ -749,31 +750,57 @@ def test_plaintext_logs_stream(caplog, capsys, plaintext_length, frame_size):
749750

750751
class NothingButRead(object):
751752
def __init__(self, data):
752-
self._data = io.BytesIO(data)
753+
self._data = data
753754

754755
def read(self, size=-1):
755756
return self._data.read(size)
756757

757758

758-
@pytest.mark.xfail
759+
class NoTell(ObjectProxy):
760+
def tell(self):
761+
raise NotImplementedError("NoTell does not tell().")
762+
763+
764+
class NoClose(ObjectProxy):
765+
closed = NotImplemented
766+
767+
def close(self):
768+
raise NotImplementedError("NoClose does not close().")
769+
770+
771+
@pytest.mark.parametrize(
772+
"wrapping_class",
773+
(
774+
NoTell,
775+
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
776+
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
777+
),
778+
)
759779
@pytest.mark.parametrize("frame_length", (0, 1024))
760-
def test_cycle_nothing_but_read(frame_length):
780+
def test_cycle_minimal_source_stream_api(frame_length, wrapping_class):
761781
raw_plaintext = exact_length_plaintext(100)
762-
plaintext = NothingButRead(raw_plaintext)
782+
plaintext = wrapping_class(io.BytesIO(raw_plaintext))
763783
key_provider = fake_kms_key_provider()
764784
raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(
765785
source=plaintext, key_provider=key_provider, frame_length=frame_length
766786
)
767-
ciphertext = NothingButRead(raw_ciphertext)
787+
ciphertext = wrapping_class(io.BytesIO(raw_ciphertext))
768788
decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider)
769789
assert raw_plaintext == decrypted
770790

771791

772-
@pytest.mark.xfail
792+
@pytest.mark.parametrize(
793+
"wrapping_class",
794+
(
795+
NoTell,
796+
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
797+
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
798+
),
799+
)
773800
@pytest.mark.parametrize("frame_length", (0, 1024))
774-
def test_encrypt_nothing_but_read(frame_length):
801+
def test_encrypt_minimal_source_stream_api(frame_length, wrapping_class):
775802
raw_plaintext = exact_length_plaintext(100)
776-
plaintext = NothingButRead(raw_plaintext)
803+
plaintext = wrapping_class(io.BytesIO(raw_plaintext))
777804
key_provider = fake_kms_key_provider()
778805
ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(
779806
source=plaintext, key_provider=key_provider, frame_length=frame_length
@@ -782,15 +809,22 @@ def test_encrypt_nothing_but_read(frame_length):
782809
assert raw_plaintext == decrypted
783810

784811

785-
@pytest.mark.xfail
812+
@pytest.mark.parametrize(
813+
"wrapping_class",
814+
(
815+
NoTell,
816+
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
817+
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
818+
),
819+
)
786820
@pytest.mark.parametrize("frame_length", (0, 1024))
787-
def test_decrypt_nothing_but_read(frame_length):
821+
def test_decrypt_minimal_source_stream_api(frame_length, wrapping_class):
788822
plaintext = exact_length_plaintext(100)
789823
key_provider = fake_kms_key_provider()
790824
raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(
791825
source=plaintext, key_provider=key_provider, frame_length=frame_length
792826
)
793-
ciphertext = NothingButRead(raw_ciphertext)
827+
ciphertext = wrapping_class(io.BytesIO(raw_ciphertext))
794828
decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider)
795829
assert plaintext == decrypted
796830

test/unit/test_streaming_client_stream_encryptor.py

+5
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,10 @@ def test_read_bytes_to_non_framed_body(self):
382382
test_encryptor.encryptor = MagicMock()
383383
test_encryptor._encryption_materials = self.mock_encryption_materials
384384
test_encryptor.encryptor.update.return_value = sentinel.ciphertext
385+
test_encryptor._StreamEncryptor__unframed_plaintext_cache = pt_stream
386+
385387
test = test_encryptor._read_bytes_to_non_framed_body(5)
388+
386389
test_encryptor.encryptor.update.assert_called_once_with(self.plaintext[:5])
387390
test_encryptor.signer.update.assert_called_once_with(sentinel.ciphertext)
388391
assert not test_encryptor.source_stream.closed
@@ -392,6 +395,8 @@ def test_read_bytes_to_non_framed_body_too_large(self):
392395
pt_stream = io.BytesIO(self.plaintext)
393396
test_encryptor = StreamEncryptor(source=pt_stream, key_provider=self.mock_key_provider)
394397
test_encryptor.bytes_read = aws_encryption_sdk.internal.defaults.MAX_NON_FRAMED_SIZE
398+
test_encryptor._StreamEncryptor__unframed_plaintext_cache = pt_stream
399+
395400
with six.assertRaisesRegex(self, SerializationError, "Source too large for non-framed message"):
396401
test_encryptor._read_bytes_to_non_framed_body(5)
397402

0 commit comments

Comments
 (0)