diff --git a/test/unit/test_streaming_client_stream_decryptor.py b/test/unit/test_streaming_client_stream_decryptor.py index c59ae4beb..69cd72c73 100644 --- a/test/unit/test_streaming_client_stream_decryptor.py +++ b/test/unit/test_streaming_client_stream_decryptor.py @@ -12,10 +12,8 @@ # language governing permissions and limitations under the License. """Unit test suite for aws_encryption_sdk.streaming_client.StreamDecryptor""" import io -import unittest import pytest -import six from mock import MagicMock, call, patch, sentinel from aws_encryption_sdk.exceptions import CustomMaximumValueExceeded, NotSupportedError, SerializationError @@ -29,8 +27,9 @@ pytestmark = [pytest.mark.unit, pytest.mark.local] -class TestStreamDecryptor(unittest.TestCase): - def setUp(self): +class TestStreamDecryptor(object): + @pytest.fixture(autouse=True) + def apply_fixtures(self): self.mock_key_provider = MagicMock(__class__=MasterKeyProvider) self.mock_materials_manager = MagicMock(__class__=CryptoMaterialsManager) self.mock_materials_manager.decrypt_materials.return_value = MagicMock( @@ -92,8 +91,8 @@ def setUp(self): # Set up decrypt patch self.mock_decrypt_patcher = patch("aws_encryption_sdk.streaming_client.decrypt") self.mock_decrypt = self.mock_decrypt_patcher.start() - - def tearDown(self): + yield + # Run tearDown self.mock_deserialize_header_patcher.stop() self.mock_deserialize_header_auth_patcher.stop() self.mock_validate_header_patcher.stop() @@ -186,12 +185,11 @@ def test_read_header_frame_too_large(self, mock_derive_datakey): test_decryptor.key_provider = self.mock_key_provider test_decryptor.source_stream = ct_stream test_decryptor._stream_length = len(VALUES["data_128"]) - with six.assertRaisesRegex( - self, - CustomMaximumValueExceeded, - "Frame Size in header found larger than custom value: {found} > {custom}".format(found=1024, custom=10), - ): + with pytest.raises(CustomMaximumValueExceeded) as excinfo: test_decryptor._read_header() + excinfo.match( + "Frame Size in header found larger than custom value: {found} > {custom}".format(found=1024, custom=10) + ) @patch("aws_encryption_sdk.streaming_client.Verifier") @patch("aws_encryption_sdk.streaming_client.DecryptionMaterialsRequest") @@ -220,14 +218,13 @@ def test_prep_non_framed_content_length_too_large(self): mock_data_key = MagicMock() test_decryptor.data_key = mock_data_key - with six.assertRaisesRegex( - self, - CustomMaximumValueExceeded, + with pytest.raises(CustomMaximumValueExceeded) as excinfo: + test_decryptor._prep_non_framed() + excinfo.match( "Non-framed message content length found larger than custom value: {found} > {custom}".format( found=len(VALUES["data_128"]), custom=len(VALUES["data_128"]) // 2 - ), - ): - test_decryptor._prep_non_framed() + ) + ) def test_prep_non_framed(self): test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=self.mock_input_stream) @@ -288,10 +285,9 @@ def test_read_bytes_from_non_framed_message_body_too_small(self): test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=ct_stream) test_decryptor.body_length = len(VALUES["data_128"] * 2) test_decryptor._header = self.mock_header - with six.assertRaisesRegex( - self, SerializationError, "Total message body contents less than specified in body description" - ): + with pytest.raises(SerializationError) as excinfo: test_decryptor._read_bytes_from_non_framed_body(1) + excinfo.match("Total message body contents less than specified in body description") def test_read_bytes_from_non_framed_no_verifier(self): ct_stream = io.BytesIO(VALUES["data_128"]) @@ -497,8 +493,9 @@ def test_read_bytes_from_framed_body_bad_sequence_number(self): frame_data.final_frame = False frame_data.ciphertext = b"asdfzxcv" self.mock_deserialize_frame.return_value = (frame_data, False) - with six.assertRaisesRegex(self, SerializationError, "Malformed message: frames out of order"): + with pytest.raises(SerializationError) as excinfo: test_decryptor._read_bytes_from_framed_body(4) + excinfo.match("Malformed message: frames out of order") @patch("aws_encryption_sdk.streaming_client.StreamDecryptor._read_bytes_from_non_framed_body") @patch("aws_encryption_sdk.streaming_client.StreamDecryptor._read_bytes_from_framed_body") @@ -549,8 +546,9 @@ def test_read_bytes_unknown(self, mock_read_frame, mock_read_block): test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=ct_stream) test_decryptor._header = MagicMock() test_decryptor._header.content_type = None - with six.assertRaisesRegex(self, NotSupportedError, "Unsupported content type"): + with pytest.raises(NotSupportedError) as excinfo: test_decryptor._read_bytes(5) + excinfo.match("Unsupported content type") @patch("aws_encryption_sdk.streaming_client._EncryptionStream.close") def test_close(self, mock_close): @@ -565,5 +563,6 @@ def test_close(self, mock_close): def test_close_no_footer(self, mock_close): self.mock_header.content_type = ContentType.FRAMED_DATA test_decryptor = StreamDecryptor(key_provider=self.mock_key_provider, source=self.mock_input_stream) - with six.assertRaisesRegex(self, SerializationError, "Footer not read"): + with pytest.raises(SerializationError) as excinfo: test_decryptor.close() + excinfo.match("Footer not read")