Skip to content

Commit 7a7ba0a

Browse files
author
Michael Brewer
authored
refactor(apigateway): Add BaseRouter and duplicate route check (#757)
1 parent a9e2067 commit 7a7ba0a

File tree

4 files changed

+149
-129
lines changed

4 files changed

+149
-129
lines changed

aws_lambda_powertools/event_handler/api_gateway.py

+115-125
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
import os
55
import re
66
import traceback
7+
import warnings
78
import zlib
9+
from abc import ABC, abstractmethod
810
from enum import Enum
9-
from functools import partial, wraps
11+
from functools import partial
1012
from http import HTTPStatus
11-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union
13+
from typing import Any, Callable, Dict, List, Optional, Set, Union
1214

1315
from aws_lambda_powertools.event_handler import content_types
1416
from aws_lambda_powertools.event_handler.exceptions import ServiceError
@@ -227,78 +229,20 @@ def build(self, event: BaseProxyEvent, cors: Optional[CORSConfig] = None) -> Dic
227229
}
228230

229231

230-
class ApiGatewayResolver:
231-
"""API Gateway and ALB proxy resolver
232-
233-
Examples
234-
--------
235-
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator
236-
237-
```python
238-
from aws_lambda_powertools import Tracer
239-
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver
240-
241-
tracer = Tracer()
242-
app = ApiGatewayResolver()
243-
244-
@app.get("/get-call")
245-
def simple_get():
246-
return {"message": "Foo"}
247-
248-
@app.post("/post-call")
249-
def simple_post():
250-
post_data: dict = app.current_event.json_body
251-
return {"message": post_data["value"]}
252-
253-
@tracer.capture_lambda_handler
254-
def lambda_handler(event, context):
255-
return app.resolve(event, context)
256-
```
257-
"""
258-
232+
class BaseRouter(ABC):
259233
current_event: BaseProxyEvent
260234
lambda_context: LambdaContext
261235

262-
def __init__(
236+
@abstractmethod
237+
def route(
263238
self,
264-
proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent,
265-
cors: Optional[CORSConfig] = None,
266-
debug: Optional[bool] = None,
267-
serializer: Optional[Callable[[Dict], str]] = None,
268-
strip_prefixes: Optional[List[str]] = None,
239+
rule: str,
240+
method: Any,
241+
cors: Optional[bool] = None,
242+
compress: bool = False,
243+
cache_control: Optional[str] = None,
269244
):
270-
"""
271-
Parameters
272-
----------
273-
proxy_type: ProxyEventType
274-
Proxy request type, defaults to API Gateway V1
275-
cors: CORSConfig
276-
Optionally configure and enabled CORS. Not each route will need to have to cors=True
277-
debug: Optional[bool]
278-
Enables debug mode, by default False. Can be also be enabled by "POWERTOOLS_EVENT_HANDLER_DEBUG"
279-
environment variable
280-
serializer : Callable, optional
281-
function to serialize `obj` to a JSON formatted `str`, by default json.dumps
282-
strip_prefixes: List[str], optional
283-
optional list of prefixes to be removed from the request path before doing the routing. This is often used
284-
with api gateways with multiple custom mappings.
285-
"""
286-
self._proxy_type = proxy_type
287-
self._routes: List[Route] = []
288-
self._cors = cors
289-
self._cors_enabled: bool = cors is not None
290-
self._cors_methods: Set[str] = {"OPTIONS"}
291-
self._debug = resolve_truthy_env_var_choice(
292-
env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false"), choice=debug
293-
)
294-
self._strip_prefixes = strip_prefixes
295-
296-
# Allow for a custom serializer or a concise json serialization
297-
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
298-
299-
if self._debug:
300-
# Always does a pretty print when in debug mode
301-
self._serializer = partial(json.dumps, indent=4, cls=Encoder)
245+
raise NotImplementedError()
302246

303247
def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
304248
"""Get route decorator with GET `method`
@@ -434,6 +378,78 @@ def lambda_handler(event, context):
434378
"""
435379
return self.route(rule, "PATCH", cors, compress, cache_control)
436380

381+
382+
class ApiGatewayResolver(BaseRouter):
383+
"""API Gateway and ALB proxy resolver
384+
385+
Examples
386+
--------
387+
Simple example with a custom lambda handler using the Tracer capture_lambda_handler decorator
388+
389+
```python
390+
from aws_lambda_powertools import Tracer
391+
from aws_lambda_powertools.event_handler.api_gateway import ApiGatewayResolver
392+
393+
tracer = Tracer()
394+
app = ApiGatewayResolver()
395+
396+
@app.get("/get-call")
397+
def simple_get():
398+
return {"message": "Foo"}
399+
400+
@app.post("/post-call")
401+
def simple_post():
402+
post_data: dict = app.current_event.json_body
403+
return {"message": post_data["value"]}
404+
405+
@tracer.capture_lambda_handler
406+
def lambda_handler(event, context):
407+
return app.resolve(event, context)
408+
```
409+
"""
410+
411+
def __init__(
412+
self,
413+
proxy_type: Enum = ProxyEventType.APIGatewayProxyEvent,
414+
cors: Optional[CORSConfig] = None,
415+
debug: Optional[bool] = None,
416+
serializer: Optional[Callable[[Dict], str]] = None,
417+
strip_prefixes: Optional[List[str]] = None,
418+
):
419+
"""
420+
Parameters
421+
----------
422+
proxy_type: ProxyEventType
423+
Proxy request type, defaults to API Gateway V1
424+
cors: CORSConfig
425+
Optionally configure and enabled CORS. Not each route will need to have to cors=True
426+
debug: Optional[bool]
427+
Enables debug mode, by default False. Can be also be enabled by "POWERTOOLS_EVENT_HANDLER_DEBUG"
428+
environment variable
429+
serializer : Callable, optional
430+
function to serialize `obj` to a JSON formatted `str`, by default json.dumps
431+
strip_prefixes: List[str], optional
432+
optional list of prefixes to be removed from the request path before doing the routing. This is often used
433+
with api gateways with multiple custom mappings.
434+
"""
435+
self._proxy_type = proxy_type
436+
self._routes: List[Route] = []
437+
self._route_keys: List[str] = []
438+
self._cors = cors
439+
self._cors_enabled: bool = cors is not None
440+
self._cors_methods: Set[str] = {"OPTIONS"}
441+
self._debug = resolve_truthy_env_var_choice(
442+
env=os.getenv(constants.EVENT_HANDLER_DEBUG_ENV, "false"), choice=debug
443+
)
444+
self._strip_prefixes = strip_prefixes
445+
446+
# Allow for a custom serializer or a concise json serialization
447+
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
448+
449+
if self._debug:
450+
# Always does a pretty print when in debug mode
451+
self._serializer = partial(json.dumps, indent=4, cls=Encoder)
452+
437453
def route(
438454
self,
439455
rule: str,
@@ -451,6 +467,10 @@ def register_resolver(func: Callable):
451467
else:
452468
cors_enabled = cors
453469
self._routes.append(Route(method, self._compile_regex(rule), func, cors_enabled, compress, cache_control))
470+
route_key = method + rule
471+
if route_key in self._route_keys:
472+
warnings.warn(f"A route like this was already registered. method: '{method}' rule: '{rule}'")
473+
self._route_keys.append(route_key)
454474
if cors_enabled:
455475
logger.debug(f"Registering method {method.upper()} to Allow Methods in CORS")
456476
self._cors_methods.add(method.upper())
@@ -474,8 +494,8 @@ def resolve(self, event, context) -> Dict[str, Any]:
474494
"""
475495
if self._debug:
476496
print(self._json_dump(event))
477-
self.current_event = self._to_proxy_event(event)
478-
self.lambda_context = context
497+
BaseRouter.current_event = self._to_proxy_event(event)
498+
BaseRouter.lambda_context = context
479499
return self._resolve().build(self.current_event, self._cors)
480500

481501
def __call__(self, event, context) -> Any:
@@ -632,71 +652,41 @@ def _json_dump(self, obj: Any) -> str:
632652
return self._serializer(obj)
633653

634654
def include_router(self, router: "Router", prefix: Optional[str] = None) -> None:
635-
"""Adds all routes defined in a router"""
636-
router._app = self
637-
for route, func in router.api.items():
638-
if prefix and route[0] == "/":
639-
route = (prefix, *route[1:])
640-
elif prefix:
641-
route = (f"{prefix}{route[0]}", *route[1:])
642-
self.route(*route)(func())
643-
655+
"""Adds all routes defined in a router
644656
645-
class Router:
646-
"""Router helper class to allow splitting ApiGatewayResolver into multiple files"""
657+
Parameters
658+
----------
659+
router : Router
660+
The Router containing a list of routes to be registered after the existing routes
661+
prefix : str, optional
662+
An optional prefix to be added to the originally defined rule
663+
"""
664+
for route, func in router._routes.items():
665+
if prefix:
666+
rule = route[0]
667+
rule = prefix if rule == "/" else f"{prefix}{rule}"
668+
route = (rule, *route[1:])
647669

648-
_app: ApiGatewayResolver
670+
self.route(*route)(func)
649671

650-
def __init__(self):
651-
self.api: Dict[tuple, Callable] = {}
652672

653-
@property
654-
def current_event(self) -> BaseProxyEvent:
655-
return self._app.current_event
673+
class Router(BaseRouter):
674+
"""Router helper class to allow splitting ApiGatewayResolver into multiple files"""
656675

657-
@property
658-
def lambda_context(self) -> LambdaContext:
659-
return self._app.lambda_context
676+
def __init__(self):
677+
self._routes: Dict[tuple, Callable] = {}
660678

661679
def route(
662680
self,
663681
rule: str,
664-
method: Union[str, Tuple[str], List[str]],
682+
method: Union[str, List[str]],
665683
cors: Optional[bool] = None,
666684
compress: bool = False,
667685
cache_control: Optional[str] = None,
668686
):
669-
def actual_decorator(func: Callable):
670-
@wraps(func)
671-
def wrapper():
672-
def inner_wrapper(**kwargs):
673-
return func(**kwargs)
674-
675-
return inner_wrapper
676-
677-
if isinstance(method, (list, tuple)):
678-
for item in method:
679-
self.api[(rule, item, cors, compress, cache_control)] = wrapper
680-
else:
681-
self.api[(rule, method, cors, compress, cache_control)] = wrapper
682-
683-
return actual_decorator
684-
685-
def get(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
686-
return self.route(rule, "GET", cors, compress, cache_control)
687-
688-
def post(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
689-
return self.route(rule, "POST", cors, compress, cache_control)
690-
691-
def put(self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None):
692-
return self.route(rule, "PUT", cors, compress, cache_control)
693-
694-
def delete(
695-
self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None
696-
):
697-
return self.route(rule, "DELETE", cors, compress, cache_control)
687+
def register_route(func: Callable):
688+
methods = method if isinstance(method, list) else [method]
689+
for item in methods:
690+
self._routes[(rule, item, cors, compress, cache_control)] = func
698691

699-
def patch(
700-
self, rule: str, cors: Optional[bool] = None, compress: bool = False, cache_control: Optional[str] = None
701-
):
702-
return self.route(rule, "PATCH", cors, compress, cache_control)
692+
return register_route

aws_lambda_powertools/utilities/validation/exceptions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ class SchemaValidationError(Exception):
88

99
def __init__(
1010
self,
11-
message: str,
11+
message: Optional[str] = None,
1212
validation_message: Optional[str] = None,
1313
name: Optional[str] = None,
1414
path: Optional[List] = None,
@@ -21,7 +21,7 @@ def __init__(
2121
2222
Parameters
2323
----------
24-
message : str
24+
message : str, optional
2525
Powertools formatted error message
2626
validation_message : str, optional
2727
Containing human-readable information what is wrong

tests/functional/event_handler/test_api_gateway.py

+27
Original file line numberDiff line numberDiff line change
@@ -994,3 +994,30 @@ def patch_func():
994994
assert result["statusCode"] == 404
995995
# AND cors headers are not returned
996996
assert "Access-Control-Allow-Origin" not in result["headers"]
997+
998+
999+
def test_duplicate_routes():
1000+
# GIVEN a duplicate routes
1001+
app = ApiGatewayResolver()
1002+
router = Router()
1003+
1004+
@router.get("/my/path")
1005+
def get_func_duplicate():
1006+
raise RuntimeError()
1007+
1008+
@app.get("/my/path")
1009+
def get_func():
1010+
return {}
1011+
1012+
@router.get("/my/path")
1013+
def get_func_another_duplicate():
1014+
raise RuntimeError()
1015+
1016+
app.include_router(router)
1017+
1018+
# WHEN calling the handler
1019+
result = app(LOAD_GW_EVENT, None)
1020+
1021+
# THEN only execute the first registered route
1022+
# AND print warnings
1023+
assert result["statusCode"] == 200

tests/functional/test_logger.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -537,11 +537,11 @@ def format(self, record: logging.LogRecord) -> str: # noqa: A003
537537
logger = Logger(service=service_name, stream=stdout, logger_formatter=custom_formatter)
538538

539539
# WHEN a lambda function is decorated with logger
540-
@logger.inject_lambda_context
540+
@logger.inject_lambda_context(correlation_id_path="foo")
541541
def handler(event, context):
542542
logger.info("Hello")
543543

544-
handler({}, lambda_context)
544+
handler({"foo": "value"}, lambda_context)
545545

546546
lambda_context_keys = (
547547
"function_name",
@@ -554,8 +554,11 @@ def handler(event, context):
554554

555555
# THEN custom key should always be present
556556
# and lambda contextual info should also be in the logs
557+
# and get_correlation_id should return None
557558
assert "my_default_key" in log
558559
assert all(k in log for k in lambda_context_keys)
560+
assert log["correlation_id"] == "value"
561+
assert logger.get_correlation_id() is None
559562

560563

561564
def test_logger_custom_handler(lambda_context, service_name, tmp_path):

0 commit comments

Comments
 (0)