Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f943f45

Browse files
committedMar 17, 2025·
Merging from develop
1 parent 26c12d4 commit f943f45

File tree

9 files changed

+484
-10
lines changed

9 files changed

+484
-10
lines changed
 

‎aws_lambda_powertools/event_handler/api_gateway.py

Lines changed: 72 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@
2424
DEFAULT_OPENAPI_TITLE,
2525
DEFAULT_OPENAPI_VERSION,
2626
)
27-
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, SchemaValidationError
27+
from aws_lambda_powertools.event_handler.openapi.exceptions import (
28+
RequestValidationError,
29+
ResponseValidationError,
30+
SchemaValidationError,
31+
)
2832
from aws_lambda_powertools.event_handler.openapi.types import (
2933
COMPONENT_REF_PREFIX,
3034
METHODS_WITH_BODY,
@@ -1501,6 +1505,7 @@ def __init__(
15011505
serializer: Callable[[dict], str] | None = None,
15021506
strip_prefixes: list[str | Pattern] | None = None,
15031507
enable_validation: bool = False,
1508+
response_validation_error_http_code: HTTPStatus | int | None = None,
15041509
):
15051510
"""
15061511
Parameters
@@ -1520,6 +1525,8 @@ def __init__(
15201525
Each prefix can be a static string or a compiled regex pattern
15211526
enable_validation: bool | None
15221527
Enables validation of the request body against the route schema, by default False.
1528+
response_validation_error_http_code
1529+
Sets the returned status code if response is not validated. enable_validation must be True.
15231530
"""
15241531
self._proxy_type = proxy_type
15251532
self._dynamic_routes: list[Route] = []
@@ -1536,6 +1543,11 @@ def __init__(
15361543
self.processed_stack_frames = []
15371544
self._response_builder_class = ResponseBuilder[BaseProxyEvent]
15381545
self.openapi_config = OpenAPIConfig() # starting an empty dataclass
1546+
self._has_response_validation_error = response_validation_error_http_code is not None
1547+
self._response_validation_error_http_code = self._validate_response_validation_error_http_code(
1548+
response_validation_error_http_code,
1549+
enable_validation,
1550+
)
15391551

15401552
# Allow for a custom serializer or a concise json serialization
15411553
self._serializer = serializer or partial(json.dumps, separators=(",", ":"), cls=Encoder)
@@ -1545,7 +1557,36 @@ def __init__(
15451557

15461558
# Note the serializer argument: only use custom serializer if provided by the caller
15471559
# Otherwise, fully rely on the internal Pydantic based mechanism to serialize responses for validation.
1548-
self.use([OpenAPIValidationMiddleware(validation_serializer=serializer)])
1560+
self.use(
1561+
[
1562+
OpenAPIValidationMiddleware(
1563+
validation_serializer=serializer,
1564+
has_response_validation_error=self._has_response_validation_error,
1565+
),
1566+
],
1567+
)
1568+
1569+
def _validate_response_validation_error_http_code(
1570+
self,
1571+
response_validation_error_http_code: HTTPStatus | int | None,
1572+
enable_validation: bool,
1573+
) -> HTTPStatus:
1574+
if response_validation_error_http_code and not enable_validation:
1575+
msg = "'response_validation_error_http_code' cannot be set when enable_validation is False."
1576+
raise ValueError(msg)
1577+
1578+
if (
1579+
not isinstance(response_validation_error_http_code, HTTPStatus)
1580+
and response_validation_error_http_code is not None
1581+
):
1582+
1583+
try:
1584+
response_validation_error_http_code = HTTPStatus(response_validation_error_http_code)
1585+
except ValueError:
1586+
msg = f"'{response_validation_error_http_code}' must be an integer representing an HTTP status code."
1587+
raise ValueError(msg) from None
1588+
1589+
return response_validation_error_http_code or HTTPStatus.UNPROCESSABLE_ENTITY
15491590

15501591
def get_openapi_schema(
15511592
self,
@@ -2484,6 +2525,21 @@ def _call_exception_handler(self, exp: Exception, route: Route) -> ResponseBuild
24842525
route=route,
24852526
)
24862527

2528+
# OpenAPIValidationMiddleware will only raise ResponseValidationError when
2529+
# 'self._response_validation_error_http_code' is not None
2530+
if isinstance(exp, ResponseValidationError):
2531+
http_code = self._response_validation_error_http_code
2532+
errors = [{"loc": e["loc"], "type": e["type"]} for e in exp.errors()]
2533+
return self._response_builder_class(
2534+
response=Response(
2535+
status_code=http_code.value,
2536+
content_type=content_types.APPLICATION_JSON,
2537+
body={"statusCode": self._response_validation_error_http_code, "detail": errors},
2538+
),
2539+
serializer=self._serializer,
2540+
route=route,
2541+
)
2542+
24872543
if isinstance(exp, ServiceError):
24882544
return self._response_builder_class(
24892545
response=Response(
@@ -2696,6 +2752,7 @@ def __init__(
26962752
serializer: Callable[[dict], str] | None = None,
26972753
strip_prefixes: list[str | Pattern] | None = None,
26982754
enable_validation: bool = False,
2755+
response_validation_error_http_code: HTTPStatus | int | None = None,
26992756
):
27002757
"""Amazon API Gateway REST and HTTP API v1 payload resolver"""
27012758
super().__init__(
@@ -2705,6 +2762,7 @@ def __init__(
27052762
serializer,
27062763
strip_prefixes,
27072764
enable_validation,
2765+
response_validation_error_http_code,
27082766
)
27092767

27102768
def _get_base_path(self) -> str:
@@ -2778,6 +2836,7 @@ def __init__(
27782836
serializer: Callable[[dict], str] | None = None,
27792837
strip_prefixes: list[str | Pattern] | None = None,
27802838
enable_validation: bool = False,
2839+
response_validation_error_http_code: HTTPStatus | int | None = None,
27812840
):
27822841
"""Amazon API Gateway HTTP API v2 payload resolver"""
27832842
super().__init__(
@@ -2787,6 +2846,7 @@ def __init__(
27872846
serializer,
27882847
strip_prefixes,
27892848
enable_validation,
2849+
response_validation_error_http_code,
27902850
)
27912851

27922852
def _get_base_path(self) -> str:
@@ -2815,9 +2875,18 @@ def __init__(
28152875
serializer: Callable[[dict], str] | None = None,
28162876
strip_prefixes: list[str | Pattern] | None = None,
28172877
enable_validation: bool = False,
2878+
response_validation_error_http_code: HTTPStatus | int | None = None,
28182879
):
28192880
"""Amazon Application Load Balancer (ALB) resolver"""
2820-
super().__init__(ProxyEventType.ALBEvent, cors, debug, serializer, strip_prefixes, enable_validation)
2881+
super().__init__(
2882+
ProxyEventType.ALBEvent,
2883+
cors,
2884+
debug,
2885+
serializer,
2886+
strip_prefixes,
2887+
enable_validation,
2888+
response_validation_error_http_code,
2889+
)
28212890

28222891
def _get_base_path(self) -> str:
28232892
# ALB doesn't have a stage variable, so we just return an empty string

‎aws_lambda_powertools/event_handler/lambda_function_url.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
)
99

1010
if TYPE_CHECKING:
11+
from http import HTTPStatus
12+
1113
from aws_lambda_powertools.event_handler import CORSConfig
1214
from aws_lambda_powertools.utilities.data_classes import LambdaFunctionUrlEvent
1315

@@ -57,6 +59,7 @@ def __init__(
5759
serializer: Callable[[dict], str] | None = None,
5860
strip_prefixes: list[str | Pattern] | None = None,
5961
enable_validation: bool = False,
62+
response_validation_error_http_code: HTTPStatus | int | None = None,
6063
):
6164
super().__init__(
6265
ProxyEventType.LambdaFunctionUrlEvent,
@@ -65,6 +68,7 @@ def __init__(
6568
serializer,
6669
strip_prefixes,
6770
enable_validation,
71+
response_validation_error_http_code,
6872
)
6973

7074
def _get_base_path(self) -> str:

‎aws_lambda_powertools/event_handler/middlewares/openapi_validation.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
)
1818
from aws_lambda_powertools.event_handler.openapi.dependant import is_scalar_field
1919
from aws_lambda_powertools.event_handler.openapi.encoders import jsonable_encoder
20-
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError
20+
from aws_lambda_powertools.event_handler.openapi.exceptions import RequestValidationError, ResponseValidationError
2121
from aws_lambda_powertools.event_handler.openapi.params import Param
2222

2323
if TYPE_CHECKING:
@@ -58,7 +58,11 @@ def get_todos(): list[Todo]:
5858
```
5959
"""
6060

61-
def __init__(self, validation_serializer: Callable[[Any], str] | None = None):
61+
def __init__(
62+
self,
63+
validation_serializer: Callable[[Any], str] | None = None,
64+
has_response_validation_error: bool = False,
65+
):
6266
"""
6367
Initialize the OpenAPIValidationMiddleware.
6468
@@ -67,8 +71,13 @@ def __init__(self, validation_serializer: Callable[[Any], str] | None = None):
6771
validation_serializer : Callable, optional
6872
Optional serializer to use when serializing the response for validation.
6973
Use it when you have a custom type that cannot be serialized by the default jsonable_encoder.
74+
75+
has_response_validation_error: bool, optional
76+
Optional flag used to distinguish between payload and validation errors.
77+
By setting this flag to True, ResponseValidationError will be raised if response could not be validated.
7078
"""
7179
self._validation_serializer = validation_serializer
80+
self._has_response_validation_error = has_response_validation_error
7281

7382
def handler(self, app: EventHandlerInstance, next_middleware: NextMiddleware) -> Response:
7483
logger.debug("OpenAPIValidationMiddleware handler")
@@ -164,6 +173,8 @@ def _serialize_response(
164173
errors: list[dict[str, Any]] = []
165174
value = _validate_field(field=field, value=response_content, loc=("response",), existing_errors=errors)
166175
if errors:
176+
if self._has_response_validation_error:
177+
raise ResponseValidationError(errors=_normalize_errors(errors), body=response_content)
167178
raise RequestValidationError(errors=_normalize_errors(errors), body=response_content)
168179

169180
if hasattr(field, "serialize"):

‎aws_lambda_powertools/event_handler/openapi/exceptions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@ def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
2323
self.body = body
2424

2525

26+
class ResponseValidationError(ValidationException):
27+
"""
28+
Raised when the response body does not match the OpenAPI schema
29+
"""
30+
31+
def __init__(self, errors: Sequence[Any], *, body: Any = None) -> None:
32+
super().__init__(errors)
33+
self.body = body
34+
35+
2636
class SerializationError(Exception):
2737
"""
2838
Base exception for all encoding errors

‎aws_lambda_powertools/event_handler/vpc_lattice.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
)
99

1010
if TYPE_CHECKING:
11+
from http import HTTPStatus
12+
1113
from aws_lambda_powertools.event_handler import CORSConfig
1214
from aws_lambda_powertools.utilities.data_classes import VPCLatticeEvent, VPCLatticeEventV2
1315

@@ -53,9 +55,18 @@ def __init__(
5355
serializer: Callable[[dict], str] | None = None,
5456
strip_prefixes: list[str | Pattern] | None = None,
5557
enable_validation: bool = False,
58+
response_validation_error_http_code: HTTPStatus | int | None = None,
5659
):
5760
"""Amazon VPC Lattice resolver"""
58-
super().__init__(ProxyEventType.VPCLatticeEvent, cors, debug, serializer, strip_prefixes, enable_validation)
61+
super().__init__(
62+
ProxyEventType.VPCLatticeEvent,
63+
cors,
64+
debug,
65+
serializer,
66+
strip_prefixes,
67+
enable_validation,
68+
response_validation_error_http_code,
69+
)
5970

6071
def _get_base_path(self) -> str:
6172
return ""
@@ -102,9 +113,18 @@ def __init__(
102113
serializer: Callable[[dict], str] | None = None,
103114
strip_prefixes: list[str | Pattern] | None = None,
104115
enable_validation: bool = False,
116+
response_validation_error_http_code: HTTPStatus | int | None = None,
105117
):
106118
"""Amazon VPC Lattice resolver"""
107-
super().__init__(ProxyEventType.VPCLatticeEventV2, cors, debug, serializer, strip_prefixes, enable_validation)
119+
super().__init__(
120+
ProxyEventType.VPCLatticeEventV2,
121+
cors,
122+
debug,
123+
serializer,
124+
strip_prefixes,
125+
enable_validation,
126+
response_validation_error_http_code,
127+
)
108128

109129
def _get_base_path(self) -> str:
110130
return ""

‎docs/core/event_handler/api_gateway.md

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ Let's rewrite the previous examples to signal our resolver what shape we expect
309309

310310
!!! info "By default, we hide extended error details for security reasons _(e.g., pydantic url, Pydantic code)_."
311311

312-
Any incoming request that fails validation will lead to a `HTTP 422: Unprocessable Entity error` response that will look similar to this:
312+
Any incoming request or and outgoing response that fails validation will lead to a `HTTP 422: Unprocessable Entity error` response that will look similar to this:
313313

314314
```json hl_lines="2 3" title="data_validation_error_unsanitized_output.json"
315315
--8<-- "examples/event_handler_rest/src/data_validation_error_unsanitized_output.json"
@@ -321,8 +321,6 @@ Here's an example where we catch validation errors, log all details for further
321321

322322
=== "data_validation_sanitized_error.py"
323323

324-
Note that Pydantic versions [1](https://docs.pydantic.dev/1.10/usage/models/#error-handling){target="_blank" rel="nofollow"} and [2](https://docs.pydantic.dev/latest/errors/errors/){target="_blank" rel="nofollow"} report validation detailed errors differently.
325-
326324
```python hl_lines="8 24-25 31"
327325
--8<-- "examples/event_handler_rest/src/data_validation_sanitized_error.py"
328326
```
@@ -398,6 +396,27 @@ We use the `Annotated` and OpenAPI `Body` type to instruct Event Handler that ou
398396
--8<-- "examples/event_handler_rest/src/validating_payload_subset_output.json"
399397
```
400398

399+
#### Validating responses
400+
401+
You can use `response_validation_error_http_code` to set a custom HTTP code for failed response validation. When this field is set, we will raise a `ResponseValidationError` instead of a `RequestValidationError`.
402+
403+
=== "customizing_response_validation.py"
404+
405+
```python hl_lines="1 16 29 33"
406+
--8<-- "examples/event_handler_rest/src/customizing_response_validation.py"
407+
```
408+
409+
1. A response with status code set here will be returned if response data is not valid.
410+
2. Operation returns a string as oppose to a `Todo` object. This will lead to a `500` response as set in line 18.
411+
412+
=== "customizing_response_validation_exception.py"
413+
414+
```python hl_lines="1 18 38 39"
415+
--8<-- "examples/event_handler_rest/src/customizing_response_validation_exception.py"
416+
```
417+
418+
1. The distinct `ResponseValidationError` exception can be caught to customise the response.
419+
401420
#### Validating query strings
402421

403422
!!! info "We will automatically validate and inject incoming query strings via type annotation."
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from http import HTTPStatus
2+
from typing import Optional
3+
4+
import requests
5+
from pydantic import BaseModel, Field
6+
7+
from aws_lambda_powertools import Logger, Tracer
8+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver
9+
from aws_lambda_powertools.logging import correlation_paths
10+
from aws_lambda_powertools.utilities.typing import LambdaContext
11+
12+
tracer = Tracer()
13+
logger = Logger()
14+
app = APIGatewayRestResolver(
15+
enable_validation=True,
16+
response_validation_error_http_code=HTTPStatus.INTERNAL_SERVER_ERROR, # (1)!
17+
)
18+
19+
20+
class Todo(BaseModel):
21+
userId: int
22+
id_: Optional[int] = Field(alias="id", default=None)
23+
title: str
24+
completed: bool
25+
26+
27+
@app.get("/todos_bad_response/<todo_id>")
28+
@tracer.capture_method
29+
def get_todo_by_id(todo_id: int) -> Todo:
30+
todo = requests.get(f"https://jsonplaceholder.typicode.com/todos/{todo_id}")
31+
todo.raise_for_status()
32+
33+
return todo.json()["title"] # (2)!
34+
35+
36+
@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP)
37+
@tracer.capture_lambda_handler
38+
def lambda_handler(event: dict, context: LambdaContext) -> dict:
39+
return app.resolve(event, context)
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from http import HTTPStatus
2+
from typing import Optional
3+
4+
import requests
5+
from pydantic import BaseModel, Field
6+
7+
from aws_lambda_powertools import Logger, Tracer
8+
from aws_lambda_powertools.event_handler import APIGatewayRestResolver, content_types
9+
from aws_lambda_powertools.event_handler.api_gateway import Response
10+
from aws_lambda_powertools.event_handler.openapi.exceptions import ResponseValidationError
11+
from aws_lambda_powertools.logging import correlation_paths
12+
from aws_lambda_powertools.utilities.typing import LambdaContext
13+
14+
tracer = Tracer()
15+
logger = Logger()
16+
app = APIGatewayRestResolver(
17+
enable_validation=True,
18+
response_validation_error_http_code=HTTPStatus.INTERNAL_SERVER_ERROR,
19+
)
20+
21+
22+
class Todo(BaseModel):
23+
userId: int
24+
id_: Optional[int] = Field(alias="id", default=None)
25+
title: str
26+
completed: bool
27+
28+
29+
@app.get("/todos_bad_response/<todo_id>")
30+
@tracer.capture_method
31+
def get_todo_by_id(todo_id: int) -> Todo:
32+
todo = requests.get(f"https://jsonplaceholder.typicode.com/todos/{todo_id}")
33+
todo.raise_for_status()
34+
35+
return todo.json()["title"]
36+
37+
38+
@app.exception_handler(ResponseValidationError) # (1)!
39+
def handle_response_validation_error(ex: ResponseValidationError):
40+
logger.error("Request failed validation", path=app.current_event.path, errors=ex.errors())
41+
42+
return Response(
43+
status_code=500,
44+
content_type=content_types.APPLICATION_JSON,
45+
body="Unexpected response.",
46+
)
47+
48+
49+
@logger.inject_lambda_context(correlation_id_path=correlation_paths.API_GATEWAY_HTTP)
50+
@tracer.capture_lambda_handler
51+
def lambda_handler(event: dict, context: LambdaContext) -> dict:
52+
return app.resolve(event, context)

‎tests/functional/event_handler/_pydantic/test_openapi_validation_middleware.py

Lines changed: 250 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
VPCLatticeResolver,
1818
VPCLatticeV2Resolver,
1919
)
20+
from aws_lambda_powertools.event_handler.openapi.exceptions import ResponseValidationError
2021
from aws_lambda_powertools.event_handler.openapi.params import Body, Header, Query
2122

2223

@@ -1128,3 +1129,252 @@ def handler(user_id: int = 123):
11281129
# THEN the handler should be invoked and return 200
11291130
result = app(minimal_event, {})
11301131
assert result["statusCode"] == 200
1132+
1133+
1134+
@pytest.mark.skipif(reason="Test temporarily disabled until falsy return is fixed")
1135+
def test_validation_error_none_returned_non_optional_type(gw_event):
1136+
# GIVEN an APIGatewayRestResolver with validation enabled
1137+
app = APIGatewayRestResolver(enable_validation=True)
1138+
1139+
class Model(BaseModel):
1140+
name: str
1141+
age: int
1142+
1143+
@app.get("/none_not_allowed")
1144+
def handler_none_not_allowed() -> Model:
1145+
return None # type: ignore
1146+
1147+
# WHEN returning None for a non-Optional type
1148+
gw_event["path"] = "/none_not_allowed"
1149+
result = app(gw_event, {})
1150+
1151+
# THEN it should return a validation error
1152+
assert result["statusCode"] == 422
1153+
body = json.loads(result["body"])
1154+
assert body["detail"][0]["type"] == "model_attributes_type"
1155+
assert body["detail"][0]["loc"] == ["response"]
1156+
1157+
1158+
def test_validation_error_incomplete_model_returned_non_optional_type(gw_event):
1159+
# GIVEN an APIGatewayRestResolver with validation enabled
1160+
app = APIGatewayRestResolver(enable_validation=True)
1161+
1162+
class Model(BaseModel):
1163+
name: str
1164+
age: int
1165+
1166+
@app.get("/incomplete_model_not_allowed")
1167+
def handler_incomplete_model_not_allowed() -> Model:
1168+
return {"age": 18} # type: ignore
1169+
1170+
# WHEN returning incomplete model for a non-Optional type
1171+
gw_event["path"] = "/incomplete_model_not_allowed"
1172+
result = app(gw_event, {})
1173+
1174+
# THEN it should return a validation error
1175+
assert result["statusCode"] == 422
1176+
body = json.loads(result["body"])
1177+
assert "missing" in body["detail"][0]["type"]
1178+
assert "name" in body["detail"][0]["loc"]
1179+
1180+
1181+
def test_none_returned_for_optional_type(gw_event):
1182+
# GIVEN an APIGatewayRestResolver with validation enabled
1183+
app = APIGatewayRestResolver(enable_validation=True)
1184+
1185+
class Model(BaseModel):
1186+
name: str
1187+
age: int
1188+
1189+
@app.get("/none_allowed")
1190+
def handler_none_allowed() -> Optional[Model]:
1191+
return None
1192+
1193+
# WHEN returning None for an Optional type
1194+
gw_event["path"] = "/none_allowed"
1195+
result = app(gw_event, {})
1196+
1197+
# THEN it should succeed
1198+
assert result["statusCode"] == 200
1199+
assert result["body"] == "null"
1200+
1201+
1202+
@pytest.mark.skipif(reason="Test temporarily disabled until falsy return is fixed")
1203+
@pytest.mark.parametrize(
1204+
"path, body",
1205+
[
1206+
("/empty_dict", {}),
1207+
("/empty_list", []),
1208+
("/none", "null"),
1209+
("/empty_string", ""),
1210+
],
1211+
ids=["empty_dict", "empty_list", "none", "empty_string"],
1212+
)
1213+
def test_none_returned_for_falsy_return(gw_event, path, body):
1214+
# GIVEN an APIGatewayRestResolver with validation enabled
1215+
app = APIGatewayRestResolver(enable_validation=True)
1216+
1217+
class Model(BaseModel):
1218+
name: str
1219+
age: int
1220+
1221+
@app.get(path)
1222+
def handler_none_allowed() -> Model:
1223+
return body
1224+
1225+
# WHEN returning None for an Optional type
1226+
gw_event["path"] = path
1227+
result = app(gw_event, {})
1228+
1229+
# THEN it should succeed
1230+
assert result["statusCode"] == 422
1231+
1232+
1233+
def test_custom_response_validation_error_http_code_valid_response(gw_event):
1234+
# GIVEN an APIGatewayRestResolver with custom response validation enabled
1235+
app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=422)
1236+
1237+
class Model(BaseModel):
1238+
name: str
1239+
age: int
1240+
1241+
@app.get("/valid_response")
1242+
def handler_valid_response() -> Model:
1243+
return {
1244+
"name": "Joe",
1245+
"age": 18,
1246+
} # type: ignore
1247+
1248+
# WHEN returning the expected type
1249+
gw_event["path"] = "/valid_response"
1250+
result = app(gw_event, {})
1251+
1252+
# THEN it should return a 200 OK
1253+
assert result["statusCode"] == 200
1254+
body = json.loads(result["body"])
1255+
assert body == {"name": "Joe", "age": 18}
1256+
1257+
1258+
@pytest.mark.skipif(reason="Test temporarily disabled until falsy return is fixed")
1259+
@pytest.mark.parametrize(
1260+
"http_code",
1261+
(422, 500, 510),
1262+
)
1263+
def test_custom_response_validation_error_http_code_invalid_response_none(
1264+
http_code,
1265+
gw_event,
1266+
):
1267+
# GIVEN an APIGatewayRestResolver with custom response validation enabled
1268+
app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=http_code)
1269+
1270+
class Model(BaseModel):
1271+
name: str
1272+
age: int
1273+
1274+
@app.get("/none_not_allowed")
1275+
def handler_none_not_allowed() -> Model:
1276+
return None # type: ignore
1277+
1278+
# WHEN returning None for a non-Optional type
1279+
gw_event["path"] = "/none_not_allowed"
1280+
result = app(gw_event, {})
1281+
1282+
# THEN it should return a validation error with the custom status code provided
1283+
assert result["statusCode"] == http_code
1284+
body = json.loads(result["body"])
1285+
assert body["detail"][0]["type"] == "model_attributes_type"
1286+
assert body["detail"][0]["loc"] == ["response"]
1287+
1288+
1289+
@pytest.mark.parametrize(
1290+
"http_code",
1291+
(422, 500, 510),
1292+
)
1293+
def test_custom_response_validation_error_http_code_invalid_response_incomplete_model(
1294+
http_code,
1295+
gw_event,
1296+
):
1297+
# GIVEN an APIGatewayRestResolver with custom response validation enabled
1298+
app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=http_code)
1299+
1300+
class Model(BaseModel):
1301+
name: str
1302+
age: int
1303+
1304+
@app.get("/incomplete_model_not_allowed")
1305+
def handler_incomplete_model_not_allowed() -> Model:
1306+
return {"age": 18} # type: ignore
1307+
1308+
# WHEN returning incomplete model for a non-Optional type
1309+
gw_event["path"] = "/incomplete_model_not_allowed"
1310+
result = app(gw_event, {})
1311+
1312+
# THEN it should return a validation error with the custom status code provided
1313+
assert result["statusCode"] == http_code
1314+
body = json.loads(result["body"])
1315+
assert body["detail"][0]["type"] == "missing"
1316+
assert body["detail"][0]["loc"] == ["response", "name"]
1317+
1318+
1319+
@pytest.mark.parametrize(
1320+
"http_code",
1321+
(422, 500, 510),
1322+
)
1323+
def test_custom_response_validation_error_sanitized_response(
1324+
http_code,
1325+
gw_event,
1326+
):
1327+
# GIVEN an APIGatewayRestResolver with custom response validation enabled
1328+
# with a sanitized response validation error response
1329+
app = APIGatewayRestResolver(enable_validation=True, response_validation_error_http_code=http_code)
1330+
1331+
class Model(BaseModel):
1332+
name: str
1333+
age: int
1334+
1335+
@app.get("/incomplete_model_not_allowed")
1336+
def handler_incomplete_model_not_allowed() -> Model:
1337+
return {"age": 18} # type: ignore
1338+
1339+
@app.exception_handler(ResponseValidationError)
1340+
def handle_response_validation_error(ex: ResponseValidationError):
1341+
return Response(
1342+
status_code=500,
1343+
body="Unexpected response.",
1344+
)
1345+
1346+
# WHEN returning incomplete model for a non-Optional type
1347+
gw_event["path"] = "/incomplete_model_not_allowed"
1348+
result = app(gw_event, {})
1349+
1350+
# THEN it should return the sanitized response
1351+
assert result["statusCode"] == 500
1352+
assert result["body"] == "Unexpected response."
1353+
1354+
1355+
def test_custom_response_validation_error_no_validation():
1356+
# GIVEN an APIGatewayRestResolver with validation not enabled
1357+
# setting a custom http status code for response validation must raise a ValueError
1358+
with pytest.raises(ValueError) as exception_info:
1359+
APIGatewayRestResolver(response_validation_error_http_code=500)
1360+
1361+
assert (
1362+
str(exception_info.value)
1363+
== "'response_validation_error_http_code' cannot be set when enable_validation is False."
1364+
)
1365+
1366+
1367+
@pytest.mark.parametrize("response_validation_error_http_code", [(20), ("hi"), (1.21)])
1368+
def test_custom_response_validation_error_bad_http_code(response_validation_error_http_code):
1369+
# GIVEN an APIGatewayRestResolver with validation enabled
1370+
# setting custom status code for response validation that is not a valid HTTP code must raise a ValueError
1371+
with pytest.raises(ValueError) as exception_info:
1372+
APIGatewayRestResolver(
1373+
enable_validation=True,
1374+
response_validation_error_http_code=response_validation_error_http_code,
1375+
)
1376+
1377+
assert (
1378+
str(exception_info.value)
1379+
== f"'{response_validation_error_http_code}' must be an integer representing an HTTP status code."
1380+
)

0 commit comments

Comments
 (0)
Please sign in to comment.