Skip to content

Commit 7b2c5f8

Browse files
NicolasHugdatumbox
authored andcommitted
Bugfix - same output for PIL and tensor when centercrop size is greater than imgsize (#3333)
Summary: * Renamed original method to test center crop * Added test method, docs and added padding when imgsize < cropsize. * BugFix - keep odd_crop_size odd * Do not crop when image size after padding matches crop size; updated test. Reviewed By: datumbox Differential Revision: D26226610 fbshipit-source-id: d1697edc05f4dfe3469443ca88428a6466cc7eee Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 188f8eb commit 7b2c5f8

File tree

4 files changed

+77
-4
lines changed

4 files changed

+77
-4
lines changed

CONTRIBUTING.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ If you have modified the code by adding a new feature or a bug-fix, please add u
9191
test:
9292
```bash
9393
pytest test/<test-module.py> -vvv -k <test_myfunc>
94-
# e.g. pytest test/test_transforms.py -vvv -k test_crop
94+
# e.g. pytest test/test_transforms.py -vvv -k test_center_crop
9595
```
9696

9797
If you would like to run all tests:

test/test_transforms.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import itertools
12
import os
23
import torch
34
import torchvision.transforms as transforms
@@ -29,7 +30,7 @@
2930

3031
class Tester(unittest.TestCase):
3132

32-
def test_crop(self):
33+
def test_center_crop(self):
3334
height = random.randint(10, 32) * 2
3435
width = random.randint(10, 32) * 2
3536
oheight = random.randint(5, (height - 2) / 2) * 2
@@ -70,6 +71,64 @@ def test_crop(self):
7071
self.assertGreater(sum2, sum1,
7172
"height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth))
7273

74+
def test_center_crop_2(self):
75+
""" Tests when center crop size is larger than image size, along any dimension"""
76+
even_image_size = (random.randint(10, 32) * 2, random.randint(10, 32) * 2)
77+
odd_image_size = (even_image_size[0] + 1, even_image_size[1] + 1)
78+
79+
# Since height is independent of width, we can ignore images with odd height and even width and vice-versa.
80+
input_image_sizes = [even_image_size, odd_image_size]
81+
82+
# Get different crop sizes
83+
delta = random.choice((1, 3, 5))
84+
crop_size_delta = [-2 * delta, -delta, 0, delta, 2 * delta]
85+
crop_size_params = itertools.product(input_image_sizes, crop_size_delta, crop_size_delta)
86+
87+
for (input_image_size, delta_height, delta_width) in crop_size_params:
88+
img = torch.ones(3, *input_image_size)
89+
crop_size = (input_image_size[0] + delta_height, input_image_size[1] + delta_width)
90+
91+
# Test both transforms, one with PIL input and one with tensor
92+
output_pil = transforms.Compose([
93+
transforms.ToPILImage(),
94+
transforms.CenterCrop(crop_size),
95+
transforms.ToTensor()],
96+
)(img)
97+
self.assertEqual(output_pil.size()[1:3], crop_size,
98+
"image_size: {} crop_size: {}".format(input_image_size, crop_size))
99+
100+
output_tensor = transforms.CenterCrop(crop_size)(img)
101+
self.assertEqual(output_tensor.size()[1:3], crop_size,
102+
"image_size: {} crop_size: {}".format(input_image_size, crop_size))
103+
104+
# Ensure output for PIL and Tensor are equal
105+
self.assertEqual((output_tensor - output_pil).sum(), 0,
106+
"image_size: {} crop_size: {}".format(input_image_size, crop_size))
107+
108+
# Check if content in center of both image and cropped output is same.
109+
center_size = (min(crop_size[0], input_image_size[0]), min(crop_size[1], input_image_size[1]))
110+
crop_center_tl, input_center_tl = [0, 0], [0, 0]
111+
for index in range(2):
112+
if crop_size[index] > input_image_size[index]:
113+
crop_center_tl[index] = (crop_size[index] - input_image_size[index]) // 2
114+
else:
115+
input_center_tl[index] = (input_image_size[index] - crop_size[index]) // 2
116+
117+
output_center = output_pil[
118+
:,
119+
crop_center_tl[0]:crop_center_tl[0] + center_size[0],
120+
crop_center_tl[1]:crop_center_tl[1] + center_size[1]
121+
]
122+
123+
img_center = img[
124+
:,
125+
input_center_tl[0]:input_center_tl[0] + center_size[0],
126+
input_center_tl[1]:input_center_tl[1] + center_size[1]
127+
]
128+
129+
self.assertEqual((output_center - img_center).sum(), 0,
130+
"image_size: {} crop_size: {}".format(input_image_size, crop_size))
131+
73132
def test_five_crop(self):
74133
to_pil_image = transforms.ToPILImage()
75134
h = random.randint(5, 25)

torchvision/transforms/functional.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -451,7 +451,8 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
451451
def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
452452
"""Crops the given image at the center.
453453
If the image is torch Tensor, it is expected
454-
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
454+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
455+
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
455456
456457
Args:
457458
img (PIL Image or Tensor): Image to be cropped.
@@ -469,6 +470,18 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
469470
image_width, image_height = _get_image_size(img)
470471
crop_height, crop_width = output_size
471472

473+
if crop_width > image_width or crop_height > image_height:
474+
padding_ltrb = [
475+
(crop_width - image_width) // 2 if crop_width > image_width else 0,
476+
(crop_height - image_height) // 2 if crop_height > image_height else 0,
477+
(crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
478+
(crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
479+
]
480+
img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0
481+
image_width, image_height = _get_image_size(img)
482+
if crop_width == image_width and crop_height == image_height:
483+
return img
484+
472485
crop_top = int(round((image_height - crop_height) / 2.))
473486
crop_left = int(round((image_width - crop_width) / 2.))
474487
return crop(img, crop_top, crop_left, crop_height, crop_width)

torchvision/transforms/transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,8 @@ def __init__(self, *args, **kwargs):
290290
class CenterCrop(torch.nn.Module):
291291
"""Crops the given image at the center.
292292
If the image is torch Tensor, it is expected
293-
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
293+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions.
294+
If image size is smaller than output size along any edge, image is padded with 0 and then center cropped.
294295
295296
Args:
296297
size (sequence or int): Desired output size of the crop. If size is an

0 commit comments

Comments
 (0)