Skip to content

Commit d194082

Browse files
ekagra-ranjanfmassa
authored andcommitted
Make crop scriptable (#1379)
* Make crop torchscriptable Relevant #1375 * Invert x and y axis * fix lint * Add crop test * revert deletion of space in functional * add import random * add dimension in doc * add import * fix flake8 * change to self.assert* * convert to uint8 * assertTrue * lint
1 parent b0f88df commit d194082

File tree

2 files changed

+35
-1
lines changed

2 files changed

+35
-1
lines changed

test/test_functional_tensor.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
import torch
2+
import torchvision.transforms as transforms
13
import torchvision.transforms.functional_tensor as F_t
4+
import torchvision.transforms.functional as F
25
import unittest
3-
import torch
6+
import random
47

58

69
class Tester(unittest.TestCase):
@@ -19,6 +22,20 @@ def test_hflip(self):
1922
self.assertEqual(hflipped_img.shape, img_tensor.shape)
2023
self.assertTrue(torch.equal(img_tensor, hflipped_img_again))
2124

25+
def test_crop(self):
26+
img_tensor = torch.randint(0, 255, (3, 16, 16), dtype=torch.uint8)
27+
top = random.randint(0, 15)
28+
left = random.randint(0, 15)
29+
height = random.randint(1, 16 - top)
30+
width = random.randint(1, 16 - left)
31+
img_cropped = F_t.crop(img_tensor, top, left, height, width)
32+
img_PIL = transforms.ToPILImage()(img_tensor)
33+
img_PIL_cropped = F.crop(img_PIL, top, left, height, width)
34+
img_cropped_GT = transforms.ToTensor()(img_PIL_cropped)
35+
36+
self.assertTrue(torch.equal(img_cropped, (img_cropped_GT * 255).to(torch.uint8)),
37+
"functional_tensor crop not working")
38+
2239

2340
if __name__ == '__main__':
2441
unittest.main()

torchvision/transforms/functional_tensor.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,20 @@ def hflip(img_tensor):
3131
raise TypeError('tensor is not a torch image.')
3232

3333
return img_tensor.flip(-1)
34+
35+
36+
def crop(img, top, left, height, width):
37+
"""Crop the given Image Tensor.
38+
Args:
39+
img (Tensor): Image to be cropped in the form [C, H, W]. (0,0) denotes the top left corner of the image.
40+
top (int): Vertical component of the top left corner of the crop box.
41+
left (int): Horizontal component of the top left corner of the crop box.
42+
height (int): Height of the crop box.
43+
width (int): Width of the crop box.
44+
Returns:
45+
Tensor: Cropped image.
46+
"""
47+
if not F._is_tensor_image(img):
48+
raise TypeError('tensor is not a torch image.')
49+
50+
return img[..., top:top + height, left:left + width]

0 commit comments

Comments
 (0)