Skip to content

Add scriptable transform: center_crop, five crop and ten_crop #1615

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 4, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,53 @@ def test_rgb_to_grayscale(self):
max_diff = (grayscale_tensor - grayscale_pil_img).abs().max()
self.assertLess(max_diff, 1.0001)

def test_center_crop(self):
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
cropped_tensor = F_t.center_crop(img_tensor, [10, 10])
cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10])
cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8)
self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))

def test_five_crop(self):
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
cropped_tensor = F_t.five_crop(img_tensor, [10, 10])
cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor), [10, 10])
self.assertTrue(torch.equal(cropped_tensor[0],
(transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[1],
(transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[2],
(transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[3],
(transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[4],
(transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))

def test_ten_crop(self):
img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
cropped_tensor = F_t.ten_crop(img_tensor, [10, 10])
cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor), [10, 10])
self.assertTrue(torch.equal(cropped_tensor[0],
(transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[1],
(transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[2],
(transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[3],
(transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[4],
(transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[5],
(transforms.ToTensor()(cropped_pil_image[5]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[6],
(transforms.ToTensor()(cropped_pil_image[7]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[7],
(transforms.ToTensor()(cropped_pil_image[6]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[8],
(transforms.ToTensor()(cropped_pil_image[8]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(cropped_tensor[9],
(transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8)))


if __name__ == '__main__':
unittest.main()
91 changes: 91 additions & 0 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,97 @@ def adjust_saturation(img, saturation_factor):
return _blend(img, rgb_to_grayscale(img), saturation_factor)


def center_crop(img, output_size):
"""Crop the Image Tensor and resize it to desired size.

Args:
img (Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
output_size (sequence or int): (height, width) of the crop box. If int,
it is used for both directions

Returns:
Tensor: Cropped image.
"""
if not F._is_tensor_image(img):
raise TypeError('tensor is not a torch image.')

_, image_width, image_height = img.size()
crop_height, crop_width = output_size
crop_top = int(round((image_height - crop_height) / 2.))
crop_left = int(round((image_width - crop_width) / 2.))

return crop(img, crop_top, crop_left, crop_height, crop_width)


def five_crop(img, size):
"""Crop the given Image Tensor into four corners and the central crop.
.. Note::
This transform returns a tuple of Tensors and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.

Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.

Returns:
tuple: tuple (tl, tr, bl, br, center)
Corresponding top left, top right, bottom left, bottom right and center crop.
"""
if not F._is_tensor_image(img):
raise TypeError('tensor is not a torch image.')

assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

_, image_width, image_height = img.size()
crop_height, crop_width = size
if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))

tl = crop(img, 0, 0, crop_width, crop_height)
tr = crop(img, image_width - crop_width, 0, image_width, crop_height)
bl = crop(img, 0, image_height - crop_height, crop_width, image_height)
br = crop(img, image_width - crop_width, image_height - crop_height, image_width, image_height)
center = center_crop(img, (crop_height, crop_width))

return (tl, tr, bl, br, center)


def ten_crop(img, size, vertical_flip=False):
"""Crop the given Image Tensor into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default).
.. Note::
This transform returns a tuple of images and there may be a
mismatch in the number of inputs and targets your ``Dataset`` returns.

Args:
size (sequence or int): Desired output size of the crop. If size is an
int instead of sequence like (h, w), a square crop (size, size) is
made.
vertical_flip (bool): Use vertical flipping instead of horizontal

Returns:
tuple: tuple (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
Corresponding top left, top right, bottom left, bottom right and center crop
and same for the flipped image's tensor.
"""
if not F._is_tensor_image(img):
raise TypeError('tensor is not a torch image.')

assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
first_five = five_crop(img, size)

if vertical_flip:
img = vflip(img)
else:
img = hflip(img)

second_five = five_crop(img, size)

return first_five + second_five


def _blend(img1, img2, ratio):
bound = 1 if img1.dtype.is_floating_point else 255
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)