diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py new file mode 100644 index 00000000000..4bdab11f6be --- /dev/null +++ b/test/test_functional_tensor.py @@ -0,0 +1,24 @@ +import torchvision.transforms.functional_tensor as F_t +import unittest +import torch + + +class Tester(unittest.TestCase): + + def test_vflip(self): + img_tensor = torch.randn(3, 16, 16) + vflipped_img = F_t.vflip(img_tensor) + vflipped_img_again = F_t.vflip(vflipped_img) + self.assertEqual(vflipped_img.shape, img_tensor.shape) + self.assertTrue(torch.equal(img_tensor, vflipped_img_again)) + + def test_hflip(self): + img_tensor = torch.randn(3, 16, 16) + hflipped_img = F_t.hflip(img_tensor) + hflipped_img_again = F_t.hflip(hflipped_img) + self.assertEqual(hflipped_img.shape, img_tensor.shape) + self.assertTrue(torch.equal(img_tensor, hflipped_img_again)) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py new file mode 100644 index 00000000000..bc12327f39b --- /dev/null +++ b/torchvision/transforms/functional_tensor.py @@ -0,0 +1,33 @@ +import torch +import torchvision.transforms.functional as F + + +def vflip(img_tensor): + """Vertically flip the given the Image Tensor. + + Args: + img_tensor (Tensor): Image Tensor to be flipped in the form CXHXW. + + Returns: + Tensor: Vertically flipped image Tensor. + """ + if not F._is_tensor_image(img_tensor): + raise TypeError('tensor is not a torch image.') + + return img_tensor.flip(-2) + + +def hflip(img_tensor): + """Horizontally flip the given the Image Tensor. + + Args: + img_tensor (Tensor): Image Tensor to be flipped in the form CXHXW. + + Returns: + Tensor: Horizontally flipped image Tensor. + """ + + if not F._is_tensor_image(img_tensor): + raise TypeError('tensor is not a torch image.') + + return img_tensor.flip(-1)