diff --git a/README.md b/README.md index 234ba219a32..68090913dbe 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,7 @@ A suite of utilities for AWS Lambda Functions that makes tracing with AWS X-Ray, * **[Logging](https://awslabs.github.io/aws-lambda-powertools-python/core/logger/)** - Structured logging made easier, and decorator to enrich structured logging with key Lambda context details * **[Metrics](https://awslabs.github.io/aws-lambda-powertools-python/core/metrics/)** - Custom Metrics created asynchronously via CloudWatch Embedded Metric Format (EMF) * **[Bring your own middleware](https://awslabs.github.io/aws-lambda-powertools-python/utilities/middleware_factory/)** - Decorator factory to create your own middleware to run logic before, and after each Lambda invocation +* **[Parameters utility](https://awslabs.github.io/aws-lambda-powertools-python/utilities/parameters/)** - Retrieve and cache parameter values from Parameter Store, Secrets Manager, or DynamoDB ### Installation diff --git a/aws_lambda_powertools/utilities/__init__.py b/aws_lambda_powertools/utilities/__init__.py new file mode 100644 index 00000000000..67be909187a --- /dev/null +++ b/aws_lambda_powertools/utilities/__init__.py @@ -0,0 +1,3 @@ +# -*- coding: utf-8 -*- + +"""General utilities for Powertools""" diff --git a/aws_lambda_powertools/utilities/parameters/__init__.py b/aws_lambda_powertools/utilities/parameters/__init__.py new file mode 100644 index 00000000000..07f9cb2ca76 --- /dev/null +++ b/aws_lambda_powertools/utilities/parameters/__init__.py @@ -0,0 +1,23 @@ +# -*- coding: utf-8 -*- + +""" +Parameter retrieval and caching utility +""" + +from .base import BaseProvider +from .dynamodb import DynamoDBProvider +from .exceptions import GetParameterError, TransformParameterError +from .secrets import SecretsProvider, get_secret +from .ssm import SSMProvider, get_parameter, get_parameters + +__all__ = [ + "BaseProvider", + "GetParameterError", + "DynamoDBProvider", + "SecretsProvider", + "SSMProvider", + "TransformParameterError", + "get_parameter", + "get_parameters", + "get_secret", +] diff --git a/aws_lambda_powertools/utilities/parameters/base.py b/aws_lambda_powertools/utilities/parameters/base.py new file mode 100644 index 00000000000..8a552b53bcb --- /dev/null +++ b/aws_lambda_powertools/utilities/parameters/base.py @@ -0,0 +1,190 @@ +""" +Base for Parameter providers +""" + +import base64 +import json +from abc import ABC, abstractmethod +from collections import namedtuple +from datetime import datetime, timedelta +from typing import Dict, Optional, Union + +from .exceptions import GetParameterError, TransformParameterError + +DEFAULT_MAX_AGE_SECS = 5 +ExpirableValue = namedtuple("ExpirableValue", ["value", "ttl"]) +# These providers will be dynamically initialized on first use of the helper functions +DEFAULT_PROVIDERS = {} + + +class BaseProvider(ABC): + """ + Abstract Base Class for Parameter providers + """ + + store = None + + def __init__(self): + """ + Initialize the base provider + """ + + self.store = {} + + def get( + self, name: str, max_age: int = DEFAULT_MAX_AGE_SECS, transform: Optional[str] = None, **sdk_options + ) -> Union[str, list, dict, bytes]: + """ + Retrieve a parameter value or return the cached value + + Parameters + ---------- + name: str + Parameter name + max_age: int + Maximum age of the cached value + transform: str + Optional transformation of the parameter value. Supported values + are "json" for JSON strings and "binary" for base 64 encoded + values. + sdk_options: dict, optional + Arguments that will be passed directly to the underlying API call + + Raises + ------ + GetParameterError + When the parameter provider fails to retrieve a parameter value for + a given name. + TransformParameterError + When the parameter provider fails to transform a parameter value. + """ + + # If there are multiple calls to the same parameter but in a different + # transform, they will be stored multiple times. This allows us to + # optimize by transforming the data only once per retrieval, thus there + # is no need to transform cached values multiple times. However, this + # means that we need to make multiple calls to the underlying parameter + # store if we need to return it in different transforms. Since the number + # of supported transform is small and the probability that a given + # parameter will always be used in a specific transform, this should be + # an acceptable tradeoff. + key = (name, transform) + + if key not in self.store or self.store[key].ttl < datetime.now(): + try: + value = self._get(name, **sdk_options) + # Encapsulate all errors into a generic GetParameterError + except Exception as exc: + raise GetParameterError(str(exc)) + + if transform is not None: + value = transform_value(value, transform) + + self.store[key] = ExpirableValue(value, datetime.now() + timedelta(seconds=max_age),) + + return self.store[key].value + + @abstractmethod + def _get(self, name: str, **sdk_options) -> str: + """ + Retrieve paramater value from the underlying parameter store + """ + raise NotImplementedError() + + def get_multiple( + self, + path: str, + max_age: int = DEFAULT_MAX_AGE_SECS, + transform: Optional[str] = None, + raise_on_transform_error: bool = False, + **sdk_options, + ) -> Union[Dict[str, str], Dict[str, dict], Dict[str, bytes]]: + """ + Retrieve multiple parameters based on a path prefix + + Parameters + ---------- + path: str + Parameter path used to retrieve multiple parameters + max_age: int, optional + Maximum age of the cached value + transform: str, optional + Optional transformation of the parameter value. Supported values + are "json" for JSON strings and "binary" for base 64 encoded + values. + raise_on_transform_error: bool, optional + Raises an exception if any transform fails, otherwise this will + return a None value for each transform that failed + sdk_options: dict, optional + Arguments that will be passed directly to the underlying API call + + Raises + ------ + GetParameterError + When the parameter provider fails to retrieve parameter values for + a given path. + TransformParameterError + When the parameter provider fails to transform a parameter value. + """ + + key = (path, transform) + + if key not in self.store or self.store[key].ttl < datetime.now(): + try: + values = self._get_multiple(path, **sdk_options) + # Encapsulate all errors into a generic GetParameterError + except Exception as exc: + raise GetParameterError(str(exc)) + + if transform is not None: + new_values = {} + for key, value in values.items(): + try: + new_values[key] = transform_value(value, transform) + except Exception as exc: + if raise_on_transform_error: + raise exc + else: + new_values[key] = None + + values = new_values + + self.store[key] = ExpirableValue(values, datetime.now() + timedelta(seconds=max_age),) + + return self.store[key].value + + @abstractmethod + def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: + """ + Retrieve multiple parameter values from the underlying parameter store + """ + raise NotImplementedError() + + +def transform_value(value: str, transform: str) -> Union[dict, bytes]: + """ + Apply a transform to a value + + Parameters + --------- + value: str + Parameter alue to transform + transform: str + Type of transform, supported values are "json" and "binary" + + Raises + ------ + TransformParameterError: + When the parameter value could not be transformed + """ + + try: + if transform == "json": + return json.loads(value) + elif transform == "binary": + return base64.b64decode(value) + else: + raise ValueError(f"Invalid transform type '{transform}'") + + except Exception as exc: + raise TransformParameterError(str(exc)) diff --git a/aws_lambda_powertools/utilities/parameters/dynamodb.py b/aws_lambda_powertools/utilities/parameters/dynamodb.py new file mode 100644 index 00000000000..4132697f0b9 --- /dev/null +++ b/aws_lambda_powertools/utilities/parameters/dynamodb.py @@ -0,0 +1,213 @@ +""" +Amazon DynamoDB parameter retrieval and caching utility +""" + + +from typing import Dict, Optional + +import boto3 +from boto3.dynamodb.conditions import Key +from botocore.config import Config + +from .base import BaseProvider + + +class DynamoDBProvider(BaseProvider): + """ + Amazon DynamoDB Parameter Provider + + Parameters + ---------- + table_name: str + Name of the DynamoDB table that stores parameters + key_attr: str, optional + Hash key for the DynamoDB table (default to 'id') + sort_attr: str, optional + Name of the DynamoDB table sort key (defaults to 'sk'), used only for get_multiple + value_attr: str, optional + Attribute that contains the values in the DynamoDB table (defaults to 'value') + config: botocore.config.Config, optional + Botocore configuration to pass during client initialization + + Example + ------- + **Retrieves a parameter value from a DynamoDB table** + + In this example, the DynamoDB table uses `id` as hash key and stores the value in the `value` + attribute. The parameter item looks like this: + + { "id": "my-parameters", "value": "Parameter value a" } + + >>> from aws_lambda_powertools.utilities.parameters import DynamoDBProvider + >>> ddb_provider = DynamoDBProvider("ParametersTable") + >>> + >>> value = ddb_provider.get("my-parameter") + >>> + >>> print(value) + My parameter value + + **Retrieves a parameter value from a DynamoDB table that has custom attribute names** + + >>> from aws_lambda_powertools.utilities.parameters import DynamoDBProvider + >>> ddb_provider = DynamoDBProvider( + ... "ParametersTable", + ... key_attr="my-id", + ... value_attr="my-value" + ... ) + >>> + >>> value = ddb_provider.get("my-parameter") + >>> + >>> print(value) + My parameter value + + **Retrieves a parameter value from a DynamoDB table in another AWS region** + + >>> from botocore.config import Config + >>> from aws_lambda_powertools.utilities.parameters import DynamoDBProvider + >>> + >>> config = Config(region_name="us-west-1") + >>> ddb_provider = DynamoDBProvider("ParametersTable", config=config) + >>> + >>> value = ddb_provider.get("my-parameter") + >>> + >>> print(value) + My parameter value + + **Retrieves a parameter value from a DynamoDB table passing options to the SDK call** + + >>> from aws_lambda_powertools.utilities.parameters import DynamoDBProvider + >>> ddb_provider = DynamoDBProvider("ParametersTable") + >>> + >>> value = ddb_provider.get("my-parameter", ConsistentRead=True) + >>> + >>> print(value) + My parameter value + + **Retrieves multiple values from a DynamoDB table** + + In this case, the provider will use a sort key to retrieve multiple values using a query under + the hood. This expects that the sort key is named `sk`. The DynamoDB table contains three items + looking like this: + + { "id": "my-parameters", "sk": "a", "value": "Parameter value a" } + { "id": "my-parameters", "sk": "b", "value": "Parameter value b" } + { "id": "my-parameters", "sk": "c", "value": "Parameter value c" } + + >>> from aws_lambda_powertools.utilities.parameters import DynamoDBProvider + >>> ddb_provider = DynamoDBProvider("ParametersTable") + >>> + >>> values = ddb_provider.get_multiple("my-parameters") + >>> + >>> for key, value in values.items(): + ... print(key, value) + a Parameter value a + b Parameter value b + c Parameter value c + + **Retrieves multiple values from a DynamoDB table that has custom attribute names** + + In this case, the provider will use a sort key to retrieve multiple values using a query under + the hood. + + >>> from aws_lambda_powertools.utilities.parameters import DynamoDBProvider + >>> ddb_provider = DynamoDBProvider( + ... "ParametersTable", + ... key_attr="my-id", + ... sort_attr="my-sort-key", + ... value_attr="my-value" + ... ) + >>> + >>> values = ddb_provider.get_multiple("my-parameters") + >>> + >>> for key, value in values.items(): + ... print(key, value) + a Parameter value a + b Parameter value b + c Parameter value c + + **Retrieves multiple values from a DynamoDB table passing options to the SDK calls** + + >>> from aws_lambda_powertools.utilities.parameters import DynamoDBProvider + >>> ddb_provider = DynamoDBProvider("ParametersTable") + >>> + >>> values = ddb_provider.get_multiple("my-parameters", ConsistentRead=True) + >>> + >>> for key, value in values.items(): + ... print(key, value) + a Parameter value a + b Parameter value b + c Parameter value c + """ + + table = None + key_attr = None + sort_attr = None + value_attr = None + + def __init__( + self, + table_name: str, + key_attr: str = "id", + sort_attr: str = "sk", + value_attr: str = "value", + config: Optional[Config] = None, + ): + """ + Initialize the DynamoDB client + """ + + config = config or Config() + self.table = boto3.resource("dynamodb", config=config).Table(table_name) + + self.key_attr = key_attr + self.sort_attr = sort_attr + self.value_attr = value_attr + + super().__init__() + + def _get(self, name: str, **sdk_options) -> str: + """ + Retrieve a parameter value from Amazon DynamoDB + + Parameters + ---------- + name: str + Name of the parameter + sdk_options: dict, optional + Dictionary of options that will be passed to the DynamoDB get_item API call + """ + + # Explicit arguments will take precedence over keyword arguments + sdk_options["Key"] = {self.key_attr: name} + + return self.table.get_item(**sdk_options)["Item"][self.value_attr] + + def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: + """ + Retrieve multiple parameter values from Amazon DynamoDB + + Parameters + ---------- + path: str + Path to retrieve the parameters + sdk_options: dict, optional + Dictionary of options that will be passed to the DynamoDB query API call + """ + + # Explicit arguments will take precedence over keyword arguments + sdk_options["KeyConditionExpression"] = Key(self.key_attr).eq(path) + + response = self.table.query(**sdk_options) + items = response.get("Items", []) + + # Keep querying while there are more items matching the partition key + while "LastEvaluatedKey" in response: + sdk_options["ExclusiveStartKey"] = response["LastEvaluatedKey"] + response = self.table.query(**sdk_options) + items.extend(response.get("Items", [])) + + retval = {} + for item in items: + retval[item[self.sort_attr]] = item[self.value_attr] + + return retval diff --git a/aws_lambda_powertools/utilities/parameters/exceptions.py b/aws_lambda_powertools/utilities/parameters/exceptions.py new file mode 100644 index 00000000000..1287568b463 --- /dev/null +++ b/aws_lambda_powertools/utilities/parameters/exceptions.py @@ -0,0 +1,11 @@ +""" +Parameter retrieval exceptions +""" + + +class GetParameterError(Exception): + """When a provider raises an exception on parameter retrieval""" + + +class TransformParameterError(Exception): + """When a provider fails to transform a parameter value""" diff --git a/aws_lambda_powertools/utilities/parameters/secrets.py b/aws_lambda_powertools/utilities/parameters/secrets.py new file mode 100644 index 00000000000..ee4585309fe --- /dev/null +++ b/aws_lambda_powertools/utilities/parameters/secrets.py @@ -0,0 +1,142 @@ +""" +AWS Secrets Manager parameter retrieval and caching utility +""" + + +from typing import Dict, Optional, Union + +import boto3 +from botocore.config import Config + +from .base import DEFAULT_PROVIDERS, BaseProvider + + +class SecretsProvider(BaseProvider): + """ + AWS Secrets Manager Parameter Provider + + Parameters + ---------- + config: botocore.config.Config, optional + Botocore configuration to pass during client initialization + + Example + ------- + **Retrieves a parameter value from Secrets Manager** + + >>> from aws_lambda_powertools.utilities.parameters import SecretsProvider + >>> secrets_provider = SecretsProvider() + >>> + >>> value secrets_provider.get("my-parameter") + >>> + >>> print(value) + My parameter value + + **Retrieves a parameter value from Secrets Manager in another AWS region** + + >>> from botocore.config import Config + >>> from aws_lambda_powertools.utilities.parameters import SecretsProvider + >>> + >>> config = Config(region_name="us-west-1") + >>> secrets_provider = SecretsProvider(config=config) + >>> + >>> value = secrets_provider.get("my-parameter") + >>> + >>> print(value) + My parameter value + + **Retrieves a parameter value from Secrets Manager passing options to the SDK call** + + >>> from aws_lambda_powertools.utilities.parameters import SecretsProvider + >>> secrets_provider = SecretsProvider() + >>> + >>> value = secrets_provider.get("my-parameter", VersionId="f658cac0-98a5-41d9-b993-8a76a7799194") + >>> + >>> print(value) + My parameter value + """ + + client = None + + def __init__(self, config: Optional[Config] = None): + """ + Initialize the Secrets Manager client + """ + + config = config or Config() + + self.client = boto3.client("secretsmanager", config=config) + + super().__init__() + + def _get(self, name: str, **sdk_options) -> str: + """ + Retrieve a parameter value from AWS Systems Manager Parameter Store + + Parameters + ---------- + name: str + Name of the parameter + sdk_options: dict + Dictionary of options that will be passed to the Secrets Manager get_secret_value API call + """ + + # Explicit arguments will take precedence over keyword arguments + sdk_options["SecretId"] = name + + return self.client.get_secret_value(**sdk_options)["SecretString"] + + def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: + """ + Retrieving multiple parameter values is not supported with AWS Secrets Manager + """ + raise NotImplementedError() + + +def get_secret(name: str, transform: Optional[str] = None, **sdk_options) -> Union[str, dict, bytes]: + """ + Retrieve a parameter value from AWS Secrets Manager + + Parameters + ---------- + name: str + Name of the parameter + transform: str, optional + Transforms the content from a JSON object ('json') or base64 binary string ('binary') + sdk_options: dict, optional + Dictionary of options that will be passed to the get_secret_value call + + Raises + ------ + GetParameterError + When the parameter provider fails to retrieve a parameter value for + a given name. + TransformParameterError + When the parameter provider fails to transform a parameter value. + + Example + ------- + **Retrieves a secret*** + + >>> from aws_lambda_powertools.utilities.parameters import get_secret + >>> + >>> get_secret("my-secret") + + **Retrieves a secret and transforms using a JSON deserializer*** + + >>> from aws_lambda_powertools.utilities.parameters import get_secret + >>> + >>> get_secret("my-secret", transform="json") + + **Retrieves a secret and passes custom arguments to the SDK** + + >>> from aws_lambda_powertools.utilities.parameters import get_secret + >>> + >>> get_secret("my-secret", VersionId="f658cac0-98a5-41d9-b993-8a76a7799194") + """ + + # Only create the provider if this function is called at least once + if "secrets" not in DEFAULT_PROVIDERS: + DEFAULT_PROVIDERS["secrets"] = SecretsProvider() + + return DEFAULT_PROVIDERS["secrets"].get(name, transform=transform, **sdk_options) diff --git a/aws_lambda_powertools/utilities/parameters/ssm.py b/aws_lambda_powertools/utilities/parameters/ssm.py new file mode 100644 index 00000000000..b458f8690d0 --- /dev/null +++ b/aws_lambda_powertools/utilities/parameters/ssm.py @@ -0,0 +1,248 @@ +""" +AWS SSM Parameter retrieval and caching utility +""" + + +from typing import Dict, Optional, Union + +import boto3 +from botocore.config import Config + +from .base import DEFAULT_PROVIDERS, BaseProvider + + +class SSMProvider(BaseProvider): + """ + AWS Systems Manager Parameter Store Provider + + Parameters + ---------- + config: botocore.config.Config, optional + Botocore configuration to pass during client initialization + + Example + ------- + **Retrieves a parameter value from Systems Manager Parameter Store** + + >>> from aws_lambda_powertools.utilities.parameters import SSMProvider + >>> ssm_provider = SSMProvider() + >>> + >>> value = ssm_provider.get("/my/parameter") + >>> + >>> print(value) + My parameter value + + **Retrieves a parameter value from Systems Manager Parameter Store in another AWS region** + + >>> from botocore.config import Config + >>> from aws_lambda_powertools.utilities.parameters import SSMProvider + >>> + >>> config = Config(region_name="us-west-1") + >>> ssm_provider = SSMProvider(config=config) + >>> + >>> value = ssm_provider.get("/my/parameter") + >>> + >>> print(value) + My parameter value + + **Retrieves multiple parameter values from Systems Manager Parameter Store using a path prefix** + + >>> from aws_lambda_powertools.utilities.parameters import SSMProvider + >>> ssm_provider = SSMProvider() + >>> + >>> values = ssm_provider.get_multiple("/my/path/prefix") + >>> + >>> for key, value in values.items(): + ... print(key, value) + /my/path/prefix/a Parameter value a + /my/path/prefix/b Parameter value b + /my/path/prefix/c Parameter value c + + **Retrieves multiple parameter values from Systems Manager Parameter Store passing options to the SDK call** + + >>> from aws_lambda_powertools.utilities.parameters import SSMProvider + >>> ssm_provider = SSMProvider() + >>> + >>> values = ssm_provider.get_multiple("/my/path/prefix", MaxResults=10) + >>> + >>> for key, value in values.items(): + ... print(key, value) + /my/path/prefix/a Parameter value a + /my/path/prefix/b Parameter value b + /my/path/prefix/c Parameter value c + """ + + client = None + + def __init__( + self, config: Optional[Config] = None, + ): + """ + Initialize the SSM Parameter Store client + """ + + config = config or Config() + self.client = boto3.client("ssm", config=config) + + super().__init__() + + def _get(self, name: str, decrypt: bool = False, **sdk_options) -> str: + """ + Retrieve a parameter value from AWS Systems Manager Parameter Store + + Parameters + ---------- + name: str + Parameter name + decrypt: bool, optional + If the parameter value should be decrypted + sdk_options: dict, optional + Dictionary of options that will be passed to the Parameter Store get_parameter API call + """ + + # Explicit arguments will take precedence over keyword arguments + sdk_options["Name"] = name + sdk_options["WithDecryption"] = decrypt + + return self.client.get_parameter(**sdk_options)["Parameter"]["Value"] + + def _get_multiple(self, path: str, decrypt: bool = False, recursive: bool = False, **sdk_options) -> Dict[str, str]: + """ + Retrieve multiple parameter values from AWS Systems Manager Parameter Store + + Parameters + ---------- + path: str + Path to retrieve the parameters + decrypt: bool, optional + If the parameter values should be decrypted + recursive: bool, optional + If this should retrieve the parameter values recursively or not + sdk_options: dict, optional + Dictionary of options that will be passed to the Parameter Store get_parameters_by_path API call + """ + + # Explicit arguments will take precedence over keyword arguments + sdk_options["Path"] = path + sdk_options["WithDecryption"] = decrypt + sdk_options["Recursive"] = recursive + + parameters = {} + for page in self.client.get_paginator("get_parameters_by_path").paginate(**sdk_options): + for parameter in page.get("Parameters", []): + # Standardize the parameter name + # The parameter name returned by SSM will contained the full path. + # However, for readability, we should return only the part after + # the path. + name = parameter["Name"] + if name.startswith(path): + name = name[len(path) :] + name = name.lstrip("/") + + parameters[name] = parameter["Value"] + + return parameters + + +def get_parameter(name: str, transform: Optional[str] = None, **sdk_options) -> Union[str, list, dict, bytes]: + """ + Retrieve a parameter value from AWS Systems Manager (SSM) Parameter Store + + Parameters + ---------- + name: str + Name of the parameter + transform: str, optional + Transforms the content from a JSON object ('json') or base64 binary string ('binary') + sdk_options: dict, optional + Dictionary of options that will be passed to the Parameter Store get_parameter API call + + Raises + ------ + GetParameterError + When the parameter provider fails to retrieve a parameter value for + a given name. + TransformParameterError + When the parameter provider fails to transform a parameter value. + + Example + ------- + **Retrieves a parameter value from Systems Manager Parameter Store** + + >>> from aws_lambda_powertools.utilities.parameters import get_parameter + >>> + >>> value = get_parameter("/my/parameter") + >>> + >>> print(value) + My parameter value + + **Retrieves a parameter value and decodes it using a Base64 decoder** + + >>> from aws_lambda_powertools.utilities.parameters import get_parameter + >>> + >>> value = get_parameter("/my/parameter", transform='binary') + >>> + >>> print(value) + My parameter value + """ + + # Only create the provider if this function is called at least once + if "ssm" not in DEFAULT_PROVIDERS: + DEFAULT_PROVIDERS["ssm"] = SSMProvider() + + return DEFAULT_PROVIDERS["ssm"].get(name, transform=transform) + + +def get_parameters( + path: str, transform: Optional[str] = None, recursive: bool = True, decrypt: bool = False, **sdk_options +) -> Union[Dict[str, str], Dict[str, dict], Dict[str, bytes]]: + """ + Retrieve multiple parameter values from AWS Systems Manager (SSM) Parameter Store + + Parameters + ---------- + path: str + Path to retrieve the parameters + transform: str, optional + Transforms the content from a JSON object ('json') or base64 binary string ('binary') + decrypt: bool, optional + If the parameter values should be decrypted + recursive: bool, optional + If this should retrieve the parameter values recursively or not, defaults to True + sdk_options: dict, optional + Dictionary of options that will be passed to the Parameter Store get_parameters_by_path API call + + Raises + ------ + GetParameterError + When the parameter provider fails to retrieve parameter values for + a given path. + TransformParameterError + When the parameter provider fails to transform a parameter value. + + Example + ------- + **Retrieves parameter values from Systems Manager Parameter Store** + + >>> from aws_lambda_powertools.utilities.parameters import get_parameter + >>> + >>> values = get_parameters("/my/path/prefix") + >>> + >>> for key, value in values.items(): + ... print(key, value) + /my/path/prefix/a Parameter value a + /my/path/prefix/b Parameter value b + /my/path/prefix/c Parameter value c + + **Retrieves parameter values and decodes them using a Base64 decoder** + + >>> from aws_lambda_powertools.utilities.parameters import get_parameter + >>> + >>> values = get_parameters("/my/path/prefix", transform='binary') + """ + + # Only create the provider if this function is called at least once + if "ssm" not in DEFAULT_PROVIDERS: + DEFAULT_PROVIDERS["ssm"] = SSMProvider() + + return DEFAULT_PROVIDERS["ssm"].get_multiple(path, transform=transform, recursive=recursive, decrypt=decrypt) diff --git a/docs/content/index.mdx b/docs/content/index.mdx index 9966a08deb3..e5c2688ecc7 100644 --- a/docs/content/index.mdx +++ b/docs/content/index.mdx @@ -18,6 +18,7 @@ Powertools is available in PyPi. You can use your favourite dependency managemen * [Logging](./core/logger) - Structured logging made easier, and decorator to enrich structured logging with key Lambda context details * [Metrics](./core/metrics) - Custom Metrics created asynchronously via CloudWatch Embedded Metric Format (EMF) * [Bring your own middleware](./utilities/middleware_factory) - Decorator factory to create your own middleware to run logic before, and after each Lambda invocation +* [Parameters utility](./utilities/parameters) - Retrieve parameter values from AWS Systems Manager Parameter Store, AWS Secrets Manager, or Amazon DynamoDB, and cache them for a specific amount of time ## Tenets diff --git a/docs/content/utilities/parameters.mdx b/docs/content/utilities/parameters.mdx new file mode 100644 index 00000000000..b40bfe2c885 --- /dev/null +++ b/docs/content/utilities/parameters.mdx @@ -0,0 +1,346 @@ +--- +title: Parameters +description: Utility +--- + +import Note from "../../src/components/Note" + +The parameters utility provides a way to retrieve parameter values from [AWS Systems Manager Parameter Store](https://docs.aws.amazon.com/systems-manager/latest/userguide/systems-manager-parameter-store.html), [AWS Secrets Manager](https://aws.amazon.com/secrets-manager/) or [Amazon DynamoDB](https://aws.amazon.com/dynamodb/). It also provides a base class to create your parameter provider implementation. + +**Key features** + +* Retrieve one or multiple parameters from the underlying provider +* Cache parameter values for a given amount of time (defaults to 5 seconds) +* Transform parameter values from JSON or base 64 encoded strings + +**IAM Permissions** + +This utility requires additional permissions to work as expected. See the table below: + +Provider | Function/Method | IAM Permission +------------------------------------------------- | ------------------------------------------------- | --------------------------------------------------------------------------------- +SSM Parameter Store | `get_parameter`, `SSMProvider.get` | `ssm:GetParameter` +SSM Parameter Store | `get_parameters`, `SSMProvider.get_multiple` | `ssm:GetParametersByPath` +Secrets Manager | `get_secret`, `SecretsManager.get` | `secretsmanager:GetSecretValue` +DynamoDB | `DynamoDBProvider.get` | `dynamodb:GetItem` +DynamoDB | `DynamoDBProvider.get_multiple` | `dynamodb:Query` + +## SSM Parameter Store + +You can retrieve a single parameter using `get_parameter` high-level function. For multiple parameters, you can use `get_parameters` and pass a path to retrieve them recursively. + +```python:title=ssm_parameter_store.py +from aws_lambda_powertools.utilities import parameters + +def handler(event, context): + # Retrieve a single parameter + value = parameters.get_parameter("/my/parameter") + + # Retrieve multiple parameters from a path prefix recursively + # This returns a dict with the parameter name as key + values = parameters.get_parameters("/my/path/prefix") + for k, v in values.items(): + print(f"{k}: {v}") +``` + +### SSMProvider class + +Alternatively, you can use the `SSMProvider` class, which give more flexibility, such as the ability to configure the underlying SDK client. + +This can be used to retrieve values from other regions, change the retry behavior, etc. + +```python:title=ssm_parameter_store.py +from aws_lambda_powertools.utilities import parameters +from botocore.config import Config + +config = Config(region_name="us-west-1") +ssm_provider = parameters.SSMProvider(config=config) + +def handler(event, context): + # Retrieve a single parameter + value = ssm_provider.get("/my/parameter") + + # Retrieve multiple parameters from a path prefix + values = ssm_provider.get_multiple("/my/path/prefix") + for k, v in values.items(): + print(f"{k}: {v}") +``` + +**Additional arguments** + +The AWS Systems Manager Parameter Store provider supports two additional arguments for the `get()` and `get_multiple()` methods: + +| Parameter | Default | Description | +|---------------|---------|-------------| +| **decrypt** | `False` | Will automatically decrypt the parameter. | +| **recursive** | `True` | For `get_multiple()` only, will fetch all parameter values recursively based on a path prefix. | + +**Example:** + +```python:title=ssm_parameter_store.py +from aws_lambda_powertools.utilities import parameters + +ssm_provider = parameters.SSMProvider() + +def handler(event, context): + decrypted_value = ssm_provider.get("/my/encrypted/parameter", decrypt=True) + + no_recursive_values = ssm_provider.get_multiple("/my/path/prefix", recursive=False) +``` + +## Secrets Manager + +For secrets stored in Secrets Manager, use `get_secret`. + +```python:title=secrets_manager.py +from aws_lambda_powertools.utilities import parameters + +def handler(event, context): + # Retrieve a single secret + value = parameters.get_secret("my-secret") +``` + +### SecretsProvider class + +Alternatively, you can use the `SecretsProvider` class, which give more flexibility, such as the ability to configure the underlying SDK client. + +This can be used to retrieve values from other regions, change the retry behavior, etc. + +```python:title=secrets_manager.py +from aws_lambda_powertools.utilities import parameters +from botocore.config import Config + +config = Config(region_name="us-west-1") +secrets_provider = parameters.SecretsProvider(config=config) + +def handler(event, context): + # Retrieve a single secret + value = secrets_provider.get("my-secret") +``` + +## DynamoDB + +To use the DynamoDB provider, you need to import and instantiate the `DynamoDBProvider` class. + +The DynamoDB Provider does not have any high-level functions, as it needs to know the name of the DynamoDB table containing the parameters. + +**DynamoDB table structure** + +When using the default options, if you want to retrieve only single parameters, your table should be structured as such, assuming a parameter named **my-parameter** with a value of **my-value**. The `id` attribute should be the [partition key](https://docs.aws.amazon.com/amazondynamodb/latest/developerguide/HowItWorks.CoreComponents.html#HowItWorks.CoreComponents.PrimaryKey) for that table. + +| `id` | `value` | +|--------------|----------| +| my-parameter | my-value | + +With this table, when you do a `dynamodb_provider.get("my-param")` call, this will return `my-value`. + +```python:title=dynamodb.py +from aws_lambda_powertools.utilities import parameters + +dynamodb_provider = parameters.DynamoDBProvider(table_name="my-table") + +def handler(event, context): + # Retrieve a value from DynamoDB + value = dynamodb_provider.get("my-parameter") +``` + +**Retrieve multiple values** + +If you want to be able to retrieve multiple parameters at once sharing the same `id`, your table needs to contain a sort key name `sk`. For example, if you want to retrieve multiple parameters having `my-hash-key` as ID: + +| `id` | `sk` | `value` | +|-------------|---------|------------| +| my-hash-key | param-a | my-value-a | +| my-hash-key | param-b | my-value-b | +| my-hash-key | param-c | my-value-c | + +With this table, when you do a `dynamodb_provider.get_multiple("my-hash-key")` call, you will receive the following dict as a response: + +``` +{ + "param-a": "my-value-a", + "param-b": "my-value-b", + "param-c": "my-value-c" +} +``` + +**Example:** + +```python:title="dynamodb_multiple.py +from aws_lambda_powertools.utilities import parameters + +dynamodb_provider = parameters.DynamoDBProvider(table_name="my-table") + +def handler(event, context): + # Retrieve multiple values by performing a Query on the DynamoDB table + # This returns a dict with the sort key attribute as dict key. + values = dynamodb_provider.get_multiple("my-hash-key") + for k, v in values.items(): + print(f"{k}: {v}") +``` + +**Additional arguments** + +The Amazon DynamoDB provider supports four additional arguments at initialization: + +| Parameter | Mandatory | Default | Description | +|----------------|-----------|---------|-------------| +| **table_name** | **Yes** | *(N/A)* | Name of the DynamoDB table containing the parameter values. +| **key_attr** | No | `id` | Hash key for the DynamoDB table. +| **sort_attr** | No | `sk` | Range key for the DynamoDB table. You don't need to set this if you don't use the `get_multiple()` method. +| **value_attr** | No | `value` | Name of the attribute containing the parameter value. + +```python:title=dynamodb.py +from aws_lambda_powertools.utilities import parameters + +dynamodb_provider = parameters.DynamoDBProvider( + table_name="my-table", + key_attr="MyKeyAttr", + sort_attr="MySortAttr", + value_attr="MyvalueAttr" +) + +def handler(event, context): + value = dynamodb_provider.get("my-parameter") +``` + +## Create your own provider + +You can create your own custom parameter store provider by inheriting the `BaseProvider` class, and implementing both `_get()` and `_get_multiple()` methods to retrieve a single, or multiple parameters from your custom store. + +All transformation and caching logic is handled by the `get()` and `get_multiple()` methods from the base provider class. + +Here is an example implementation using S3 as a custom parameter store: + +```python:title=custom_provider.py +import copy + +from aws_lambda_powertools.utilities import BaseProvider +import boto3 + +class S3Provider(BaseProvider): + bucket_name = None + client = None + + def __init__(self, bucket_name: str): + # Initialize the client to your custom parameter store + # E.g.: + + self.bucket_name = bucket_name + self.client = boto3.client("s3") + + def _get(self, name: str, **sdk_options) -> str: + # Retrieve a single value + # E.g.: + + sdk_options["Bucket"] = self.bucket_name + sdk_options["Key"] = name + + response = self.client.get_object(**sdk_options) + return + + def _get_multiple(self, path: str, **sdk_options) -> Dict[str, str]: + # Retrieve multiple values + # E.g.: + + list_sdk_options = copy.deepcopy(sdk_options) + + list_sdk_options["Bucket"] = self.bucket_name + list_sdk_options["Prefix"] = path + + list_response = self.client.list_objects_v2(**list_sdk_options) + + parameters = {} + + for obj in list_response.get("Contents", []): + get_sdk_options = copy.deepcopy(sdk_options) + + get_sdk_options["Bucket"] = self.bucket_name + get_sdk_options["Key"] = obj["Key"] + + get_response = self.client.get_object(**get_sdk_options) + + parameters[obj["Key"]] = get_response["Body"].read().decode() + + return parameters + +``` + +## Transform values + +For parameters stored in JSON or Base64 format, you can use the `transform` argument for deserialization - The `transform` argument is available across all providers, including the high level functions. + +```python:title=transform.py +from aws_lambda_powertools.utilities import parameters + +ssm_provider = parameters.SSMProvider() + +def handler(event, context): + # Transform a JSON string + value_from_json = ssm_provider.get("/my/json/parameter", transform="json") + + # Transform a Base64 encoded string + value_from_binary = ssm_provider.get("/my/binary/parameter", transform="binary") +``` + +You can also use the `transform` argument with high-level functions: + +```python:title=transform.py +from aws_lambda_powertools.utilities import parameters + +def handler(event, context): + value_from_json = parameters.get_parameter("/my/json/parameter", transform="json") +``` + +### Partial transform failures with `get_multiple()` + +If you use `transform` with `get_multiple()`, you can have a single malformed parameter value. To prevent failing the entire request, the method will return a `None` value for the parameters that failed to transform. + +You can override this by setting the `raise_on_transform_error` argument to `True`. If you do so, a single transform error will raise a `TransformParameterError` exception. + +For example, if you have three parameters (*/param/a*, */param/b* and */param/c*) but */param/c* is malformed: + +```python:title=partial_failures.py +from aws_lambda_powertools.utilities import parameters + +ssm_provider = parameters.SSMProvider() + +def handler(event, context): + # This will display: + # /param/a: [some value] + # /param/b: [some value] + # /param/c: None + values = ssm_provider.get_multiple("/param", transform="json") + for k, v in values.items(): + print(f"{k}: {v}") + + # This will raise a TransformParameterError exception + values = ssm_provider.get_multiple("/param", transform="json", raise_on_transform_error=True) +``` + +## Additional SDK arguments + +You can use arbitrary keyword arguments to pass it directly to the underlying SDK method. + +```python:title=ssm_parameter_store.py +from aws_lambda_powertools.utilities import parameters + +secrets_provider = parameters.SecretsProvider() + +def handler(event, context): + # The 'VersionId' argument will be passed to the underlying get_secret_value() call. + value = secrets_provider.get("my-secret", VersionId="e62ec170-6b01-48c7-94f3-d7497851a8d2") +``` + +Here is the mapping between this utility's functions and methods and the underlying SDK: + +| Provider | Function/Method | Client name | Function name | +|---------------------|---------------------------------|-------------|---------------| +| SSM Parameter Store | `get_parameter` | `ssm` | [get_parameter](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ssm.html#SSM.Client.get_parameter) | +| SSM Parameter Store | `get_parameters` | `ssm` | [get_parameters_by_path](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ssm.html#SSM.Client.get_parameters_by_path) | +| SSM Parameter Store | `SSMProvider.get` | `ssm` | [get_parameter](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ssm.html#SSM.Client.get_parameter) | +| SSM Parameter Store | `SSMProvider.get_multiple` | `ssm` | [get_parameters_by_path](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/ssm.html#SSM.Client.get_parameters_by_path) | +| Secrets Manager | `get_secret` | `secretsmanager` | [get_secret_value](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/secretsmanager.html#SecretsManager.Client.get_secret_value) | +| Secrets Manager | `SecretsManager.get` | `secretsmanager` | [get_secret_value](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/secretsmanager.html#SecretsManager.Client.get_secret_value) | +| DynamoDB | `DynamoDBProvider.get` | `dynamodb` ([Table resource](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#table)) | [get_item](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Table.get_item) +| DynamoDB | `DynamoDBProvider.get_multiple` | `dynamodb` ([Table resource](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#table)) | [query](https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/dynamodb.html#DynamoDB.Table.query) diff --git a/docs/gatsby-config.js b/docs/gatsby-config.js index f62c2d4e30e..d518ee8e715 100644 --- a/docs/gatsby-config.js +++ b/docs/gatsby-config.js @@ -31,6 +31,7 @@ module.exports = { ], 'Utilities': [ 'utilities/middleware_factory', + 'utilities/parameters', ], }, navConfig: { diff --git a/poetry.lock b/poetry.lock index a5570d1445c..e7b7cdff1db 100644 --- a/poetry.lock +++ b/poetry.lock @@ -162,7 +162,7 @@ typed-ast = ">=1.4.0" d = ["aiohttp (>=3.3.2)", "aiohttp-cors"] [[package]] -category = "dev" +category = "main" description = "The AWS SDK for Python" name = "boto3" optional = false @@ -816,7 +816,7 @@ security = ["pyOpenSSL (>=0.14)", "cryptography (>=1.3.4)"] socks = ["PySocks (>=1.5.6,<1.5.7 || >1.5.7)", "win-inet-pton"] [[package]] -category = "dev" +category = "main" description = "An Amazon S3 Transfer Manager" name = "s3transfer" optional = false @@ -951,7 +951,6 @@ multidict = ">=4.0" [[package]] category = "main" description = "Backport of pathlib-compatible object wrapper for zip files" -marker = "python_version < \"3.8\"" name = "zipp" optional = false python-versions = ">=3.6" @@ -962,7 +961,8 @@ docs = ["sphinx", "jaraco.packaging (>=3.2)", "rst.linker (>=1.9)"] testing = ["jaraco.itertools", "func-timeout"] [metadata] -content-hash = "a2760fd5f04b7f1841509fcbcb4ccdaf35d92d1395627787e4a11f391a0597d2" +content-hash = "18607a712e4a4a05de7350ecbcf26327a4fb45bb8609dc7f3d19b7610c2faafc" +lock-version = "1.0" python-versions = "^3.6" [metadata.files] diff --git a/pyproject.toml b/pyproject.toml index 0f0d9ac6ebe..3f823717a29 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ license = "MIT-0" python = "^3.6" aws-xray-sdk = "^2.5.0" fastjsonschema = "~=2.14.4" +boto3 = "^1.12" [tool.poetry.dev-dependencies] coverage = {extras = ["toml"], version = "^5.0.3"} diff --git a/tests/functional/test_utilities_parameters.py b/tests/functional/test_utilities_parameters.py new file mode 100644 index 00000000000..7a0677b2197 --- /dev/null +++ b/tests/functional/test_utilities_parameters.py @@ -0,0 +1,1470 @@ +import base64 +import json +import random +import string +from datetime import datetime, timedelta +from typing import Dict + +import pytest +from boto3.dynamodb.conditions import Key +from botocore import stub +from botocore.config import Config + +from aws_lambda_powertools.utilities import parameters +from aws_lambda_powertools.utilities.parameters.base import BaseProvider, ExpirableValue + + +@pytest.fixture(scope="function") +def mock_name(): + # Parameter name must match [a-zA-Z0-9_.-/]+ + return "".join(random.choices(string.ascii_letters + string.digits + "_.-/", k=random.randrange(3, 200))) + + +@pytest.fixture(scope="function") +def mock_value(): + # Standard parameters can be up to 4 KB + return "".join(random.choices(string.printable, k=random.randrange(100, 4000))) + + +@pytest.fixture(scope="function") +def mock_version(): + return random.randrange(1, 1000) + + +@pytest.fixture(scope="module") +def config(): + return Config(region_name="us-east-1") + + +def test_dynamodb_provider_get(mock_name, mock_value, config): + """ + Test DynamoDBProvider.get() with a non-cached value + """ + + table_name = "TEST_TABLE" + + # Create a new provider + provider = parameters.DynamoDBProvider(table_name, config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.table.meta.client) + response = {"Item": {"id": {"S": mock_name}, "value": {"S": mock_value}}} + expected_params = {"TableName": table_name, "Key": {"id": mock_name}} + stubber.add_response("get_item", response, expected_params) + stubber.activate() + + try: + value = provider.get(mock_name) + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_dynamodb_provider_get_default_config(monkeypatch, mock_name, mock_value): + """ + Test DynamoDBProvider.get() without setting a config + """ + + monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1") + + table_name = "TEST_TABLE" + + # Create a new provider + provider = parameters.DynamoDBProvider(table_name) + + # Stub the boto3 client + stubber = stub.Stubber(provider.table.meta.client) + response = {"Item": {"id": {"S": mock_name}, "value": {"S": mock_value}}} + expected_params = {"TableName": table_name, "Key": {"id": mock_name}} + stubber.add_response("get_item", response, expected_params) + stubber.activate() + + try: + value = provider.get(mock_name) + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_dynamodb_provider_get_cached(mock_name, mock_value, config): + """ + Test DynamoDBProvider.get() with a cached value + """ + + table_name = "TEST_TABLE" + + # Create a new provider + provider = parameters.DynamoDBProvider(table_name, config=config) + + # Inject value in the internal store + provider.store[(mock_name, None)] = ExpirableValue(mock_value, datetime.now() + timedelta(seconds=60)) + + # Stub the boto3 client + stubber = stub.Stubber(provider.table.meta.client) + stubber.activate() + + try: + value = provider.get(mock_name) + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_dynamodb_provider_get_expired(mock_name, mock_value, config): + """ + Test DynamoDBProvider.get() with a cached but expired value + """ + + table_name = "TEST_TABLE" + + # Create a new provider + provider = parameters.DynamoDBProvider(table_name, config=config) + + # Inject value in the internal store + provider.store[(mock_name, None)] = ExpirableValue(mock_value, datetime.now() - timedelta(seconds=60)) + + # Stub the boto3 client + stubber = stub.Stubber(provider.table.meta.client) + response = {"Item": {"id": {"S": mock_name}, "value": {"S": mock_value}}} + expected_params = {"TableName": table_name, "Key": {"id": mock_name}} + stubber.add_response("get_item", response, expected_params) + stubber.activate() + + try: + value = provider.get(mock_name) + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_dynamodb_provider_get_sdk_options(mock_name, mock_value, config): + """ + Test DynamoDBProvider.get() with SDK options + """ + + table_name = "TEST_TABLE" + + # Create a new provider + provider = parameters.DynamoDBProvider(table_name, config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.table.meta.client) + response = {"Item": {"id": {"S": mock_name}, "value": {"S": mock_value}}} + expected_params = {"TableName": table_name, "Key": {"id": mock_name}, "ConsistentRead": True} + stubber.add_response("get_item", response, expected_params) + stubber.activate() + + try: + value = provider.get(mock_name, ConsistentRead=True) + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_dynamodb_provider_get_sdk_options_overwrite(mock_name, mock_value, config): + """ + Test DynamoDBProvider.get() with SDK options that should be overwritten + """ + + table_name = "TEST_TABLE" + + # Create a new provider + provider = parameters.DynamoDBProvider(table_name, config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.table.meta.client) + response = {"Item": {"id": {"S": mock_name}, "value": {"S": mock_value}}} + expected_params = {"TableName": table_name, "Key": {"id": mock_name}} + stubber.add_response("get_item", response, expected_params) + stubber.activate() + + try: + value = provider.get(mock_name, Key="THIS_SHOULD_BE_OVERWRITTEN") + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_dynamodb_provider_get_multiple(mock_name, mock_value, config): + """ + Test DynamoDBProvider.get_multiple() with a non-cached path + """ + + mock_param_names = ["A", "B", "C"] + table_name = "TEST_TABLE" + + # Create a new provider + provider = parameters.DynamoDBProvider(table_name, config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.table.meta.client) + response = { + "Items": [ + {"id": {"S": mock_name}, "sk": {"S": name}, "value": {"S": f"{mock_value}/{name}"}} + for name in mock_param_names + ] + } + expected_params = {"TableName": table_name, "KeyConditionExpression": Key("id").eq(mock_name)} + stubber.add_response("query", response, expected_params) + stubber.activate() + + try: + values = provider.get_multiple(mock_name) + + stubber.assert_no_pending_responses() + + assert len(values) == len(mock_param_names) + for name in mock_param_names: + assert name in values + assert values[name] == f"{mock_value}/{name}" + finally: + stubber.deactivate() + + +def test_dynamodb_provider_get_multiple_next_token(mock_name, mock_value, config): + """ + Test DynamoDBProvider.get_multiple() with a non-cached path + """ + + mock_param_names = ["A", "B", "C"] + table_name = "TEST_TABLE" + + # Create a new provider + provider = parameters.DynamoDBProvider(table_name, config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.table.meta.client) + + # First call + response = { + "Items": [ + {"id": {"S": mock_name}, "sk": {"S": name}, "value": {"S": f"{mock_value}/{name}"}} + for name in mock_param_names[:1] + ], + "LastEvaluatedKey": {"id": {"S": mock_name}, "sk": {"S": mock_param_names[0]}}, + } + expected_params = {"TableName": table_name, "KeyConditionExpression": Key("id").eq(mock_name)} + stubber.add_response("query", response, expected_params) + + # Second call + response = { + "Items": [ + {"id": {"S": mock_name}, "sk": {"S": name}, "value": {"S": f"{mock_value}/{name}"}} + for name in mock_param_names[1:] + ] + } + expected_params = { + "TableName": table_name, + "KeyConditionExpression": Key("id").eq(mock_name), + "ExclusiveStartKey": {"id": mock_name, "sk": mock_param_names[0]}, + } + stubber.add_response("query", response, expected_params) + stubber.activate() + + try: + values = provider.get_multiple(mock_name) + + stubber.assert_no_pending_responses() + + assert len(values) == len(mock_param_names) + for name in mock_param_names: + assert name in values + assert values[name] == f"{mock_value}/{name}" + finally: + stubber.deactivate() + + +def test_dynamodb_provider_get_multiple_sdk_options(mock_name, mock_value, config): + """ + Test DynamoDBProvider.get_multiple() with custom SDK options + """ + + mock_param_names = ["A", "B", "C"] + table_name = "TEST_TABLE" + + # Create a new provider + provider = parameters.DynamoDBProvider(table_name, config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.table.meta.client) + response = { + "Items": [ + {"id": {"S": mock_name}, "sk": {"S": name}, "value": {"S": f"{mock_value}/{name}"}} + for name in mock_param_names + ] + } + expected_params = { + "TableName": table_name, + "KeyConditionExpression": Key("id").eq(mock_name), + "ConsistentRead": True, + } + stubber.add_response("query", response, expected_params) + stubber.activate() + + try: + values = provider.get_multiple(mock_name, ConsistentRead=True) + + stubber.assert_no_pending_responses() + + assert len(values) == len(mock_param_names) + for name in mock_param_names: + assert name in values + assert values[name] == f"{mock_value}/{name}" + finally: + stubber.deactivate() + + +def test_dynamodb_provider_get_multiple_sdk_options_overwrite(mock_name, mock_value, config): + """ + Test DynamoDBProvider.get_multiple() with custom SDK options that should be overwritten + """ + + mock_param_names = ["A", "B", "C"] + table_name = "TEST_TABLE" + + # Create a new provider + provider = parameters.DynamoDBProvider(table_name, config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.table.meta.client) + response = { + "Items": [ + {"id": {"S": mock_name}, "sk": {"S": name}, "value": {"S": f"{mock_value}/{name}"}} + for name in mock_param_names + ] + } + expected_params = { + "TableName": table_name, + "KeyConditionExpression": Key("id").eq(mock_name), + } + stubber.add_response("query", response, expected_params) + stubber.activate() + + try: + values = provider.get_multiple(mock_name, KeyConditionExpression="THIS_SHOULD_BE_OVERWRITTEN") + + stubber.assert_no_pending_responses() + + assert len(values) == len(mock_param_names) + for name in mock_param_names: + assert name in values + assert values[name] == f"{mock_value}/{name}" + finally: + stubber.deactivate() + + +def test_ssm_provider_get(mock_name, mock_value, mock_version, config): + """ + Test SSMProvider.get() with a non-cached value + """ + + # Create a new provider + provider = parameters.SSMProvider(config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + response = { + "Parameter": { + "Name": mock_name, + "Type": "String", + "Value": mock_value, + "Version": mock_version, + "Selector": f"{mock_name}:{mock_version}", + "SourceResult": "string", + "LastModifiedDate": datetime(2015, 1, 1), + "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{mock_name}", + } + } + expected_params = {"Name": mock_name, "WithDecryption": False} + stubber.add_response("get_parameter", response, expected_params) + stubber.activate() + + try: + value = provider.get(mock_name) + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_ssm_provider_get_default_config(monkeypatch, mock_name, mock_value, mock_version): + """ + Test SSMProvider.get() without specifying the config + """ + + monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1") + + # Create a new provider + provider = parameters.SSMProvider() + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + response = { + "Parameter": { + "Name": mock_name, + "Type": "String", + "Value": mock_value, + "Version": mock_version, + "Selector": f"{mock_name}:{mock_version}", + "SourceResult": "string", + "LastModifiedDate": datetime(2015, 1, 1), + "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{mock_name}", + } + } + expected_params = {"Name": mock_name, "WithDecryption": False} + stubber.add_response("get_parameter", response, expected_params) + stubber.activate() + + try: + value = provider.get(mock_name) + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_ssm_provider_get_cached(mock_name, mock_value, config): + """ + Test SSMProvider.get() with a cached value + """ + + # Create a new provider + provider = parameters.SSMProvider(config=config) + + # Inject value in the internal store + provider.store[(mock_name, None)] = ExpirableValue(mock_value, datetime.now() + timedelta(seconds=60)) + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + stubber.activate() + + try: + value = provider.get(mock_name) + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_ssm_provider_get_expired(mock_name, mock_value, mock_version, config): + """ + Test SSMProvider.get() with a cached but expired value + """ + + # Create a new provider + provider = parameters.SSMProvider(config=config) + + # Inject value in the internal store + provider.store[(mock_name, None)] = ExpirableValue(mock_value, datetime.now() - timedelta(seconds=60)) + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + response = { + "Parameter": { + "Name": mock_name, + "Type": "String", + "Value": mock_value, + "Version": mock_version, + "Selector": f"{mock_name}:{mock_version}", + "SourceResult": "string", + "LastModifiedDate": datetime(2015, 1, 1), + "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{mock_name}", + } + } + expected_params = {"Name": mock_name, "WithDecryption": False} + stubber.add_response("get_parameter", response, expected_params) + stubber.activate() + + try: + value = provider.get(mock_name) + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_ssm_provider_get_sdk_options_overwrite(mock_name, mock_value, mock_version, config): + """ + Test SSMProvider.get() with custom SDK options overwritten + """ + + # Create a new provider + provider = parameters.SSMProvider(config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + response = { + "Parameter": { + "Name": mock_name, + "Type": "String", + "Value": mock_value, + "Version": mock_version, + "Selector": f"{mock_name}:{mock_version}", + "SourceResult": "string", + "LastModifiedDate": datetime(2015, 1, 1), + "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{mock_name}", + } + } + expected_params = {"Name": mock_name, "WithDecryption": False} + stubber.add_response("get_parameter", response, expected_params) + stubber.activate() + + try: + value = provider.get(mock_name, Name="THIS_SHOULD_BE_OVERWRITTEN", WithDecryption=True) + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_ssm_provider_get_multiple(mock_name, mock_value, mock_version, config): + """ + Test SSMProvider.get_multiple() with a non-cached path + """ + + mock_param_names = ["A", "B", "C"] + + # Create a new provider + provider = parameters.SSMProvider(config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + response = { + "Parameters": [ + { + "Name": f"{mock_name}/{name}", + "Type": "String", + "Value": f"{mock_value}/{name}", + "Version": mock_version, + "Selector": f"{mock_name}/{name}:{mock_version}", + "SourceResult": "string", + "LastModifiedDate": datetime(2015, 1, 1), + "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{mock_name}/{name}", + } + for name in mock_param_names + ] + } + expected_params = {"Path": mock_name, "Recursive": False, "WithDecryption": False} + stubber.add_response("get_parameters_by_path", response, expected_params) + stubber.activate() + + try: + values = provider.get_multiple(mock_name) + + stubber.assert_no_pending_responses() + + assert len(values) == len(mock_param_names) + for name in mock_param_names: + assert name in values + assert values[name] == f"{mock_value}/{name}" + finally: + stubber.deactivate() + + +def test_ssm_provider_get_multiple_different_path(mock_name, mock_value, mock_version, config): + """ + Test SSMProvider.get_multiple() with a non-cached path and names that don't start with the path + """ + + mock_param_names = ["A", "B", "C"] + + # Create a new provider + provider = parameters.SSMProvider(config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + response = { + "Parameters": [ + { + "Name": f"{name}", + "Type": "String", + "Value": f"{mock_value}/{name}", + "Version": mock_version, + "Selector": f"{mock_name}/{name}:{mock_version}", + "SourceResult": "string", + "LastModifiedDate": datetime(2015, 1, 1), + "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{mock_name}/{name}", + } + for name in mock_param_names + ] + } + expected_params = {"Path": mock_name, "Recursive": False, "WithDecryption": False} + stubber.add_response("get_parameters_by_path", response, expected_params) + stubber.activate() + + try: + values = provider.get_multiple(mock_name) + + stubber.assert_no_pending_responses() + + assert len(values) == len(mock_param_names) + for name in mock_param_names: + assert name in values + assert values[name] == f"{mock_value}/{name}" + finally: + stubber.deactivate() + + +def test_ssm_provider_get_multiple_next_token(mock_name, mock_value, mock_version, config): + """ + Test SSMProvider.get_multiple() with a non-cached path with multiple calls + """ + + mock_param_names = ["A", "B", "C"] + + # Create a new provider + provider = parameters.SSMProvider(config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + + # First call + response = { + "Parameters": [ + { + "Name": f"{mock_name}/{name}", + "Type": "String", + "Value": f"{mock_value}/{name}", + "Version": mock_version, + "Selector": f"{mock_name}/{name}:{mock_version}", + "SourceResult": "string", + "LastModifiedDate": datetime(2015, 1, 1), + "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{mock_name}/{name}", + } + for name in mock_param_names[:1] + ], + "NextToken": "next_token", + } + expected_params = {"Path": mock_name, "Recursive": False, "WithDecryption": False} + stubber.add_response("get_parameters_by_path", response, expected_params) + + # Second call + response = { + "Parameters": [ + { + "Name": f"{mock_name}/{name}", + "Type": "String", + "Value": f"{mock_value}/{name}", + "Version": mock_version, + "Selector": f"{mock_name}/{name}:{mock_version}", + "SourceResult": "string", + "LastModifiedDate": datetime(2015, 1, 1), + "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{mock_name}/{name}", + } + for name in mock_param_names[1:] + ] + } + expected_params = {"Path": mock_name, "Recursive": False, "WithDecryption": False, "NextToken": "next_token"} + stubber.add_response("get_parameters_by_path", response, expected_params) + stubber.activate() + + try: + values = provider.get_multiple(mock_name) + + stubber.assert_no_pending_responses() + + assert len(values) == len(mock_param_names) + for name in mock_param_names: + assert name in values + assert values[name] == f"{mock_value}/{name}" + finally: + stubber.deactivate() + + +def test_ssm_provider_get_multiple_sdk_options(mock_name, mock_value, mock_version, config): + """ + Test SSMProvider.get_multiple() with SDK options + """ + + mock_param_names = ["A", "B", "C"] + + # Create a new provider + provider = parameters.SSMProvider(config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + response = { + "Parameters": [ + { + "Name": f"{mock_name}/{name}", + "Type": "String", + "Value": f"{mock_value}/{name}", + "Version": mock_version, + "Selector": f"{mock_name}/{name}:{mock_version}", + "SourceResult": "string", + "LastModifiedDate": datetime(2015, 1, 1), + "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{mock_name}/{name}", + } + for name in mock_param_names + ] + } + expected_params = {"Path": mock_name, "Recursive": False, "WithDecryption": False, "MaxResults": 10} + stubber.add_response("get_parameters_by_path", response, expected_params) + stubber.activate() + + try: + values = provider.get_multiple(mock_name, MaxResults=10) + + stubber.assert_no_pending_responses() + + assert len(values) == len(mock_param_names) + for name in mock_param_names: + assert name in values + assert values[name] == f"{mock_value}/{name}" + finally: + stubber.deactivate() + + +def test_ssm_provider_get_multiple_sdk_options_overwrite(mock_name, mock_value, mock_version, config): + """ + Test SSMProvider.get_multiple() with SDK options overwritten + """ + + mock_param_names = ["A", "B", "C"] + + # Create a new provider + provider = parameters.SSMProvider(config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + response = { + "Parameters": [ + { + "Name": f"{mock_name}/{name}", + "Type": "String", + "Value": f"{mock_value}/{name}", + "Version": mock_version, + "Selector": f"{mock_name}/{name}:{mock_version}", + "SourceResult": "string", + "LastModifiedDate": datetime(2015, 1, 1), + "ARN": f"arn:aws:ssm:us-east-2:111122223333:parameter/{mock_name}/{name}", + } + for name in mock_param_names + ] + } + expected_params = {"Path": mock_name, "Recursive": False, "WithDecryption": False} + stubber.add_response("get_parameters_by_path", response, expected_params) + stubber.activate() + + try: + values = provider.get_multiple( + mock_name, Path="THIS_SHOULD_BE_OVERWRITTEN", Recursive=False, WithDecryption=True, + ) + + stubber.assert_no_pending_responses() + + assert len(values) == len(mock_param_names) + for name in mock_param_names: + assert name in values + assert values[name] == f"{mock_value}/{name}" + finally: + stubber.deactivate() + + +def test_secrets_provider_get(mock_name, mock_value, config): + """ + Test SecretsProvider.get() with a non-cached value + """ + + # Create a new provider + provider = parameters.SecretsProvider(config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + response = { + "ARN": f"arn:aws:secretsmanager:us-east-1:132456789012:secret/{mock_name}", + "Name": mock_name, + "VersionId": "7a9155b8-2dc9-466e-b4f6-5bc46516c84d", + "SecretString": mock_value, + "CreatedDate": datetime(2015, 1, 1), + } + expected_params = {"SecretId": mock_name} + stubber.add_response("get_secret_value", response, expected_params) + stubber.activate() + + try: + value = provider.get(mock_name) + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_secrets_provider_get_default_config(monkeypatch, mock_name, mock_value): + """ + Test SecretsProvider.get() without specifying a config + """ + + monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1") + + # Create a new provider + provider = parameters.SecretsProvider() + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + response = { + "ARN": f"arn:aws:secretsmanager:us-east-1:132456789012:secret/{mock_name}", + "Name": mock_name, + "VersionId": "7a9155b8-2dc9-466e-b4f6-5bc46516c84d", + "SecretString": mock_value, + "CreatedDate": datetime(2015, 1, 1), + } + expected_params = {"SecretId": mock_name} + stubber.add_response("get_secret_value", response, expected_params) + stubber.activate() + + try: + value = provider.get(mock_name) + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_secrets_provider_get_cached(mock_name, mock_value, config): + """ + Test SecretsProvider.get() with a cached value + """ + + # Create a new provider + provider = parameters.SecretsProvider(config=config) + + # Inject value in the internal store + provider.store[(mock_name, None)] = ExpirableValue(mock_value, datetime.now() + timedelta(seconds=60)) + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + stubber.activate() + + try: + value = provider.get(mock_name) + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_secrets_provider_get_expired(mock_name, mock_value, config): + """ + Test SecretsProvider.get() with a cached but expired value + """ + + # Create a new provider + provider = parameters.SecretsProvider(config=config) + + # Inject value in the internal store + provider.store[(mock_name, None)] = ExpirableValue(mock_value, datetime.now() - timedelta(seconds=60)) + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + response = { + "ARN": f"arn:aws:secretsmanager:us-east-1:132456789012:secret/{mock_name}", + "Name": mock_name, + "VersionId": "7a9155b8-2dc9-466e-b4f6-5bc46516c84d", + "SecretString": mock_value, + "CreatedDate": datetime(2015, 1, 1), + } + expected_params = {"SecretId": mock_name} + stubber.add_response("get_secret_value", response, expected_params) + stubber.activate() + + try: + value = provider.get(mock_name) + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_secrets_provider_get_sdk_options(mock_name, mock_value, config): + """ + Test SecretsProvider.get() with custom SDK options + """ + + # Create a new provider + provider = parameters.SecretsProvider(config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + response = { + "ARN": f"arn:aws:secretsmanager:us-east-1:132456789012:secret/{mock_name}", + "Name": mock_name, + "VersionId": "7a9155b8-2dc9-466e-b4f6-5bc46516c84d", + "SecretString": mock_value, + "CreatedDate": datetime(2015, 1, 1), + } + expected_params = {"SecretId": mock_name, "VersionId": "7a9155b8-2dc9-466e-b4f6-5bc46516c84d"} + stubber.add_response("get_secret_value", response, expected_params) + stubber.activate() + + try: + value = provider.get(mock_name, VersionId="7a9155b8-2dc9-466e-b4f6-5bc46516c84d") + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_secrets_provider_get_sdk_options_overwrite(mock_name, mock_value, config): + """ + Test SecretsProvider.get() with custom SDK options overwritten + """ + + # Create a new provider + provider = parameters.SecretsProvider(config=config) + + # Stub the boto3 client + stubber = stub.Stubber(provider.client) + response = { + "ARN": f"arn:aws:secretsmanager:us-east-1:132456789012:secret/{mock_name}", + "Name": mock_name, + "VersionId": "7a9155b8-2dc9-466e-b4f6-5bc46516c84d", + "SecretString": mock_value, + "CreatedDate": datetime(2015, 1, 1), + } + expected_params = {"SecretId": mock_name} + stubber.add_response("get_secret_value", response, expected_params) + stubber.activate() + + try: + value = provider.get(mock_name, SecretId="THIS_SHOULD_BE_OVERWRITTEN") + + assert value == mock_value + stubber.assert_no_pending_responses() + finally: + stubber.deactivate() + + +def test_base_provider_get_exception(mock_name): + """ + Test BaseProvider.get() that raises an exception + """ + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + assert name == mock_name + raise Exception("test exception raised") + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + raise NotImplementedError() + + provider = TestProvider() + + with pytest.raises(parameters.GetParameterError) as excinfo: + provider.get(mock_name) + + assert "test exception raised" in str(excinfo) + + +def test_base_provider_get_multiple_exception(mock_name): + """ + Test BaseProvider.get_multiple() that raises an exception + """ + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise NotImplementedError() + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + assert path == mock_name + raise Exception("test exception raised") + + provider = TestProvider() + + with pytest.raises(parameters.GetParameterError) as excinfo: + provider.get_multiple(mock_name) + + assert "test exception raised" in str(excinfo) + + +def test_base_provider_get_transform_json(mock_name, mock_value): + """ + Test BaseProvider.get() with a json transform + """ + + mock_data = json.dumps({mock_name: mock_value}) + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + assert name == mock_name + return mock_data + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + raise NotImplementedError() + + provider = TestProvider() + + value = provider.get(mock_name, transform="json") + + assert isinstance(value, dict) + assert mock_name in value + assert value[mock_name] == mock_value + + +def test_base_provider_get_transform_json_exception(mock_name, mock_value): + """ + Test BaseProvider.get() with a json transform that raises an exception + """ + + mock_data = json.dumps({mock_name: mock_value}) + "{" + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + assert name == mock_name + return mock_data + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + raise NotImplementedError() + + provider = TestProvider() + + with pytest.raises(parameters.TransformParameterError) as excinfo: + provider.get(mock_name, transform="json") + + assert "Extra data" in str(excinfo) + + +def test_base_provider_get_transform_binary(mock_name, mock_value): + """ + Test BaseProvider.get() with a binary transform + """ + + mock_binary = mock_value.encode() + mock_data = base64.b64encode(mock_binary).decode() + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + assert name == mock_name + return mock_data + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + raise NotImplementedError() + + provider = TestProvider() + + value = provider.get(mock_name, transform="binary") + + assert isinstance(value, bytes) + assert value == mock_binary + + +def test_base_provider_get_transform_binary_exception(mock_name): + """ + Test BaseProvider.get() with a binary transform that raises an exception + """ + + mock_data = "qw" + print(mock_data) + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + assert name == mock_name + return mock_data + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + raise NotImplementedError() + + provider = TestProvider() + + with pytest.raises(parameters.TransformParameterError) as excinfo: + provider.get(mock_name, transform="binary") + + assert "Incorrect padding" in str(excinfo) + + +def test_base_provider_get_multiple_transform_json(mock_name, mock_value): + """ + Test BaseProvider.get_multiple() with a json transform + """ + + mock_data = json.dumps({mock_name: mock_value}) + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise NotImplementedError() + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + assert path == mock_name + return {"A": mock_data} + + provider = TestProvider() + + value = provider.get_multiple(mock_name, transform="json") + + assert isinstance(value, dict) + assert value["A"][mock_name] == mock_value + + +def test_base_provider_get_multiple_transform_json_partial_failure(mock_name, mock_value): + """ + Test BaseProvider.get_multiple() with a json transform that contains a partial failure + """ + + mock_data = json.dumps({mock_name: mock_value}) + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise NotImplementedError() + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + assert path == mock_name + return {"A": mock_data, "B": mock_data + "{"} + + provider = TestProvider() + + value = provider.get_multiple(mock_name, transform="json") + + assert isinstance(value, dict) + assert value["A"][mock_name] == mock_value + assert value["B"] is None + + +def test_base_provider_get_multiple_transform_json_exception(mock_name, mock_value): + """ + Test BaseProvider.get_multiple() with a json transform that raises an exception + """ + + mock_data = json.dumps({mock_name: mock_value}) + "{" + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise NotImplementedError() + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + assert path == mock_name + return {"A": mock_data} + + provider = TestProvider() + + with pytest.raises(parameters.TransformParameterError) as excinfo: + provider.get_multiple(mock_name, transform="json", raise_on_transform_error=True) + + assert "Extra data" in str(excinfo) + + +def test_base_provider_get_multiple_transform_binary(mock_name, mock_value): + """ + Test BaseProvider.get_multiple() with a binary transform + """ + + mock_binary = mock_value.encode() + mock_data = base64.b64encode(mock_binary).decode() + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise NotImplementedError() + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + assert path == mock_name + return {"A": mock_data} + + provider = TestProvider() + + value = provider.get_multiple(mock_name, transform="binary") + + assert isinstance(value, dict) + assert value["A"] == mock_binary + + +def test_base_provider_get_multiple_transform_binary_partial_failure(mock_name, mock_value): + """ + Test BaseProvider.get_multiple() with a binary transform that contains a partial failure + """ + + mock_binary = mock_value.encode() + mock_data_a = base64.b64encode(mock_binary).decode() + mock_data_b = "qw" + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise NotImplementedError() + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + assert path == mock_name + return {"A": mock_data_a, "B": mock_data_b} + + provider = TestProvider() + + value = provider.get_multiple(mock_name, transform="binary") + + assert isinstance(value, dict) + assert value["A"] == mock_binary + assert value["B"] is None + + +def test_base_provider_get_multiple_transform_binary_exception(mock_name): + """ + Test BaseProvider.get_multiple() with a binary transform that raises an exception + """ + + mock_data = "qw" + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise NotImplementedError() + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + assert path == mock_name + return {"A": mock_data} + + provider = TestProvider() + + with pytest.raises(parameters.TransformParameterError) as excinfo: + provider.get_multiple(mock_name, transform="binary", raise_on_transform_error=True) + + assert "Incorrect padding" in str(excinfo) + + +def test_base_provider_get_multiple_cached(mock_name, mock_value): + """ + Test BaseProvider.get_multiple() with cached values + """ + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise NotImplementedError() + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + raise NotImplementedError() + + provider = TestProvider() + + provider.store[(mock_name, None)] = ExpirableValue({"A": mock_value}, datetime.now() + timedelta(seconds=60)) + + value = provider.get_multiple(mock_name) + + assert isinstance(value, dict) + assert value["A"] == mock_value + + +def test_base_provider_get_multiple_expired(mock_name, mock_value): + """ + Test BaseProvider.get_multiple() with expired values + """ + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise NotImplementedError() + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + assert path == mock_name + return {"A": mock_value} + + provider = TestProvider() + + provider.store[(mock_name, None)] = ExpirableValue({"B": mock_value}, datetime.now() - timedelta(seconds=60)) + + value = provider.get_multiple(mock_name) + + assert isinstance(value, dict) + assert value["A"] == mock_value + + +def test_get_parameter(monkeypatch, mock_name, mock_value): + """ + Test get_parameter() + """ + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + assert name == mock_name + return mock_value + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + raise NotImplementedError() + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) + + value = parameters.get_parameter(mock_name) + + assert value == mock_value + + +def test_get_parameter_new(monkeypatch, mock_name, mock_value): + """ + Test get_parameter() without a default provider + """ + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + assert name == mock_name + return mock_value + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + raise NotImplementedError() + + monkeypatch.setattr(parameters.ssm, "DEFAULT_PROVIDERS", {}) + monkeypatch.setattr(parameters.ssm, "SSMProvider", TestProvider) + + value = parameters.get_parameter(mock_name) + + assert value == mock_value + + +def test_get_parameters(monkeypatch, mock_name, mock_value): + """ + Test get_parameters() + """ + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise NotImplementedError() + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + assert path == mock_name + return {"A": mock_value} + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "ssm", TestProvider()) + + values = parameters.get_parameters(mock_name) + + assert len(values) == 1 + assert values["A"] == mock_value + + +def test_get_parameters_new(monkeypatch, mock_name, mock_value): + """ + Test get_parameters() without a default provider + """ + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + raise NotImplementedError() + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + assert path == mock_name + return mock_value + + monkeypatch.setattr(parameters.ssm, "DEFAULT_PROVIDERS", {}) + monkeypatch.setattr(parameters.ssm, "SSMProvider", TestProvider) + + value = parameters.get_parameters(mock_name) + + assert value == mock_value + + +def test_get_secret(monkeypatch, mock_name, mock_value): + """ + Test get_secret() + """ + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + assert name == mock_name + return mock_value + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + raise NotImplementedError() + + monkeypatch.setitem(parameters.base.DEFAULT_PROVIDERS, "secrets", TestProvider()) + + value = parameters.get_secret(mock_name) + + assert value == mock_value + + +def test_get_secret_new(monkeypatch, mock_name, mock_value): + """ + Test get_secret() without a default provider + """ + + class TestProvider(BaseProvider): + def _get(self, name: str, **kwargs) -> str: + assert name == mock_name + return mock_value + + def _get_multiple(self, path: str, **kwargs) -> Dict[str, str]: + raise NotImplementedError() + + monkeypatch.setattr(parameters.secrets, "DEFAULT_PROVIDERS", {}) + monkeypatch.setattr(parameters.secrets, "SecretsProvider", TestProvider) + + value = parameters.get_secret(mock_name) + + assert value == mock_value + + +def test_transform_value_json(mock_value): + """ + Test transform_value() with a json transform + """ + + mock_data = json.dumps({"A": mock_value}) + + value = parameters.base.transform_value(mock_data, "json") + + assert isinstance(value, dict) + assert value["A"] == mock_value + + +def test_transform_value_json_exception(mock_value): + """ + Test transform_value() with a json transform that fails + """ + + mock_data = json.dumps({"A": mock_value}) + "{" + + with pytest.raises(parameters.TransformParameterError) as excinfo: + parameters.base.transform_value(mock_data, "json") + + assert "Extra data" in str(excinfo) + + +def test_transform_value_binary(mock_value): + """ + Test transform_value() with a binary transform + """ + + mock_binary = mock_value.encode() + mock_data = base64.b64encode(mock_binary).decode() + + value = parameters.base.transform_value(mock_data, "binary") + + assert isinstance(value, bytes) + assert value == mock_binary + + +def test_transform_value_binary_exception(): + """ + Test transform_value() with a binary transform that fails + """ + + mock_data = "qw" + + with pytest.raises(parameters.TransformParameterError) as excinfo: + parameters.base.transform_value(mock_data, "binary") + + assert "Incorrect padding" in str(excinfo) + + +def test_transform_value_wrong(mock_value): + """ + Test transform_value() with an incorrect transform + """ + + with pytest.raises(parameters.TransformParameterError) as excinfo: + parameters.base.transform_value(mock_value, "INCORRECT") + + assert "Invalid transform type" in str(excinfo)