diff --git a/integration/combination/test_api_with_cors.py b/integration/combination/test_api_with_cors.py index 938566aaa..414ed5d75 100644 --- a/integration/combination/test_api_with_cors.py +++ b/integration/combination/test_api_with_cors.py @@ -1,3 +1,4 @@ +from integration.helpers.base_test import BaseTest import requests from unittest.case import skipIf @@ -7,8 +8,6 @@ from integration.helpers.deployer.utils.retry import retry from parameterized import parameterized -from integration.helpers.exception import StatusCodeError - ALL_METHODS = "DELETE,GET,HEAD,OPTIONS,PATCH,POST,PUT" @@ -30,8 +29,8 @@ def test_cors(self, file_name): allow_headers = "headers" max_age = "600" - self.verify_options_request(base_url + "/apione", allow_methods, allow_origin, allow_headers, max_age) - self.verify_options_request(base_url + "/apitwo", allow_methods, allow_origin, allow_headers, max_age) + self.verify_cors_options_request(base_url + "/apione", allow_methods, allow_origin, allow_headers, max_age) + self.verify_cors_options_request(base_url + "/apitwo", allow_methods, allow_origin, allow_headers, max_age) def test_cors_with_shorthand_notation(self): self.create_and_verify_stack("combination/api_with_cors_shorthand") @@ -42,8 +41,8 @@ def test_cors_with_shorthand_notation(self): allow_headers = None # This should be absent from response max_age = None # This should be absent from response - self.verify_options_request(base_url + "/apione", ALL_METHODS, allow_origin, allow_headers, max_age) - self.verify_options_request(base_url + "/apitwo", "OPTIONS,POST", allow_origin, allow_headers, max_age) + self.verify_cors_options_request(base_url + "/apione", ALL_METHODS, allow_origin, allow_headers, max_age) + self.verify_cors_options_request(base_url + "/apitwo", "OPTIONS,POST", allow_origin, allow_headers, max_age) def test_cors_with_only_methods(self): self.create_and_verify_stack("combination/api_with_cors_only_methods") @@ -55,8 +54,8 @@ def test_cors_with_only_methods(self): allow_headers = None # This should be absent from response max_age = None # This should be absent from response - self.verify_options_request(base_url + "/apione", allow_methods, allow_origin, allow_headers, max_age) - self.verify_options_request(base_url + "/apitwo", allow_methods, allow_origin, allow_headers, max_age) + self.verify_cors_options_request(base_url + "/apione", allow_methods, allow_origin, allow_headers, max_age) + self.verify_cors_options_request(base_url + "/apitwo", allow_methods, allow_origin, allow_headers, max_age) def test_cors_with_only_headers(self): self.create_and_verify_stack("combination/api_with_cors_only_headers") @@ -67,8 +66,8 @@ def test_cors_with_only_headers(self): allow_headers = "headers" max_age = None # This should be absent from response - self.verify_options_request(base_url + "/apione", ALL_METHODS, allow_origin, allow_headers, max_age) - self.verify_options_request(base_url + "/apitwo", "OPTIONS,POST", allow_origin, allow_headers, max_age) + self.verify_cors_options_request(base_url + "/apione", ALL_METHODS, allow_origin, allow_headers, max_age) + self.verify_cors_options_request(base_url + "/apitwo", "OPTIONS,POST", allow_origin, allow_headers, max_age) def test_cors_with_only_max_age(self): self.create_and_verify_stack("combination/api_with_cors_only_max_age") @@ -79,17 +78,12 @@ def test_cors_with_only_max_age(self): allow_headers = None max_age = "600" - self.verify_options_request(base_url + "/apione", ALL_METHODS, allow_origin, allow_headers, max_age) - self.verify_options_request(base_url + "/apitwo", "OPTIONS,POST", allow_origin, allow_headers, max_age) + self.verify_cors_options_request(base_url + "/apione", ALL_METHODS, allow_origin, allow_headers, max_age) + self.verify_cors_options_request(base_url + "/apitwo", "OPTIONS,POST", allow_origin, allow_headers, max_age) - @retry(StatusCodeError, 3) - def verify_options_request(self, url, allow_methods, allow_origin, allow_headers, max_age): - response = requests.options(url) - status = response.status_code - if status != 200: - raise StatusCodeError("Request to {} failed with status: {}, expected status: 200".format(url, status)) + def verify_cors_options_request(self, url, allow_methods, allow_origin, allow_headers, max_age): + response = self.verify_options_request(url, 200) - self.assertEqual(status, 200, "Options request must be successful and return HTTP 200") headers = response.headers self.assertEqual( headers.get("Access-Control-Allow-Methods"), allow_methods, "Allow-Methods header must have proper value" diff --git a/integration/combination/test_api_with_gateway_responses.py b/integration/combination/test_api_with_gateway_responses.py index c3efdc10f..39979e016 100644 --- a/integration/combination/test_api_with_gateway_responses.py +++ b/integration/combination/test_api_with_gateway_responses.py @@ -1,10 +1,14 @@ +import logging from unittest.case import skipIf +from tenacity import stop_after_attempt, retry_if_exception_type, after_log, wait_exponential, retry, wait_random + from integration.helpers.base_test import BaseTest -from integration.helpers.deployer.utils.retry import retry from integration.helpers.resource import current_region_does_not_support from integration.config.service_names import GATEWAY_RESPONSES +LOG = logging.getLogger(__name__) + @skipIf( current_region_does_not_support([GATEWAY_RESPONSES]), "GatewayResponses is not supported in this testing region" @@ -33,7 +37,13 @@ def test_gateway_responses(self): base_url = stack_outputs["ApiUrl"] self._verify_request_response_and_cors(base_url + "iam", 403) - @retry(AssertionError, exc_raise=AssertionError, exc_raise_msg="Unable to verify GatewayResponse request.") + @retry( + stop=stop_after_attempt(5), + wait=wait_exponential(multiplier=1, min=4, max=10) + wait_random(0, 1), + retry=retry_if_exception_type(AssertionError), + after=after_log(LOG, logging.WARNING), + reraise=True, + ) def _verify_request_response_and_cors(self, url, expected_response): response = self.verify_get_request_response(url, expected_response) access_control_allow_origin = response.headers.get("Access-Control-Allow-Origin", "") diff --git a/integration/combination/test_function_with_cwe_dlq_generated.py b/integration/combination/test_function_with_cwe_dlq_generated.py index 138da7476..de05fbd00 100644 --- a/integration/combination/test_function_with_cwe_dlq_generated.py +++ b/integration/combination/test_function_with_cwe_dlq_generated.py @@ -2,6 +2,7 @@ from unittest.case import skipIf from integration.helpers.base_test import BaseTest +from integration.helpers.common_api import get_queue_policy from integration.helpers.resource import first_item_in_dict, current_region_does_not_support from integration.config.service_names import CWE_CWS_DLQ @@ -30,9 +31,7 @@ def test_function_with_cwe(self): # checking if the generated dead-letter queue has necessary resource based policy attached to it sqs_client = self.client_provider.sqs_client - dlq_policy_result = sqs_client.get_queue_attributes(QueueUrl=lambda_target_dlq_url, AttributeNames=["Policy"]) - dlq_policy_doc = dlq_policy_result["Attributes"]["Policy"] - dlq_policy = json.loads(dlq_policy_doc)["Statement"] + dlq_policy = get_queue_policy(queue_url=lambda_target_dlq_url, sqs_client=sqs_client) self.assertEqual(len(dlq_policy), 1, "Only one statement must be in Dead-letter queue policy") dlq_policy_statement = dlq_policy[0] diff --git a/integration/helpers/base_test.py b/integration/helpers/base_test.py index 2dfd2a99d..0a84c9b33 100644 --- a/integration/helpers/base_test.py +++ b/integration/helpers/base_test.py @@ -10,12 +10,22 @@ from integration.helpers.client_provider import ClientProvider from integration.helpers.deployer.exceptions.exceptions import ThrottlingError from integration.helpers.deployer.utils.retry import retry_with_exponential_backoff_and_jitter +from integration.helpers.exception import StatusCodeError from integration.helpers.request_utils import RequestUtils from integration.helpers.resource import generate_suffix, create_bucket, verify_stack_resources from integration.helpers.s3_uploader import S3Uploader from integration.helpers.yaml_utils import dump_yaml, load_yaml from samtranslator.yaml_helper import yaml_parse +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, + after_log, + wait_random, +) + try: from pathlib import Path except ImportError: @@ -502,6 +512,13 @@ def verify_stack(self, end_state="CREATE_COMPLETE"): if error: self.fail(error) + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10) + wait_random(0, 1), + retry=retry_if_exception_type(StatusCodeError), + after=after_log(LOG, logging.WARNING), + reraise=True, + ) def verify_get_request_response(self, url, expected_status_code, headers=None): """ Verify if the get request to a certain url return the expected status code @@ -510,13 +527,47 @@ def verify_get_request_response(self, url, expected_status_code, headers=None): ---------- url : string the url for the get request - expected_status_code : string + expected_status_code : int the expected status code headers : dict headers to use in request """ response = BaseTest.do_get_request_with_logging(url, headers) - self.assertEqual(response.status_code, expected_status_code, " must return HTTP " + str(expected_status_code)) + if response.status_code != expected_status_code: + raise StatusCodeError( + "Request to {} failed with status: {}, expected status: {}".format( + url, response.status_code, expected_status_code + ) + ) + return response + + @retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10) + wait_random(0, 1), + retry=retry_if_exception_type(StatusCodeError), + after=after_log(LOG, logging.WARNING), + reraise=True, + ) + def verify_options_request(self, url, expected_status_code, headers=None): + """ + Verify if the option request to a certain url return the expected status code + + Parameters + ---------- + url : string + the url for the get request + expected_status_code : int + the expected status code + headers : dict + headers to use in request + """ + response = BaseTest.do_options_request_with_logging(url, headers) + if response.status_code != expected_status_code: + raise StatusCodeError( + "Request to {} failed with status: {}, expected status: {}".format( + url, response.status_code, expected_status_code + ) + ) return response def get_default_test_template_parameters(self): @@ -570,3 +621,19 @@ def do_get_request_with_logging(url, headers=None): amazon_headers = RequestUtils(response).get_amazon_headers() REQUEST_LOGGER.info("Request made to " + url, extra={"status": response.status_code, "headers": amazon_headers}) return response + + @staticmethod + def do_options_request_with_logging(url, headers=None): + """ + Perform a options request to an APIGW endpoint and log relevant info + Parameters + ---------- + url : string + the url for the get request + headers : dict + headers to use in request + """ + response = requests.options(url, headers=headers) if headers else requests.get(url) + amazon_headers = RequestUtils(response).get_amazon_headers() + REQUEST_LOGGER.info("Request made to " + url, extra={"status": response.status_code, "headers": amazon_headers}) + return response diff --git a/integration/helpers/common_api.py b/integration/helpers/common_api.py index c261fed5d..d418f6472 100644 --- a/integration/helpers/common_api.py +++ b/integration/helpers/common_api.py @@ -1,6 +1,25 @@ import json +import logging +from tenacity import ( + retry, + stop_after_attempt, + wait_exponential, + retry_if_exception_type, + after_log, + wait_random, +) +LOG = logging.getLogger(__name__) + + +@retry( + stop=stop_after_attempt(3), + wait=wait_exponential(multiplier=1, min=4, max=10) + wait_random(0, 1), + retry=retry_if_exception_type(KeyError), + after=after_log(LOG, logging.WARNING), + reraise=True, +) def get_queue_policy(queue_url, sqs_client): result = sqs_client.get_queue_attributes(QueueUrl=queue_url, AttributeNames=["Policy"]) policy_document = result["Attributes"]["Policy"] diff --git a/integration/single/test_basic_function.py b/integration/single/test_basic_function.py index 35865262d..112e2043f 100644 --- a/integration/single/test_basic_function.py +++ b/integration/single/test_basic_function.py @@ -42,7 +42,7 @@ def test_function_with_http_api_events(self, file_name): endpoint = self.get_api_v2_endpoint("MyHttpApi") - self.assertEqual(BaseTest.do_get_request_with_logging(endpoint).text, self.FUNCTION_OUTPUT) + self._verify_get_request(endpoint, self.FUNCTION_OUTPUT) @parameterized.expand( [ @@ -283,3 +283,7 @@ def _assert_invoke(self, lambda_client, function_name, qualifier=None, expected_ response = lambda_client.invoke(**request_params) self.assertEqual(response.get("StatusCode"), expected_status_code) + + def _verify_get_request(self, url, expected_text): + response = self.verify_get_request_response(url, 200) + self.assertEqual(response.text, expected_text) diff --git a/requirements/dev.txt b/requirements/dev.txt index 1c674a098..95f8dfd65 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -15,6 +15,7 @@ parameterized~=0.7.4 click~=7.1 dateparser~=0.7 boto3>=1.23,<2 +tenacity~=7.0.0 # Requirements for examples requests~=2.24.0