24
24
from cryptography .hazmat .primitives import hashes , serialization
25
25
from cryptography .hazmat .primitives .asymmetric import padding
26
26
from mock import MagicMock
27
+ from wrapt import ObjectProxy
27
28
28
29
import aws_encryption_sdk
29
30
from aws_encryption_sdk import KMSMasterKeyProvider
@@ -749,31 +750,57 @@ def test_plaintext_logs_stream(caplog, capsys, plaintext_length, frame_size):
749
750
750
751
class NothingButRead (object ):
751
752
def __init__ (self , data ):
752
- self ._data = io . BytesIO ( data )
753
+ self ._data = data
753
754
754
755
def read (self , size = - 1 ):
755
756
return self ._data .read (size )
756
757
757
758
758
- @pytest .mark .xfail
759
+ class NoTell (ObjectProxy ):
760
+ def tell (self ):
761
+ raise NotImplementedError ("NoTell does not tell()." )
762
+
763
+
764
+ class NoClose (ObjectProxy ):
765
+ closed = NotImplemented
766
+
767
+ def close (self ):
768
+ raise NotImplementedError ("NoClose does not close()." )
769
+
770
+
771
+ @pytest .mark .parametrize (
772
+ "wrapping_class" ,
773
+ (
774
+ NoTell ,
775
+ pytest .param (NoClose , marks = pytest .mark .xfail (strict = True )),
776
+ pytest .param (NothingButRead , marks = pytest .mark .xfail (strict = True )),
777
+ ),
778
+ )
759
779
@pytest .mark .parametrize ("frame_length" , (0 , 1024 ))
760
- def test_cycle_nothing_but_read (frame_length ):
780
+ def test_cycle_minimal_source_stream_api (frame_length , wrapping_class ):
761
781
raw_plaintext = exact_length_plaintext (100 )
762
- plaintext = NothingButRead ( raw_plaintext )
782
+ plaintext = wrapping_class ( io . BytesIO ( raw_plaintext ) )
763
783
key_provider = fake_kms_key_provider ()
764
784
raw_ciphertext , _encrypt_header = aws_encryption_sdk .encrypt (
765
785
source = plaintext , key_provider = key_provider , frame_length = frame_length
766
786
)
767
- ciphertext = NothingButRead ( raw_ciphertext )
787
+ ciphertext = wrapping_class ( io . BytesIO ( raw_ciphertext ) )
768
788
decrypted , _decrypt_header = aws_encryption_sdk .decrypt (source = ciphertext , key_provider = key_provider )
769
789
assert raw_plaintext == decrypted
770
790
771
791
772
- @pytest .mark .xfail
792
+ @pytest .mark .parametrize (
793
+ "wrapping_class" ,
794
+ (
795
+ NoTell ,
796
+ pytest .param (NoClose , marks = pytest .mark .xfail (strict = True )),
797
+ pytest .param (NothingButRead , marks = pytest .mark .xfail (strict = True )),
798
+ ),
799
+ )
773
800
@pytest .mark .parametrize ("frame_length" , (0 , 1024 ))
774
- def test_encrypt_nothing_but_read (frame_length ):
801
+ def test_encrypt_minimal_source_stream_api (frame_length , wrapping_class ):
775
802
raw_plaintext = exact_length_plaintext (100 )
776
- plaintext = NothingButRead ( raw_plaintext )
803
+ plaintext = wrapping_class ( io . BytesIO ( raw_plaintext ) )
777
804
key_provider = fake_kms_key_provider ()
778
805
ciphertext , _encrypt_header = aws_encryption_sdk .encrypt (
779
806
source = plaintext , key_provider = key_provider , frame_length = frame_length
@@ -782,15 +809,22 @@ def test_encrypt_nothing_but_read(frame_length):
782
809
assert raw_plaintext == decrypted
783
810
784
811
785
- @pytest .mark .xfail
812
+ @pytest .mark .parametrize (
813
+ "wrapping_class" ,
814
+ (
815
+ NoTell ,
816
+ pytest .param (NoClose , marks = pytest .mark .xfail (strict = True )),
817
+ pytest .param (NothingButRead , marks = pytest .mark .xfail (strict = True )),
818
+ ),
819
+ )
786
820
@pytest .mark .parametrize ("frame_length" , (0 , 1024 ))
787
- def test_decrypt_nothing_but_read (frame_length ):
821
+ def test_decrypt_minimal_source_stream_api (frame_length , wrapping_class ):
788
822
plaintext = exact_length_plaintext (100 )
789
823
key_provider = fake_kms_key_provider ()
790
824
raw_ciphertext , _encrypt_header = aws_encryption_sdk .encrypt (
791
825
source = plaintext , key_provider = key_provider , frame_length = frame_length
792
826
)
793
- ciphertext = NothingButRead ( raw_ciphertext )
827
+ ciphertext = wrapping_class ( io . BytesIO ( raw_ciphertext ) )
794
828
decrypted , _decrypt_header = aws_encryption_sdk .decrypt (source = ciphertext , key_provider = key_provider )
795
829
assert plaintext == decrypted
796
830
0 commit comments