diff --git a/aws_lambda_powertools/shared/constants.py b/aws_lambda_powertools/shared/constants.py index 6061462a051..622ffbce47b 100644 --- a/aws_lambda_powertools/shared/constants.py +++ b/aws_lambda_powertools/shared/constants.py @@ -16,5 +16,8 @@ XRAY_TRACE_ID_ENV: str = "_X_AMZN_TRACE_ID" LAMBDA_TASK_ROOT_ENV: str = "LAMBDA_TASK_ROOT" + +LAMBDA_FUNCTION_NAME_ENV: str = "AWS_LAMBDA_FUNCTION_NAME" + XRAY_SDK_MODULE: str = "aws_xray_sdk" XRAY_SDK_CORE_MODULE: str = "aws_xray_sdk.core" diff --git a/aws_lambda_powertools/shared/types.py b/aws_lambda_powertools/shared/types.py new file mode 100644 index 00000000000..c5c91535bd3 --- /dev/null +++ b/aws_lambda_powertools/shared/types.py @@ -0,0 +1,3 @@ +from typing import Any, Callable, TypeVar + +AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) # noqa: VNE001 diff --git a/aws_lambda_powertools/tracing/tracer.py b/aws_lambda_powertools/tracing/tracer.py index e57d24044dc..dc010a3712f 100644 --- a/aws_lambda_powertools/tracing/tracer.py +++ b/aws_lambda_powertools/tracing/tracer.py @@ -5,11 +5,12 @@ import logging import numbers import os -from typing import Any, Awaitable, Callable, Dict, Optional, Sequence, TypeVar, Union, cast, overload +from typing import Any, Callable, Dict, Optional, Sequence, Union, cast, overload from ..shared import constants from ..shared.functions import resolve_env_var_choice, resolve_truthy_env_var_choice from ..shared.lazy_import import LazyLoader +from ..shared.types import AnyCallableT from .base import BaseProvider, BaseSegment is_cold_start = True @@ -18,9 +19,6 @@ aws_xray_sdk = LazyLoader(constants.XRAY_SDK_MODULE, globals(), constants.XRAY_SDK_MODULE) aws_xray_sdk.core = LazyLoader(constants.XRAY_SDK_CORE_MODULE, globals(), constants.XRAY_SDK_CORE_MODULE) -AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any]) # noqa: VNE001 -AnyAwaitableT = TypeVar("AnyAwaitableT", bound=Awaitable) - class Tracer: """Tracer using AWS-XRay to provide decorators with known defaults for Lambda functions diff --git a/aws_lambda_powertools/utilities/idempotency/__init__.py b/aws_lambda_powertools/utilities/idempotency/__init__.py index b46d0855a93..4461453a8be 100644 --- a/aws_lambda_powertools/utilities/idempotency/__init__.py +++ b/aws_lambda_powertools/utilities/idempotency/__init__.py @@ -5,6 +5,6 @@ from aws_lambda_powertools.utilities.idempotency.persistence.base import BasePersistenceLayer from aws_lambda_powertools.utilities.idempotency.persistence.dynamodb import DynamoDBPersistenceLayer -from .idempotency import IdempotencyConfig, idempotent +from .idempotency import IdempotencyConfig, idempotent, idempotent_function -__all__ = ("DynamoDBPersistenceLayer", "BasePersistenceLayer", "idempotent", "IdempotencyConfig") +__all__ = ("DynamoDBPersistenceLayer", "BasePersistenceLayer", "idempotent", "idempotent_function", "IdempotencyConfig") diff --git a/aws_lambda_powertools/utilities/idempotency/base.py b/aws_lambda_powertools/utilities/idempotency/base.py new file mode 100644 index 00000000000..4b82c923a70 --- /dev/null +++ b/aws_lambda_powertools/utilities/idempotency/base.py @@ -0,0 +1,181 @@ +import logging +from typing import Any, Callable, Dict, Optional, Tuple + +from aws_lambda_powertools.utilities.idempotency.config import IdempotencyConfig +from aws_lambda_powertools.utilities.idempotency.exceptions import ( + IdempotencyAlreadyInProgressError, + IdempotencyInconsistentStateError, + IdempotencyItemAlreadyExistsError, + IdempotencyItemNotFoundError, + IdempotencyKeyError, + IdempotencyPersistenceLayerError, + IdempotencyValidationError, +) +from aws_lambda_powertools.utilities.idempotency.persistence.base import ( + STATUS_CONSTANTS, + BasePersistenceLayer, + DataRecord, +) + +MAX_RETRIES = 2 +logger = logging.getLogger(__name__) + + +class IdempotencyHandler: + """ + Base class to orchestrate calls to persistence layer. + """ + + def __init__( + self, + function: Callable, + function_payload: Any, + config: IdempotencyConfig, + persistence_store: BasePersistenceLayer, + function_args: Optional[Tuple] = None, + function_kwargs: Optional[Dict] = None, + ): + """ + Initialize the IdempotencyHandler + + Parameters + ---------- + function_payload: Any + JSON Serializable payload to be hashed + config: IdempotencyConfig + Idempotency Configuration + persistence_store : BasePersistenceLayer + Instance of persistence layer to store idempotency records + function_args: Optional[Tuple] + Function arguments + function_kwargs: Optional[Dict] + Function keyword arguments + """ + self.function = function + self.data = function_payload + self.fn_args = function_args + self.fn_kwargs = function_kwargs + + persistence_store.configure(config) + self.persistence_store = persistence_store + + def handle(self) -> Any: + """ + Main entry point for handling idempotent execution of a function. + + Returns + ------- + Any + Function response + + """ + # IdempotencyInconsistentStateError can happen under rare but expected cases + # when persistent state changes in the small time between put & get requests. + # In most cases we can retry successfully on this exception. + for i in range(MAX_RETRIES + 1): # pragma: no cover + try: + return self._process_idempotency() + except IdempotencyInconsistentStateError: + if i == MAX_RETRIES: + raise # Bubble up when exceeded max tries + + def _process_idempotency(self): + try: + # We call save_inprogress first as an optimization for the most common case where no idempotent record + # already exists. If it succeeds, there's no need to call get_record. + self.persistence_store.save_inprogress(data=self.data) + except IdempotencyKeyError: + raise + except IdempotencyItemAlreadyExistsError: + # Now we know the item already exists, we can retrieve it + record = self._get_idempotency_record() + return self._handle_for_status(record) + except Exception as exc: + raise IdempotencyPersistenceLayerError("Failed to save in progress record to idempotency store") from exc + + return self._get_function_response() + + def _get_idempotency_record(self) -> DataRecord: + """ + Retrieve the idempotency record from the persistence layer. + + Raises + ---------- + IdempotencyInconsistentStateError + + """ + try: + data_record = self.persistence_store.get_record(data=self.data) + except IdempotencyItemNotFoundError: + # This code path will only be triggered if the record is removed between save_inprogress and get_record. + logger.debug( + f"An existing idempotency record was deleted before we could fetch it. Proceeding with {self.function}" + ) + raise IdempotencyInconsistentStateError("save_inprogress and get_record return inconsistent results.") + + # Allow this exception to bubble up + except IdempotencyValidationError: + raise + + # Wrap remaining unhandled exceptions with IdempotencyPersistenceLayerError to ease exception handling for + # clients + except Exception as exc: + raise IdempotencyPersistenceLayerError("Failed to get record from idempotency store") from exc + + return data_record + + def _handle_for_status(self, data_record: DataRecord) -> Optional[Dict[Any, Any]]: + """ + Take appropriate action based on data_record's status + + Parameters + ---------- + data_record: DataRecord + + Returns + ------- + Optional[Dict[Any, Any] + Function's response previously used for this idempotency key, if it has successfully executed already. + + Raises + ------ + AlreadyInProgressError + A function execution is already in progress + IdempotencyInconsistentStateError + The persistence store reports inconsistent states across different requests. Retryable. + """ + # This code path will only be triggered if the record becomes expired between the save_inprogress call and here + if data_record.status == STATUS_CONSTANTS["EXPIRED"]: + raise IdempotencyInconsistentStateError("save_inprogress and get_record return inconsistent results.") + + if data_record.status == STATUS_CONSTANTS["INPROGRESS"]: + raise IdempotencyAlreadyInProgressError( + f"Execution already in progress with idempotency key: " + f"{self.persistence_store.event_key_jmespath}={data_record.idempotency_key}" + ) + + return data_record.response_json_as_dict() + + def _get_function_response(self): + try: + response = self.function(*self.fn_args, **self.fn_kwargs) + except Exception as handler_exception: + # We need these nested blocks to preserve function's exception in case the persistence store operation + # also raises an exception + try: + self.persistence_store.delete_record(data=self.data, exception=handler_exception) + except Exception as delete_exception: + raise IdempotencyPersistenceLayerError( + "Failed to delete record from idempotency store" + ) from delete_exception + raise + + else: + try: + self.persistence_store.save_success(data=self.data, result=response) + except Exception as save_exception: + raise IdempotencyPersistenceLayerError( + "Failed to update record state to success in idempotency store" + ) from save_exception + + return response diff --git a/aws_lambda_powertools/utilities/idempotency/idempotency.py b/aws_lambda_powertools/utilities/idempotency/idempotency.py index fc1d4d47d55..06c9a578aa2 100644 --- a/aws_lambda_powertools/utilities/idempotency/idempotency.py +++ b/aws_lambda_powertools/utilities/idempotency/idempotency.py @@ -1,25 +1,15 @@ """ Primary interface for idempotent Lambda functions utility """ +import functools import logging -from typing import Any, Callable, Dict, Optional +from typing import Any, Callable, Dict, Optional, cast from aws_lambda_powertools.middleware_factory import lambda_handler_decorator +from aws_lambda_powertools.shared.types import AnyCallableT +from aws_lambda_powertools.utilities.idempotency.base import IdempotencyHandler from aws_lambda_powertools.utilities.idempotency.config import IdempotencyConfig -from aws_lambda_powertools.utilities.idempotency.exceptions import ( - IdempotencyAlreadyInProgressError, - IdempotencyInconsistentStateError, - IdempotencyItemAlreadyExistsError, - IdempotencyItemNotFoundError, - IdempotencyKeyError, - IdempotencyPersistenceLayerError, - IdempotencyValidationError, -) -from aws_lambda_powertools.utilities.idempotency.persistence.base import ( - STATUS_CONSTANTS, - BasePersistenceLayer, - DataRecord, -) +from aws_lambda_powertools.utilities.idempotency.persistence.base import BasePersistenceLayer from aws_lambda_powertools.utilities.typing import LambdaContext logger = logging.getLogger(__name__) @@ -32,9 +22,10 @@ def idempotent( context: LambdaContext, persistence_store: BasePersistenceLayer, config: Optional[IdempotencyConfig] = None, + **kwargs, ) -> Any: """ - Middleware to handle idempotency + Decorator to handle idempotency Parameters ---------- @@ -66,174 +57,88 @@ def idempotent( """ config = config or IdempotencyConfig() + args = event, context idempotency_handler = IdempotencyHandler( - lambda_handler=handler, event=event, context=context, persistence_store=persistence_store, config=config + function=handler, + function_payload=event, + config=config, + persistence_store=persistence_store, + function_args=args, + function_kwargs=kwargs, ) - # IdempotencyInconsistentStateError can happen under rare but expected cases when persistent state changes in the - # small time between put & get requests. In most cases we can retry successfully on this exception. - # Maintenance: Allow customers to specify number of retries - max_handler_retries = 2 - for i in range(max_handler_retries + 1): - try: - return idempotency_handler.handle() - except IdempotencyInconsistentStateError: - if i == max_handler_retries: - # Allow the exception to bubble up after max retries exceeded - raise + return idempotency_handler.handle() -class IdempotencyHandler: +def idempotent_function( + function: Optional[AnyCallableT] = None, + *, + data_keyword_argument: str, + persistence_store: BasePersistenceLayer, + config: Optional[IdempotencyConfig] = None, +) -> Any: """ - Class to orchestrate calls to persistence layer. + Decorator to handle idempotency of any function + + Parameters + ---------- + function: Callable + Function to be decorated + data_keyword_argument: str + Keyword parameter name in function's signature that we should hash as idempotency key, e.g. "order" + persistence_store: BasePersistenceLayer + Instance of BasePersistenceLayer to store data + config: IdempotencyConfig + Configuration + + Examples + -------- + **Processes an order in an idempotent manner** + + from aws_lambda_powertools.utilities.idempotency import ( + idempotent_function, DynamoDBPersistenceLayer, IdempotencyConfig + ) + + idem_config=IdempotencyConfig(event_key_jmespath="order_id") + persistence_layer = DynamoDBPersistenceLayer(table_name="idempotency_store") + + @idempotent_function(data_keyword_argument="order", config=idem_config, persistence_store=persistence_layer) + def process_order(customer_id: str, order: dict, **kwargs): + return {"StatusCode": 200} """ - def __init__( - self, - lambda_handler: Callable[[Any, LambdaContext], Any], - event: Dict[str, Any], - context: LambdaContext, - config: IdempotencyConfig, - persistence_store: BasePersistenceLayer, - ): - """ - Initialize the IdempotencyHandler - - Parameters - ---------- - lambda_handler : Callable[[Any, LambdaContext], Any] - Lambda function handler - event : Dict[str, Any] - Event payload lambda handler will be called with - context : LambdaContext - Context object which will be passed to lambda handler - persistence_store : BasePersistenceLayer - Instance of persistence layer to store idempotency records - """ - persistence_store.configure(config) - self.persistence_store = persistence_store - self.context = context - self.event = event - self.lambda_handler = lambda_handler - - def handle(self) -> Any: - """ - Main entry point for handling idempotent execution of lambda handler. - - Returns - ------- - Any - lambda handler response - - """ - try: - # We call save_inprogress first as an optimization for the most common case where no idempotent record - # already exists. If it succeeds, there's no need to call get_record. - self.persistence_store.save_inprogress(event=self.event, context=self.context) - except IdempotencyKeyError: - raise - except IdempotencyItemAlreadyExistsError: - # Now we know the item already exists, we can retrieve it - record = self._get_idempotency_record() - return self._handle_for_status(record) - except Exception as exc: - raise IdempotencyPersistenceLayerError("Failed to save in progress record to idempotency store") from exc - - return self._call_lambda_handler() - - def _get_idempotency_record(self) -> DataRecord: - """ - Retrieve the idempotency record from the persistence layer. - - Raises - ---------- - IdempotencyInconsistentStateError - - """ - try: - event_record = self.persistence_store.get_record(event=self.event, context=self.context) - except IdempotencyItemNotFoundError: - # This code path will only be triggered if the record is removed between save_inprogress and get_record. - logger.debug( - "An existing idempotency record was deleted before we could retrieve it. Proceeding with lambda " - "handler" - ) - raise IdempotencyInconsistentStateError("save_inprogress and get_record return inconsistent results.") - - # Allow this exception to bubble up - except IdempotencyValidationError: - raise - - # Wrap remaining unhandled exceptions with IdempotencyPersistenceLayerError to ease exception handling for - # clients - except Exception as exc: - raise IdempotencyPersistenceLayerError("Failed to get record from idempotency store") from exc - - return event_record - - def _handle_for_status(self, event_record: DataRecord) -> Optional[Dict[Any, Any]]: - """ - Take appropriate action based on event_record's status - - Parameters - ---------- - event_record: DataRecord - - Returns - ------- - Optional[Dict[Any, Any] - Lambda response previously used for this idempotency key, if it has successfully executed already. - - Raises - ------ - AlreadyInProgressError - A lambda execution is already in progress - IdempotencyInconsistentStateError - The persistence store reports inconsistent states across different requests. Retryable. - """ - # This code path will only be triggered if the record becomes expired between the save_inprogress call and here - if event_record.status == STATUS_CONSTANTS["EXPIRED"]: - raise IdempotencyInconsistentStateError("save_inprogress and get_record return inconsistent results.") - - if event_record.status == STATUS_CONSTANTS["INPROGRESS"]: - raise IdempotencyAlreadyInProgressError( - f"Execution already in progress with idempotency key: " - f"{self.persistence_store.event_key_jmespath}={event_record.idempotency_key}" + if function is None: + return cast( + AnyCallableT, + functools.partial( + idempotent_function, + data_keyword_argument=data_keyword_argument, + persistence_store=persistence_store, + config=config, + ), + ) + + config = config or IdempotencyConfig() + + @functools.wraps(function) + def decorate(*args, **kwargs): + payload = kwargs.get(data_keyword_argument) + + if payload is None: + raise RuntimeError( + f"Unable to extract '{data_keyword_argument}' from keyword arguments." + f" Ensure this exists in your function's signature as well as the caller used it as a keyword argument" ) - return event_record.response_json_as_dict() - - def _call_lambda_handler(self) -> Any: - """ - Call the lambda handler function and update the persistence store appropriate depending on the output - - Returns - ------- - Any - lambda handler response - - """ - try: - handler_response = self.lambda_handler(self.event, self.context) - except Exception as handler_exception: - # We need these nested blocks to preserve lambda handler exception in case the persistence store operation - # also raises an exception - try: - self.persistence_store.delete_record( - event=self.event, context=self.context, exception=handler_exception - ) - except Exception as delete_exception: - raise IdempotencyPersistenceLayerError( - "Failed to delete record from idempotency store" - ) from delete_exception - raise - - else: - try: - self.persistence_store.save_success(event=self.event, context=self.context, result=handler_response) - except Exception as save_exception: - raise IdempotencyPersistenceLayerError( - "Failed to update record state to success in idempotency store" - ) from save_exception - - return handler_response + idempotency_handler = IdempotencyHandler( + function=function, + function_payload=payload, + config=config, + persistence_store=persistence_store, + function_args=args, + function_kwargs=kwargs, + ) + + return idempotency_handler.handle() + + return cast(AnyCallableT, decorate) diff --git a/aws_lambda_powertools/utilities/idempotency/persistence/base.py b/aws_lambda_powertools/utilities/idempotency/persistence/base.py index 0388adfbf55..2f5dd512ac6 100644 --- a/aws_lambda_powertools/utilities/idempotency/persistence/base.py +++ b/aws_lambda_powertools/utilities/idempotency/persistence/base.py @@ -6,6 +6,7 @@ import hashlib import json import logging +import os import warnings from abc import ABC, abstractmethod from types import MappingProxyType @@ -13,6 +14,7 @@ import jmespath +from aws_lambda_powertools.shared import constants from aws_lambda_powertools.shared.cache_dict import LRUDict from aws_lambda_powertools.shared.jmespath_utils import PowertoolsFunctions from aws_lambda_powertools.shared.json_encoder import Encoder @@ -23,7 +25,6 @@ IdempotencyKeyError, IdempotencyValidationError, ) -from aws_lambda_powertools.utilities.typing import LambdaContext logger = logging.getLogger(__name__) @@ -153,16 +154,14 @@ def configure(self, config: IdempotencyConfig) -> None: self._cache = LRUDict(max_items=config.local_cache_max_items) self.hash_function = getattr(hashlib, config.hash_function) - def _get_hashed_idempotency_key(self, event: Dict[str, Any], context: LambdaContext) -> str: + def _get_hashed_idempotency_key(self, data: Dict[str, Any]) -> str: """ - Extract data from lambda event using event key jmespath, and return a hashed representation + Extract idempotency key and return a hashed representation Parameters ---------- - event: Dict[str, Any] - Lambda event - context: LambdaContext - Lambda context + data: Dict[str, Any] + Incoming data Returns ------- @@ -170,18 +169,17 @@ def _get_hashed_idempotency_key(self, event: Dict[str, Any], context: LambdaCont Hashed representation of the data extracted by the jmespath expression """ - data = event - if self.event_key_jmespath: - data = self.event_key_compiled_jmespath.search(event, options=jmespath.Options(**self.jmespath_options)) + data = self.event_key_compiled_jmespath.search(data, options=jmespath.Options(**self.jmespath_options)) - if self.is_missing_idempotency_key(data): + if self.is_missing_idempotency_key(data=data): if self.raise_on_no_idempotency_key: raise IdempotencyKeyError("No data found to create a hashed idempotency_key") warnings.warn(f"No value found for idempotency_key. jmespath: {self.event_key_jmespath}") - generated_hash = self._generate_hash(data) - return f"{context.function_name}#{generated_hash}" + generated_hash = self._generate_hash(data=data) + function_name = os.getenv(constants.LAMBDA_FUNCTION_NAME_ENV, "test-func") + return f"{function_name}#{generated_hash}" @staticmethod def is_missing_idempotency_key(data) -> bool: @@ -189,14 +187,14 @@ def is_missing_idempotency_key(data) -> bool: return all(x is None for x in data) return not data - def _get_hashed_payload(self, lambda_event: Dict[str, Any]) -> str: + def _get_hashed_payload(self, data: Dict[str, Any]) -> str: """ - Extract data from lambda event using validation key jmespath, and return a hashed representation + Extract payload using validation key jmespath and return a hashed representation Parameters ---------- - lambda_event: Dict[str, Any] - Lambda event + data: Dict[str, Any] + Payload Returns ------- @@ -206,8 +204,8 @@ def _get_hashed_payload(self, lambda_event: Dict[str, Any]) -> str: """ if not self.payload_validation_enabled: return "" - data = self.validation_key_jmespath.search(lambda_event) - return self._generate_hash(data) + data = self.validation_key_jmespath.search(data) + return self._generate_hash(data=data) def _generate_hash(self, data: Any) -> str: """ @@ -228,26 +226,26 @@ def _generate_hash(self, data: Any) -> str: hashed_data = self.hash_function(json.dumps(data, cls=Encoder).encode()) return hashed_data.hexdigest() - def _validate_payload(self, lambda_event: Dict[str, Any], data_record: DataRecord) -> None: + def _validate_payload(self, data: Dict[str, Any], data_record: DataRecord) -> None: """ - Validate that the hashed payload matches in the lambda event and stored data record + Validate that the hashed payload matches data provided and stored data record Parameters ---------- - lambda_event: Dict[str, Any] - Lambda event + data: Dict[str, Any] + Payload data_record: DataRecord DataRecord instance Raises ---------- IdempotencyValidationError - Event payload doesn't match the stored record for the given idempotency key + Payload doesn't match the stored record for the given idempotency key """ if self.payload_validation_enabled: - lambda_payload_hash = self._get_hashed_payload(lambda_event) - if data_record.payload_hash != lambda_payload_hash: + data_hash = self._get_hashed_payload(data=data) + if data_record.payload_hash != data_hash: raise IdempotencyValidationError("Payload does not match stored record for this event key") def _get_expiry_timestamp(self) -> int: @@ -288,12 +286,12 @@ def _save_to_cache(self, data_record: DataRecord): def _retrieve_from_cache(self, idempotency_key: str): if not self.use_local_cache: return - cached_record = self._cache.get(idempotency_key) + cached_record = self._cache.get(key=idempotency_key) if cached_record: if not cached_record.is_expired: return cached_record logger.debug(f"Removing expired local cache record for idempotency key: {idempotency_key}") - self._delete_from_cache(idempotency_key) + self._delete_from_cache(idempotency_key=idempotency_key) def _delete_from_cache(self, idempotency_key: str): if not self.use_local_cache: @@ -301,52 +299,48 @@ def _delete_from_cache(self, idempotency_key: str): if idempotency_key in self._cache: del self._cache[idempotency_key] - def save_success(self, event: Dict[str, Any], context: LambdaContext, result: dict) -> None: + def save_success(self, data: Dict[str, Any], result: dict) -> None: """ Save record of function's execution completing successfully Parameters ---------- - event: Dict[str, Any] - Lambda event - context: LambdaContext - Lambda context + data: Dict[str, Any] + Payload result: dict - The response from lambda handler + The response from function """ response_data = json.dumps(result, cls=Encoder) data_record = DataRecord( - idempotency_key=self._get_hashed_idempotency_key(event, context), + idempotency_key=self._get_hashed_idempotency_key(data=data), status=STATUS_CONSTANTS["COMPLETED"], expiry_timestamp=self._get_expiry_timestamp(), response_data=response_data, - payload_hash=self._get_hashed_payload(event), + payload_hash=self._get_hashed_payload(data=data), ) logger.debug( - f"Lambda successfully executed. Saving record to persistence store with " + f"Function successfully executed. Saving record to persistence store with " f"idempotency key: {data_record.idempotency_key}" ) self._update_record(data_record=data_record) - self._save_to_cache(data_record) + self._save_to_cache(data_record=data_record) - def save_inprogress(self, event: Dict[str, Any], context: LambdaContext) -> None: + def save_inprogress(self, data: Dict[str, Any]) -> None: """ Save record of function's execution being in progress Parameters ---------- - event: Dict[str, Any] - Lambda event - context: LambdaContext - Lambda context + data: Dict[str, Any] + Payload """ data_record = DataRecord( - idempotency_key=self._get_hashed_idempotency_key(event, context), + idempotency_key=self._get_hashed_idempotency_key(data=data), status=STATUS_CONSTANTS["INPROGRESS"], expiry_timestamp=self._get_expiry_timestamp(), - payload_hash=self._get_hashed_payload(event), + payload_hash=self._get_hashed_payload(data=data), ) logger.debug(f"Saving in progress record for idempotency key: {data_record.idempotency_key}") @@ -354,42 +348,37 @@ def save_inprogress(self, event: Dict[str, Any], context: LambdaContext) -> None if self._retrieve_from_cache(idempotency_key=data_record.idempotency_key): raise IdempotencyItemAlreadyExistsError - self._put_record(data_record) + self._put_record(data_record=data_record) - def delete_record(self, event: Dict[str, Any], context: LambdaContext, exception: Exception): + def delete_record(self, data: Dict[str, Any], exception: Exception): """ Delete record from the persistence store Parameters ---------- - event: Dict[str, Any] - Lambda event - context: LambdaContext - Lambda context + data: Dict[str, Any] + Payload exception - The exception raised by the lambda handler + The exception raised by the function """ - data_record = DataRecord(idempotency_key=self._get_hashed_idempotency_key(event, context)) + data_record = DataRecord(idempotency_key=self._get_hashed_idempotency_key(data=data)) logger.debug( - f"Lambda raised an exception ({type(exception).__name__}). Clearing in progress record in persistence " + f"Function raised an exception ({type(exception).__name__}). Clearing in progress record in persistence " f"store for idempotency key: {data_record.idempotency_key}" ) - self._delete_record(data_record) + self._delete_record(data_record=data_record) - self._delete_from_cache(data_record.idempotency_key) + self._delete_from_cache(idempotency_key=data_record.idempotency_key) - def get_record(self, event: Dict[str, Any], context: LambdaContext) -> DataRecord: + def get_record(self, data: Dict[str, Any]) -> DataRecord: """ - Calculate idempotency key for lambda_event, then retrieve item from persistence store using idempotency key - and return it as a DataRecord instance.and return it as a DataRecord instance. + Retrieve idempotency key for data provided, fetch from persistence store, and convert to DataRecord. Parameters ---------- - event: Dict[str, Any] - Lambda event - context: LambdaContext - Lambda context + data: Dict[str, Any] + Payload Returns ------- @@ -401,22 +390,22 @@ def get_record(self, event: Dict[str, Any], context: LambdaContext) -> DataRecor IdempotencyItemNotFoundError Exception raised if no record exists in persistence store with the idempotency key IdempotencyValidationError - Event payload doesn't match the stored record for the given idempotency key + Payload doesn't match the stored record for the given idempotency key """ - idempotency_key = self._get_hashed_idempotency_key(event, context) + idempotency_key = self._get_hashed_idempotency_key(data=data) cached_record = self._retrieve_from_cache(idempotency_key=idempotency_key) if cached_record: logger.debug(f"Idempotency record found in cache with idempotency key: {idempotency_key}") - self._validate_payload(event, cached_record) + self._validate_payload(data=data, data_record=cached_record) return cached_record - record = self._get_record(idempotency_key) + record = self._get_record(idempotency_key=idempotency_key) self._save_to_cache(data_record=record) - self._validate_payload(event, record) + self._validate_payload(data=data, data_record=record) return record @abstractmethod diff --git a/docs/utilities/idempotency.md b/docs/utilities/idempotency.md index 8a0d1c81d5a..d941946b681 100644 --- a/docs/utilities/idempotency.md +++ b/docs/utilities/idempotency.md @@ -121,7 +121,83 @@ You can quickly start by initializing the `DynamoDBPersistenceLayer` class and u } ``` -#### Choosing a payload subset for idempotency +### Idempotent_function decorator + +Similar to [idempotent decorator](#idempotent-decorator), you can use `idempotent_function` decorator for any synchronous Python function. + +When using `idempotent_function`, you must tell us which keyword parameter in your function signature has the data we should use via **`data_keyword_argument`** - Such data must be JSON serializable. + + + +!!! warning "Make sure to call your decorated function using keyword arguments" + +=== "app.py" + + This example also demonstrates how you can integrate with [Batch utility](batch.md), so you can process each record in an idempotent manner. + + ```python hl_lines="4 13 18 25" + import uuid + + from aws_lambda_powertools.utilities.batch import sqs_batch_processor + from aws_lambda_powertools.utilities.idempotency import idempotent_function, DynamoDBPersistenceLayer, IdempotencyConfig + + + dynamodb = DynamoDBPersistenceLayer(table_name="idem") + config = IdempotencyConfig( + event_key_jmespath="messageId", # see "Choosing a payload subset for idempotency" section + use_local_cache=True, + ) + + @idempotent_function(data_keyword_argument="data", config=config, persistence_store=dynamodb) + def dummy(arg_one, arg_two, data: dict, **kwargs): + return {"data": data} + + + @idempotent_function(data_keyword_argument="record", config=config, persistence_store=dynamodb) + def record_handler(record): + return {"message": record["body"]} + + + @sqs_batch_processor(record_handler=record_handler) + def lambda_handler(event, context): + # `data` parameter must be called as a keyword argument to work + dummy("hello", "universe", data="test") + return {"statusCode": 200} + ``` + +=== "Example event" + + ```json hl_lines="4" + { + "Records": [ + { + "messageId": "059f36b4-87a3-44ab-83d2-661975830a7d", + "receiptHandle": "AQEBwJnKyrHigUMZj6rYigCgxlaS3SLy0a...", + "body": "Test message.", + "attributes": { + "ApproximateReceiveCount": "1", + "SentTimestamp": "1545082649183", + "SenderId": "AIDAIENQZJOLO23YVJ4VO", + "ApproximateFirstReceiveTimestamp": "1545082649185" + }, + "messageAttributes": { + "testAttr": { + "stringValue": "100", + "binaryValue": "base64Str", + "dataType": "Number" + } + }, + "md5OfBody": "e4e68fb7bd0e697a0ae8f1bb342846b3", + "eventSource": "aws:sqs", + "eventSourceARN": "arn:aws:sqs:us-east-2:123456789012:my-queue", + "awsRegion": "us-east-2" + } + ] + } + ``` + + +### Choosing a payload subset for idempotency !!! tip "Dealing with always changing payloads" When dealing with a more elaborate payload, where parts of the payload always change, you should use **`event_key_jmespath`** parameter. @@ -198,7 +274,7 @@ Imagine the function executes successfully, but the client never receives the re } ``` -#### Idempotency request flow +### Idempotency request flow This sequence diagram shows an example flow of what happens in the payment scenario: diff --git a/pyproject.toml b/pyproject.toml index 5a92641e26c..3c9053f71fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ pydantic = ["pydantic", "email-validator"] [tool.coverage.run] source = ["aws_lambda_powertools"] -omit = ["tests/*", "aws_lambda_powertools/exceptions/*"] +omit = ["tests/*", "aws_lambda_powertools/exceptions/*", "aws_lambda_powertools/utilities/parser/types.py"] branch = true [tool.coverage.html] diff --git a/tests/functional/idempotency/test_idempotency.py b/tests/functional/idempotency/test_idempotency.py index 0ecc84b7f9c..5505a7dc5c9 100644 --- a/tests/functional/idempotency/test_idempotency.py +++ b/tests/functional/idempotency/test_idempotency.py @@ -18,7 +18,7 @@ IdempotencyPersistenceLayerError, IdempotencyValidationError, ) -from aws_lambda_powertools.utilities.idempotency.idempotency import idempotent +from aws_lambda_powertools.utilities.idempotency.idempotency import idempotent, idempotent_function from aws_lambda_powertools.utilities.idempotency.persistence.base import BasePersistenceLayer, DataRecord from aws_lambda_powertools.utilities.validation import envelopes, validator from tests.functional.utils import load_event @@ -221,7 +221,6 @@ def lambda_handler(event, context): stubber.deactivate() -@pytest.mark.skipif(sys.version_info < (3, 8), reason="issue with pytest mock lib for < 3.8") @pytest.mark.parametrize("idempotency_config", [{"use_local_cache": True}], indirect=True) def test_idempotent_lambda_first_execution_cached( idempotency_config: IdempotencyConfig, @@ -255,7 +254,7 @@ def lambda_handler(event, context): retrieve_from_cache_spy.assert_called_once() save_to_cache_spy.assert_called_once() - assert save_to_cache_spy.call_args[0][0].status == "COMPLETED" + assert save_to_cache_spy.call_args[1]["data_record"].status == "COMPLETED" assert persistence_store._cache.get(hashed_idempotency_key).status == "COMPLETED" # This lambda call should not call AWS API @@ -739,7 +738,7 @@ def test_default_no_raise_on_missing_idempotency_key( assert "body" in persistence_store.event_key_jmespath # WHEN getting the hashed idempotency key for an event with no `body` key - hashed_key = persistence_store._get_hashed_idempotency_key({}, lambda_context) + hashed_key = persistence_store._get_hashed_idempotency_key({}) # THEN return the hash of None expected_value = "test-func#" + md5(json.dumps(None).encode()).hexdigest() @@ -760,7 +759,7 @@ def test_raise_on_no_idempotency_key( # WHEN getting the hashed idempotency key for an event with no `body` key with pytest.raises(IdempotencyKeyError) as excinfo: - persistence_store._get_hashed_idempotency_key({}, lambda_context) + persistence_store._get_hashed_idempotency_key({}) # THEN raise IdempotencyKeyError error assert "No data found to create a hashed idempotency_key" in str(excinfo.value) @@ -790,7 +789,7 @@ def test_jmespath_with_powertools_json( } # WHEN calling _get_hashed_idempotency_key - result = persistence_store._get_hashed_idempotency_key(api_gateway_proxy_event, lambda_context) + result = persistence_store._get_hashed_idempotency_key(api_gateway_proxy_event) # THEN the hashed idempotency key should match the extracted values generated hash assert result == "test-func#" + persistence_store._generate_hash(expected_value) @@ -807,7 +806,7 @@ def test_custom_jmespath_function_overrides_builtin_functions( with pytest.raises(jmespath.exceptions.UnknownFunctionError, match="Unknown function: powertools_json()"): # WHEN calling _get_hashed_idempotency_key # THEN raise unknown function - persistence_store._get_hashed_idempotency_key({}, lambda_context) + persistence_store._get_hashed_idempotency_key({}) def test_idempotent_lambda_save_inprogress_error(persistence_store: DynamoDBPersistenceLayer, lambda_context): @@ -885,3 +884,95 @@ def lambda_handler(event, _): result = lambda_handler(mock_event, lambda_context) # THEN we expect the handler to execute successfully assert result == expected_result + + +def test_idempotent_function(): + # Scenario to validate we can use idempotent_function with any function + mock_event = {"data": "value"} + persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(json.dumps(mock_event).encode()).hexdigest()) + expected_result = {"message": "Foo"} + + @idempotent_function(persistence_store=persistence_layer, data_keyword_argument="record") + def record_handler(record): + return expected_result + + # WHEN calling the function + result = record_handler(record=mock_event) + # THEN we expect the function to execute successfully + assert result == expected_result + + +def test_idempotent_function_arbitrary_args_kwargs(): + # Scenario to validate we can use idempotent_function with a function + # with an arbitrary number of args and kwargs + mock_event = {"data": "value"} + persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(json.dumps(mock_event).encode()).hexdigest()) + expected_result = {"message": "Foo"} + + @idempotent_function(persistence_store=persistence_layer, data_keyword_argument="record") + def record_handler(arg_one, arg_two, record, is_record): + return expected_result + + # WHEN calling the function + result = record_handler("foo", "bar", record=mock_event, is_record=True) + # THEN we expect the function to execute successfully + assert result == expected_result + + +def test_idempotent_function_invalid_data_kwarg(): + mock_event = {"data": "value"} + persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(json.dumps(mock_event).encode()).hexdigest()) + expected_result = {"message": "Foo"} + keyword_argument = "payload" + + # GIVEN data_keyword_argument does not match fn signature + @idempotent_function(persistence_store=persistence_layer, data_keyword_argument=keyword_argument) + def record_handler(record): + return expected_result + + # WHEN calling the function + # THEN we expect to receive a Runtime error + with pytest.raises(RuntimeError, match=f"Unable to extract '{keyword_argument}'"): + record_handler(record=mock_event) + + +def test_idempotent_function_arg_instead_of_kwarg(): + mock_event = {"data": "value"} + persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(json.dumps(mock_event).encode()).hexdigest()) + expected_result = {"message": "Foo"} + keyword_argument = "record" + + # GIVEN data_keyword_argument matches fn signature + @idempotent_function(persistence_store=persistence_layer, data_keyword_argument=keyword_argument) + def record_handler(record): + return expected_result + + # WHEN calling the function without named argument + # THEN we expect to receive a Runtime error + with pytest.raises(RuntimeError, match=f"Unable to extract '{keyword_argument}'"): + record_handler(mock_event) + + +def test_idempotent_function_and_lambda_handler(lambda_context): + # Scenario to validate we can use both idempotent_function and idempotent decorators + mock_event = {"data": "value"} + persistence_layer = MockPersistenceLayer("test-func#" + hashlib.md5(json.dumps(mock_event).encode()).hexdigest()) + expected_result = {"message": "Foo"} + + @idempotent_function(persistence_store=persistence_layer, data_keyword_argument="record") + def record_handler(record): + return expected_result + + @idempotent(persistence_store=persistence_layer) + def lambda_handler(event, _): + return expected_result + + # WHEN calling the function + fn_result = record_handler(record=mock_event) + + # WHEN calling lambda handler + handler_result = lambda_handler(mock_event, lambda_context) + + # THEN we expect the function and lambda handler to execute successfully + assert fn_result == expected_result + assert handler_result == expected_result