diff --git a/aws_lambda_powertools/tracing/tracer.py b/aws_lambda_powertools/tracing/tracer.py
index 56ba74dfdcf..0ce55e60837 100644
--- a/aws_lambda_powertools/tracing/tracer.py
+++ b/aws_lambda_powertools/tracing/tracer.py
@@ -1,3 +1,4 @@
+import contextlib
import copy
import functools
import inspect
@@ -320,6 +321,39 @@ def lambda_handler(event: dict, context: Any) -> Dict:
booking_id = event.get("booking_id")
asyncio.run(confirm_booking(booking_id=booking_id))
+ **Custom generator function using capture_method decorator**
+
+ from aws_lambda_powertools import Tracer
+ tracer = Tracer(service="booking")
+
+ @tracer.capture_method
+ def bookings_generator(booking_id):
+ resp = call_to_booking_service()
+ yield resp[0]
+ yield resp[1]
+
+ def lambda_handler(event: dict, context: Any) -> Dict:
+ gen = bookings_generator(booking_id=booking_id)
+ result = list(gen)
+
+ **Custom generator context manager using capture_method decorator**
+
+ from aws_lambda_powertools import Tracer
+ tracer = Tracer(service="booking")
+
+ @tracer.capture_method
+ @contextlib.contextmanager
+ def booking_actions(booking_id):
+ resp = call_to_booking_service()
+ yield "example result"
+ cleanup_stuff()
+
+ def lambda_handler(event: dict, context: Any) -> Dict:
+ booking_id = event.get("booking_id")
+
+ with booking_actions(booking_id=booking_id) as booking:
+ result = booking
+
**Tracing nested async calls**
from aws_lambda_powertools import Tracer
@@ -392,43 +426,93 @@ async def async_tasks():
err
Exception raised by method
"""
- method_name = f"{method.__name__}"
if inspect.iscoroutinefunction(method):
+ decorate = self._decorate_async_function(method=method)
+ elif inspect.isgeneratorfunction(method):
+ decorate = self._decorate_generator_function(method=method)
+ elif hasattr(method, "__wrapped__") and inspect.isgeneratorfunction(method.__wrapped__):
+ decorate = self._decorate_generator_function_with_context_manager(method=method)
+ else:
+ decorate = self._decorate_sync_function(method=method)
- @functools.wraps(method)
- async def decorate(*args, **kwargs):
- async with self.provider.in_subsegment_async(name=f"## {method_name}") as subsegment:
- try:
- logger.debug(f"Calling method: {method_name}")
- response = await method(*args, **kwargs)
- self._add_response_as_metadata(function_name=method_name, data=response, subsegment=subsegment)
- except Exception as err:
- logger.exception(f"Exception received from '{method_name}' method")
- self._add_full_exception_as_metadata(
- function_name=method_name, error=err, subsegment=subsegment
- )
- raise
-
- return response
+ return decorate
- else:
+ def _decorate_async_function(self, method: Callable = None):
+ method_name = f"{method.__name__}"
+
+ @functools.wraps(method)
+ async def decorate(*args, **kwargs):
+ async with self.provider.in_subsegment_async(name=f"## {method_name}") as subsegment:
+ try:
+ logger.debug(f"Calling method: {method_name}")
+ response = await method(*args, **kwargs)
+ self._add_response_as_metadata(function_name=method_name, data=response, subsegment=subsegment)
+ except Exception as err:
+ logger.exception(f"Exception received from '{method_name}' method")
+ self._add_full_exception_as_metadata(function_name=method_name, error=err, subsegment=subsegment)
+ raise
- @functools.wraps(method)
- def decorate(*args, **kwargs):
- with self.provider.in_subsegment(name=f"## {method_name}") as subsegment:
- try:
- logger.debug(f"Calling method: {method_name}")
- response = method(*args, **kwargs)
- self._add_response_as_metadata(function_name=method_name, data=response, subsegment=subsegment)
- except Exception as err:
- logger.exception(f"Exception received from '{method_name}' method")
- self._add_full_exception_as_metadata(
- function_name=method_name, error=err, subsegment=subsegment
- )
- raise
-
- return response
+ return response
+
+ return decorate
+
+ def _decorate_generator_function(self, method: Callable = None):
+ method_name = f"{method.__name__}"
+
+ @functools.wraps(method)
+ def decorate(*args, **kwargs):
+ with self.provider.in_subsegment(name=f"## {method_name}") as subsegment:
+ try:
+ logger.debug(f"Calling method: {method_name}")
+ result = yield from method(*args, **kwargs)
+ self._add_response_as_metadata(function_name=method_name, data=result, subsegment=subsegment)
+ except Exception as err:
+ logger.exception(f"Exception received from '{method_name}' method")
+ self._add_full_exception_as_metadata(function_name=method_name, error=err, subsegment=subsegment)
+ raise
+
+ return result
+
+ return decorate
+
+ def _decorate_generator_function_with_context_manager(self, method: Callable = None):
+ method_name = f"{method.__name__}"
+
+ @functools.wraps(method)
+ @contextlib.contextmanager
+ def decorate(*args, **kwargs):
+ with self.provider.in_subsegment(name=f"## {method_name}") as subsegment:
+ try:
+ logger.debug(f"Calling method: {method_name}")
+ with method(*args, **kwargs) as return_val:
+ result = return_val
+ self._add_response_as_metadata(function_name=method_name, data=result, subsegment=subsegment)
+ except Exception as err:
+ logger.exception(f"Exception received from '{method_name}' method")
+ self._add_full_exception_as_metadata(function_name=method_name, error=err, subsegment=subsegment)
+ raise
+
+ yield result
+
+ return decorate
+
+ def _decorate_sync_function(self, method: Callable = None):
+ method_name = f"{method.__name__}"
+
+ @functools.wraps(method)
+ def decorate(*args, **kwargs):
+ with self.provider.in_subsegment(name=f"## {method_name}") as subsegment:
+ try:
+ logger.debug(f"Calling method: {method_name}")
+ response = method(*args, **kwargs)
+ self._add_response_as_metadata(function_name=method_name, data=response, subsegment=subsegment)
+ except Exception as err:
+ logger.exception(f"Exception received from '{method_name}' method")
+ self._add_full_exception_as_metadata(function_name=method_name, error=err, subsegment=subsegment)
+ raise
+
+ return response
return decorate
diff --git a/docs/content/core/tracer.mdx b/docs/content/core/tracer.mdx
index 17b1b540f0b..677ad4ccae0 100644
--- a/docs/content/core/tracer.mdx
+++ b/docs/content/core/tracer.mdx
@@ -14,7 +14,7 @@ Tracer is an opinionated thin wrapper for [AWS X-Ray Python SDK](https://github.
* Capture cold start as annotation, and responses as well as full exceptions as metadata
* Run functions locally with SAM CLI without code change to disable tracing
* Explicitly disable tracing via env var `POWERTOOLS_TRACE_DISABLED="true"`
-* Support tracing async methods
+* Support tracing async methods, generators, and context managers
* Auto patch supported modules, or a tuple of explicit modules supported by AWS X-Ray
## Initialization
@@ -111,16 +111,18 @@ def collect_payment(charge_id):
...
```
-## Asynchronous functions
+## Asynchronous and generator functions
We do not support async Lambda handler - Lambda handler itself must be synchronous
-You can trace an asynchronous function using the `capture_method`. The decorator will detect whether your function is asynchronous, and adapt its behaviour accordingly.
+You can trace asynchronous functions and generator functions (including context managers) using `capture_method`.
+The decorator will detect whether your function is asynchronous, a generator, or a context manager and adapt its behaviour accordingly.
```python:title=lambda_handler_with_async_code.py
import asyncio
+import contextlib
from aws_lambda_powertools import Tracer
tracer = Tracer()
@@ -130,9 +132,29 @@ async def collect_payment():
...
# highlight-end
+# highlight-start
+@contextlib.contextmanager
+@tracer.capture_method
+def collect_payment_ctxman():
+ yield result
+ ...
+# highlight-end
+
+# highlight-start
+@tracer.capture_method
+def collect_payment_gen():
+ yield result
+ ...
+# highlight-end
+
@tracer.capture_lambda_handler
def handler(evt, ctx): # highlight-line
asyncio.run(collect_payment())
+
+ with collect_payment_ctxman as result:
+ do_something_with(result)
+
+ another_result = list(collect_payment_gen())
```
## Tracing aiohttp requests
diff --git a/example/tests/test_handler.py b/example/tests/test_handler.py
index cbd46ae6f3f..f95e1fda12c 100644
--- a/example/tests/test_handler.py
+++ b/example/tests/test_handler.py
@@ -4,6 +4,14 @@
import pytest
+from aws_lambda_powertools import Tracer
+
+
+@pytest.fixture(scope="function", autouse=True)
+def reset_tracing_config():
+ Tracer._reset_config()
+ yield
+
@pytest.fixture()
def env_vars(monkeypatch):
diff --git a/tests/functional/test_tracing.py b/tests/functional/test_tracing.py
index cda0a85cc4d..59b93789907 100644
--- a/tests/functional/test_tracing.py
+++ b/tests/functional/test_tracing.py
@@ -1,3 +1,5 @@
+import contextlib
+
import pytest
from aws_lambda_powertools import Tracer
@@ -150,3 +152,36 @@ def sums_values():
return func_1() + func_2()
sums_values()
+
+
+def test_tracer_yield_with_capture():
+ # GIVEN tracer method decorator is used
+ tracer = Tracer(disabled=True)
+
+ # WHEN capture_method decorator is applied to a context manager
+ @tracer.capture_method
+ @contextlib.contextmanager
+ def yield_with_capture():
+ yield "testresult"
+
+ # Or WHEN capture_method decorator is applied to a generator function
+ @tracer.capture_method
+ def generator_func():
+ yield "testresult2"
+
+ @tracer.capture_lambda_handler
+ def handler(event, context):
+ result = []
+ with yield_with_capture() as yielded_value:
+ result.append(yielded_value)
+
+ gen = generator_func()
+
+ result.append(next(gen))
+
+ return result
+
+ # THEN no exception is thrown, and the functions properly return values
+ result = handler({}, {})
+ assert "testresult" in result
+ assert "testresult2" in result
diff --git a/tests/unit/test_tracing.py b/tests/unit/test_tracing.py
index d1e5408bb77..8f7d9a646dd 100644
--- a/tests/unit/test_tracing.py
+++ b/tests/unit/test_tracing.py
@@ -1,3 +1,4 @@
+import contextlib
import sys
from typing import NamedTuple
from unittest import mock
@@ -348,3 +349,65 @@ async def greeting(name, message):
put_metadata_mock_args = in_subsegment_mock.put_metadata.call_args[1]
assert put_metadata_mock_args["key"] == "greeting error"
assert put_metadata_mock_args["namespace"] == "booking"
+
+
+def test_tracer_yield_from_context_manager(mocker, provider_stub, in_subsegment_mock):
+ # GIVEN tracer is initialized
+ provider = provider_stub(in_subsegment=in_subsegment_mock.in_subsegment)
+ tracer = Tracer(provider=provider, service="booking")
+
+ # WHEN capture_method decorator is used on a context manager
+ @tracer.capture_method
+ @contextlib.contextmanager
+ def yield_with_capture():
+ yield "test result"
+
+ @tracer.capture_lambda_handler
+ def handler(event, context):
+ response = []
+ with yield_with_capture() as yielded_value:
+ response.append(yielded_value)
+
+ return response
+
+ result = handler({}, {})
+
+ # THEN we should have a subsegment named after the method name
+ # and add its response as trace metadata
+ handler_trace, yield_function_trace = in_subsegment_mock.in_subsegment.call_args_list
+
+ assert "test result" in in_subsegment_mock.put_metadata.call_args[1]["value"]
+ assert in_subsegment_mock.in_subsegment.call_count == 2
+ assert handler_trace == mocker.call(name="## handler")
+ assert yield_function_trace == mocker.call(name="## yield_with_capture")
+ assert "test result" in result
+
+
+def test_tracer_yield_from_generator(mocker, provider_stub, in_subsegment_mock):
+ # GIVEN tracer is initialized
+ provider = provider_stub(in_subsegment=in_subsegment_mock.in_subsegment)
+ tracer = Tracer(provider=provider, service="booking")
+
+ # WHEN capture_method decorator is used on a generator function
+ @tracer.capture_method
+ def generator_fn():
+ yield "test result"
+
+ @tracer.capture_lambda_handler
+ def handler(event, context):
+ gen = generator_fn()
+ response = list(gen)
+
+ return response
+
+ result = handler({}, {})
+
+ # THEN we should have a subsegment named after the method name
+ # and add its response as trace metadata
+ handler_trace, generator_fn_trace = in_subsegment_mock.in_subsegment.call_args_list
+
+ assert "test result" in in_subsegment_mock.put_metadata.call_args[1]["value"]
+ assert in_subsegment_mock.in_subsegment.call_count == 2
+ assert handler_trace == mocker.call(name="## handler")
+ assert generator_fn_trace == mocker.call(name="## generator_fn")
+ assert "test result" in result