diff --git a/lambda_proxy/proxy.py b/lambda_proxy/proxy.py index c16a363..a1bda9d 100644 --- a/lambda_proxy/proxy.py +++ b/lambda_proxy/proxy.py @@ -90,6 +90,7 @@ def __init__( binary_b64encode: bool = False, ttl=None, cache_control=None, + custom_headers: dict = None, description: str = None, tag: Tuple = None, ) -> None: @@ -105,6 +106,7 @@ def __init__( self.b64encode = binary_b64encode self.ttl = ttl self.cache_control = cache_control + self.custom_headers = custom_headers self.description = description or self.endpoint.__doc__ self.tag = tag if self.compression and self.compression not in ["gzip", "zlib", "deflate"]: @@ -367,6 +369,7 @@ def _add_route(self, path: str, endpoint: Callable, **kwargs) -> None: binary_encode = kwargs.pop("binary_b64encode", False) ttl = kwargs.pop("ttl", None) cache_control = kwargs.pop("cache_control", None) + custom_headers = kwargs.pop("custom_headers", None) description = kwargs.pop("description", None) tag = kwargs.pop("tag", None) @@ -400,6 +403,7 @@ def _add_route(self, path: str, endpoint: Callable, **kwargs) -> None: binary_encode, ttl, cache_control, + custom_headers, description, tag, ) @@ -494,7 +498,7 @@ def new_func(*args, **kwargs) -> Callable: def setup_docs(self) -> None: """Add default documentation routes.""" - openapi_url = f"/openapi.json" + openapi_url = "/openapi.json" def _openapi() -> Tuple[str, str, str]: """Return OpenAPI json.""" @@ -534,7 +538,7 @@ def _redoc_ui_html() -> Tuple[str, str, str]: self._add_route("/redoc", _redoc_ui_html, cors=True, tag=["documentation"]) - def response( + def response( # noqa: C901 self, status: Union[int, str], content_type: str, @@ -546,6 +550,7 @@ def response( b64encode: bool = False, ttl: int = None, cache_control: str = None, + custom_headers: dict = None, ): """Return HTTP response. @@ -582,6 +587,9 @@ def response( "headers": {"Content-Type": content_type}, } + if custom_headers: + messageData["headers"].update(custom_headers) + if cors: messageData["headers"]["Access-Control-Allow-Origin"] = "*" messageData["headers"]["Access-Control-Allow-Methods"] = ",".join( @@ -715,4 +723,5 @@ def __call__(self, event, context): b64encode=route_entry.b64encode, ttl=route_entry.ttl, cache_control=route_entry.cache_control, + custom_headers=route_entry.custom_headers, )