diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py new file mode 100644 index 00000000000..fc744055e6c --- /dev/null +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -0,0 +1,220 @@ +import base64 +import json +import re +import zlib +from enum import Enum +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union + +from aws_lambda_powertools.shared.json_encoder import Encoder +from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 +from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent +from aws_lambda_powertools.utilities.typing import LambdaContext + + +class ProxyEventType(Enum): + http_api_v1 = "APIGatewayProxyEvent" + http_api_v2 = "APIGatewayProxyEventV2" + alb_event = "ALBEvent" + api_gateway = http_api_v1 + + +class CORSConfig(object): + """CORS Config""" + + _REQUIRED_HEADERS = ["Authorization", "Content-Type", "X-Amz-Date", "X-Api-Key", "X-Amz-Security-Token"] + + def __init__( + self, + allow_origin: str = "*", + allow_headers: List[str] = None, + expose_headers: List[str] = None, + max_age: int = None, + allow_credentials: bool = False, + ): + """ + Parameters + ---------- + allow_origin: str + The value of the `Access-Control-Allow-Origin` to send in the response. Defaults to "*", but should + only be used during development. + allow_headers: str + The list of additional allowed headers. This list is added to list of + built in allowed headers: `Authorization`, `Content-Type`, `X-Amz-Date`, + `X-Api-Key`, `X-Amz-Security-Token`. + expose_headers: str + A list of values to return for the Access-Control-Expose-Headers + max_age: int + The value for the `Access-Control-Max-Age` + allow_credentials: bool + A boolean value that sets the value of `Access-Control-Allow-Credentials` + """ + self.allow_origin = allow_origin + self.allow_headers = set(self._REQUIRED_HEADERS + (allow_headers or [])) + self.expose_headers = expose_headers or [] + self.max_age = max_age + self.allow_credentials = allow_credentials + + def to_dict(self) -> Dict[str, str]: + headers = { + "Access-Control-Allow-Origin": self.allow_origin, + "Access-Control-Allow-Headers": ",".join(sorted(self.allow_headers)), + } + if self.expose_headers: + headers["Access-Control-Expose-Headers"] = ",".join(self.expose_headers) + if self.max_age is not None: + headers["Access-Control-Max-Age"] = str(self.max_age) + if self.allow_credentials is True: + headers["Access-Control-Allow-Credentials"] = "true" + return headers + + +class Route: + def __init__( + self, method: str, rule: Any, func: Callable, cors: bool, compress: bool, cache_control: Optional[str] + ): + self.method = method.upper() + self.rule = rule + self.func = func + self.cors = cors + self.compress = compress + self.cache_control = cache_control + + +class Response: + def __init__( + self, status_code: int, content_type: Optional[str], body: Union[str, bytes, None], headers: Dict = None + ): + self.status_code = status_code + self.body = body + self.base64_encoded = False + self.headers: Dict = headers or {} + if content_type: + self.headers.setdefault("Content-Type", content_type) + + def add_cors(self, cors: CORSConfig): + self.headers.update(cors.to_dict()) + + def add_cache_control(self, cache_control: str): + self.headers["Cache-Control"] = cache_control if self.status_code == 200 else "no-cache" + + def compress(self): + self.headers["Content-Encoding"] = "gzip" + if isinstance(self.body, str): + self.body = bytes(self.body, "utf-8") + gzip = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) + self.body = gzip.compress(self.body) + gzip.flush() + + def to_dict(self) -> Dict[str, Any]: + if isinstance(self.body, bytes): + self.base64_encoded = True + self.body = base64.b64encode(self.body).decode() + return { + "statusCode": self.status_code, + "headers": self.headers, + "body": self.body, + "isBase64Encoded": self.base64_encoded, + } + + +class ApiGatewayResolver: + current_event: BaseProxyEvent + lambda_context: LambdaContext + + def __init__(self, proxy_type: Enum = ProxyEventType.http_api_v1, cors: CORSConfig = None): + self._proxy_type = proxy_type + self._routes: List[Route] = [] + self._cors = cors + self._cors_methods: Set[str] = {"OPTIONS"} + + def get(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + return self.route(rule, "GET", cors, compress, cache_control) + + def post(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + return self.route(rule, "POST", cors, compress, cache_control) + + def put(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + return self.route(rule, "PUT", cors, compress, cache_control) + + def delete(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + return self.route(rule, "DELETE", cors, compress, cache_control) + + def patch(self, rule: str, cors: bool = False, compress: bool = False, cache_control: str = None): + return self.route(rule, "PATCH", cors, compress, cache_control) + + def route(self, rule: str, method: str, cors: bool = False, compress: bool = False, cache_control: str = None): + def register_resolver(func: Callable): + self._routes.append(Route(method, self._compile_regex(rule), func, cors, compress, cache_control)) + if cors: + self._cors_methods.add(method.upper()) + return func + + return register_resolver + + def resolve(self, event, context) -> Dict[str, Any]: + self.current_event = self._to_data_class(event) + self.lambda_context = context + route, response = self._find_route(self.current_event.http_method.upper(), self.current_event.path) + if route is None: # No matching route was found + return response.to_dict() + + if route.cors: + response.add_cors(self._cors or CORSConfig()) + if route.cache_control: + response.add_cache_control(route.cache_control) + if route.compress and "gzip" in (self.current_event.get_header_value("accept-encoding") or ""): + response.compress() + + return response.to_dict() + + @staticmethod + def _compile_regex(rule: str): + rule_regex: str = re.sub(r"(<\w+>)", r"(?P\1.+)", rule) + return re.compile("^{}$".format(rule_regex)) + + def _to_data_class(self, event: Dict) -> BaseProxyEvent: + if self._proxy_type == ProxyEventType.http_api_v1: + return APIGatewayProxyEvent(event) + if self._proxy_type == ProxyEventType.http_api_v2: + return APIGatewayProxyEventV2(event) + return ALBEvent(event) + + def _find_route(self, method: str, path: str) -> Tuple[Optional[Route], Response]: + for route in self._routes: + if method != route.method: + continue + match: Optional[re.Match] = route.rule.match(path) + if match: + return self._call_route(route, match.groupdict()) + + headers = {} + if self._cors: + headers.update(self._cors.to_dict()) + if method == "OPTIONS": # Preflight + headers["Access-Control-Allow-Methods"] = ",".join(sorted(self._cors_methods)) + return None, Response(status_code=204, content_type=None, body=None, headers=headers) + + return None, Response( + status_code=404, + content_type="application/json", + body=json.dumps({"message": f"No route found for '{method}.{path}'"}), + headers=headers, + ) + + def _call_route(self, route: Route, args: Dict[str, str]) -> Tuple[Route, Response]: + return route, self._to_response(route.func(**args)) + + @staticmethod + def _to_response(result: Union[Tuple[int, str, Union[bytes, str]], Dict, Response]) -> Response: + if isinstance(result, Response): + return result + elif isinstance(result, dict): + return Response( + status_code=200, + content_type="application/json", + body=json.dumps(result, separators=(",", ":"), cls=Encoder), + ) + else: # Tuple[int, str, Union[bytes, str]] + return Response(*result) + + def __call__(self, event, context) -> Any: + return self.resolve(event, context) diff --git a/aws_lambda_powertools/utilities/data_classes/alb_event.py b/aws_lambda_powertools/utilities/data_classes/alb_event.py index 6c7cb9e60c3..73e064d0f26 100644 --- a/aws_lambda_powertools/utilities/data_classes/alb_event.py +++ b/aws_lambda_powertools/utilities/data_classes/alb_event.py @@ -21,14 +21,6 @@ class ALBEvent(BaseProxyEvent): def request_context(self) -> ALBEventRequestContext: return ALBEventRequestContext(self._data) - @property - def http_method(self) -> str: - return self["httpMethod"] - - @property - def path(self) -> str: - return self["path"] - @property def multi_value_query_string_parameters(self) -> Optional[Dict[str, List[str]]]: return self.get("multiValueQueryStringParameters") diff --git a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py index 6c06e48e63e..20cbfa58fd2 100644 --- a/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py +++ b/aws_lambda_powertools/utilities/data_classes/api_gateway_proxy_event.py @@ -217,15 +217,6 @@ def version(self) -> str: def resource(self) -> str: return self["resource"] - @property - def path(self) -> str: - return self["path"] - - @property - def http_method(self) -> str: - """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" - return self["httpMethod"] - @property def multi_value_headers(self) -> Dict[str, List[str]]: return self["multiValueHeaders"] @@ -446,3 +437,12 @@ def path_parameters(self) -> Optional[Dict[str, str]]: @property def stage_variables(self) -> Optional[Dict[str, str]]: return self.get("stageVariables") + + @property + def path(self) -> str: + return self.raw_path + + @property + def http_method(self) -> str: + """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" + return self.request_context.http.method diff --git a/aws_lambda_powertools/utilities/data_classes/common.py b/aws_lambda_powertools/utilities/data_classes/common.py index 6f393cccb60..a6b975c6072 100644 --- a/aws_lambda_powertools/utilities/data_classes/common.py +++ b/aws_lambda_powertools/utilities/data_classes/common.py @@ -1,3 +1,4 @@ +import json from typing import Any, Dict, Optional @@ -57,8 +58,23 @@ def is_base64_encoded(self) -> Optional[bool]: @property def body(self) -> Optional[str]: + """Submitted body of the request as a string""" return self.get("body") + @property + def json_body(self) -> Any: + """Parses the submitted body as json""" + return json.loads(self["body"]) + + @property + def path(self) -> str: + return self["path"] + + @property + def http_method(self) -> str: + """The HTTP method used. Valid values include: DELETE, GET, HEAD, OPTIONS, PATCH, POST, and PUT.""" + return self["httpMethod"] + def get_query_string_value(self, name: str, default_value: Optional[str] = None) -> Optional[str]: """Get query string value by name diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py new file mode 100644 index 00000000000..df13b047d0d --- /dev/null +++ b/tests/functional/event_handler/test_api_gateway.py @@ -0,0 +1,474 @@ +import base64 +import json +import zlib +from decimal import Decimal +from pathlib import Path +from typing import Dict, Tuple + +from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver, CORSConfig, ProxyEventType, Response +from aws_lambda_powertools.shared.json_encoder import Encoder +from aws_lambda_powertools.utilities.data_classes import ALBEvent, APIGatewayProxyEvent, APIGatewayProxyEventV2 +from tests.functional.utils import load_event + + +def read_media(file_name: str) -> bytes: + path = Path(str(Path(__file__).parent.parent.parent.parent) + "/docs/media/" + file_name) + return path.read_bytes() + + +LOAD_GW_EVENT = load_event("apiGatewayProxyEvent.json") +TEXT_HTML = "text/html" +APPLICATION_JSON = "application/json" + + +def test_alb_event(): + # GIVEN a Application Load Balancer proxy type event + app = ApiGatewayResolver(proxy_type=ProxyEventType.alb_event) + + @app.get("/lambda") + def foo() -> Tuple[int, str, str]: + assert isinstance(app.current_event, ALBEvent) + assert app.lambda_context == {} + return 200, TEXT_HTML, "foo" + + # WHEN calling the event handler + result = app(load_event("albEvent.json"), {}) + + # THEN process event correctly + # AND set the current_event type as ALBEvent + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == TEXT_HTML + assert result["body"] == "foo" + + +def test_api_gateway_v1(): + # GIVEN a Http API V1 proxy type event + app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1) + + @app.get("/my/path") + def get_lambda() -> Tuple[int, str, str]: + assert isinstance(app.current_event, APIGatewayProxyEvent) + assert app.lambda_context == {} + return 200, APPLICATION_JSON, json.dumps({"foo": "value"}) + + # WHEN calling the event handler + result = app(LOAD_GW_EVENT, {}) + + # THEN process event correctly + # AND set the current_event type as APIGatewayProxyEvent + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == APPLICATION_JSON + + +def test_api_gateway(): + # GIVEN a Rest API Gateway proxy type event + app = ApiGatewayResolver(proxy_type=ProxyEventType.api_gateway) + + @app.get("/my/path") + def get_lambda() -> Tuple[int, str, str]: + assert isinstance(app.current_event, APIGatewayProxyEvent) + return 200, TEXT_HTML, "foo" + + # WHEN calling the event handler + result = app(LOAD_GW_EVENT, {}) + + # THEN process event correctly + # AND set the current_event type as APIGatewayProxyEvent + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == TEXT_HTML + assert result["body"] == "foo" + + +def test_api_gateway_v2(): + # GIVEN a Http API V2 proxy type event + app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v2) + + @app.post("/my/path") + def my_path() -> Tuple[int, str, str]: + assert isinstance(app.current_event, APIGatewayProxyEventV2) + post_data = app.current_event.json_body + return 200, "plain/text", post_data["username"] + + # WHEN calling the event handler + result = app(load_event("apiGatewayProxyV2Event.json"), {}) + + # THEN process event correctly + # AND set the current_event type as APIGatewayProxyEventV2 + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == "plain/text" + assert result["body"] == "tom" + + +def test_include_rule_matching(): + # GIVEN + app = ApiGatewayResolver() + + @app.get("//") + def get_lambda(my_id: str, name: str) -> Tuple[int, str, str]: + assert name == "my" + return 200, "plain/html", my_id + + # WHEN calling the event handler + result = app(LOAD_GW_EVENT, {}) + + # THEN + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == "plain/html" + assert result["body"] == "path" + + +def test_no_matches(): + # GIVEN an event that does not match any of the given routes + app = ApiGatewayResolver() + + @app.get("/not_matching_get") + def get_func(): + raise RuntimeError() + + @app.post("/no_matching_post") + def post_func(): + raise RuntimeError() + + @app.put("/no_matching_put") + def put_func(): + raise RuntimeError() + + @app.delete("/no_matching_delete") + def delete_func(): + raise RuntimeError() + + @app.patch("/no_matching_patch") + def patch_func(): + raise RuntimeError() + + def handler(event, context): + return app.resolve(event, context) + + # Also check check the route configurations + routes = app._routes + assert len(routes) == 5 + for route in routes: + if route.func == get_func: + assert route.method == "GET" + elif route.func == post_func: + assert route.method == "POST" + elif route.func == put_func: + assert route.method == "PUT" + elif route.func == delete_func: + assert route.method == "DELETE" + elif route.func == patch_func: + assert route.method == "PATCH" + + # WHEN calling the handler + # THEN return a 404 + result = handler(LOAD_GW_EVENT, None) + assert result["statusCode"] == 404 + # AND cors headers are not returned + assert "Access-Control-Allow-Origin" not in result["headers"] + + +def test_cors(): + # GIVEN a function with cors=True + # AND http method set to GET + app = ApiGatewayResolver() + + @app.get("/my/path", cors=True) + def with_cors() -> Tuple[int, str, str]: + return 200, TEXT_HTML, "test" + + def handler(event, context): + return app.resolve(event, context) + + # WHEN calling the event handler + result = handler(LOAD_GW_EVENT, None) + + # THEN the headers should include cors headers + assert "headers" in result + headers = result["headers"] + assert headers["Content-Type"] == TEXT_HTML + assert headers["Access-Control-Allow-Origin"] == "*" + assert "Access-Control-Allow-Credentials" not in headers + assert headers["Access-Control-Allow-Headers"] == ",".join(sorted(CORSConfig._REQUIRED_HEADERS)) + + +def test_compress(): + # GIVEN a function that has compress=True + # AND an event with a "Accept-Encoding" that include gzip + app = ApiGatewayResolver() + mock_event = {"path": "/my/request", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} + expected_value = '{"test": "value"}' + + @app.get("/my/request", compress=True) + def with_compression() -> Tuple[int, str, str]: + return 200, APPLICATION_JSON, expected_value + + def handler(event, context): + return app.resolve(event, context) + + # WHEN calling the event handler + result = handler(mock_event, None) + + # THEN then gzip the response and base64 encode as a string + assert result["isBase64Encoded"] is True + body = result["body"] + assert isinstance(body, str) + decompress = zlib.decompress(base64.b64decode(body), wbits=zlib.MAX_WBITS | 16).decode("UTF-8") + assert decompress == expected_value + headers = result["headers"] + assert headers["Content-Encoding"] == "gzip" + + +def test_base64_encode(): + # GIVEN a function that returns bytes + app = ApiGatewayResolver() + mock_event = {"path": "/my/path", "httpMethod": "GET", "headers": {"Accept-Encoding": "deflate, gzip"}} + + @app.get("/my/path", compress=True) + def read_image() -> Tuple[int, str, bytes]: + return 200, "image/png", read_media("idempotent_sequence_exception.png") + + # WHEN calling the event handler + result = app(mock_event, None) + + # THEN return the body and a base64 encoded string + assert result["isBase64Encoded"] is True + body = result["body"] + assert isinstance(body, str) + headers = result["headers"] + assert headers["Content-Encoding"] == "gzip" + + +def test_compress_no_accept_encoding(): + # GIVEN a function with compress=True + # AND the request has no "Accept-Encoding" set to include gzip + app = ApiGatewayResolver() + expected_value = "Foo" + + @app.get("/my/path", compress=True) + def return_text() -> Tuple[int, str, str]: + return 200, "text/plain", expected_value + + # WHEN calling the event handler + result = app({"path": "/my/path", "httpMethod": "GET", "headers": {}}, None) + + # THEN don't perform any gzip compression + assert result["isBase64Encoded"] is False + assert result["body"] == expected_value + + +def test_cache_control_200(): + # GIVEN a function with cache_control set + app = ApiGatewayResolver() + + @app.get("/success", cache_control="max-age=600") + def with_cache_control() -> Tuple[int, str, str]: + return 200, TEXT_HTML, "has 200 response" + + def handler(event, context): + return app.resolve(event, context) + + # WHEN calling the event handler + # AND the function returns a 200 status code + result = handler({"path": "/success", "httpMethod": "GET"}, None) + + # THEN return the set Cache-Control + headers = result["headers"] + assert headers["Content-Type"] == TEXT_HTML + assert headers["Cache-Control"] == "max-age=600" + + +def test_cache_control_non_200(): + # GIVEN a function with cache_control set + app = ApiGatewayResolver() + + @app.delete("/fails", cache_control="max-age=600") + def with_cache_control_has_500() -> Tuple[int, str, str]: + return 503, TEXT_HTML, "has 503 response" + + def handler(event, context): + return app.resolve(event, context) + + # WHEN calling the event handler + # AND the function returns a 503 status code + result = handler({"path": "/fails", "httpMethod": "DELETE"}, None) + + # THEN return a Cache-Control of "no-cache" + headers = result["headers"] + assert headers["Content-Type"] == TEXT_HTML + assert headers["Cache-Control"] == "no-cache" + + +def test_rest_api(): + # GIVEN a function that returns a Dict + app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1) + expected_dict = {"foo": "value", "second": Decimal("100.01")} + + @app.get("/my/path") + def rest_func() -> Dict: + return expected_dict + + # WHEN calling the event handler + result = app(LOAD_GW_EVENT, {}) + + # THEN automatically process this as a json rest api response + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == APPLICATION_JSON + expected_str = json.dumps(expected_dict, separators=(",", ":"), indent=None, cls=Encoder) + assert result["body"] == expected_str + + +def test_handling_response_type(): + # GIVEN a function that returns Response + app = ApiGatewayResolver(proxy_type=ProxyEventType.http_api_v1) + + @app.get("/my/path") + def rest_func() -> Response: + return Response( + status_code=404, + content_type="used-if-not-set-in-header", + body="Not found", + headers={"Content-Type": "header-content-type-wins", "custom": "value"}, + ) + + # WHEN calling the event handler + result = app(LOAD_GW_EVENT, {}) + + # THEN the result can include some additional field control like overriding http headers + assert result["statusCode"] == 404 + assert result["headers"]["Content-Type"] == "header-content-type-wins" + assert result["headers"]["custom"] == "value" + assert result["body"] == "Not found" + + +def test_custom_cors_config(): + # GIVEN a custom cors configuration + allow_header = ["foo2"] + cors_config = CORSConfig( + allow_origin="https://foo1", + expose_headers=["foo1"], + allow_headers=allow_header, + max_age=100, + allow_credentials=True, + ) + app = ApiGatewayResolver(cors=cors_config) + event = {"path": "/cors", "httpMethod": "GET"} + + @app.get("/cors", cors=True) + def get_with_cors(): + return {} + + @app.get("/another-one") + def another_one(): + return {} + + # WHEN calling the event handler + result = app(event, None) + + # THEN return the custom cors headers + assert "headers" in result + headers = result["headers"] + assert headers["Content-Type"] == APPLICATION_JSON + assert headers["Access-Control-Allow-Origin"] == cors_config.allow_origin + expected_allows_headers = ",".join(sorted(set(allow_header + cors_config._REQUIRED_HEADERS))) + assert headers["Access-Control-Allow-Headers"] == expected_allows_headers + assert headers["Access-Control-Expose-Headers"] == ",".join(cors_config.expose_headers) + assert headers["Access-Control-Max-Age"] == str(cors_config.max_age) + assert "Access-Control-Allow-Credentials" in headers + assert headers["Access-Control-Allow-Credentials"] == "true" + + # AND custom cors was set on the app + assert isinstance(app._cors, CORSConfig) + assert app._cors is cors_config + # AND routes without cors don't include "Access-Control" headers + event = {"path": "/another-one", "httpMethod": "GET"} + result = app(event, None) + headers = result["headers"] + assert "Access-Control-Allow-Origin" not in headers + + +def test_no_content_response(): + # GIVEN a response with no content-type or body + response = Response(status_code=204, content_type=None, body=None, headers=None) + + # WHEN calling to_dict + result = response.to_dict() + + # THEN return an None body and no Content-Type header + assert result["body"] is None + assert result["statusCode"] == 204 + assert "Content-Type" not in result["headers"] + + +def test_no_matches_with_cors(): + # GIVEN an event that does not match any of the given routes + # AND cors enabled + app = ApiGatewayResolver(cors=CORSConfig()) + + # WHEN calling the handler + result = app({"path": "/another-one", "httpMethod": "GET"}, None) + + # THEN return a 404 + # AND cors headers are returned + assert result["statusCode"] == 404 + assert "Access-Control-Allow-Origin" in result["headers"] + + +def test_preflight(): + # GIVEN an event for an OPTIONS call that does not match any of the given routes + # AND cors is enabled + app = ApiGatewayResolver(cors=CORSConfig()) + + @app.get("/foo", cors=True) + def foo_cors(): + ... + + @app.route(method="delete", rule="/foo", cors=True) + def foo_delete_cors(): + ... + + @app.post("/foo") + def post_no_cors(): + ... + + # WHEN calling the handler + result = app({"path": "/foo", "httpMethod": "OPTIONS"}, None) + + # THEN return no content + # AND include Access-Control-Allow-Methods of the cors methods used + assert result["statusCode"] == 204 + assert result["body"] is None + headers = result["headers"] + assert "Content-Type" not in headers + assert "Access-Control-Allow-Origin" in result["headers"] + assert headers["Access-Control-Allow-Methods"] == "DELETE,GET,OPTIONS" + + +def test_custom_preflight_response(): + # GIVEN cors is enabled + # AND we have a custom preflight method + # AND the request matches this custom preflight route + app = ApiGatewayResolver(cors=CORSConfig()) + + @app.route(method="OPTIONS", rule="/some-call", cors=True) + def custom_preflight(): + return Response( + status_code=200, + content_type=TEXT_HTML, + body="Foo", + headers={"Access-Control-Allow-Methods": "CUSTOM"}, + ) + + @app.route(method="CUSTOM", rule="/some-call", cors=True) + def custom_method(): + ... + + # WHEN calling the handler + result = app({"path": "/some-call", "httpMethod": "OPTIONS"}, None) + + # THEN return the custom preflight response + assert result["statusCode"] == 200 + assert result["body"] == "Foo" + headers = result["headers"] + assert headers["Content-Type"] == TEXT_HTML + assert "Access-Control-Allow-Origin" in result["headers"] + assert headers["Access-Control-Allow-Methods"] == "CUSTOM" diff --git a/tests/functional/event_handler/test_appsync.py b/tests/functional/event_handler/test_appsync.py index c72331c32f1..e260fef89ab 100644 --- a/tests/functional/event_handler/test_appsync.py +++ b/tests/functional/event_handler/test_appsync.py @@ -1,18 +1,12 @@ import asyncio -import json import sys -from pathlib import Path import pytest from aws_lambda_powertools.event_handler import AppSyncResolver from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent from aws_lambda_powertools.utilities.typing import LambdaContext - - -def load_event(file_name: str) -> dict: - path = Path(str(Path(__file__).parent.parent.parent) + "/events/" + file_name) - return json.loads(path.read_text()) +from tests.functional.utils import load_event def test_direct_resolver(): diff --git a/tests/functional/idempotency/conftest.py b/tests/functional/idempotency/conftest.py index d34d5da7d12..e100957dee7 100644 --- a/tests/functional/idempotency/conftest.py +++ b/tests/functional/idempotency/conftest.py @@ -1,7 +1,6 @@ import datetime import hashlib import json -import os from collections import namedtuple from decimal import Decimal from unittest import mock @@ -17,6 +16,7 @@ from aws_lambda_powertools.utilities.idempotency.idempotency import IdempotencyConfig from aws_lambda_powertools.utilities.validation import envelopes from aws_lambda_powertools.utilities.validation.base import unwrap_event_from_envelope +from tests.functional.utils import load_event TABLE_NAME = "TEST_TABLE" @@ -28,11 +28,7 @@ def config() -> Config: @pytest.fixture(scope="module") def lambda_apigw_event(): - full_file_name = os.path.dirname(os.path.realpath(__file__)) + "/../../events/" + "apiGatewayProxyV2Event.json" - with open(full_file_name) as fp: - event = json.load(fp) - - return event + return load_event("apiGatewayProxyV2Event.json") @pytest.fixture diff --git a/tests/functional/parser/test_alb.py b/tests/functional/parser/test_alb.py index 88631c7194c..d48e39f1bab 100644 --- a/tests/functional/parser/test_alb.py +++ b/tests/functional/parser/test_alb.py @@ -3,7 +3,7 @@ from aws_lambda_powertools.utilities.parser import ValidationError, event_parser from aws_lambda_powertools.utilities.parser.models import AlbModel from aws_lambda_powertools.utilities.typing import LambdaContext -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=AlbModel) diff --git a/tests/functional/parser/test_apigw.py b/tests/functional/parser/test_apigw.py index 333654f3f89..fc679d5dc37 100644 --- a/tests/functional/parser/test_apigw.py +++ b/tests/functional/parser/test_apigw.py @@ -2,7 +2,7 @@ from aws_lambda_powertools.utilities.parser.models import APIGatewayProxyEventModel from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.parser.schemas import MyApiGatewayBusiness -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=MyApiGatewayBusiness, envelope=envelopes.ApiGatewayEnvelope) diff --git a/tests/functional/parser/test_cloudwatch.py b/tests/functional/parser/test_cloudwatch.py index 9a61f339140..7290d0bffcb 100644 --- a/tests/functional/parser/test_cloudwatch.py +++ b/tests/functional/parser/test_cloudwatch.py @@ -9,7 +9,7 @@ from aws_lambda_powertools.utilities.parser.models import CloudWatchLogsLogEvent, CloudWatchLogsModel from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.parser.schemas import MyCloudWatchBusiness -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=MyCloudWatchBusiness, envelope=envelopes.CloudWatchLogsEnvelope) diff --git a/tests/functional/parser/test_dynamodb.py b/tests/functional/parser/test_dynamodb.py index bd7e0795f42..9917fac234b 100644 --- a/tests/functional/parser/test_dynamodb.py +++ b/tests/functional/parser/test_dynamodb.py @@ -5,7 +5,7 @@ from aws_lambda_powertools.utilities.parser import ValidationError, envelopes, event_parser from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.parser.schemas import MyAdvancedDynamoBusiness, MyDynamoBusiness -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=MyDynamoBusiness, envelope=envelopes.DynamoDBStreamEnvelope) diff --git a/tests/functional/parser/test_eventbridge.py b/tests/functional/parser/test_eventbridge.py index 7a3066d7b04..6242403ab35 100644 --- a/tests/functional/parser/test_eventbridge.py +++ b/tests/functional/parser/test_eventbridge.py @@ -5,7 +5,7 @@ from aws_lambda_powertools.utilities.parser import ValidationError, envelopes, event_parser from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.parser.schemas import MyAdvancedEventbridgeBusiness, MyEventbridgeBusiness -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=MyEventbridgeBusiness, envelope=envelopes.EventBridgeEnvelope) diff --git a/tests/functional/parser/test_kinesis.py b/tests/functional/parser/test_kinesis.py index 5a7a94e0dac..632a7463805 100644 --- a/tests/functional/parser/test_kinesis.py +++ b/tests/functional/parser/test_kinesis.py @@ -6,7 +6,7 @@ from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamModel, KinesisDataStreamRecordPayload from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.parser.schemas import MyKinesisBusiness -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=MyKinesisBusiness, envelope=envelopes.KinesisDataStreamEnvelope) diff --git a/tests/functional/parser/test_s3 object_event.py b/tests/functional/parser/test_s3 object_event.py index da015338cf4..90c2555360d 100644 --- a/tests/functional/parser/test_s3 object_event.py +++ b/tests/functional/parser/test_s3 object_event.py @@ -1,7 +1,7 @@ from aws_lambda_powertools.utilities.parser import event_parser from aws_lambda_powertools.utilities.parser.models import S3ObjectLambdaEvent from aws_lambda_powertools.utilities.typing import LambdaContext -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=S3ObjectLambdaEvent) diff --git a/tests/functional/parser/test_s3.py b/tests/functional/parser/test_s3.py index a9c325f3a97..71a5dc6afe3 100644 --- a/tests/functional/parser/test_s3.py +++ b/tests/functional/parser/test_s3.py @@ -1,7 +1,7 @@ from aws_lambda_powertools.utilities.parser import event_parser, parse from aws_lambda_powertools.utilities.parser.models import S3Model, S3RecordModel from aws_lambda_powertools.utilities.typing import LambdaContext -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=S3Model) diff --git a/tests/functional/parser/test_ses.py b/tests/functional/parser/test_ses.py index f96da7bad66..d434e2350f8 100644 --- a/tests/functional/parser/test_ses.py +++ b/tests/functional/parser/test_ses.py @@ -1,7 +1,7 @@ from aws_lambda_powertools.utilities.parser import event_parser from aws_lambda_powertools.utilities.parser.models import SesModel, SesRecordModel from aws_lambda_powertools.utilities.typing import LambdaContext -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event @event_parser(model=SesModel) diff --git a/tests/functional/parser/test_sns.py b/tests/functional/parser/test_sns.py index 015af3693fa..81158a4419e 100644 --- a/tests/functional/parser/test_sns.py +++ b/tests/functional/parser/test_sns.py @@ -5,7 +5,7 @@ from aws_lambda_powertools.utilities.parser import ValidationError, envelopes, event_parser from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.parser.schemas import MyAdvancedSnsBusiness, MySnsBusiness -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event from tests.functional.validator.conftest import sns_event # noqa: F401 diff --git a/tests/functional/parser/test_sqs.py b/tests/functional/parser/test_sqs.py index 0cea8246b50..7ca883616f2 100644 --- a/tests/functional/parser/test_sqs.py +++ b/tests/functional/parser/test_sqs.py @@ -5,7 +5,7 @@ from aws_lambda_powertools.utilities.parser import ValidationError, envelopes, event_parser from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.parser.schemas import MyAdvancedSqsBusiness, MySqsBusiness -from tests.functional.parser.utils import load_event +from tests.functional.utils import load_event from tests.functional.validator.conftest import sqs_event # noqa: F401 diff --git a/tests/functional/parser/utils.py b/tests/functional/parser/utils.py deleted file mode 100644 index 7cb949b1289..00000000000 --- a/tests/functional/parser/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -import json -import os -from typing import Any - - -def get_event_file_path(file_name: str) -> str: - return os.path.dirname(os.path.realpath(__file__)) + "/../../events/" + file_name - - -def load_event(file_name: str) -> Any: - full_file_name = get_event_file_path(file_name) - with open(full_file_name) as fp: - return json.load(fp) diff --git a/tests/functional/test_data_classes.py b/tests/functional/test_data_classes.py index 0221acc6853..d346eca480a 100644 --- a/tests/functional/test_data_classes.py +++ b/tests/functional/test_data_classes.py @@ -1,7 +1,6 @@ import base64 import datetime import json -import os from secrets import compare_digest from urllib.parse import quote_plus @@ -58,12 +57,7 @@ StreamViewType, ) from aws_lambda_powertools.utilities.data_classes.s3_object_event import S3ObjectLambdaEvent - - -def load_event(file_name: str) -> dict: - full_file_name = os.path.dirname(os.path.realpath(__file__)) + "/../events/" + file_name - with open(full_file_name) as fp: - return json.load(fp) +from tests.functional.utils import load_event def test_dict_wrapper_equals(): diff --git a/tests/functional/utils.py b/tests/functional/utils.py new file mode 100644 index 00000000000..a58d27f3526 --- /dev/null +++ b/tests/functional/utils.py @@ -0,0 +1,8 @@ +import json +from pathlib import Path +from typing import Any + + +def load_event(file_name: str) -> Any: + path = Path(str(Path(__file__).parent.parent) + "/events/" + file_name) + return json.loads(path.read_text())