diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index e318420102b..e464bf733a8 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -76,6 +76,53 @@ def test_rgb_to_grayscale(self): max_diff = (grayscale_tensor - grayscale_pil_img).abs().max() self.assertLess(max_diff, 1.0001) + def test_center_crop(self): + img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) + cropped_tensor = F_t.center_crop(img_tensor, [10, 10]) + cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10]) + cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8) + self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor)) + + def test_five_crop(self): + img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) + cropped_tensor = F_t.five_crop(img_tensor, [10, 10]) + cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor), [10, 10]) + self.assertTrue(torch.equal(cropped_tensor[0], + (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8))) + self.assertTrue(torch.equal(cropped_tensor[1], + (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8))) + self.assertTrue(torch.equal(cropped_tensor[2], + (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8))) + self.assertTrue(torch.equal(cropped_tensor[3], + (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8))) + self.assertTrue(torch.equal(cropped_tensor[4], + (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8))) + + def test_ten_crop(self): + img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) + cropped_tensor = F_t.ten_crop(img_tensor, [10, 10]) + cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor), [10, 10]) + self.assertTrue(torch.equal(cropped_tensor[0], + (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8))) + self.assertTrue(torch.equal(cropped_tensor[1], + (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8))) + self.assertTrue(torch.equal(cropped_tensor[2], + (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8))) + self.assertTrue(torch.equal(cropped_tensor[3], + (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8))) + self.assertTrue(torch.equal(cropped_tensor[4], + (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8))) + self.assertTrue(torch.equal(cropped_tensor[5], + (transforms.ToTensor()(cropped_pil_image[5]) * 255).to(torch.uint8))) + self.assertTrue(torch.equal(cropped_tensor[6], + (transforms.ToTensor()(cropped_pil_image[7]) * 255).to(torch.uint8))) + self.assertTrue(torch.equal(cropped_tensor[7], + (transforms.ToTensor()(cropped_pil_image[6]) * 255).to(torch.uint8))) + self.assertTrue(torch.equal(cropped_tensor[8], + (transforms.ToTensor()(cropped_pil_image[8]) * 255).to(torch.uint8))) + self.assertTrue(torch.equal(cropped_tensor[9], + (transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8))) + if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index c741ab2e7e8..bd56ae3a131 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -125,6 +125,97 @@ def adjust_saturation(img, saturation_factor): return _blend(img, rgb_to_grayscale(img), saturation_factor) +def center_crop(img, output_size): + """Crop the Image Tensor and resize it to desired size. + + Args: + img (Tensor): Image to be cropped. (0,0) denotes the top left corner of the image. + output_size (sequence or int): (height, width) of the crop box. If int, + it is used for both directions + + Returns: + Tensor: Cropped image. + """ + if not F._is_tensor_image(img): + raise TypeError('tensor is not a torch image.') + + _, image_width, image_height = img.size() + crop_height, crop_width = output_size + crop_top = int(round((image_height - crop_height) / 2.)) + crop_left = int(round((image_width - crop_width) / 2.)) + + return crop(img, crop_top, crop_left, crop_height, crop_width) + + +def five_crop(img, size): + """Crop the given Image Tensor into four corners and the central crop. + .. Note:: + This transform returns a tuple of Tensors and there may be a + mismatch in the number of inputs and targets your ``Dataset`` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + + Returns: + tuple: tuple (tl, tr, bl, br, center) + Corresponding top left, top right, bottom left, bottom right and center crop. + """ + if not F._is_tensor_image(img): + raise TypeError('tensor is not a torch image.') + + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + + _, image_width, image_height = img.size() + crop_height, crop_width = size + if crop_width > image_width or crop_height > image_height: + msg = "Requested crop size {} is bigger than input size {}" + raise ValueError(msg.format(size, (image_height, image_width))) + + tl = crop(img, 0, 0, crop_width, crop_height) + tr = crop(img, image_width - crop_width, 0, image_width, crop_height) + bl = crop(img, 0, image_height - crop_height, crop_width, image_height) + br = crop(img, image_width - crop_width, image_height - crop_height, image_width, image_height) + center = center_crop(img, (crop_height, crop_width)) + + return (tl, tr, bl, br, center) + + +def ten_crop(img, size, vertical_flip=False): + """Crop the given Image Tensor into four corners and the central crop plus the + flipped version of these (horizontal flipping is used by default). + .. Note:: + This transform returns a tuple of images and there may be a + mismatch in the number of inputs and targets your ``Dataset`` returns. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + vertical_flip (bool): Use vertical flipping instead of horizontal + + Returns: + tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip) + Corresponding top left, top right, bottom left, bottom right and center crop + and same for the flipped image's tensor. + """ + if not F._is_tensor_image(img): + raise TypeError('tensor is not a torch image.') + + assert len(size) == 2, "Please provide only two dimensions (h, w) for size." + first_five = five_crop(img, size) + + if vertical_flip: + img = vflip(img) + else: + img = hflip(img) + + second_five = five_crop(img, size) + + return first_five + second_five + + def _blend(img1, img2, ratio): bound = 1 if img1.dtype.is_floating_point else 255 return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)