Skip to content

Commit 5b6c04b

Browse files
committed
expand minimal source stream API testing
1 parent 6c9aaed commit 5b6c04b

File tree

1 file changed

+45
-11
lines changed

1 file changed

+45
-11
lines changed

test/functional/test_f_aws_encryption_sdk_client.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from cryptography.hazmat.primitives import hashes, serialization
2525
from cryptography.hazmat.primitives.asymmetric import padding
2626
from mock import MagicMock
27+
from wrapt import ObjectProxy
2728

2829
import aws_encryption_sdk
2930
from aws_encryption_sdk import KMSMasterKeyProvider
@@ -749,31 +750,57 @@ def test_plaintext_logs_stream(caplog, capsys, plaintext_length, frame_size):
749750

750751
class NothingButRead(object):
751752
def __init__(self, data):
752-
self._data = io.BytesIO(data)
753+
self._data = data
753754

754755
def read(self, size=-1):
755756
return self._data.read(size)
756757

757758

758-
@pytest.mark.xfail
759+
class NoTell(ObjectProxy):
760+
def tell(self):
761+
raise NotImplementedError("NoTell does not tell().")
762+
763+
764+
class NoClose(ObjectProxy):
765+
closed = NotImplemented
766+
767+
def close(self):
768+
raise NotImplementedError("NoClose does not close().")
769+
770+
771+
@pytest.mark.parametrize(
772+
"wrapping_class",
773+
(
774+
pytest.param(NoTell, marks=pytest.mark.xfail),
775+
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
776+
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
777+
),
778+
)
759779
@pytest.mark.parametrize("frame_length", (0, 1024))
760-
def test_cycle_nothing_but_read(frame_length):
780+
def test_cycle_minimal_source_stream_api(frame_length, wrapping_class):
761781
raw_plaintext = exact_length_plaintext(100)
762-
plaintext = NothingButRead(raw_plaintext)
782+
plaintext = wrapping_class(io.BytesIO(raw_plaintext))
763783
key_provider = fake_kms_key_provider()
764784
raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(
765785
source=plaintext, key_provider=key_provider, frame_length=frame_length
766786
)
767-
ciphertext = NothingButRead(raw_ciphertext)
787+
ciphertext = wrapping_class(io.BytesIO(raw_ciphertext))
768788
decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider)
769789
assert raw_plaintext == decrypted
770790

771791

772-
@pytest.mark.xfail
792+
@pytest.mark.parametrize(
793+
"wrapping_class",
794+
(
795+
pytest.param(NoTell, marks=pytest.mark.xfail),
796+
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
797+
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
798+
),
799+
)
773800
@pytest.mark.parametrize("frame_length", (0, 1024))
774-
def test_encrypt_nothing_but_read(frame_length):
801+
def test_encrypt_minimal_source_stream_api(frame_length, wrapping_class):
775802
raw_plaintext = exact_length_plaintext(100)
776-
plaintext = NothingButRead(raw_plaintext)
803+
plaintext = wrapping_class(io.BytesIO(raw_plaintext))
777804
key_provider = fake_kms_key_provider()
778805
ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(
779806
source=plaintext, key_provider=key_provider, frame_length=frame_length
@@ -782,15 +809,22 @@ def test_encrypt_nothing_but_read(frame_length):
782809
assert raw_plaintext == decrypted
783810

784811

785-
@pytest.mark.xfail
812+
@pytest.mark.parametrize(
813+
"wrapping_class",
814+
(
815+
NoTell,
816+
pytest.param(NoClose, marks=pytest.mark.xfail(strict=True)),
817+
pytest.param(NothingButRead, marks=pytest.mark.xfail(strict=True)),
818+
),
819+
)
786820
@pytest.mark.parametrize("frame_length", (0, 1024))
787-
def test_decrypt_nothing_but_read(frame_length):
821+
def test_decrypt_minimal_source_stream_api(frame_length, wrapping_class):
788822
plaintext = exact_length_plaintext(100)
789823
key_provider = fake_kms_key_provider()
790824
raw_ciphertext, _encrypt_header = aws_encryption_sdk.encrypt(
791825
source=plaintext, key_provider=key_provider, frame_length=frame_length
792826
)
793-
ciphertext = NothingButRead(raw_ciphertext)
827+
ciphertext = wrapping_class(io.BytesIO(raw_ciphertext))
794828
decrypted, _decrypt_header = aws_encryption_sdk.decrypt(source=ciphertext, key_provider=key_provider)
795829
assert plaintext == decrypted
796830

0 commit comments

Comments
 (0)