diff --git a/aws_lambda_powertools/tracing/tracer.py b/aws_lambda_powertools/tracing/tracer.py index 56ba74dfdcf..0ce55e60837 100644 --- a/aws_lambda_powertools/tracing/tracer.py +++ b/aws_lambda_powertools/tracing/tracer.py @@ -1,3 +1,4 @@ +import contextlib import copy import functools import inspect @@ -320,6 +321,39 @@ def lambda_handler(event: dict, context: Any) -> Dict: booking_id = event.get("booking_id") asyncio.run(confirm_booking(booking_id=booking_id)) + **Custom generator function using capture_method decorator** + + from aws_lambda_powertools import Tracer + tracer = Tracer(service="booking") + + @tracer.capture_method + def bookings_generator(booking_id): + resp = call_to_booking_service() + yield resp[0] + yield resp[1] + + def lambda_handler(event: dict, context: Any) -> Dict: + gen = bookings_generator(booking_id=booking_id) + result = list(gen) + + **Custom generator context manager using capture_method decorator** + + from aws_lambda_powertools import Tracer + tracer = Tracer(service="booking") + + @tracer.capture_method + @contextlib.contextmanager + def booking_actions(booking_id): + resp = call_to_booking_service() + yield "example result" + cleanup_stuff() + + def lambda_handler(event: dict, context: Any) -> Dict: + booking_id = event.get("booking_id") + + with booking_actions(booking_id=booking_id) as booking: + result = booking + **Tracing nested async calls** from aws_lambda_powertools import Tracer @@ -392,43 +426,93 @@ async def async_tasks(): err Exception raised by method """ - method_name = f"{method.__name__}" if inspect.iscoroutinefunction(method): + decorate = self._decorate_async_function(method=method) + elif inspect.isgeneratorfunction(method): + decorate = self._decorate_generator_function(method=method) + elif hasattr(method, "__wrapped__") and inspect.isgeneratorfunction(method.__wrapped__): + decorate = self._decorate_generator_function_with_context_manager(method=method) + else: + decorate = self._decorate_sync_function(method=method) - @functools.wraps(method) - async def decorate(*args, **kwargs): - async with self.provider.in_subsegment_async(name=f"## {method_name}") as subsegment: - try: - logger.debug(f"Calling method: {method_name}") - response = await method(*args, **kwargs) - self._add_response_as_metadata(function_name=method_name, data=response, subsegment=subsegment) - except Exception as err: - logger.exception(f"Exception received from '{method_name}' method") - self._add_full_exception_as_metadata( - function_name=method_name, error=err, subsegment=subsegment - ) - raise - - return response + return decorate - else: + def _decorate_async_function(self, method: Callable = None): + method_name = f"{method.__name__}" + + @functools.wraps(method) + async def decorate(*args, **kwargs): + async with self.provider.in_subsegment_async(name=f"## {method_name}") as subsegment: + try: + logger.debug(f"Calling method: {method_name}") + response = await method(*args, **kwargs) + self._add_response_as_metadata(function_name=method_name, data=response, subsegment=subsegment) + except Exception as err: + logger.exception(f"Exception received from '{method_name}' method") + self._add_full_exception_as_metadata(function_name=method_name, error=err, subsegment=subsegment) + raise - @functools.wraps(method) - def decorate(*args, **kwargs): - with self.provider.in_subsegment(name=f"## {method_name}") as subsegment: - try: - logger.debug(f"Calling method: {method_name}") - response = method(*args, **kwargs) - self._add_response_as_metadata(function_name=method_name, data=response, subsegment=subsegment) - except Exception as err: - logger.exception(f"Exception received from '{method_name}' method") - self._add_full_exception_as_metadata( - function_name=method_name, error=err, subsegment=subsegment - ) - raise - - return response + return response + + return decorate + + def _decorate_generator_function(self, method: Callable = None): + method_name = f"{method.__name__}" + + @functools.wraps(method) + def decorate(*args, **kwargs): + with self.provider.in_subsegment(name=f"## {method_name}") as subsegment: + try: + logger.debug(f"Calling method: {method_name}") + result = yield from method(*args, **kwargs) + self._add_response_as_metadata(function_name=method_name, data=result, subsegment=subsegment) + except Exception as err: + logger.exception(f"Exception received from '{method_name}' method") + self._add_full_exception_as_metadata(function_name=method_name, error=err, subsegment=subsegment) + raise + + return result + + return decorate + + def _decorate_generator_function_with_context_manager(self, method: Callable = None): + method_name = f"{method.__name__}" + + @functools.wraps(method) + @contextlib.contextmanager + def decorate(*args, **kwargs): + with self.provider.in_subsegment(name=f"## {method_name}") as subsegment: + try: + logger.debug(f"Calling method: {method_name}") + with method(*args, **kwargs) as return_val: + result = return_val + self._add_response_as_metadata(function_name=method_name, data=result, subsegment=subsegment) + except Exception as err: + logger.exception(f"Exception received from '{method_name}' method") + self._add_full_exception_as_metadata(function_name=method_name, error=err, subsegment=subsegment) + raise + + yield result + + return decorate + + def _decorate_sync_function(self, method: Callable = None): + method_name = f"{method.__name__}" + + @functools.wraps(method) + def decorate(*args, **kwargs): + with self.provider.in_subsegment(name=f"## {method_name}") as subsegment: + try: + logger.debug(f"Calling method: {method_name}") + response = method(*args, **kwargs) + self._add_response_as_metadata(function_name=method_name, data=response, subsegment=subsegment) + except Exception as err: + logger.exception(f"Exception received from '{method_name}' method") + self._add_full_exception_as_metadata(function_name=method_name, error=err, subsegment=subsegment) + raise + + return response return decorate diff --git a/docs/content/core/tracer.mdx b/docs/content/core/tracer.mdx index 17b1b540f0b..677ad4ccae0 100644 --- a/docs/content/core/tracer.mdx +++ b/docs/content/core/tracer.mdx @@ -14,7 +14,7 @@ Tracer is an opinionated thin wrapper for [AWS X-Ray Python SDK](https://github. * Capture cold start as annotation, and responses as well as full exceptions as metadata * Run functions locally with SAM CLI without code change to disable tracing * Explicitly disable tracing via env var `POWERTOOLS_TRACE_DISABLED="true"` -* Support tracing async methods +* Support tracing async methods, generators, and context managers * Auto patch supported modules, or a tuple of explicit modules supported by AWS X-Ray ## Initialization @@ -111,16 +111,18 @@ def collect_payment(charge_id): ... ``` -## Asynchronous functions +## Asynchronous and generator functions We do not support async Lambda handler - Lambda handler itself must be synchronous
-You can trace an asynchronous function using the `capture_method`. The decorator will detect whether your function is asynchronous, and adapt its behaviour accordingly. +You can trace asynchronous functions and generator functions (including context managers) using `capture_method`. +The decorator will detect whether your function is asynchronous, a generator, or a context manager and adapt its behaviour accordingly. ```python:title=lambda_handler_with_async_code.py import asyncio +import contextlib from aws_lambda_powertools import Tracer tracer = Tracer() @@ -130,9 +132,29 @@ async def collect_payment(): ... # highlight-end +# highlight-start +@contextlib.contextmanager +@tracer.capture_method +def collect_payment_ctxman(): + yield result + ... +# highlight-end + +# highlight-start +@tracer.capture_method +def collect_payment_gen(): + yield result + ... +# highlight-end + @tracer.capture_lambda_handler def handler(evt, ctx): # highlight-line asyncio.run(collect_payment()) + + with collect_payment_ctxman as result: + do_something_with(result) + + another_result = list(collect_payment_gen()) ``` ## Tracing aiohttp requests diff --git a/example/tests/test_handler.py b/example/tests/test_handler.py index cbd46ae6f3f..f95e1fda12c 100644 --- a/example/tests/test_handler.py +++ b/example/tests/test_handler.py @@ -4,6 +4,14 @@ import pytest +from aws_lambda_powertools import Tracer + + +@pytest.fixture(scope="function", autouse=True) +def reset_tracing_config(): + Tracer._reset_config() + yield + @pytest.fixture() def env_vars(monkeypatch): diff --git a/tests/functional/test_tracing.py b/tests/functional/test_tracing.py index cda0a85cc4d..59b93789907 100644 --- a/tests/functional/test_tracing.py +++ b/tests/functional/test_tracing.py @@ -1,3 +1,5 @@ +import contextlib + import pytest from aws_lambda_powertools import Tracer @@ -150,3 +152,36 @@ def sums_values(): return func_1() + func_2() sums_values() + + +def test_tracer_yield_with_capture(): + # GIVEN tracer method decorator is used + tracer = Tracer(disabled=True) + + # WHEN capture_method decorator is applied to a context manager + @tracer.capture_method + @contextlib.contextmanager + def yield_with_capture(): + yield "testresult" + + # Or WHEN capture_method decorator is applied to a generator function + @tracer.capture_method + def generator_func(): + yield "testresult2" + + @tracer.capture_lambda_handler + def handler(event, context): + result = [] + with yield_with_capture() as yielded_value: + result.append(yielded_value) + + gen = generator_func() + + result.append(next(gen)) + + return result + + # THEN no exception is thrown, and the functions properly return values + result = handler({}, {}) + assert "testresult" in result + assert "testresult2" in result diff --git a/tests/unit/test_tracing.py b/tests/unit/test_tracing.py index d1e5408bb77..8f7d9a646dd 100644 --- a/tests/unit/test_tracing.py +++ b/tests/unit/test_tracing.py @@ -1,3 +1,4 @@ +import contextlib import sys from typing import NamedTuple from unittest import mock @@ -348,3 +349,65 @@ async def greeting(name, message): put_metadata_mock_args = in_subsegment_mock.put_metadata.call_args[1] assert put_metadata_mock_args["key"] == "greeting error" assert put_metadata_mock_args["namespace"] == "booking" + + +def test_tracer_yield_from_context_manager(mocker, provider_stub, in_subsegment_mock): + # GIVEN tracer is initialized + provider = provider_stub(in_subsegment=in_subsegment_mock.in_subsegment) + tracer = Tracer(provider=provider, service="booking") + + # WHEN capture_method decorator is used on a context manager + @tracer.capture_method + @contextlib.contextmanager + def yield_with_capture(): + yield "test result" + + @tracer.capture_lambda_handler + def handler(event, context): + response = [] + with yield_with_capture() as yielded_value: + response.append(yielded_value) + + return response + + result = handler({}, {}) + + # THEN we should have a subsegment named after the method name + # and add its response as trace metadata + handler_trace, yield_function_trace = in_subsegment_mock.in_subsegment.call_args_list + + assert "test result" in in_subsegment_mock.put_metadata.call_args[1]["value"] + assert in_subsegment_mock.in_subsegment.call_count == 2 + assert handler_trace == mocker.call(name="## handler") + assert yield_function_trace == mocker.call(name="## yield_with_capture") + assert "test result" in result + + +def test_tracer_yield_from_generator(mocker, provider_stub, in_subsegment_mock): + # GIVEN tracer is initialized + provider = provider_stub(in_subsegment=in_subsegment_mock.in_subsegment) + tracer = Tracer(provider=provider, service="booking") + + # WHEN capture_method decorator is used on a generator function + @tracer.capture_method + def generator_fn(): + yield "test result" + + @tracer.capture_lambda_handler + def handler(event, context): + gen = generator_fn() + response = list(gen) + + return response + + result = handler({}, {}) + + # THEN we should have a subsegment named after the method name + # and add its response as trace metadata + handler_trace, generator_fn_trace = in_subsegment_mock.in_subsegment.call_args_list + + assert "test result" in in_subsegment_mock.put_metadata.call_args[1]["value"] + assert in_subsegment_mock.in_subsegment.call_count == 2 + assert handler_trace == mocker.call(name="## handler") + assert generator_fn_trace == mocker.call(name="## generator_fn") + assert "test result" in result