diff --git a/torchvision/prototype/models/video/resnet.py b/torchvision/prototype/models/video/resnet.py index a7abddf9d55..268043fdac3 100644 --- a/torchvision/prototype/models/video/resnet.py +++ b/torchvision/prototype/models/video/resnet.py @@ -62,7 +62,7 @@ def _video_resnet( class R3D_18Weights(Weights): Kinetics400_RefV1 = WeightEntry( url="https://download.pytorch.org/models/r3d_18-b3b3357e.pth", - transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)), + transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "acc@1": 52.75, @@ -74,7 +74,7 @@ class R3D_18Weights(Weights): class MC3_18Weights(Weights): Kinetics400_RefV1 = WeightEntry( url="https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", - transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)), + transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "acc@1": 53.90, @@ -86,7 +86,7 @@ class MC3_18Weights(Weights): class R2Plus1D_18Weights(Weights): Kinetics400_RefV1 = WeightEntry( url="https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", - transforms=partial(Kinect400Eval, resize_size=(128, 171), crop_size=(112, 112)), + transforms=partial(Kinect400Eval, crop_size=(112, 112), resize_size=(128, 171)), meta={ **_COMMON_META, "acc@1": 57.50, diff --git a/torchvision/prototype/transforms/_presets.py b/torchvision/prototype/transforms/_presets.py index 3aed58bb8d6..3b9d733d8df 100644 --- a/torchvision/prototype/transforms/_presets.py +++ b/torchvision/prototype/transforms/_presets.py @@ -3,8 +3,7 @@ import torch from torch import Tensor, nn -from ... import transforms as T -from ...transforms import functional as F +from ...transforms import functional as F, InterpolationMode __all__ = ["CocoEval", "ImageNetEval", "Kinect400Eval", "VocEval"] @@ -26,42 +25,47 @@ def __init__( resize_size: int = 256, mean: Tuple[float, ...] = (0.485, 0.456, 0.406), std: Tuple[float, ...] = (0.229, 0.224, 0.225), - interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, ) -> None: super().__init__() - self._resize = T.Resize(resize_size, interpolation=interpolation) - self._crop = T.CenterCrop(crop_size) - self._normalize = T.Normalize(mean=mean, std=std) + self._crop_size = [crop_size] + self._size = [resize_size] + self._mean = list(mean) + self._std = list(std) + self._interpolation = interpolation def forward(self, img: Tensor) -> Tensor: - img = self._crop(self._resize(img)) + img = F.resize(img, self._size, interpolation=self._interpolation) + img = F.center_crop(img, self._crop_size) if not isinstance(img, Tensor): img = F.pil_to_tensor(img) img = F.convert_image_dtype(img, torch.float) - return self._normalize(img) + img = F.normalize(img, mean=self._mean, std=self._std) + return img class Kinect400Eval(nn.Module): def __init__( self, - resize_size: Tuple[int, int], crop_size: Tuple[int, int], + resize_size: Tuple[int, int], mean: Tuple[float, ...] = (0.43216, 0.394666, 0.37645), std: Tuple[float, ...] = (0.22803, 0.22145, 0.216989), - interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, ) -> None: super().__init__() - self._convert = T.ConvertImageDtype(torch.float) - self._resize = T.Resize(resize_size, interpolation=interpolation) - self._normalize = T.Normalize(mean=mean, std=std) - self._crop = T.CenterCrop(crop_size) + self._crop_size = list(crop_size) + self._size = list(resize_size) + self._mean = list(mean) + self._std = list(std) + self._interpolation = interpolation def forward(self, vid: Tensor) -> Tensor: vid = vid.permute(0, 3, 1, 2) # (T, H, W, C) => (T, C, H, W) - vid = self._convert(vid) - vid = self._resize(vid) - vid = self._normalize(vid) - vid = self._crop(vid) + vid = F.resize(vid, self._size, interpolation=self._interpolation) + vid = F.center_crop(vid, self._crop_size) + vid = F.convert_image_dtype(vid, torch.float) + vid = F.normalize(vid, mean=self._mean, std=self._std) return vid.permute(1, 0, 2, 3) # (T, C, H, W) => (C, T, H, W) @@ -71,8 +75,8 @@ def __init__( resize_size: int, mean: Tuple[float, ...] = (0.485, 0.456, 0.406), std: Tuple[float, ...] = (0.229, 0.224, 0.225), - interpolation: T.InterpolationMode = T.InterpolationMode.BILINEAR, - interpolation_target: T.InterpolationMode = T.InterpolationMode.NEAREST, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + interpolation_target: InterpolationMode = InterpolationMode.NEAREST, ) -> None: super().__init__() self._size = [resize_size]