diff --git a/aws_lambda_powertools/utilities/data_masking/base.py b/aws_lambda_powertools/utilities/data_masking/base.py index 3eed26045c2..d961ce55690 100644 --- a/aws_lambda_powertools/utilities/data_masking/base.py +++ b/aws_lambda_powertools/utilities/data_masking/base.py @@ -26,6 +26,18 @@ logger = logging.getLogger(__name__) +def prepare_data(data: Any) -> Any: + if hasattr(data, "__dataclass_fields__"): + import dataclasses + return dataclasses.asdict(data) + + if callable(getattr(data, "model_dump", None)): + return data.model_dump() + + if callable(getattr(data, "dict", None)): + return data.dict() + + return data class DataMasking: """ @@ -93,6 +105,7 @@ def encrypt( data_masker = DataMasking(provider=encryption_provider) encrypted = data_masker.encrypt({"secret": "value"}) """ + data = prepare_data(data) return self._apply_action( data=data, fields=None, @@ -135,7 +148,7 @@ def decrypt( data_masker = DataMasking(provider=encryption_provider) encrypted = data_masker.decrypt(encrypted_data) """ - + data = prepare_data(data) return self._apply_action( data=data, fields=None, @@ -184,6 +197,7 @@ def erase( Any The data with sensitive information erased or masked. """ + data = prepare_data(data) if masking_rules: return self._apply_masking_rules(data=data, masking_rules=masking_rules) else: diff --git a/tests/unit/data_masking/test_data_masking_input_types.py b/tests/unit/data_masking/test_data_masking_input_types.py new file mode 100644 index 00000000000..b9bf41dd875 --- /dev/null +++ b/tests/unit/data_masking/test_data_masking_input_types.py @@ -0,0 +1,74 @@ +import dataclasses +import pytest +from pydantic import BaseModel + +from aws_lambda_powertools.utilities.data_masking.base import DataMasking +from aws_lambda_powertools.utilities.data_masking.constants import DATA_MASKING_STRING + +@pytest.fixture +def data_masker() -> DataMasking: + return DataMasking() + +# --------------------------- +# Test with a Pydantic model +# --------------------------- +class MyPydanticModel(BaseModel): + name: str + age: int + +def test_erase_on_pydantic_model(data_masker): + # GIVEN a Pydantic model instance + model_instance = MyPydanticModel(name="powertools", age=5) + + # WHEN calling erase with fields=["age"] + result = data_masker.erase(model_instance, fields=["age"]) + + # THEN the result should be a dict with the "age" field masked + assert isinstance(result, dict) + assert result["age"] == DATA_MASKING_STRING + assert result["name"] == "powertools" + + +# --------------------------- +# Test with a dataclass +# --------------------------- +@dataclasses.dataclass +class MyDataClass: + name: str + age: int + +def test_erase_on_dataclass(data_masker): + # GIVEN a dataclass instance + dc_instance = MyDataClass(name="powertools", age=5) + + # WHEN calling erase with fields=["age"] + result = data_masker.erase(dc_instance, fields=["age"]) + + # THEN the result should be a dict with the "age" field masked + assert isinstance(result, dict) + assert result["age"] == DATA_MASKING_STRING + assert result["name"] == "powertools" + + +# --------------------------- +# Test with a custom class that implements dict() +# --------------------------- +class MyCustomClass: + def __init__(self, name, age): + self.name = name + self.age = age + + def dict(self): + return {"name": self.name, "age": self.age} + +def test_erase_on_custom_class(data_masker): + # GIVEN an instance of a custom class with a dict() method + custom_instance = MyCustomClass("powertools", 5) + + # WHEN calling erase with fields=["age"] + result = data_masker.erase(custom_instance, fields=["age"]) + + # THEN the result should be a dict with the "age" field masked + assert isinstance(result, dict) + assert result["age"] == DATA_MASKING_STRING + assert result["name"] == "powertools" \ No newline at end of file