Skip to content

Commit fe6b335

Browse files
fix(parser): fallback to validate_python when using type[Model] and nested models (#5313)
* Fix Pydantic limitation * Add e2e tests * Reverting change in e2e layer
1 parent b271c17 commit fe6b335

File tree

6 files changed

+138
-16
lines changed

6 files changed

+138
-16
lines changed

aws_lambda_powertools/utilities/parser/envelopes/base.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
from abc import ABC, abstractmethod
55
from typing import TYPE_CHECKING, Any, TypeVar
66

7-
from aws_lambda_powertools.utilities.parser.functions import _retrieve_or_set_model_from_cache
7+
from aws_lambda_powertools.utilities.parser.functions import (
8+
_parse_and_validate_event,
9+
_retrieve_or_set_model_from_cache,
10+
)
811

912
if TYPE_CHECKING:
1013
from aws_lambda_powertools.utilities.parser.types import T
@@ -38,11 +41,7 @@ def _parse(data: dict[str, Any] | Any | None, model: type[T]) -> T | None:
3841
adapter = _retrieve_or_set_model_from_cache(model=model)
3942

4043
logger.debug("parsing event against model")
41-
if isinstance(data, str):
42-
logger.debug("parsing event as string")
43-
return adapter.validate_json(data)
44-
45-
return adapter.validate_python(data)
44+
return _parse_and_validate_event(data=data, adapter=adapter)
4645

4746
@abstractmethod
4847
def parse(self, data: dict[str, Any] | Any | None, model: type[T]):

aws_lambda_powertools/utilities/parser/functions.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING
3+
import json
4+
import logging
5+
from typing import TYPE_CHECKING, Any
46

57
from pydantic import TypeAdapter
68

@@ -11,6 +13,8 @@
1113

1214
CACHE_TYPE_ADAPTER = LRUDict(max_items=1024)
1315

16+
logger = logging.getLogger(__name__)
17+
1418

1519
def _retrieve_or_set_model_from_cache(model: type[T]) -> TypeAdapter:
1620
"""
@@ -38,3 +42,38 @@ def _retrieve_or_set_model_from_cache(model: type[T]) -> TypeAdapter:
3842

3943
CACHE_TYPE_ADAPTER[id_model] = TypeAdapter(model)
4044
return CACHE_TYPE_ADAPTER[id_model]
45+
46+
47+
def _parse_and_validate_event(data: dict[str, Any] | Any, adapter: TypeAdapter):
48+
"""
49+
Parse and validate the event data using the provided adapter.
50+
51+
Params
52+
------
53+
data: dict | Any
54+
The event data to be parsed and validated.
55+
adapter: TypeAdapter
56+
The adapter object used for validation.
57+
58+
Returns:
59+
dict: The validated event data.
60+
61+
Raises:
62+
ValidationError: If the data is invalid or cannot be parsed.
63+
"""
64+
logger.debug("Parsing event against model")
65+
66+
if isinstance(data, str):
67+
logger.debug("Parsing event as string")
68+
try:
69+
return adapter.validate_json(data)
70+
except NotImplementedError:
71+
# See: https://github.com/aws-powertools/powertools-lambda-python/issues/5303
72+
# See: https://github.com/pydantic/pydantic/issues/8890
73+
logger.debug(
74+
"Falling back to Python validation due to Pydantic implementation."
75+
"See issue: https://github.com/aws-powertools/powertools-lambda-python/issues/5303",
76+
)
77+
data = json.loads(data)
78+
79+
return adapter.validate_python(data)

aws_lambda_powertools/utilities/parser/parser.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@
88

99
from aws_lambda_powertools.middleware_factory import lambda_handler_decorator
1010
from aws_lambda_powertools.utilities.parser.exceptions import InvalidEnvelopeError, InvalidModelTypeError
11-
from aws_lambda_powertools.utilities.parser.functions import _retrieve_or_set_model_from_cache
11+
from aws_lambda_powertools.utilities.parser.functions import (
12+
_parse_and_validate_event,
13+
_retrieve_or_set_model_from_cache,
14+
)
1215

1316
if TYPE_CHECKING:
1417
from aws_lambda_powertools.utilities.parser.envelopes.base import Envelope
@@ -189,10 +192,7 @@ def handler(event: Order, context: LambdaContext):
189192
adapter = _retrieve_or_set_model_from_cache(model=model)
190193

191194
logger.debug("Parsing and validating event model; no envelope used")
192-
if isinstance(event, str):
193-
return adapter.validate_json(event)
194-
195-
return adapter.validate_python(event)
195+
return _parse_and_validate_event(data=event, adapter=adapter)
196196

197197
# Pydantic raises PydanticSchemaGenerationError when the model is not a Pydantic model
198198
# This is seen in the tests where we pass a non-Pydantic model type to the parser or
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import json
2+
from typing import Any, Dict, Type, Union
3+
4+
from pydantic import BaseModel
5+
6+
from aws_lambda_powertools.utilities.parser import parse
7+
from aws_lambda_powertools.utilities.typing import LambdaContext
8+
9+
AnyInheritedModel = Union[Type[BaseModel], BaseModel]
10+
RawDictOrModel = Union[Dict[str, Any], AnyInheritedModel]
11+
12+
13+
class ModelWithUnionType(BaseModel):
14+
name: str
15+
profile: RawDictOrModel
16+
17+
18+
def lambda_handler(event: ModelWithUnionType, context: LambdaContext):
19+
event = json.dumps(event)
20+
21+
event_parsed = parse(event=event, model=ModelWithUnionType)
22+
23+
return {"name": event_parsed.name}

tests/e2e/parser/test_parser.py

+21
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ def handler_with_dataclass_arn(infrastructure: dict) -> str:
2020
return infrastructure.get("HandlerWithDataclass", "")
2121

2222

23+
@pytest.fixture
24+
def handler_with_type_model_class(infrastructure: dict) -> str:
25+
return infrastructure.get("HandlerWithModelTypeClass", "")
26+
27+
2328
@pytest.mark.xdist_group(name="parser")
2429
def test_parser_with_basic_model(handler_with_basic_model_arn):
2530
# GIVEN
@@ -66,3 +71,19 @@ def test_parser_with_dataclass(handler_with_dataclass_arn):
6671
ret = parser_execution["Payload"].read().decode("utf-8")
6772

6873
assert "powertools" in ret
74+
75+
76+
@pytest.mark.xdist_group(name="parser")
77+
def test_parser_with_type_model(handler_with_type_model_class):
78+
# GIVEN
79+
payload = json.dumps({"name": "powertools", "profile": {"description": "python", "size": "XXL"}})
80+
81+
# WHEN
82+
parser_execution, _ = data_fetcher.get_lambda_response(
83+
lambda_arn=handler_with_type_model_class,
84+
payload=payload,
85+
)
86+
87+
ret = parser_execution["Payload"].read().decode("utf-8")
88+
89+
assert "powertools" in ret

tests/functional/parser/test_parser.py

+44-4
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import json
2+
from datetime import datetime
23
from typing import Any, Dict, Literal, Union
34

45
import pydantic
56
import pytest
67
from pydantic import ValidationError
78
from typing_extensions import Annotated
89

9-
from aws_lambda_powertools.utilities.parser import (
10-
event_parser,
11-
exceptions,
12-
)
10+
from aws_lambda_powertools.utilities.parser import event_parser, exceptions, parse
11+
from aws_lambda_powertools.utilities.parser.envelopes.sqs import SqsEnvelope
12+
from aws_lambda_powertools.utilities.parser.models import SqsModel
13+
from aws_lambda_powertools.utilities.parser.models.event_bridge import EventBridgeModel
1314
from aws_lambda_powertools.utilities.typing import LambdaContext
1415

1516

@@ -161,3 +162,42 @@ def handler(event: test_input, _: Any) -> str:
161162

162163
ret = handler(test_input, None)
163164
assert ret == expected
165+
166+
167+
def test_parser_with_model_type_model_and_envelope():
168+
event = {
169+
"Records": [
170+
{
171+
"messageId": "19dd0b57-b21e-4ac1-bd88-01bbb068cb78",
172+
"receiptHandle": "MessageReceiptHandle",
173+
"body": EventBridgeModel(
174+
version="version",
175+
id="id",
176+
source="source",
177+
account="account",
178+
time=datetime.now(),
179+
region="region",
180+
resources=[],
181+
detail={"key": "value"},
182+
).model_dump_json(),
183+
"attributes": {
184+
"ApproximateReceiveCount": "1",
185+
"SentTimestamp": "1523232000000",
186+
"SenderId": "123456789012",
187+
"ApproximateFirstReceiveTimestamp": "1523232000001",
188+
},
189+
"messageAttributes": {},
190+
"md5OfBody": "{{{md5_of_body}}}",
191+
"eventSource": "aws:sqs",
192+
"eventSourceARN": "arn:aws:sqs:us-east-1:123456789012:MyQueue",
193+
"awsRegion": "us-east-1",
194+
},
195+
],
196+
}
197+
198+
def handler(event: SqsModel, _: LambdaContext):
199+
parsed_event: EventBridgeModel = parse(event, model=EventBridgeModel, envelope=SqsEnvelope)
200+
print(parsed_event)
201+
assert parsed_event[0].version == "version"
202+
203+
handler(event, LambdaContext())

0 commit comments

Comments
 (0)