diff --git a/aws_lambda_powertools/event_handler/api_gateway.py b/aws_lambda_powertools/event_handler/api_gateway.py index 754cc24710d..d3a79761556 100644 --- a/aws_lambda_powertools/event_handler/api_gateway.py +++ b/aws_lambda_powertools/event_handler/api_gateway.py @@ -6,9 +6,9 @@ import traceback import zlib from enum import Enum -from functools import partial +from functools import partial, wraps from http import HTTPStatus -from typing import Any, Callable, Dict, List, Optional, Set, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union from aws_lambda_powertools.event_handler import content_types from aws_lambda_powertools.event_handler.exceptions import ServiceError @@ -630,3 +630,73 @@ def _to_response(self, result: Union[Dict, Response]) -> Response: def _json_dump(self, obj: Any) -> str: return self._serializer(obj) + + def include_router(self, router: "Router", prefix: Optional[str] = None) -> None: + """Adds all routes defined in a router""" + router._app = self + for route, func in router.api.items(): + if prefix and route[0] == "/": + route = (prefix, *route[1:]) + elif prefix: + route = (f"{prefix}{route[0]}", *route[1:]) + self.route(*route)(func()) + + +class Router: + """Router helper class to allow splitting ApiGatewayResolver into multiple files""" + + _app: ApiGatewayResolver + + def __init__(self): + self.api: Dict[tuple, Callable] = {} + + @property + def current_event(self) -> BaseProxyEvent: + return self._app.current_event + + @property + def lambda_context(self) -> LambdaContext: + return self._app.lambda_context + + def route( + self, + rule: str, + method: Union[str, Tuple[str], List[str]], + cors: Optional[bool] = None, + compress: bool = False, + cache_control: Optional[str] = None, + ): + def actual_decorator(func: Callable): + @wraps(func) + def wrapper(): + def inner_wrapper(**kwargs): + return func(**kwargs) + + return inner_wrapper + + if isinstance(method, (list, tuple)): + for item in method: + self.api[(rule, item, cors, compress, cache_control)] = wrapper + else: + self.api[(rule, method, cors, compress, cache_control)] = wrapper + + return actual_decorator + + def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None): + return self.route(rule, "GET", cors, compress, cache_control) + + def post(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None): + return self.route(rule, "POST", cors, compress, cache_control) + + def put(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None): + return self.route(rule, "PUT", cors, compress, cache_control) + + def delete( + self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None + ): + return self.route(rule, "DELETE", cors, compress, cache_control) + + def patch( + self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None + ): + return self.route(rule, "PATCH", cors, compress, cache_control) diff --git a/tests/functional/event_handler/test_api_gateway.py b/tests/functional/event_handler/test_api_gateway.py index 21700ec09dd..afc979065f8 100644 --- a/tests/functional/event_handler/test_api_gateway.py +++ b/tests/functional/event_handler/test_api_gateway.py @@ -17,6 +17,7 @@ ProxyEventType, Response, ResponseBuilder, + Router, ) from aws_lambda_powertools.event_handler.exceptions import ( BadRequestError, @@ -860,3 +861,136 @@ def base(): # THEN process event correctly assert result["statusCode"] == 200 assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + + +def test_api_gateway_app_router(): + # GIVEN a Router with registered routes + app = ApiGatewayResolver() + router = Router() + + @router.get("/my/path") + def foo(): + return {} + + app.include_router(router) + # WHEN calling the event handler after applying routes from router object + result = app(LOAD_GW_EVENT, {}) + + # THEN process event correctly + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + + +def test_api_gateway_app_router_with_params(): + # GIVEN a Router with registered routes + app = ApiGatewayResolver() + router = Router() + req = "foo" + event = deepcopy(LOAD_GW_EVENT) + event["resource"] = "/accounts/{account_id}" + event["path"] = f"/accounts/{req}" + lambda_context = {} + + @router.route(rule="/accounts/", method=["GET", "POST"]) + def foo(account_id): + assert router.current_event.raw_event == event + assert router.lambda_context == lambda_context + assert account_id == f"{req}" + return {} + + app.include_router(router) + # WHEN calling the event handler after applying routes from router object + result = app(event, lambda_context) + + # THEN process event correctly + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + + +def test_api_gateway_app_router_with_prefix(): + # GIVEN a Router with registered routes + # AND a prefix is defined during the registration + app = ApiGatewayResolver() + router = Router() + + @router.get(rule="/path") + def foo(): + return {} + + app.include_router(router, prefix="/my") + # WHEN calling the event handler after applying routes from router object + result = app(LOAD_GW_EVENT, {}) + + # THEN process event correctly + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + + +def test_api_gateway_app_router_with_prefix_equals_path(): + # GIVEN a Router with registered routes + # AND a prefix is defined during the registration + app = ApiGatewayResolver() + router = Router() + + @router.get(rule="/") + def foo(): + return {} + + app.include_router(router, prefix="/my/path") + # WHEN calling the event handler after applying routes from router object + # WITH the request path matching the registration prefix + result = app(LOAD_GW_EVENT, {}) + + # THEN process event correctly + assert result["statusCode"] == 200 + assert result["headers"]["Content-Type"] == content_types.APPLICATION_JSON + + +def test_api_gateway_app_router_with_different_methods(): + # GIVEN a Router with all the possible HTTP methods + app = ApiGatewayResolver() + router = Router() + + @router.get("/not_matching_get") + def get_func(): + raise RuntimeError() + + @router.post("/no_matching_post") + def post_func(): + raise RuntimeError() + + @router.put("/no_matching_put") + def put_func(): + raise RuntimeError() + + @router.delete("/no_matching_delete") + def delete_func(): + raise RuntimeError() + + @router.patch("/no_matching_patch") + def patch_func(): + raise RuntimeError() + + app.include_router(router) + + # Also check check the route configurations + routes = app._routes + assert len(routes) == 5 + for route in routes: + if route.func == get_func: + assert route.method == "GET" + elif route.func == post_func: + assert route.method == "POST" + elif route.func == put_func: + assert route.method == "PUT" + elif route.func == delete_func: + assert route.method == "DELETE" + elif route.func == patch_func: + assert route.method == "PATCH" + + # WHEN calling the handler + # THEN return a 404 + result = app(LOAD_GW_EVENT, None) + assert result["statusCode"] == 404 + # AND cors headers are not returned + assert "Access-Control-Allow-Origin" not in result["headers"]