Skip to content

Remove requirement for tell() on decrypt #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Dec 6, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions src/aws_encryption_sdk/internal/formatting/deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,26 +282,38 @@ def deserialize_header_auth(stream, algorithm, verifier=None):


def deserialize_non_framed_values(stream, header, verifier=None):
"""Deserializes the IV and Tag from a non-framed stream.
"""Deserializes the IV and body length from a non-framed stream.

:param stream: Source data stream
:type stream: io.BytesIO
:param header: Deserialized header
:type header: aws_encryption_sdk.structures.MessageHeader
:param verifier: Signature verifier object (optional)
:type verifier: aws_encryption_sdk.internal.crypto.Verifier
:returns: IV, Tag, and Data Length values for body
:rtype: tuple of bytes, bytes, and int
:returns: IV and Data Length values for body
:rtype: tuple of bytes and int
"""
_LOGGER.debug("Starting non-framed body iv/tag deserialization")
(data_iv, data_length) = unpack_values(">{}sQ".format(header.algorithm.iv_len), stream, verifier)
body_start = stream.tell()
stream.seek(data_length, 1)
return data_iv, data_length


def deserialize_tag(stream, header, verifier=None):
"""Deserialize the Tag value from a non-framed stream.

:param stream: Source data stream
:type stream: io.BytesIO
:param header: Deserialized header
:type header: aws_encryption_sdk.structures.MessageHeader
:param verifier: Signature verifier object (optional)
:type verifier: aws_encryption_sdk.internal.crypto.Verifier
:returns: Tag value for body
:rtype: bytes
"""
(data_tag,) = unpack_values(
format_string=">{auth_len}s".format(auth_len=header.algorithm.auth_len), stream=stream, verifier=None
format_string=">{auth_len}s".format(auth_len=header.algorithm.auth_len), stream=stream, verifier=verifier
)
stream.seek(body_start, 0)
return data_iv, data_tag, data_length
return data_tag


def update_verifier_with_tag(stream, header, verifier):
Expand Down
66 changes: 43 additions & 23 deletions src/aws_encryption_sdk/streaming_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ class StreamDecryptor(_EncryptionStream): # pylint: disable=too-many-instance-a
def __init__(self, **kwargs): # pylint: disable=unused-argument,super-init-not-called
"""Prepares necessary initial values."""
self.last_sequence_number = 0
self.__unframed_bytes_read = 0

def _prep_message(self):
"""Performs initial message setup."""
Expand All @@ -713,6 +714,7 @@ def _read_header(self):
:raises CustomMaximumValueExceeded: if frame length is greater than the custom max value
"""
header, raw_header = aws_encryption_sdk.internal.formatting.deserialize.deserialize_header(self.source_stream)
self.__unframed_bytes_read += len(raw_header)

if (
self.config.max_body_length is not None
Expand Down Expand Up @@ -751,9 +753,21 @@ def _read_header(self):
)
return header, header_auth

@property
def body_start(self):
"""Log deprecation warning when body_start is accessed."""
_LOGGER.warning("StreamDecryptor.body_start is deprecated and will be removed in 1.4.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would adding this deprecation message be released through a minor rev or a patch rev?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Per our versioning statement, the warnings should be added in a Z rev and the removal should happen in a Y rev.

https://github.com/aws/aws-encryption-sdk-python/blob/master/VERSIONING.rst

return self._body_start

@property
def body_end(self):
"""Log deprecation warning when body_end is accessed."""
_LOGGER.warning("StreamDecryptor.body_end is deprecated and will be removed in 1.4.0")
return self._body_end

def _prep_non_framed(self):
"""Prepare the opening data for a non-framed message."""
iv, tag, self.body_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values(
self._unframed_body_iv, self.body_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values( # noqa # pylint: disable=line-too-long
stream=self.source_stream, header=self._header, verifier=self.verifier
)

Expand All @@ -764,24 +778,10 @@ def _prep_non_framed(self):
)
)

aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string(
content_type=self._header.content_type, is_final_frame=True
)
associated_data = aws_encryption_sdk.internal.formatting.encryption_context.assemble_content_aad(
message_id=self._header.message_id,
aad_content_string=aad_content_string,
seq_num=1,
length=self.body_length,
)
self.decryptor = Decryptor(
algorithm=self._header.algorithm,
key=self._derived_data_key,
associated_data=associated_data,
iv=iv,
tag=tag,
)
self.body_start = self.source_stream.tell()
self.body_end = self.body_start + self.body_length
self.__unframed_bytes_read += self._header.algorithm.iv_len
self.__unframed_bytes_read += 8 # encrypted content length field
self._body_start = self.__unframed_bytes_read
self._body_end = self._body_start + self.body_length

def _read_bytes_from_non_framed_body(self, b):
"""Reads the requested number of bytes from a streaming non-framed message body.
Expand All @@ -792,7 +792,8 @@ def _read_bytes_from_non_framed_body(self, b):
"""
_LOGGER.debug("starting non-framed body read")
# Always read the entire message for non-framed message bodies.
bytes_to_read = self.body_end - self.source_stream.tell()
bytes_to_read = self.body_length

_LOGGER.debug("%d bytes requested; reading %d bytes", b, bytes_to_read)
ciphertext = self.source_stream.read(bytes_to_read)

Expand All @@ -802,11 +803,30 @@ def _read_bytes_from_non_framed_body(self, b):
if self.verifier is not None:
self.verifier.update(ciphertext)

plaintext = self.decryptor.update(ciphertext)
plaintext += self.decryptor.finalize()
aws_encryption_sdk.internal.formatting.deserialize.update_verifier_with_tag(
tag = aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag(
stream=self.source_stream, header=self._header, verifier=self.verifier
)

aad_content_string = aws_encryption_sdk.internal.utils.get_aad_content_string(
content_type=self._header.content_type, is_final_frame=True
)
associated_data = aws_encryption_sdk.internal.formatting.encryption_context.assemble_content_aad(
message_id=self._header.message_id,
aad_content_string=aad_content_string,
seq_num=1,
length=self.body_length,
)
self.decryptor = Decryptor(
algorithm=self._header.algorithm,
key=self._derived_data_key,
associated_data=associated_data,
iv=self._unframed_body_iv,
tag=tag,
)

plaintext = self.decryptor.update(ciphertext)
plaintext += self.decryptor.finalize()

self.footer = aws_encryption_sdk.internal.formatting.deserialize.deserialize_footer(
stream=self.source_stream, verifier=self.verifier
)
Expand Down
66 changes: 66 additions & 0 deletions test/functional/test_f_aws_encryption_sdk_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -745,3 +745,69 @@ def test_plaintext_logs_stream(caplog, capsys, plaintext_length, frame_size):

_look_in_logs(caplog, plaintext)
_error_check(capsys)


class NothingButRead(object):
def __init__(self, data):
self._data = io.BytesIO(data)

def read(self, size=-1):
return self._data.read(size)


@pytest.mark.xfail
@pytest.mark.parametrize("frame_length", (0, 1024))
def test_cycle_nothing_but_read(frame_length):
raw_plaintext = exact_length_plaintext(100)
plaintext = NothingButRead(raw_plaintext)
key_provider = fake_kms_key_provider()
raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(
source=plaintext, key_provider=key_provider, frame_length=frame_length
)
ciphertext = NothingButRead(raw_ciphertext)
decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider)
assert raw_plaintext == decrypted


@pytest.mark.xfail
@pytest.mark.parametrize("frame_length", (0, 1024))
def test_encrypt_nothing_but_read(frame_length):
raw_plaintext = exact_length_plaintext(100)
plaintext = NothingButRead(raw_plaintext)
key_provider = fake_kms_key_provider()
ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(
source=plaintext, key_provider=key_provider, frame_length=frame_length
)
decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider)
assert raw_plaintext == decrypted


@pytest.mark.xfail
@pytest.mark.parametrize("frame_length", (0, 1024))
def test_decrypt_nothing_but_read(frame_length):
plaintext = exact_length_plaintext(100)
key_provider = fake_kms_key_provider()
raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(
source=plaintext, key_provider=key_provider, frame_length=frame_length
)
ciphertext = NothingButRead(raw_ciphertext)
decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider)
assert plaintext == decrypted


@pytest.mark.parametrize("attribute, no_later_than", (("body_start", "1.4.0"), ("body_end", "1.4.0")))
def test_decryptor_deprecated_attributes(caplog, attribute, no_later_than):
caplog.set_level(logging.WARNING)
plaintext = exact_length_plaintext(100)
key_provider = fake_kms_key_provider()
ciphertext, _header = aws_encryption_sdk.encrypt(source=plaintext, key_provider=key_provider, frame_length=0)
with aws_encryption_sdk.stream(mode="decrypt", source=ciphertext, key_provider=key_provider) as decryptor:
decrypted = decryptor.read()

assert decrypted == plaintext
assert hasattr(decryptor, attribute)
watch_string = "StreamDecryptor.{name} is deprecated and will be removed in {version}".format(
name=attribute, version=no_later_than
)
assert watch_string in caplog.text
assert aws_encryption_sdk.__version__ < no_later_than
27 changes: 27 additions & 0 deletions test/unit/test_deserialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# language governing permissions and limitations under the License.
"""Unit test suite for aws_encryption_sdk.deserialize"""
import io
import struct
import unittest

import pytest
Expand All @@ -29,6 +30,32 @@
pytestmark = [pytest.mark.unit, pytest.mark.local]


def test_deserialize_non_framed_values():
iv = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11"
length = 42
packed = struct.pack(">12sQ", iv, length)
mock_header = MagicMock(algorithm=MagicMock(iv_len=12))

parsed_iv, parsed_length = aws_encryption_sdk.internal.formatting.deserialize.deserialize_non_framed_values(
stream=io.BytesIO(packed), header=mock_header
)

assert parsed_iv == iv
assert parsed_length == length


def test_deserialize_tag():
tag = b"\x00\x01\x02\x03\x04\x05\x06\x07\x08\x09\x10\x11\x12\x13\x14\x15"
packed = struct.pack(">16s", tag)
mock_header = MagicMock(algorithm=MagicMock(auth_len=16))

parsed_tag = aws_encryption_sdk.internal.formatting.deserialize.deserialize_tag(
stream=io.BytesIO(packed), header=mock_header
)

assert parsed_tag == tag


class TestDeserialize(unittest.TestCase):
def setUp(self):
self.mock_wrapping_algorithm = MagicMock()
Expand Down
6 changes: 3 additions & 3 deletions test/unit/test_streaming_client_encryption_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@

import aws_encryption_sdk.exceptions
from aws_encryption_sdk.internal.defaults import LINE_LENGTH
from aws_encryption_sdk.internal.utils.streams import InsistentReaderBytesIO
from aws_encryption_sdk.key_providers.base import MasterKeyProvider
from aws_encryption_sdk.streaming_client import _ClientConfig, _EncryptionStream

from .test_values import VALUES
from .unit_test_utils import assert_prepped_stream_identity

pytestmark = [pytest.mark.unit, pytest.mark.local]

Expand Down Expand Up @@ -110,7 +110,7 @@ def test_new_with_params(self):
)

assert mock_stream.config.source == self.mock_source_stream
assert isinstance(mock_stream.config.source, InsistentReaderBytesIO)
assert_prepped_stream_identity(mock_stream.config.source, object)
assert mock_stream.config.key_provider is self.mock_key_provider
assert mock_stream.config.mock_read_bytes is sentinel.read_bytes
assert mock_stream.config.line_length == io.DEFAULT_BUFFER_SIZE
Expand All @@ -120,7 +120,7 @@ def test_new_with_params(self):
assert mock_stream.output_buffer == b""
assert not mock_stream._message_prepped
assert mock_stream.source_stream == self.mock_source_stream
assert isinstance(mock_stream.source_stream, InsistentReaderBytesIO)
assert_prepped_stream_identity(mock_stream.source_stream, object)
assert mock_stream._stream_length is mock_int_sentinel
assert mock_stream.line_length == io.DEFAULT_BUFFER_SIZE

Expand Down
Loading