From 8e1b3524d376bf76da9b11acbf0d9dacc3f1763c Mon Sep 17 00:00:00 2001
From: heitorlessa <lessa@amazon.co.uk>
Date: Fri, 16 Jul 2021 22:31:59 +0200
Subject: [PATCH] fix: mypy generic type to preserve signature

---
 aws_lambda_powertools/tracing/tracer.py | 36 ++++++++++++++++++-------
 1 file changed, 27 insertions(+), 9 deletions(-)

diff --git a/aws_lambda_powertools/tracing/tracer.py b/aws_lambda_powertools/tracing/tracer.py
index 48b7866cf0a..5709b1956c2 100644
--- a/aws_lambda_powertools/tracing/tracer.py
+++ b/aws_lambda_powertools/tracing/tracer.py
@@ -5,7 +5,7 @@
 import logging
 import numbers
 import os
-from typing import Any, Callable, Dict, Optional, Sequence, Union
+from typing import Any, Awaitable, Callable, Dict, Optional, Sequence, TypeVar, Union, cast, overload
 
 from ..shared import constants
 from ..shared.functions import resolve_env_var_choice, resolve_truthy_env_var_choice
@@ -18,6 +18,9 @@
 aws_xray_sdk = LazyLoader(constants.XRAY_SDK_MODULE, globals(), constants.XRAY_SDK_MODULE)
 aws_xray_sdk.core = LazyLoader(constants.XRAY_SDK_CORE_MODULE, globals(), constants.XRAY_SDK_CORE_MODULE)
 
+AnyCallableT = TypeVar("AnyCallableT", bound=Callable[..., Any])  # noqa: VNE001
+AnyAwaitableT = TypeVar("AnyAwaitableT", bound=Awaitable)
+
 
 class Tracer:
     """Tracer using AWS-XRay to provide decorators with known defaults for Lambda functions
@@ -329,12 +332,26 @@ def decorate(event, context, **kwargs):
 
         return decorate
 
+    # see #465
+    @overload
+    def capture_method(self, method: "AnyCallableT") -> "AnyCallableT":
+        ...
+
+    @overload
     def capture_method(
         self,
-        method: Optional[Callable] = None,
+        method: None = None,
         capture_response: Optional[bool] = None,
         capture_error: Optional[bool] = None,
-    ):
+    ) -> Callable[["AnyCallableT"], "AnyCallableT"]:
+        ...
+
+    def capture_method(
+        self,
+        method: Optional[AnyCallableT] = None,
+        capture_response: Optional[bool] = None,
+        capture_error: Optional[bool] = None,
+    ) -> AnyCallableT:
         """Decorator to create subsegment for arbitrary functions
 
         It also captures both response and exceptions as metadata
@@ -487,8 +504,9 @@ async def async_tasks():
         # Return a partial function with args filled
         if method is None:
             logger.debug("Decorator called with parameters")
-            return functools.partial(
-                self.capture_method, capture_response=capture_response, capture_error=capture_error
+            return cast(
+                AnyCallableT,
+                functools.partial(self.capture_method, capture_response=capture_response, capture_error=capture_error),
             )
 
         method_name = f"{method.__name__}"
@@ -509,7 +527,7 @@ async def async_tasks():
             return self._decorate_generator_function(
                 method=method, capture_response=capture_response, capture_error=capture_error, method_name=method_name
             )
-        elif hasattr(method, "__wrapped__") and inspect.isgeneratorfunction(method.__wrapped__):
+        elif hasattr(method, "__wrapped__") and inspect.isgeneratorfunction(method.__wrapped__):  # type: ignore
             return self._decorate_generator_function_with_context_manager(
                 method=method, capture_response=capture_response, capture_error=capture_error, method_name=method_name
             )
@@ -602,11 +620,11 @@ def decorate(*args, **kwargs):
 
     def _decorate_sync_function(
         self,
-        method: Callable,
+        method: AnyCallableT,
         capture_response: Optional[Union[bool, str]] = None,
         capture_error: Optional[Union[bool, str]] = None,
         method_name: Optional[str] = None,
-    ):
+    ) -> AnyCallableT:
         @functools.wraps(method)
         def decorate(*args, **kwargs):
             with self.provider.in_subsegment(name=f"## {method_name}") as subsegment:
@@ -628,7 +646,7 @@ def decorate(*args, **kwargs):
 
                 return response
 
-        return decorate
+        return cast(AnyCallableT, decorate)
 
     def _add_response_as_metadata(
         self,