diff --git a/example/handler.py b/example/handler.py index 8a0c073..e9c7b53 100644 --- a/example/handler.py +++ b/example/handler.py @@ -1,83 +1,92 @@ """app: handle requests.""" -from typing import Dict, Tuple -import typing.io - -import json +from typing import Dict from lambda_proxy.proxy import API +from lambda_proxy.responses import PlainTextResponse, Response -app = API(name="app", debug=True) - - -@app.get("/", cors=True) -def main() -> Tuple[str, str, str]: - """Return JSON Object.""" - return ("OK", "text/plain", "Yo") - - -@app.get("/", cors=True) -def _re_one(regex1: str) -> Tuple[str, str, str]: - """Return JSON Object.""" - return ("OK", "text/plain", regex1) - +app = API(name="app") -@app.get("/", cors=True) -def _re_two(regex2: str) -> Tuple[str, str, str]: - """Return JSON Object.""" - return ("OK", "text/plain", regex2) +@app.get("/", cors=True, response_class=PlainTextResponse) +def main(): + """Return String.""" + return "Yo" -@app.post("/people", cors=True) -def people_post(body) -> Tuple[str, str, str]: - """Return JSON Object.""" - return ("OK", "text/plain", body) +@app.post("/people", cors=True, response_class=PlainTextResponse) +def people_post(body): + """Return String.""" + return body -@app.get("/people", cors=True) -def people_get() -> Tuple[str, str, str]: - """Return JSON Object.""" - return ("OK", "text/plain", "Nope") - -@app.get("/", cors=True) -@app.get("//", cors=True) -def double(user: str, num: int = 0) -> Tuple[str, str, str]: - """Return JSON Object.""" - return ("OK", "text/plain", f"{user}-{num}") +@app.get("/people", cors=True, response_class=PlainTextResponse) +def people_get(): + """Return String.""" + return "Nope" -@app.get("/kw/", cors=True) -def kw_method(user: str, **kwargs: Dict) -> Tuple[str, str, str]: - """Return JSON Object.""" - return ("OK", "text/plain", f"{user}") +@app.get("/kw/", cors=True, response_class=PlainTextResponse) +def kw_method(user: str, **kwargs: Dict): + """Return String.""" + return f"{user}" -@app.get("/ctx/", cors=True) +@app.get("/ctx/", cors=True, response_class=PlainTextResponse) @app.pass_context @app.pass_event -def ctx_method(evt: Dict, ctx: Dict, user: str, num: int = 0) -> Tuple[str, str, str]: - """Return JSON Object.""" - return ("OK", "text/plain", f"{user}-{num}") +def ctx_method(evt: Dict, ctx: Dict, user: str, num: int = 0): + """Return String.""" + return f"{user}-{num}" -@app.get("/json", cors=True) -def json_handler() -> Tuple[str, str, str]: +@app.get("/json/itworks", cors=True) +def json_handler(): """Return JSON Object.""" - return ("OK", "application/json", json.dumps({"app": "it works"})) + return {"app": "it works"} @app.get("/binary", cors=True, payload_compression_method="gzip") -def bin() -> Tuple[str, str, typing.io.BinaryIO]: +def bin(): """Return image.""" with open("./rpix.png", "rb") as f: - return ("OK", "image/png", f.read()) + return Response(f.read(), media_type="image/png") @app.get( "/b64binary", cors=True, payload_compression_method="gzip", binary_b64encode=True, ) -def b64bin() -> Tuple[str, str, typing.io.BinaryIO]: +def b64bin(): """Return base64 encoded image.""" with open("./rpix.png", "rb") as f: - return ("OK", "image/png", f.read()) + return Response(f.read(), media_type="image/png") + + +@app.get("/header/json", cors=True) +def addHeader_handler(resp: Response): + """Return JSON Object.""" + resp.headers["Cache-Control"] = "max-age=3600" + return {"app": "it works"} + + +@app.get("/", cors=True, response_class=PlainTextResponse) +@app.get("//", cors=True, response_class=PlainTextResponse) +def double(user: str, num: int = 0): + """Return String.""" + return f"{user}-{num}" + + +@app.get( + "/", cors=True, response_class=PlainTextResponse +) +def _re_one(regex1: str): + """Return String.""" + return regex1 + + +@app.get( + "/", cors=True, response_class=PlainTextResponse +) +def _re_two(regex2: str): + """Return String.""" + return regex2 diff --git a/lambda_proxy/errors.py b/lambda_proxy/errors.py new file mode 100644 index 0000000..c69e8b8 --- /dev/null +++ b/lambda_proxy/errors.py @@ -0,0 +1,29 @@ +""" +lambda-proxy Errors. + +Original code from + - https://github.com/encode/starlette/blob/master/starlette/exceptions.py + - https://github.com/tiangolo/fastapi/blob/master/fastapi/exceptions.py +""" + +import http + + +class HTTPException(Exception): + """Base HTTP Execption for lambda-proxy.""" + + def __init__( + self, status_code: int, detail: str = None, headers: dict = None + ) -> None: + """Set Exception.""" + if detail is None: + detail = http.HTTPStatus(status_code).phrase + + self.status_code = status_code + self.detail = detail + self.headers = headers + + def __repr__(self) -> str: + """Exception repr.""" + class_name = self.__class__.__name__ + return f"{class_name}(status_code={self.status_code!r}, detail={self.detail!r})" diff --git a/lambda_proxy/proxy.py b/lambda_proxy/proxy.py index c16a363..e03dad9 100644 --- a/lambda_proxy/proxy.py +++ b/lambda_proxy/proxy.py @@ -3,7 +3,7 @@ Freely adapted from https://github.com/aws/chalice """ -from typing import Any, Callable, Dict, List, Optional, Tuple, Sequence, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Type import inspect @@ -14,10 +14,12 @@ import zlib import base64 import logging -import warnings from functools import wraps from lambda_proxy import templates +from lambda_proxy.responses import Response, JSONResponse, HTMLResponse +from lambda_proxy.errors import HTTPException + params_expr = re.compile(r"(<[^>]*>)") proxy_pattern = re.compile(r"/{(?P.+)\+}$") @@ -27,6 +29,18 @@ regex_pattern = re.compile( r"^<(?Pregex)\((?P.+)\):(?P[a-zA-Z0-9_]+)>$" ) +binary_types = [ + "application/octet-stream", + "application/x-protobuf", + "application/x-tar", + "application/zip", + "image/png", + "image/jpeg", + "image/jpg", + "image/tiff", + "image/webp", + "image/jp2", +] def _path_to_regex(path: str) -> str: @@ -83,15 +97,14 @@ def __init__( self, endpoint: Callable, path: str, - methods: List = ["GET"], + methods: Sequence[str] = ["GET"], cors: bool = False, token: bool = False, payload_compression_method: str = "", binary_b64encode: bool = False, - ttl=None, - cache_control=None, - description: str = None, - tag: Tuple = None, + description: Optional[str] = None, + tag: Optional[Sequence[str]] = None, + response_class: Type[Response] = JSONResponse, ) -> None: """Initialize route object.""" self.endpoint = endpoint @@ -103,15 +116,17 @@ def __init__( self.token = token self.compression = payload_compression_method self.b64encode = binary_b64encode - self.ttl = ttl - self.cache_control = cache_control self.description = description or self.endpoint.__doc__ self.tag = tag + self.response_class = response_class + if self.compression and self.compression not in ["gzip", "zlib", "deflate"]: raise ValueError( f"'{payload_compression_method}' is not a supported compression" ) + self.response_param: Optional[str] = None + def __eq__(self, other) -> bool: """Check for equality.""" return self.__dict__ == other.__dict__ @@ -256,11 +271,21 @@ def _get_parameters(self, route: RouteEntry) -> List[Dict]: for name, arg in endpoint_args.items(): if name not in endpoint_args_names: continue + parameter = {"name": name, "in": "query", "schema": {}} - if arg.default is not inspect.Parameter.empty: + + if isinstance(arg.annotation, type) and issubclass( + arg.annotation, Response + ): + route.response_param = name + continue + + elif arg.default is not inspect.Parameter.empty: parameter["schema"]["default"] = arg.default + elif arg.kind == inspect.Parameter.VAR_KEYWORD: parameter["schema"]["format"] = "dict" + else: parameter["schema"]["format"] = "string" parameter["required"] = True @@ -359,24 +384,20 @@ def _already_configured(self, log) -> bool: return False - def _add_route(self, path: str, endpoint: Callable, **kwargs) -> None: - methods = kwargs.pop("methods", ["GET"]) - cors = kwargs.pop("cors", False) - token = kwargs.pop("token", "") - payload_compression = kwargs.pop("payload_compression_method", "") - binary_encode = kwargs.pop("binary_b64encode", False) - ttl = kwargs.pop("ttl", None) - cache_control = kwargs.pop("cache_control", None) - description = kwargs.pop("description", None) - tag = kwargs.pop("tag", None) - - if ttl: - warnings.warn( - "ttl will be deprecated in 6.0.0, please use 'cache-control'", - DeprecationWarning, - stacklevel=2, - ) - + def _add_route( + self, + path: str, + endpoint: Callable, + methods: Sequence[str] = ["GET"], + cors: bool = False, + payload_compression_method: str = "", + binary_b64encode: bool = False, + description: Optional[str] = None, + tag: Optional[Sequence[str]] = None, + response_class: Type[Response] = JSONResponse, + token: bool = False, + **kwargs, + ) -> None: if kwargs: raise TypeError( "TypeError: route() got unexpected keyword " @@ -393,15 +414,14 @@ def _add_route(self, path: str, endpoint: Callable, **kwargs) -> None: route = RouteEntry( endpoint, path, - methods, - cors, - token, - payload_compression, - binary_encode, - ttl, - cache_control, - description, - tag, + methods=methods, + cors=cors, + token=token, + payload_compression_method=payload_compression_method, + binary_b64encode=binary_b64encode, + description=description, + tag=tag, + response_class=response_class, ) self.routes.append(route) @@ -409,6 +429,7 @@ def _checkroute(self, path: str, method: str) -> bool: for route in self.routes: if method in route.methods and path == route.path: return True + return False def _url_matching(self, url: str, method: str) -> Optional[RouteEntry]: @@ -496,90 +517,60 @@ def setup_docs(self) -> None: """Add default documentation routes.""" openapi_url = f"/openapi.json" - def _openapi() -> Tuple[str, str, str]: + def _openapi(): """Return OpenAPI json.""" - return ( - "OK", - "application/json", - json.dumps(self._get_openapi(openapi_prefix=self.request_path.prefix)), - ) + return self._get_openapi(openapi_prefix=self.request_path.prefix) self._add_route(openapi_url, _openapi, cors=True, tag=["documentation"]) - def _swagger_ui_html() -> Tuple[str, str, str]: + def _swagger_ui_html(): """Display Swagger HTML UI.""" openapi_prefix = self.request_path.prefix - return ( - "OK", - "text/html", - templates.swagger( - openapi_url=f"{openapi_prefix}{openapi_url}", - title=self.name + " - Swagger UI", - ), + return templates.swagger( + openapi_url=f"{openapi_prefix}{openapi_url}", + title=self.name + " - Swagger UI", ) - self._add_route("/docs", _swagger_ui_html, cors=True, tag=["documentation"]) + self._add_route( + "/docs", + _swagger_ui_html, + cors=True, + tag=["documentation"], + response_class=HTMLResponse, + ) - def _redoc_ui_html() -> Tuple[str, str, str]: + def _redoc_ui_html(): """Display Redoc HTML UI.""" openapi_prefix = self.request_path.prefix - return ( - "OK", - "text/html", - templates.redoc( - openapi_url=f"{openapi_prefix}{openapi_url}", - title=self.name + " - ReDoc", - ), + return templates.redoc( + openapi_url=f"{openapi_prefix}{openapi_url}", + title=self.name + " - ReDoc", ) - self._add_route("/redoc", _redoc_ui_html, cors=True, tag=["documentation"]) + self._add_route( + "/redoc", + _redoc_ui_html, + cors=True, + tag=["documentation"], + response_class=HTMLResponse, + ) def response( self, - status: Union[int, str], - content_type: str, - response_body: Any, + response: Response, cors: bool = False, accepted_methods: Sequence = [], - accepted_compression: str = "", compression: str = "", b64encode: bool = False, - ttl: int = None, - cache_control: str = None, ): """Return HTTP response. including response code (status), headers and body """ - statusCode = { - "OK": 200, - "EMPTY": 204, - "NOK": 400, - "FOUND": 302, - "NOT_FOUND": 404, - "CONFLICT": 409, - "ERROR": 500, - } - - binary_types = [ - "application/octet-stream", - "application/x-protobuf", - "application/x-tar", - "application/zip", - "image/png", - "image/jpeg", - "image/jpg", - "image/tiff", - "image/webp", - "image/jp2", - ] - - status = statusCode[status] if isinstance(status, str) else status - messageData: Dict[str, Any] = { - "statusCode": status, - "headers": {"Content-Type": content_type}, + "statusCode": response.status_code, + "headers": response.headers.copy(), } if cors: @@ -589,47 +580,32 @@ def response( ) messageData["headers"]["Access-Control-Allow-Credentials"] = "true" + response_body = response.body + accepted_compression = self.event["headers"].get("accept-encoding", "") if compression and compression in accepted_compression: messageData["headers"]["Content-Encoding"] = compression - if isinstance(response_body, str): - response_body = bytes(response_body, "utf-8") - if compression == "gzip": gzip_compress = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS | 16) response_body = ( gzip_compress.compress(response_body) + gzip_compress.flush() ) + elif compression == "zlib": zlib_compress = zlib.compressobj(9, zlib.DEFLATED, zlib.MAX_WBITS) response_body = ( zlib_compress.compress(response_body) + zlib_compress.flush() ) + elif compression == "deflate": deflate_compress = zlib.compressobj(9, zlib.DEFLATED, -zlib.MAX_WBITS) response_body = ( deflate_compress.compress(response_body) + deflate_compress.flush() ) - else: - return self.response( - "ERROR", - "application/json", - json.dumps( - {"errorMessage": f"Unsupported compression mode: {compression}"} - ), - ) - if ttl: - messageData["headers"]["Cache-Control"] = ( - f"max-age={ttl}" if status == 200 else "no-cache" - ) - elif cache_control: - messageData["headers"]["Cache-Control"] = ( - cache_control if status == 200 else "no-cache" - ) + else: + raise Exception(f"Unsupported compression mode: {compression}") - if ( - content_type in binary_types or not isinstance(response_body, str) - ) and b64encode: + if response.media_type in binary_types and b64encode: messageData["isBase64Encoded"] = True messageData["body"] = base64.b64encode(response_body).decode() else: @@ -651,68 +627,65 @@ def __call__(self, event, context): (key.lower(), value) for key, value in headers.items() ) - self.request_path = ApigwPath(self.event) - if self.request_path.path is None: - return self.response( - "NOK", - "application/json", - json.dumps({"errorMessage": "Missing or invalid path"}), - ) + try: + self.request_path = ApigwPath(self.event) + if self.request_path.path is None: + raise HTTPException(404, "Missing or invalid path") + + http_method = event["httpMethod"] + route_entry = self._url_matching(self.request_path.path, http_method) + if not route_entry: + raise HTTPException( + 501, + f"No view function for: {http_method} - {self.request_path.path}", + ) - http_method = event["httpMethod"] - route_entry = self._url_matching(self.request_path.path, http_method) - if not route_entry: - return self.response( - "NOK", - "application/json", - json.dumps( - { - "errorMessage": "No view function for: {} - {}".format( - http_method, self.request_path.path - ) - } - ), - ) + request_params = event.get("queryStringParameters", {}) or {} + if route_entry.token: + if not self._validate_token(request_params.get("access_token")): + raise HTTPException(401, "Invalid access token") - request_params = event.get("queryStringParameters", {}) or {} - if route_entry.token: - if not self._validate_token(request_params.get("access_token")): - return self.response( - "ERROR", - "application/json", - json.dumps({"message": "Invalid access token"}), - ) + # remove access_token from kwargs + request_params.pop("access_token", False) - # remove access_token from kwargs - request_params.pop("access_token", False) + function_kwargs = self._get_matching_args( + route_entry, self.request_path.path + ) - function_kwargs = self._get_matching_args(route_entry, self.request_path.path) - function_kwargs.update(request_params.copy()) - if http_method in ["POST", "PUT", "PATCH"] and event.get("body"): - body = event["body"] - if event.get("isBase64Encoded"): - body = base64.b64decode(body).decode() - function_kwargs.update(dict(body=body)) + function_kwargs.update(request_params.copy()) + if http_method in ["POST", "PUT", "PATCH"] and event.get("body"): + body = event["body"] + if event.get("isBase64Encoded"): + body = base64.b64decode(body).decode() + function_kwargs.update(dict(body=body)) + + # if route has a response we add it back into the args + if route_entry.response_param: + function_kwargs[route_entry.response_param] = Response() - try: response = route_entry.endpoint(**function_kwargs) - except Exception as err: - self.log.error(str(err)) - response = ( - "ERROR", - "application/json", - json.dumps({"errorMessage": str(err)}), + if not isinstance(response, Response): + response = route_entry.response_class(response) + + if route_entry.response_param: + sub_response = function_kwargs[route_entry.response_param] + response.headers.update(sub_response.headers) + if sub_response.status_code: + response.status_code = sub_response.status_code + + return self.response( + response, + cors=route_entry.cors, + accepted_methods=route_entry.methods, + compression=route_entry.compression, + b64encode=route_entry.b64encode, ) - return self.response( - response[0], - response[1], - response[2], - cors=route_entry.cors, - accepted_methods=route_entry.methods, - accepted_compression=self.event["headers"].get("accept-encoding", ""), - compression=route_entry.compression, - b64encode=route_entry.b64encode, - ttl=route_entry.ttl, - cache_control=route_entry.cache_control, - ) + except HTTPException as exc: + return self.response( + JSONResponse( + {"detail": exc.detail}, + status_code=exc.status_code, + headers=exc.headers, + ) + ) diff --git a/lambda_proxy/responses.py b/lambda_proxy/responses.py new file mode 100644 index 0000000..9cf95d1 --- /dev/null +++ b/lambda_proxy/responses.py @@ -0,0 +1,86 @@ +""" +Common response models. + +Freely adapted from https://github.com/encode/starlette/blob/master/starlette/responses.py +""" + +from typing import Any, Dict +import json + + +class Response: + """Response Base Class.""" + + media_type = None + charset = "utf-8" + + def __init__( + self, + content: Any = None, + status_code: int = 200, + headers: Dict[str, str] = {}, + media_type: str = None, + ) -> None: + """Initiate Response.""" + self.body = self.render(content) + self.status_code = status_code + if media_type is not None: + self.media_type = media_type + + self.init_headers(headers) + + def render(self, content: Any) -> bytes: + """Encode content.""" + if content is None: + return b"" + if isinstance(content, bytes): + return content + return content.encode(self.charset) + + def init_headers(self, headers: Dict[str, str] = {}) -> None: + """Create headers.""" + self._headers = headers.copy() + if self.body: + self._headers.update({"content-length": str(len(self.body))}) + + if self.media_type: + self._headers.update({"Content-Type": self.media_type}) + + @property + def headers(self) -> Dict: + """Return response headers.""" + return self._headers + + +class XMLResponse(Response): + """XML Response.""" + + media_type = "application/xml" + + +class HTMLResponse(Response): + """HTML Response.""" + + media_type = "text/html" + + +class JSONResponse(Response): + """JSON Response.""" + + media_type = "application/json" + + def render(self, content: Any) -> bytes: + """Dump dict to JSON string.""" + return json.dumps( + content, + ensure_ascii=False, + allow_nan=False, + indent=None, + separators=(",", ":"), + ).encode("utf-8") + + +class PlainTextResponse(Response): + """Plain Text Response.""" + + media_type = "text/plain" diff --git a/tox.ini b/tox.ini index d64486d..33393b4 100644 --- a/tox.ini +++ b/tox.ini @@ -5,7 +5,7 @@ envlist = py36,py37 [flake8] ignore = D203 exclude = .git,__pycache__,docs/source/conf.py,old,build,dist -max-complexity = 10 +max-complexity = 15 max-line-length = 90 [mypy]