Skip to content

Commit 1addc91

Browse files
committed
make decorator more generic
1 parent e2e0e3d commit 1addc91

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

torchvision/prototype/features/_image.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import warnings
44
from typing import Any, Optional, Union, Tuple, cast
55

6-
import PIL.Image
76
import torch
87
from torchvision.prototype.utils._internal import StrEnum
98
from torchvision.transforms.functional import to_pil_image
@@ -79,7 +78,7 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace:
7978
def show(self) -> None:
8079
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
8180
# promote this out of the prototype state
82-
cast(PIL.Image.Image, to_pil_image(make_grid(self.view(-1, *self.shape[-3:])))).show()
81+
to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show()
8382

8483
def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image:
8584
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we

torchvision/transforms/functional.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import numbers
44
import warnings
55
from enum import Enum
6-
from typing import Callable
7-
from typing import List, Tuple, Any, Optional
6+
from typing import List, Tuple, Any, Optional, TypeVar, Callable, cast
87

98
import numpy as np
109
import torch
@@ -20,14 +19,16 @@
2019
from . import functional_pil as F_pil
2120
from . import functional_tensor as F_t
2221

22+
F = TypeVar("F", bound=Callable[..., Any])
2323

24-
def log_api_usage_once(fn: Callable[..., Tensor]) -> Callable[..., Tensor]:
24+
25+
def log_api_usage_once(fn: F) -> F:
2526
@functools.wraps(fn)
26-
def wrapper(*args: Any, **kwargs: Any) -> Tensor:
27+
def wrapper(*args, **kwargs):
2728
_log_api_usage_once(fn)
2829
return fn(*args, **kwargs)
2930

30-
return wrapper
31+
return cast(F, wrapper)
3132

3233

3334
class InterpolationMode(Enum):

0 commit comments

Comments
 (0)