diff --git a/aws_lambda_powertools/utilities/batch/base.py b/aws_lambda_powertools/utilities/batch/base.py index 1122bea4c03..80191503055 100644 --- a/aws_lambda_powertools/utilities/batch/base.py +++ b/aws_lambda_powertools/utilities/batch/base.py @@ -4,6 +4,7 @@ Batch processing utilities """ import copy +import inspect import logging import sys from abc import ABC, abstractmethod @@ -15,6 +16,7 @@ from aws_lambda_powertools.utilities.data_classes.dynamo_db_stream_event import DynamoDBRecord from aws_lambda_powertools.utilities.data_classes.kinesis_stream_event import KinesisStreamRecord from aws_lambda_powertools.utilities.data_classes.sqs_event import SQSRecord +from aws_lambda_powertools.utilities.typing import LambdaContext logger = logging.getLogger(__name__) @@ -55,6 +57,8 @@ class BasePartialProcessor(ABC): Abstract class for batch processors. """ + lambda_context: LambdaContext + def __init__(self): self.success_messages: List[BatchEventTypes] = [] self.fail_messages: List[BatchEventTypes] = [] @@ -94,7 +98,7 @@ def __enter__(self): def __exit__(self, exception_type, exception_value, traceback): self._clean() - def __call__(self, records: List[dict], handler: Callable): + def __call__(self, records: List[dict], handler: Callable, lambda_context: Optional[LambdaContext] = None): """ Set instance attributes before execution @@ -107,6 +111,31 @@ def __call__(self, records: List[dict], handler: Callable): """ self.records = records self.handler = handler + + # NOTE: If a record handler has `lambda_context` parameter in its function signature, we inject it. + # This is the earliest we can inspect for signature to prevent impacting performance. + # + # Mechanism: + # + # 1. When using the `@batch_processor` decorator, this happens automatically. + # 2. When using the context manager, customers have to include `lambda_context` param. + # + # Scenario: Injects Lambda context + # + # def record_handler(record, lambda_context): ... # noqa: E800 + # with processor(records=batch, handler=record_handler, lambda_context=context): ... # noqa: E800 + # + # Scenario: Does NOT inject Lambda context (default) + # + # def record_handler(record): pass # noqa: E800 + # with processor(records=batch, handler=record_handler): ... # noqa: E800 + # + if lambda_context is None: + self._handler_accepts_lambda_context = False + else: + self.lambda_context = lambda_context + self._handler_accepts_lambda_context = "lambda_context" in inspect.signature(self.handler).parameters + return self def success_handler(self, record, result: Any) -> SuccessResponse: @@ -155,7 +184,7 @@ def failure_handler(self, record, exception: ExceptionInfo) -> FailureResponse: @lambda_handler_decorator def batch_processor( - handler: Callable, event: Dict, context: Dict, record_handler: Callable, processor: BasePartialProcessor + handler: Callable, event: Dict, context: LambdaContext, record_handler: Callable, processor: BasePartialProcessor ): """ Middleware to handle batch event processing @@ -166,7 +195,7 @@ def batch_processor( Lambda's handler event: Dict Lambda's Event - context: Dict + context: LambdaContext Lambda's Context record_handler: Callable Callable to process each record from the batch @@ -193,7 +222,7 @@ def batch_processor( """ records = event["Records"] - with processor(records, record_handler): + with processor(records, record_handler, lambda_context=context): processor.process() return handler(event, context) @@ -365,7 +394,11 @@ def _process_record(self, record: dict) -> Union[SuccessResponse, FailureRespons """ data = self._to_batch_type(record=record, event_type=self.event_type, model=self.model) try: - result = self.handler(record=data) + if self._handler_accepts_lambda_context: + result = self.handler(record=data, lambda_context=self.lambda_context) + else: + result = self.handler(record=data) + return self.success_handler(record=record, result=result) except Exception: return self.failure_handler(record=data, exception=sys.exc_info()) diff --git a/tests/functional/test_utilities_batch.py b/tests/functional/test_utilities_batch.py index a5e1e706437..8654b96e9b1 100644 --- a/tests/functional/test_utilities_batch.py +++ b/tests/functional/test_utilities_batch.py @@ -3,6 +3,7 @@ from random import randint from typing import Callable, Dict, Optional from unittest.mock import patch +from uuid import uuid4 import pytest from botocore.config import Config @@ -24,6 +25,7 @@ from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecord as KinesisDataStreamRecordModel from aws_lambda_powertools.utilities.parser.models import KinesisDataStreamRecordPayload, SqsRecordModel from aws_lambda_powertools.utilities.parser.types import Literal +from aws_lambda_powertools.utilities.typing import LambdaContext from tests.functional.utils import b64_to_str, str_to_b64 @@ -167,6 +169,18 @@ def factory(item: Dict) -> str: return factory +@pytest.fixture(scope="module") +def lambda_context() -> LambdaContext: + class DummyLambdaContext: + def __init__(self): + self.function_name = "test-func" + self.memory_limit_in_mb = 128 + self.invoked_function_arn = "arn:aws:lambda:eu-west-1:809313241234:function:test-func" + self.aws_request_id = f"{uuid4()}" + + return DummyLambdaContext + + @pytest.mark.parametrize( "success_messages_count", ([1, 18, 34]), @@ -908,3 +922,41 @@ def lambda_handler(event, context): # THEN raise BatchProcessingError assert "All records failed processing. " in str(e.value) + + +def test_batch_processor_handler_receives_lambda_context(sqs_event_factory, lambda_context: LambdaContext): + # GIVEN + def record_handler(record, lambda_context: LambdaContext = None): + return lambda_context.function_name == "test-func" + + first_record = SQSRecord(sqs_event_factory("success")) + event = {"Records": [first_record.raw_event]} + + processor = BatchProcessor(event_type=EventType.SQS) + + @batch_processor(record_handler=record_handler, processor=processor) + def lambda_handler(event, context): + return processor.response() + + # WHEN/THEN + lambda_handler(event, lambda_context()) + + +def test_batch_processor_context_manager_handler_receives_lambda_context( + sqs_event_factory, lambda_context: LambdaContext +): + # GIVEN + def record_handler(record, lambda_context: LambdaContext = None): + return lambda_context.function_name == "test-func" + + first_record = SQSRecord(sqs_event_factory("success")) + event = {"Records": [first_record.raw_event]} + + processor = BatchProcessor(event_type=EventType.SQS) + + def lambda_handler(event, context): + with processor(records=event["Records"], handler=record_handler, lambda_context=context) as batch: + batch.process() + + # WHEN/THEN + lambda_handler(event, lambda_context())