diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index d944a923387..903798d19e3 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -1,6 +1,9 @@ +import torch +import torchvision.transforms as transforms import torchvision.transforms.functional_tensor as F_t +import torchvision.transforms.functional as F import unittest -import torch +import random class Tester(unittest.TestCase): @@ -19,6 +22,20 @@ def test_hflip(self): self.assertEqual(hflipped_img.shape, img_tensor.shape) self.assertTrue(torch.equal(img_tensor, hflipped_img_again)) + def test_crop(self): + img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8) + top = random.randint(0, 15) + left = random.randint(0, 15) + height = random.randint(1, 16 - top) + width = random.randint(1, 16 - left) + img_cropped = F_t.crop(img_tensor, top, left, height, width) + img_PIL = transforms.ToPILImage()(img_tensor) + img_PIL_cropped = F.crop(img_PIL, top, left, height, width) + img_cropped_GT = transforms.ToTensor()(img_PIL_cropped) + + self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)), + "functional_tensor crop not working") + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index ec530dc2f3a..5027958164c 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -31,3 +31,20 @@ def hflip(img_tensor): raise TypeError('tensor is not a torch image.') return img_tensor.flip(-1) + + +def crop(img, top, left, height, width): + """Crop the given Image Tensor. + Args: + img (Tensor): Image to be cropped in the form [C, H, W]. (0,0) denotes the top left corner of the image. + top (int): Vertical component of the top left corner of the crop box. + left (int): Horizontal component of the top left corner of the crop box. + height (int): Height of the crop box. + width (int): Width of the crop box. + Returns: + Tensor: Cropped image. + """ + if not F._is_tensor_image(img): + raise TypeError('tensor is not a torch image.') + + return img[..., top:top + height, left:left + width]