Skip to content

Commit 6682f7c

Browse files
committed
expand decorator test to all functional transforms
1 parent d1f9cc4 commit 6682f7c

File tree

3 files changed

+35
-68
lines changed

3 files changed

+35
-68
lines changed

torchvision/prototype/transforms/_presets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Optional, Tuple
1+
from typing import Dict, Optional, Tuple, cast
22

33
import torch
44
from torch import Tensor, nn
@@ -41,7 +41,7 @@ def forward(self, img: Tensor) -> Tensor:
4141
img = F.pil_to_tensor(img)
4242
img = F.convert_image_dtype(img, torch.float)
4343
img = F.normalize(img, mean=self._mean, std=self._std)
44-
return img
44+
return cast(Tensor, img)
4545

4646

4747
class Kinect400Eval(nn.Module):
@@ -65,7 +65,7 @@ def forward(self, vid: Tensor) -> Tensor:
6565
vid = F.resize(vid, self._size, interpolation=self._interpolation)
6666
vid = F.center_crop(vid, self._crop_size)
6767
vid = F.convert_image_dtype(vid, torch.float)
68-
vid = F.normalize(vid, mean=self._mean, std=self._std)
68+
vid = cast(Tensor, F.normalize(vid, mean=self._mean, std=self._std))
6969
return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W)
7070

7171

torchvision/prototype/transforms/kernels/_geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def resize_image(
3939
new_height, new_width = size
4040
num_channels, old_height, old_width = image.shape[-3:]
4141
batch_shape = image.shape[:-3]
42-
return _F.resize(
42+
return _F.resize( # type: ignore[no-any-return]
4343
image.reshape((-1, num_channels, old_height, old_width)),
4444
size=size,
4545
interpolation=interpolation,

0 commit comments

Comments
 (0)