diff --git a/aws_lambda_powertools/event_handler/appsync.py b/aws_lambda_powertools/event_handler/appsync.py index fb53b71c77d..4598f800d78 100644 --- a/aws_lambda_powertools/event_handler/appsync.py +++ b/aws_lambda_powertools/event_handler/appsync.py @@ -49,10 +49,6 @@ def __init__(self): super().__init__() self.context = {} # early init as customers might add context before event resolution - self.current_batch_event: List[AppSyncResolverEvent] = [] - self.current_event: Optional[AppSyncResolverEvent] = None - self.lambda_context: Optional[LambdaContext] = None - def __call__( self, event: dict, @@ -139,10 +135,13 @@ def lambda_handler(event, context): """ self.lambda_context = context + Router.lambda_context = context if isinstance(event, list): + Router.current_batch_event = [data_model(e) for e in event] response = self._call_batch_resolver(event=event, data_model=data_model) else: + Router.current_event = data_model(event) response = self._call_single_resolver(event=event, data_model=data_model) self.clear_context() diff --git a/aws_lambda_powertools/event_handler/graphql_appsync/router.py b/aws_lambda_powertools/event_handler/graphql_appsync/router.py index 2046f09e03c..56830bf9f8d 100644 --- a/aws_lambda_powertools/event_handler/graphql_appsync/router.py +++ b/aws_lambda_powertools/event_handler/graphql_appsync/router.py @@ -1,11 +1,16 @@ -from typing import Callable, Optional +from typing import Callable, List, Optional from aws_lambda_powertools.event_handler.graphql_appsync._registry import ResolverRegistry from aws_lambda_powertools.event_handler.graphql_appsync.base import BaseRouter +from aws_lambda_powertools.utilities.data_classes.appsync_resolver_event import AppSyncResolverEvent +from aws_lambda_powertools.utilities.typing.lambda_context import LambdaContext class Router(BaseRouter): context: dict + current_batch_event: List[AppSyncResolverEvent] = [] + current_event: Optional[AppSyncResolverEvent] = None + lambda_context: Optional[LambdaContext] = None def __init__(self): self.context = {} # early init as customers might add context before event resolution diff --git a/tests/events/appSyncBatchEvent.json b/tests/events/appSyncBatchEvent.json new file mode 100644 index 00000000000..49f98eecc55 --- /dev/null +++ b/tests/events/appSyncBatchEvent.json @@ -0,0 +1,46 @@ +[ + { + "arguments": { + "user_id": "1" + }, + "identity": { + "sub": "192879fc-a240-4bf1-ab5a-d6a00f3063f9", + "username": "jdoe" + }, + "prev": null, + "info": { + "selectionSetList": [ + "id", + "field1", + "field2" + ], + "selectionSetGraphQL": "{\n id\n field1\n field2\n}", + "parentTypeName": "Mutation", + "fieldName": "createSomething", + "variables": {} + }, + "stash": {} + }, + { + "arguments": { + "user_id": "2" + }, + "identity": { + "sub": "192879fc-a240-4bf1-ab5a-d6a00f3063f9", + "username": "jdoe" + }, + "prev": null, + "info": { + "selectionSetList": [ + "id", + "field1", + "field2" + ], + "selectionSetGraphQL": "{\n id\n field1\n field2\n}", + "parentTypeName": "Mutation", + "fieldName": "createSomething", + "variables": {} + }, + "stash": {} + } +] diff --git a/tests/functional/event_handler/required_dependencies/appsync/test_appsync_batch_resolvers.py b/tests/functional/event_handler/required_dependencies/appsync/test_appsync_batch_resolvers.py index cc78cbf0f9c..a6452ee683d 100644 --- a/tests/functional/event_handler/required_dependencies/appsync/test_appsync_batch_resolvers.py +++ b/tests/functional/event_handler/required_dependencies/appsync/test_appsync_batch_resolvers.py @@ -8,6 +8,7 @@ from aws_lambda_powertools.utilities.data_classes import AppSyncResolverEvent from aws_lambda_powertools.utilities.typing import LambdaContext from aws_lambda_powertools.warnings import PowertoolsUserWarning +from tests.functional.utils import load_event # TESTS RECEIVING THE EVENT PARTIALLY AND PROCESS EACH RECORD PER TIME. @@ -710,7 +711,6 @@ def create_something(event: List[AppSyncResolverEvent]) -> List: # noqa AA03 VN assert result == [appsync_event["source"]["id"] for appsync_event in event] assert app.current_batch_event and len(app.current_batch_event) == len(event) - assert not app.current_event def test_resolve_async_batch_processing_with_simple_queries_with_aggregate(): @@ -772,7 +772,6 @@ async def create_something(event: List[AppSyncResolverEvent]) -> List: # noqa A assert result == [appsync_event["source"]["id"] for appsync_event in event] assert app.current_batch_event and len(app.current_batch_event) == len(event) - assert not app.current_event def test_resolve_batch_processing_with_aggregate_and_returning_a_non_list(): @@ -907,3 +906,40 @@ def do_something_with_post_id(post_id): ... # THEN the resolver should raise a InvalidBatchResponse when processing the batch of queries with pytest.raises(InvalidBatchResponse): app.resolve(event, LambdaContext()) + + +def test_include_router_access_batch_current_event(): + mock_event = load_event("appSyncBatchEvent.json") + + # GIVEN An instance of AppSyncResolver, a Router instance, and a resolver function registered with the router + app = AppSyncResolver() + router = Router() + + @router.batch_resolver(field_name="createSomething") + def get_user(event: List) -> List: + return [router.current_batch_event[0].identity.sub] + + app.include_router(router) + + # WHEN we resolve the event + ret = app.resolve(mock_event, {}) + + # THEN the resolver must be able to return a field in the batch_current_event + assert ret[0] == mock_event[0]["identity"]["sub"] + + +def test_app_access_batch_current_event(): + mock_event = load_event("appSyncBatchEvent.json") + + # GIVEN An instance of AppSyncResolver and a resolver function registered with the app + app = AppSyncResolver() + + @app.batch_resolver(field_name="createSomething") + def get_user(event: List) -> List: + return [app.current_batch_event[0].identity.sub] + + # WHEN we resolve the event + ret = app.resolve(mock_event, {}) + + # THEN the resolver must be able to return a field in the batch_current_event + assert ret[0] == mock_event[0]["identity"]["sub"] diff --git a/tests/functional/event_handler/required_dependencies/appsync/test_appsync_single_resolvers.py b/tests/functional/event_handler/required_dependencies/appsync/test_appsync_single_resolvers.py index 1ace266e731..43443323750 100644 --- a/tests/functional/event_handler/required_dependencies/appsync/test_appsync_single_resolvers.py +++ b/tests/functional/event_handler/required_dependencies/appsync/test_appsync_single_resolvers.py @@ -251,3 +251,41 @@ def test_include_router_merges_context(): app.include_router(router) assert app.context == router.context + + +def test_include_router_access_current_event(): + mock_event = load_event("appSyncDirectResolver.json") + + # GIVEN An instance of AppSyncResolver, a Router instance, and a resolver function registered with the router + app = AppSyncResolver() + router = Router() + + @router.resolver(field_name="createSomething") + def get_user(id: str) -> dict: # noqa AA03 VNE003 + return router.current_event.identity.sub + + app.include_router(router) + + # WHEN we resolve the event + ret = app.resolve(mock_event, {}) + + # THEN the resolver must be able to return a field in the current_event + assert ret == mock_event["identity"]["sub"] + + +def test_app_access_current_event(): + # Check whether we can handle an example appsync direct resolver + mock_event = load_event("appSyncDirectResolver.json") + + # GIVEN An instance of AppSyncResolver and a resolver function registered with the app + app = AppSyncResolver() + + @app.resolver(field_name="createSomething") + def get_user(id: str) -> dict: # noqa AA03 VNE003 + return app.current_event.identity.sub + + # WHEN we resolve the event + ret = app.resolve(mock_event, {}) + + # THEN the resolver must be able to return a field in the current_event + assert ret == mock_event["identity"]["sub"]