diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index d76c35b72aa..f3110d05101 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -91,7 +91,7 @@ If you have modified the code by adding a new feature or a bug-fix, please add u test: ```bash pytest test/ -vvv -k -# e.g. pytest test/test_transforms.py -vvv -k test_crop +# e.g. pytest test/test_transforms.py -vvv -k test_center_crop ``` If you would like to run all tests: diff --git a/test/test_transforms.py b/test/test_transforms.py index 0562adb2a90..25d61fafeb4 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1,3 +1,4 @@ +import itertools import os import torch import torchvision.transforms as transforms @@ -29,7 +30,7 @@ class Tester(unittest.TestCase): - def test_crop(self): + def test_center_crop(self): height = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2 oheight = random.randint(5, (height - 2) / 2) * 2 @@ -70,6 +71,64 @@ def test_crop(self): self.assertGreater(sum2, sum1, "height: {} width: {} oheight: {} owdith: {}".format(height, width, oheight, owidth)) + def test_center_crop_2(self): + """ Tests when center crop size is larger than image size, along any dimension""" + even_image_size = (random.randint(10, 32) * 2, random.randint(10, 32) * 2) + odd_image_size = (even_image_size[0] + 1, even_image_size[1] + 1) + + # Since height is independent of width, we can ignore images with odd height and even width and vice-versa. + input_image_sizes = [even_image_size, odd_image_size] + + # Get different crop sizes + delta = random.choice((1, 3, 5)) + crop_size_delta = [-2 * delta, -delta, 0, delta, 2 * delta] + crop_size_params = itertools.product(input_image_sizes, crop_size_delta, crop_size_delta) + + for (input_image_size, delta_height, delta_width) in crop_size_params: + img = torch.ones(3, *input_image_size) + crop_size = (input_image_size[0] + delta_height, input_image_size[1] + delta_width) + + # Test both transforms, one with PIL input and one with tensor + output_pil = transforms.Compose([ + transforms.ToPILImage(), + transforms.CenterCrop(crop_size), + transforms.ToTensor()], + )(img) + self.assertEqual(output_pil.size()[1:3], crop_size, + "image_size: {} crop_size: {}".format(input_image_size, crop_size)) + + output_tensor = transforms.CenterCrop(crop_size)(img) + self.assertEqual(output_tensor.size()[1:3], crop_size, + "image_size: {} crop_size: {}".format(input_image_size, crop_size)) + + # Ensure output for PIL and Tensor are equal + self.assertEqual((output_tensor - output_pil).sum(), 0, + "image_size: {} crop_size: {}".format(input_image_size, crop_size)) + + # Check if content in center of both image and cropped output is same. + center_size = (min(crop_size[0], input_image_size[0]), min(crop_size[1], input_image_size[1])) + crop_center_tl, input_center_tl = [0, 0], [0, 0] + for index in range(2): + if crop_size[index] > input_image_size[index]: + crop_center_tl[index] = (crop_size[index] - input_image_size[index]) // 2 + else: + input_center_tl[index] = (input_image_size[index] - crop_size[index]) // 2 + + output_center = output_pil[ + :, + crop_center_tl[0]:crop_center_tl[0] + center_size[0], + crop_center_tl[1]:crop_center_tl[1] + center_size[1] + ] + + img_center = img[ + :, + input_center_tl[0]:input_center_tl[0] + center_size[0], + input_center_tl[1]:input_center_tl[1] + center_size[1] + ] + + self.assertEqual((output_center - img_center).sum(), 0, + "image_size: {} crop_size: {}".format(input_image_size, crop_size)) + def test_five_crop(self): to_pil_image = transforms.ToPILImage() h = random.randint(5, 25) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 993a17db1eb..ab1e2e9b29b 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -451,7 +451,8 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor: def center_crop(img: Tensor, output_size: List[int]) -> Tensor: """Crops the given image at the center. If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. Args: img (PIL Image or Tensor): Image to be cropped. @@ -469,6 +470,18 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor: image_width, image_height = _get_image_size(img) crop_height, crop_width = output_size + if crop_width > image_width or crop_height > image_height: + padding_ltrb = [ + (crop_width - image_width) // 2 if crop_width > image_width else 0, + (crop_height - image_height) // 2 if crop_height > image_height else 0, + (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, + (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, + ] + img = pad(img, padding_ltrb, fill=0) # PIL uses fill value 0 + image_width, image_height = _get_image_size(img) + if crop_width == image_width and crop_height == image_height: + return img + crop_top = int(round((image_height - crop_height) / 2.)) crop_left = int(round((image_width - crop_width) / 2.)) return crop(img, crop_top, crop_left, crop_height, crop_width) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index ff12a070571..30911978558 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -290,7 +290,8 @@ def __init__(self, *args, **kwargs): class CenterCrop(torch.nn.Module): """Crops the given image at the center. If the image is torch Tensor, it is expected - to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions + to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions. + If image size is smaller than output size along any edge, image is padded with 0 and then center cropped. Args: size (sequence or int): Desired output size of the crop. If size is an