Skip to content

Remove requirement for tell() on source_stream #110

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 11, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/aws_encryption_sdk/streaming_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,8 @@ def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-
):
raise SerializationError("Source too large for non-framed message")

self.__unframed_plaintext_cache = io.BytesIO()

def ciphertext_length(self):
"""Returns the length of the resulting ciphertext message in bytes.

Expand Down Expand Up @@ -486,14 +488,25 @@ def _write_header(self):

def _prep_non_framed(self):
"""Prepare the opening data for a non-framed message."""
try:
plaintext_length = self.stream_length
self.__unframed_plaintext_cache = self.source_stream
except NotSupportedError:
# We need to know the plaintext length before we can start processing the data.
# If we cannot seek on the source then we need to read the entire source into memory.
self.__unframed_plaintext_cache = io.BytesIO()
self.__unframed_plaintext_cache.write(self.source_stream.read())
plaintext_length = self.__unframed_plaintext_cache.tell()
self.__unframed_plaintext_cache.seek(0)

aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string(
content_type=self.content_type, is_final_frame=True
)
associated_data = assemble_content_aad(
message_id=self._header.message_id,
aad_content_string=aad_content_string,
seq_num=1,
length=self.stream_length,
length=plaintext_length,
)
self.encryptor = Encryptor(
algorithm=self._encryption_materials.algorithm,
Expand All @@ -504,7 +517,7 @@ def _prep_non_framed(self):
self.output_buffer += serialize_non_framed_open(
algorithm=self._encryption_materials.algorithm,
iv=self.encryptor.iv,
plaintext_length=self.stream_length,
plaintext_length=plaintext_length,
signer=self.signer,
)

Expand All @@ -516,7 +529,7 @@ def _read_bytes_to_non_framed_body(self, b):
:rtype: bytes
"""
_LOGGER.debug("Reading %d bytes", b)
plaintext = self.source_stream.read(b)
plaintext = self.__unframed_plaintext_cache.read(b)
plaintext_length = len(plaintext)
if self.tell() + len(plaintext) > MAX_NON_FRAMED_SIZE:
raise SerializationError("Source too large for non-framed message")
Expand All @@ -529,6 +542,7 @@ def _read_bytes_to_non_framed_body(self, b):
if len(plaintext) < b:
_LOGGER.debug("Closing encryptor after receiving only %d bytes of %d bytes requested", plaintext_length, b)
self.source_stream.close()
self.__unframed_plaintext_cache.close()
closing = self.encryptor.finalize()

if self.signer is not None:
Expand Down
56 changes: 45 additions & 11 deletions test/functional/test_f_aws_encryption_sdk_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding
from mock import MagicMock
from wrapt import ObjectProxy

import aws_encryption_sdk
from aws_encryption_sdk import KMSMasterKeyProvider
Expand Down Expand Up @@ -749,31 +750,57 @@ def test_plaintext_logs_stream(caplog, capsys, plaintext_length, frame_size):

class NothingButRead(object):
def __init__(self, data):
self._data = io.BytesIO(data)
self._data = data

def read(self, size=-1):
return self._data.read(size)


@pytest.mark.xfail
class NoTell(ObjectProxy):
def tell(self):
raise NotImplementedError("NoTell does not tell().")


class NoClose(ObjectProxy):
closed = NotImplemented

def close(self):
raise NotImplementedError("NoClose does not close().")


@pytest.mark.parametrize(
"wrapping_class",
(
NoTell,
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
),
)
@pytest.mark.parametrize("frame_length", (0, 1024))
def test_cycle_nothing_but_read(frame_length):
def test_cycle_minimal_source_stream_api(frame_length, wrapping_class):
raw_plaintext = exact_length_plaintext(100)
plaintext = NothingButRead(raw_plaintext)
plaintext = wrapping_class(io.BytesIO(raw_plaintext))
key_provider = fake_kms_key_provider()
raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(
source=plaintext, key_provider=key_provider, frame_length=frame_length
)
ciphertext = NothingButRead(raw_ciphertext)
ciphertext = wrapping_class(io.BytesIO(raw_ciphertext))
decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider)
assert raw_plaintext == decrypted


@pytest.mark.xfail
@pytest.mark.parametrize(
"wrapping_class",
(
NoTell,
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
),
)
@pytest.mark.parametrize("frame_length", (0, 1024))
def test_encrypt_nothing_but_read(frame_length):
def test_encrypt_minimal_source_stream_api(frame_length, wrapping_class):
raw_plaintext = exact_length_plaintext(100)
plaintext = NothingButRead(raw_plaintext)
plaintext = wrapping_class(io.BytesIO(raw_plaintext))
key_provider = fake_kms_key_provider()
ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(
source=plaintext, key_provider=key_provider, frame_length=frame_length
Expand All @@ -782,15 +809,22 @@ def test_encrypt_nothing_but_read(frame_length):
assert raw_plaintext == decrypted


@pytest.mark.xfail
@pytest.mark.parametrize(
"wrapping_class",
(
NoTell,
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
),
)
@pytest.mark.parametrize("frame_length", (0, 1024))
def test_decrypt_nothing_but_read(frame_length):
def test_decrypt_minimal_source_stream_api(frame_length, wrapping_class):
plaintext = exact_length_plaintext(100)
key_provider = fake_kms_key_provider()
raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(
source=plaintext, key_provider=key_provider, frame_length=frame_length
)
ciphertext = NothingButRead(raw_ciphertext)
ciphertext = wrapping_class(io.BytesIO(raw_ciphertext))
decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider)
assert plaintext == decrypted

Expand Down
5 changes: 5 additions & 0 deletions test/unit/test_streaming_client_stream_encryptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,10 @@ def test_read_bytes_to_non_framed_body(self):
test_encryptor.encryptor = MagicMock()
test_encryptor._encryption_materials = self.mock_encryption_materials
test_encryptor.encryptor.update.return_value = sentinel.ciphertext
test_encryptor._StreamEncryptor__unframed_plaintext_cache = pt_stream

test = test_encryptor._read_bytes_to_non_framed_body(5)

test_encryptor.encryptor.update.assert_called_once_with(self.plaintext[:5])
test_encryptor.signer.update.assert_called_once_with(sentinel.ciphertext)
assert not test_encryptor.source_stream.closed
Expand All @@ -392,6 +395,8 @@ def test_read_bytes_to_non_framed_body_too_large(self):
pt_stream = io.BytesIO(self.plaintext)
test_encryptor = StreamEncryptor(source=pt_stream, key_provider=self.mock_key_provider)
test_encryptor.bytes_read = aws_encryption_sdk.internal.defaults.MAX_NON_FRAMED_SIZE
test_encryptor._StreamEncryptor__unframed_plaintext_cache = pt_stream

with six.assertRaisesRegex(self, SerializationError, "Source too large for non-framed message"):
test_encryptor._read_bytes_to_non_framed_body(5)

Expand Down