Skip to content

Commit e2e0e3d

Browse files
committed
make decorator annotations more concrete
1 parent 6682f7c commit e2e0e3d

File tree

4 files changed

+8
-7
lines changed

4 files changed

+8
-7
lines changed

torchvision/prototype/features/_image.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import warnings
44
from typing import Any, Optional, Union, Tuple, cast
55

6+
import PIL.Image
67
import torch
78
from torchvision.prototype.utils._internal import StrEnum
89
from torchvision.transforms.functional import to_pil_image
@@ -78,7 +79,7 @@ def guess_color_space(data: torch.Tensor) -> ColorSpace:
7879
def show(self) -> None:
7980
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we
8081
# promote this out of the prototype state
81-
to_pil_image(make_grid(self.view(-1, *self.shape[-3:]))).show()
82+
cast(PIL.Image.Image, to_pil_image(make_grid(self.view(-1, *self.shape[-3:])))).show()
8283

8384
def draw_bounding_box(self, bounding_box: BoundingBox, **kwargs: Any) -> Image:
8485
# TODO: this is useful for developing and debugging but we should remove or at least revisit this before we

torchvision/prototype/transforms/_presets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional, Tuple, cast
1+
from typing import Dict, Optional, Tuple
22

33
import torch
44
from torch import Tensor, nn
@@ -41,7 +41,7 @@ def forward(self, img: Tensor) -> Tensor:
4141
img = F.pil_to_tensor(img)
4242
img = F.convert_image_dtype(img, torch.float)
4343
img = F.normalize(img, mean=self._mean, std=self._std)
44-
return cast(Tensor, img)
44+
return img
4545

4646

4747
class Kinect400Eval(nn.Module):
@@ -65,7 +65,7 @@ def forward(self, vid: Tensor) -> Tensor:
6565
vid = F.resize(vid, self._size, interpolation=self._interpolation)
6666
vid = F.center_crop(vid, self._crop_size)
6767
vid = F.convert_image_dtype(vid, torch.float)
68-
vid = cast(Tensor, F.normalize(vid, mean=self._mean, std=self._std))
68+
vid = F.normalize(vid, mean=self._mean, std=self._std)
6969
return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W)
7070

7171

torchvision/prototype/transforms/kernels/_geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def resize_image(
3939
new_height, new_width = size
4040
num_channels, old_height, old_width = image.shape[-3:]
4141
batch_shape = image.shape[:-3]
42-
return _F.resize( # type: ignore[no-any-return]
42+
return _F.resize(
4343
image.reshape((-1, num_channels, old_height, old_width)),
4444
size=size,
4545
interpolation=interpolation,

torchvision/transforms/functional.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
from . import functional_tensor as F_t
2222

2323

24-
def log_api_usage_once(fn: Callable) -> Callable:
24+
def log_api_usage_once(fn: Callable[..., Tensor]) -> Callable[..., Tensor]:
2525
@functools.wraps(fn)
26-
def wrapper(*args, **kwargs):
26+
def wrapper(*args: Any, **kwargs: Any) -> Tensor:
2727
_log_api_usage_once(fn)
2828
return fn(*args, **kwargs)
2929

0 commit comments

Comments
 (0)