-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Make crop scriptable #1379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make crop scriptable #1379
Changes from 11 commits
2ce74ea
1b94441
38ef37a
05b2787
9353085
58b2240
6f54c3b
4be6334
9634785
7c5b4a6
1237fcc
a871402
12b8e26
bf12955
150c96f
4e16aec
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,9 @@ | ||
import torch | ||
import torchvision.transforms as transforms | ||
import torchvision.transforms.functional_tensor as F_t | ||
import torchvision.transforms.functional as F | ||
import unittest | ||
import torch | ||
import random | ||
|
||
|
||
class Tester(unittest.TestCase): | ||
|
@@ -19,6 +22,21 @@ def test_hflip(self): | |
self.assertEqual(hflipped_img.shape, img_tensor.shape) | ||
self.assertTrue(torch.equal(img_tensor, hflipped_img_again)) | ||
|
||
def test_crop(self): | ||
img_tensor = torch.FloatTensor(3, 16, 16).uniform_(0, 1) | ||
top = random.randint(0, 15) | ||
left = random.randint(0, 15) | ||
height = random.randint(1, 16-top) | ||
width = random.randint(1, 16-left) | ||
img_cropped = F_t.crop(img_tensor, top, left, height, width) | ||
img_PIL = transforms.ToPILImage()(img_tensor) | ||
img_PIL_cropped = F.crop(img_PIL, top, left, height, width) | ||
img_cropped_GT = transforms.ToTensor()(img_PIL_cropped) | ||
|
||
max_diff = (img_cropped_GT - img_cropped).abs().max() | ||
|
||
assert max_diff < 5e-3, "Functional crop not working" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ def vflip(img_tensor): | |
Tensor: Vertically flipped image Tensor. | ||
""" | ||
if not F._is_tensor_image(img_tensor): | ||
raise TypeError('tensor is not a torch image.') | ||
raise TypeError('Input image is not a tensor.') | ||
|
||
return img_tensor.flip(-2) | ||
|
||
|
@@ -28,6 +28,23 @@ def hflip(img_tensor): | |
""" | ||
|
||
if not F._is_tensor_image(img_tensor): | ||
raise TypeError('tensor is not a torch image.') | ||
raise TypeError('Input image is not a tensor.') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This message was the same as supplied here https://github.com/pytorch/vision/blob/53b062ca58932bbf387b96f2dd3397c4495b735b/torchvision/transforms/functional.py#L209 which already works on input tensor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The link you provided isn't working but I get your concern. It's a valid point and I will revert the change. |
||
|
||
return img_tensor.flip(-1) | ||
|
||
|
||
def crop(img, top, left, height, width): | ||
"""Crop the given Image Tensor. | ||
Args: | ||
img (Tensor): Image to be cropped in the form [C, H, W]. (0,0) denotes the top left corner of the image. | ||
top (int): Vertical component of the top left corner of the crop box. | ||
left (int): Horizontal component of the top left corner of the crop box. | ||
height (int): Height of the crop box. | ||
width (int): Width of the crop box. | ||
Returns: | ||
Tensor: Cropped image. | ||
""" | ||
if not F._is_tensor_image(img): | ||
raise TypeError('Input image is not a tensor.') | ||
|
||
return img[..., top:top + height, left:left + width] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you use
torch.rand(3, 16, 16)
instead?torch.FloatTensor
is a legacy API and is deprecated