Skip to content

refactor(event_handler): use standard collections for types + refactor code #6495

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 17, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions aws_lambda_powertools/event_handler/api_gateway.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
from functools import partial
from http import HTTPStatus
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Mapping, Match, Pattern, Sequence, TypeVar, cast
from typing import TYPE_CHECKING, Any, Generic, Literal, Match, Pattern, TypeVar, cast

from typing_extensions import override

@@ -59,6 +59,9 @@
)
from aws_lambda_powertools.utilities.data_classes.common import BaseProxyEvent

if TYPE_CHECKING:
from collections.abc import Callable, Mapping, Sequence

logger = logging.getLogger(__name__)

_DYNAMIC_ROUTE_PATTERN = r"(<\w+>)"
@@ -68,6 +71,7 @@
_NAMED_GROUP_BOUNDARY_PATTERN = rf"(?P\1[{_SAFE_URI}{_UNSAFE_URI}\\w]+)"
_DEFAULT_OPENAPI_RESPONSE_DESCRIPTION = "Successful Response"
_ROUTE_REGEX = "^{}$"
_JSON_DUMP_CALL = partial(json.dumps, separators=(",", ":"), cls=Encoder)

ResponseEventT = TypeVar("ResponseEventT", bound=BaseProxyEvent)
ResponseT = TypeVar("ResponseT")
@@ -830,7 +834,7 @@ class ResponseBuilder(Generic[ResponseEventT]):
def __init__(
self,
response: Response,
serializer: Callable[[Any], str] = partial(json.dumps, separators=(",", ":"), cls=Encoder),
serializer: Callable[[Any], str] = _JSON_DUMP_CALL,
route: Route | None = None,
):
self.response = response
@@ -1723,8 +1727,9 @@ def get_openapi_schema(
security = security or self.openapi_config.security
openapi_extensions = openapi_extensions or self.openapi_config.openapi_extensions

from pydantic.json_schema import GenerateJsonSchema

from aws_lambda_powertools.event_handler.openapi.compat import (
GenerateJsonSchema,
get_compat_model_name_map,
get_definitions,
)
4 changes: 3 additions & 1 deletion aws_lambda_powertools/event_handler/appsync.py
Original file line number Diff line number Diff line change
@@ -3,13 +3,15 @@
import asyncio
import logging
import warnings
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any

from aws_lambda_powertools.event_handler.graphql_appsync.exceptions import InvalidBatchResponse, ResolverNotFoundError
from aws_lambda_powertools.event_handler.graphql_appsync.router import Router
from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent

if TYPE_CHECKING:
from collections.abc import Callable

from aws_lambda_powertools.utilities.typing import LambdaContext

from aws_lambda_powertools.warnings import PowertoolsUserWarning
3 changes: 2 additions & 1 deletion aws_lambda_powertools/event_handler/bedrock_agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import json
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any

from typing_extensions import override

@@ -14,6 +14,7 @@
from aws_lambda_powertools.event_handler.openapi.constants import DEFAULT_API_VERSION, DEFAULT_OPENAPI_VERSION

if TYPE_CHECKING:
from collections.abc import Callable
from http import HTTPStatus
from re import Match

Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import logging
from typing import Any, Callable
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Callable

logger = logging.getLogger(__name__)

Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Callable
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Callable


class BaseRouter(ABC):
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING

from aws_lambda_powertools.event_handler.graphql_appsync._registry import ResolverRegistry
from aws_lambda_powertools.event_handler.graphql_appsync.base import BaseRouter

if TYPE_CHECKING:
from collections.abc import Callable

from aws_lambda_powertools.utilities.data_classes.appsync_resolver_event import AppSyncResolverEvent
from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext

Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Pattern
from typing import TYPE_CHECKING, Pattern

from aws_lambda_powertools.event_handler.api_gateway import (
ApiGatewayResolver,
ProxyEventType,
)

if TYPE_CHECKING:
from collections.abc import Callable
from http import HTTPStatus

from aws_lambda_powertools.event_handler import CORSConfig
Original file line number Diff line number Diff line change
@@ -365,7 +365,7 @@ def _validate_field(
"""
Validate a field, and append any errors to the existing_errors list.
"""
validated_value, errors = field.validate(value, value, loc=loc)
validated_value, errors = field.validate(value=value, loc=loc)

if isinstance(errors, list):
processed_errors = _regenerate_error_with_loc(errors=errors, loc_prefix=())
34 changes: 16 additions & 18 deletions aws_lambda_powertools/event_handler/openapi/compat.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,30 @@
# mypy: ignore-errors
# flake8: noqa
from __future__ import annotations

from collections import deque
from copy import copy
from collections.abc import Mapping, Sequence

# MAINTENANCE: remove when deprecating Pydantic v1. Mypy doesn't handle two different code paths that import different
# versions of a module, so we need to ignore errors here.

from dataclasses import dataclass, is_dataclass
from enum import Enum
from typing import TYPE_CHECKING, Any, Deque, FrozenSet, List, Mapping, Sequence, Set, Tuple, Union

from typing_extensions import Annotated, Literal, get_origin, get_args

from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo
from typing import TYPE_CHECKING, Any, Deque, FrozenSet, List, Set, Tuple, Union

from aws_lambda_powertools.event_handler.openapi.types import COMPONENT_REF_PREFIX, UnionType

from pydantic import TypeAdapter, ValidationError
from pydantic import BaseModel, TypeAdapter, ValidationError, create_model

# Importing from internal libraries in Pydantic may introduce potential risks, as these internal libraries
# are not part of the public API and may change without notice in future releases.
# We use this for forward reference, as it allows us to handle forward references in type annotations.
from pydantic._internal._typing_extra import eval_type_lenient
from pydantic.fields import FieldInfo
from pydantic._internal._utils import lenient_issubclass
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue
from pydantic_core import PydanticUndefined, PydanticUndefinedType
from typing_extensions import Annotated, Literal, get_args, get_origin

from aws_lambda_powertools.event_handler.openapi.types import UnionType

if TYPE_CHECKING:
from pydantic.fields import FieldInfo
from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue

from aws_lambda_powertools.event_handler.openapi.types import IncEx, ModelNameMap

Undefined = PydanticUndefined
@@ -119,7 +113,10 @@ def serialize(
)

def validate(
self, value: Any, values: dict[str, Any] = {}, *, loc: tuple[int | str, ...] = ()
self,
value: Any,
*,
loc: tuple[int | str, ...] = (),
) -> tuple[Any, list[dict[str, Any]] | None]:
try:
return (self._type_adapter.validate_python(value, from_attributes=True), None)
@@ -184,7 +181,8 @@ def copy_field_info(*, field_info: FieldInfo, annotation: Any) -> FieldInfo:

def get_missing_field_error(loc: tuple[str, ...]) -> dict[str, Any]:
error = ValidationError.from_exception_data(
"Field required", [{"type": "missing", "loc": loc, "input": {}}]
"Field required",
[{"type": "missing", "loc": loc, "input": {}}],
).errors()[0]
error["input"] = None
return error
@@ -308,7 +306,7 @@ def value_is_sequence(value: Any) -> bool:

def _annotation_is_complex(annotation: type[Any] | None) -> bool:
return (
lenient_issubclass(annotation, (BaseModel, Mapping)) # TODO: UploadFile
lenient_issubclass(annotation, (BaseModel, Mapping)) # Keep it to UploadFile
or _annotation_is_sequence(annotation)
or is_dataclass(annotation)
)
4 changes: 3 additions & 1 deletion aws_lambda_powertools/event_handler/openapi/dependant.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

import inspect
import re
from typing import TYPE_CHECKING, Any, Callable, ForwardRef, cast
from typing import TYPE_CHECKING, Any, ForwardRef, cast

from aws_lambda_powertools.event_handler.openapi.compat import (
ModelField,
@@ -27,6 +27,8 @@
from aws_lambda_powertools.event_handler.openapi.types import OpenAPIResponse, OpenAPIResponseContentModel

if TYPE_CHECKING:
from collections.abc import Callable

from pydantic import BaseModel

"""
4 changes: 3 additions & 1 deletion aws_lambda_powertools/event_handler/openapi/encoders.py
Original file line number Diff line number Diff line change
@@ -8,7 +8,7 @@
from pathlib import Path, PurePath
from re import Pattern
from types import GeneratorType
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any
from uuid import UUID

from pydantic import BaseModel
@@ -17,6 +17,8 @@
from aws_lambda_powertools.event_handler.openapi.compat import _model_dump

if TYPE_CHECKING:
from collections.abc import Callable

from aws_lambda_powertools.event_handler.openapi.types import IncEx

from aws_lambda_powertools.event_handler.openapi.exceptions import SerializationError
3 changes: 2 additions & 1 deletion aws_lambda_powertools/event_handler/openapi/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Any, Literal, Sequence
from collections.abc import Sequence
from typing import Any, Literal


class ValidationException(Exception):
4 changes: 3 additions & 1 deletion aws_lambda_powertools/event_handler/openapi/params.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,7 @@

import inspect
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Literal
from typing import TYPE_CHECKING, Any, Literal

from pydantic import BaseConfig
from pydantic.fields import FieldInfo
@@ -20,6 +20,8 @@
)

if TYPE_CHECKING:
from collections.abc import Callable

from aws_lambda_powertools.event_handler.openapi.models import Example
from aws_lambda_powertools.event_handler.openapi.types import CacheKey

3 changes: 2 additions & 1 deletion aws_lambda_powertools/event_handler/openapi/types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

import types
from typing import TYPE_CHECKING, Any, Callable, Dict, Set, Type, TypedDict, Union
from typing import TYPE_CHECKING, Any, Dict, Set, Type, TypedDict, Union

if TYPE_CHECKING:
from collections.abc import Callable
from enum import Enum

from pydantic import BaseModel
7 changes: 5 additions & 2 deletions aws_lambda_powertools/event_handler/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from __future__ import annotations

from typing import Any, Dict, List, Mapping
from typing import TYPE_CHECKING, Any

if TYPE_CHECKING:
from collections.abc import Mapping


class _FrozenDict(dict):
@@ -18,7 +21,7 @@ def __hash__(self):
return hash(frozenset(self.keys()))


class _FrozenListDict(List[Dict[str, List[str]]]):
class _FrozenListDict(list[dict[str, list[str]]]):
"""
Freezes a list of dictionaries containing lists of strings.
3 changes: 2 additions & 1 deletion aws_lambda_powertools/event_handler/vpc_lattice.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Pattern
from typing import TYPE_CHECKING, Pattern

from aws_lambda_powertools.event_handler.api_gateway import (
ApiGatewayResolver,
ProxyEventType,
)

if TYPE_CHECKING:
from collections.abc import Callable
from http import HTTPStatus

from aws_lambda_powertools.event_handler import CORSConfig
2 changes: 2 additions & 0 deletions tests/functional/event_handler/_pydantic/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json

import fastjsonschema
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from pydantic import BaseModel

from aws_lambda_powertools.event_handler import content_types
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import math
from collections import deque
from dataclasses import dataclass
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import json

from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver, Router
Original file line number Diff line number Diff line change
@@ -278,7 +278,7 @@ class User(BaseModel):

@app.get("/")
def handler() -> User:
return User(name="Ruben Fonseca")
return User(name="Powertools")

schema = app.get_openapi_schema()
assert len(schema.paths.keys()) == 1
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import pytest

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
from aws_lambda_powertools.event_handler.openapi.models import (
APIKey,
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from aws_lambda_powertools.event_handler.api_gateway import APIGatewayRestResolver
from aws_lambda_powertools.event_handler.openapi.models import Server

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import json
import warnings
from typing import Dict

import pytest

@@ -73,7 +74,7 @@ def test_openapi_swagger_json_view_with_default_path():

assert result["statusCode"] == 200
assert result["multiValueHeaders"]["Content-Type"] == ["application/json"]
assert isinstance(json.loads(result["body"]), Dict)
assert isinstance(json.loads(result["body"]), dict)
assert "OpenAPI JSON View" in result["body"]


@@ -87,7 +88,7 @@ def test_openapi_swagger_json_view_with_custom_path():

assert result["statusCode"] == 200
assert result["multiValueHeaders"]["Content-Type"] == ["application/json"]
assert isinstance(json.loads(result["body"]), Dict)
assert isinstance(json.loads(result["body"]), dict)
assert "OpenAPI JSON View" in result["body"]


Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from aws_lambda_powertools.event_handler import APIGatewayRestResolver
from aws_lambda_powertools.event_handler.openapi.models import Tag

Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from typing import List, Optional
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

from aws_lambda_powertools.event_handler import AppSyncResolver
from aws_lambda_powertools.event_handler.graphql_appsync.exceptions import InvalidBatchResponse, ResolverNotFoundError
from aws_lambda_powertools.event_handler.graphql_appsync.router import Router
from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent
from aws_lambda_powertools.utilities.typing import LambdaContext
from aws_lambda_powertools.warnings import PowertoolsUserWarning
from tests.functional.utils import load_event

if TYPE_CHECKING:
from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent


# TESTS RECEIVING THE EVENT PARTIALLY AND PROCESS EACH RECORD PER TIME.
def test_resolve_batch_processing_with_related_events_one_at_time():
@@ -95,7 +99,7 @@ def test_resolve_batch_processing_with_related_events_one_at_time():
app = AppSyncResolver()

@app.batch_resolver(type_name="Post", field_name="relatedPosts", aggregate=False)
def related_posts(event: AppSyncResolverEvent) -> Optional[list]:
def related_posts(event: AppSyncResolverEvent) -> list | None:
return posts_related[event.source["post_id"]]

# WHEN related_posts function, which is the batch resolver, is called with the event.
@@ -155,7 +159,7 @@ def test_resolve_batch_processing_with_simple_queries_one_at_time():

# WHEN the batch resolver for the listLocations field is defined
@app.batch_resolver(field_name="listLocations", aggregate=False)
def create_something(event: AppSyncResolverEvent) -> Optional[list]: # noqa AA03 VNE003
def create_something(event: AppSyncResolverEvent) -> list | None: # noqa AA03 VNE003
return event.source["id"] if event.source else None

# THEN the resolver should correctly process the batch of queries
@@ -211,7 +215,7 @@ def test_resolve_batch_processing_with_raise_on_exception_one_at_time():

# WHEN the sync batch resolver for the 'listLocations' field is defined with raise_on_error=True
@app.batch_resolver(field_name="listLocations", raise_on_error=True, aggregate=False)
def create_something(event: AppSyncResolverEvent) -> Optional[list]: # noqa AA03 VNE003
def create_something(event: AppSyncResolverEvent) -> list | None: # noqa AA03 VNE003
raise RuntimeError

# THEN the resolver should raise a RuntimeError when processing the batch of queries
@@ -264,7 +268,7 @@ def test_async_resolve_batch_processing_with_raise_on_exception_one_at_time():

# WHEN the async batch resolver for the 'listLocations' field is defined with raise_on_error=True
@app.async_batch_resolver(field_name="listLocations", raise_on_error=True, aggregate=False)
async def create_something(event: AppSyncResolverEvent) -> Optional[list]: # noqa AA03 VNE003
async def create_something(event: AppSyncResolverEvent) -> list | None: # noqa AA03 VNE003
raise RuntimeError

# THEN the resolver should raise a RuntimeError when processing the batch of queries
@@ -315,7 +319,7 @@ def test_resolve_batch_processing_without_exception_one_at_time():
app = AppSyncResolver()

@app.batch_resolver(field_name="listLocations", raise_on_error=False, aggregate=False)
def create_something(event: AppSyncResolverEvent) -> Optional[list]: # noqa AA03 VNE003
def create_something(event: AppSyncResolverEvent) -> list | None: # noqa AA03 VNE003
raise RuntimeError

# Call the implicit handler
@@ -371,7 +375,7 @@ def test_resolve_async_batch_processing_without_exception_one_at_time():

# WHEN the batch resolver for the 'listLocations' field is defined with raise_on_error=False
@app.async_batch_resolver(field_name="listLocations", raise_on_error=False, aggregate=False)
async def create_something(event: AppSyncResolverEvent) -> Optional[list]: # noqa AA03 VNE003
async def create_something(event: AppSyncResolverEvent) -> list | None: # noqa AA03 VNE003
raise RuntimeError

result = app.resolve(event, LambdaContext())
@@ -548,7 +552,7 @@ def test_resolve_async_batch_processing():

# WHEN the async batch resolver for the 'listLocations' field is defined
@app.async_batch_resolver(field_name="listLocations", aggregate=False)
async def create_something(event: AppSyncResolverEvent) -> Optional[list]:
async def create_something(event: AppSyncResolverEvent) -> list | None:
return event.source["id"] if event.source else None

# THEN the resolver should correctly process the batch of queries asynchronously
@@ -699,7 +703,7 @@ def test_resolve_batch_processing_with_simple_queries_with_aggregate():
# WHEN using an aggregated event
# WHEN function returns a List
@app.batch_resolver(field_name="listLocations")
def create_something(event: List[AppSyncResolverEvent]) -> List: # noqa AA03 VNE003
def create_something(event: list[AppSyncResolverEvent]) -> list: # noqa AA03 VNE003
results = []
for record in event:
results.append(record.source.get("id") if record.source else None)
@@ -760,7 +764,7 @@ def test_resolve_async_batch_processing_with_simple_queries_with_aggregate():
# WHEN using an aggregated event
# WHEN function returns a List
@app.async_batch_resolver(field_name="listLocations")
async def create_something(event: List[AppSyncResolverEvent]) -> List: # noqa AA03 VNE003
async def create_something(event: list[AppSyncResolverEvent]) -> list: # noqa AA03 VNE003
results = []
for record in event:
results.append(record.source.get("id") if record.source else None)
@@ -797,7 +801,7 @@ def test_resolve_batch_processing_with_aggregate_and_returning_a_non_list():
# WHEN using an aggregated event
# WHEN function return something different than a List
@app.batch_resolver(field_name="listLocations")
def create_something(event: List[AppSyncResolverEvent]) -> Optional[List]: # noqa AA03 VNE003
def create_something(event: list[AppSyncResolverEvent]) -> list | None: # noqa AA03 VNE003
return event[0].source.get("id") if event[0].source else None

# THEN the resolver should raise a InvalidBatchResponse when processing the batch of queries
@@ -828,7 +832,7 @@ def test_resolve_async_batch_processing_with_aggregate_and_returning_a_non_list(
# WHEN using an aggregated event
# WHEN function return something different than a List
@app.async_batch_resolver(field_name="listLocations")
async def create_something(event: List[AppSyncResolverEvent]) -> Optional[List]: # noqa AA03 VNE003
async def create_something(event: list[AppSyncResolverEvent]) -> list | None: # noqa AA03 VNE003
return event[0].source.get("id") if event[0].source else None

# THEN the resolver should raise a InvalidBatchResponse when processing the batch of queries
@@ -859,7 +863,7 @@ def test_resolve_sync_batch_processing_with_aggregate_and_without_return():
# WHEN using an aggregated event
# WHEN function there is no return statement
@app.batch_resolver(field_name="listLocations")
def create_something(event: List[AppSyncResolverEvent]) -> Optional[List]: # noqa AA03 VNE003
def create_something(event: list[AppSyncResolverEvent]) -> list | None: # noqa AA03 VNE003
def do_something_with_post_id(post_id): ...

post_id = event[0].source.get("id") if event[0].source else None
@@ -895,7 +899,7 @@ def test_resolve_async_batch_processing_with_aggregate_and_without_return():
# WHEN using an aggregated event
# WHEN function there is no return statement
@app.async_batch_resolver(field_name="listLocations")
async def create_something(event: List[AppSyncResolverEvent]) -> Optional[List]: # noqa AA03 VNE003
async def create_something(event: list[AppSyncResolverEvent]) -> list | None: # noqa AA03 VNE003
def do_something_with_post_id(post_id): ...

post_id = event[0].source.get("id") if event[0].source else None
@@ -916,7 +920,7 @@ def test_include_router_access_batch_current_event():
router = Router()

@router.batch_resolver(field_name="createSomething")
def get_user(event: List) -> List:
def get_user(event: list) -> list:
return [router.current_batch_event[0].identity.sub]

app.include_router(router)
@@ -935,7 +939,7 @@ def test_app_access_batch_current_event():
app = AppSyncResolver()

@app.batch_resolver(field_name="createSomething")
def get_user(event: List) -> List:
def get_user(event: list) -> list:
return [app.current_batch_event[0].identity.sub]

# WHEN we resolve the event
@@ -952,7 +956,7 @@ def test_context_is_accessible_in_sync_batch_resolver():
app = AppSyncResolver()

@app.batch_resolver(field_name="createSomething")
def get_user(event: List) -> List:
def get_user(event: list) -> list:
return [app.context.get("project_name")]

# WHEN we resolve the event
@@ -971,7 +975,7 @@ def test_context_is_accessible_in_async_batch_resolver():
app = AppSyncResolver()

@app.async_batch_resolver(field_name="createSomething")
async def get_user(event: List) -> List:
async def get_user(event: list) -> list:
return [app.context.get("project_name")]

# WHEN we resolve the event
@@ -1034,7 +1038,7 @@ def handle_value_error(ex: ValueError):

# WHEN the sync batch resolver for the 'listLocations' field is defined with raise_on_error=True
@app.batch_resolver(field_name="listLocations", raise_on_error=True, aggregate=False)
def create_something(event: AppSyncResolverEvent) -> Optional[list]: # noqa AA03 VNE003
def create_something(event: AppSyncResolverEvent) -> list | None: # noqa AA03 VNE003
raise ValueError

# Call the implicit handler
@@ -1095,7 +1099,7 @@ def handle_value_error(ex: ValueError):

# WHEN the sync batch resolver for the 'listLocations' field is defined with raise_on_error=False
@app.batch_resolver(field_name="listLocations", raise_on_error=False, aggregate=False)
def create_something(event: AppSyncResolverEvent) -> Optional[list]: # noqa AA03 VNE003
def create_something(event: AppSyncResolverEvent) -> list | None: # noqa AA03 VNE003
raise ValueError

# Call the implicit handler
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio

import pytest
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import base64
import json
import re
@@ -8,7 +10,6 @@
from enum import Enum
from json import JSONEncoder
from pathlib import Path
from typing import Dict

import pytest

@@ -48,7 +49,7 @@


def read_media(file_name: str) -> bytes:
path = Path(str(Path(__file__).parent.parent.parent.parent) + "/../docs/media/" + file_name)
path = Path(f"{str(Path(__file__).parent.parent.parent.parent)}/../docs/media/{file_name}")
return path.read_bytes()


@@ -642,7 +643,7 @@ def test_rest_api():
expected_dict = {"foo": "value", "second": Decimal("100.01")}

@app.get("/my/path")
def rest_func() -> Dict:
def rest_func() -> dict:
return expected_dict

# WHEN calling the event handler
@@ -1187,7 +1188,7 @@ def custom_serializer(data) -> str:
app = ApiGatewayResolver(serializer=custom_serializer)

@app.get("/custom_serializer")
def get_custom_values() -> Dict:
def get_custom_values() -> dict:
return {"values": deque(["name", "age"])}

# WHEN calling handler
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import List
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

@@ -20,9 +22,12 @@
from aws_lambda_powertools.event_handler.middlewares.schema_validation import (
SchemaValidationMiddleware,
)
from aws_lambda_powertools.event_handler.types import EventHandlerInstance
from tests.functional.utils import load_event

if TYPE_CHECKING:
from aws_lambda_powertools.event_handler.types import EventHandlerInstance


API_REST_EVENT = load_event("apiGatewayProxyEvent.json")
API_RESTV2_EVENT = load_event("apiGatewayProxyV2Event_GET.json")

@@ -362,14 +367,14 @@ def test_api_gateway_middleware_order_with_include_router_last(app: EventHandler
router = Router()

def global_app_middleware(app: EventHandlerInstance, next_middleware: NextMiddleware):
middleware_order: List[str] = router.context.get("middleware_order", [])
middleware_order: list[str] = router.context.get("middleware_order", [])
middleware_order.append("app")

app.append_context(middleware_order=middleware_order)
return next_middleware(app)

def global_router_middleware(router: EventHandlerInstance, next_middleware: NextMiddleware):
middleware_order: List[str] = router.context.get("middleware_order", [])
middleware_order: list[str] = router.context.get("middleware_order", [])
middleware_order.append("router")

router.append_context(middleware_order=middleware_order)
@@ -439,14 +444,14 @@ def test_api_gateway_middleware_order_with_include_router_first(app: EventHandle
router = Router()

def global_app_middleware(app: EventHandlerInstance, next_middleware: NextMiddleware):
middleware_order: List[str] = router.context.get("middleware_order", [])
middleware_order: list[str] = router.context.get("middleware_order", [])
middleware_order.append("app")

app.append_context(middleware_order=middleware_order)
return next_middleware(app)

def global_router_middleware(router: EventHandlerInstance, next_middleware: NextMiddleware):
middleware_order: List[str] = router.context.get("middleware_order", [])
middleware_order: list[str] = router.context.get("middleware_order", [])
middleware_order.append("router")

router.append_context(middleware_order=middleware_order)
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from aws_lambda_powertools.event_handler import (
ALBResolver,
APIGatewayHttpResolver,
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from aws_lambda_powertools.event_handler import (
LambdaFunctionUrlResolver,
Response,
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from aws_lambda_powertools.event_handler import (
ALBResolver,
APIGatewayHttpResolver,
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from aws_lambda_powertools.event_handler import (
Response,
VPCLatticeResolver,
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

from aws_lambda_powertools.event_handler import (
Response,
VPCLatticeV2Resolver,