Skip to content

Scriptable Resize Added #1666

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,36 @@ def test_ten_crop(self):
(transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8)))
self.assertTrue(torch.equal(img_tensor, img_tensor_clone))

def test_resize(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a test for scriptability?

height = random.randint(24, 32) * 2
width = random.randint(24, 32) * 2
img = torch.ones(3, height, width)
img_clone = img.clone()
modes = ["bilinear", "nearest", "bicubic"]

for mode in modes:
# (Int) for resizing
output_size = random.randint(5, 12) * 2
result = F_t.resize(img, output_size, interpolation=mode)
if height < width:
self.assertEqual(output_size, result.shape[1])
else:
self.assertEqual(output_size, result.shape[2])

# (Int, Int) for resizing
output_size = (random.randint(5, 12) * 2, random.randint(5, 12) * 2)
result = F_t.resize(img, output_size, interpolation=mode)
self.assertEqual((output_size[0], output_size[1]), (result.shape[1], result.shape[2]))

# checking input tensor is not mutated
self.assertTrue(torch.equal(img, img_clone))

# checking overshooting for bicubic
output_size = (random.randint(5, 12) * 2, random.randint(5, 12) * 2)
result = F_t.resize(img, output_size, interpolation="bicubic")
clamped_tensor = result.clamp(min=0, max=255)
self.assertTrue(torch.equal(result, clamped_tensor))


if __name__ == '__main__':
unittest.main()
38 changes: 38 additions & 0 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torchvision.transforms.functional as F
import torch.nn.functional as Fn


def vflip(img_tensor):
Expand Down Expand Up @@ -219,3 +220,40 @@ def ten_crop(img, size, vertical_flip=False):
def _blend(img1, img2, ratio):
bound = 1 if img1.dtype.is_floating_point else 255
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)


def resize(img, size, interpolation="bilinear"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is something that I still need to figure out: should we use the old PIL-based interpolation with int, or the torch-based one? I'm not sure.

r"""Resize the input Tensor Image to the given size.

Args:
img (Tensor): Image to be resized.
size (sequence or int): Desired output size. If size is a sequence like
(h, w), the output size will be matched to this. If size is an int,
the smaller edge of the image will be matched to this number maintaing
the aspect ratio. i.e, if height > width, then image will be rescaled to
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`
interpolation (string, optional): Desired interpolation ["bilinear", "nearest", "bicubic"]. Default is
``bilinear``

Returns:
Tensor: Resized image Tensor.
"""
if not F._is_tensor_image(img):
raise TypeError('tensor is not a torch image.')

if isinstance(size, int):
w, h = img.shape[2], img.shape[1]
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(size * h / w)
out_img = Fn.interpolate(img.unsqueeze(0), size=(oh, ow), mode=interpolation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assumes that the input is not batched (but it could potentially be, like for videos).
Instead, I think we should unsqueeze only if it's a 3d tensor, and in this case we should squeeze back.

else:
oh = size
ow = int(size * w / h)
out_img = Fn.interpolate(img.unsqueeze(0), size=(oh, ow), mode=interpolation)
else:
out_img = Fn.interpolate(img.unsqueeze(0), size=size, mode=interpolation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is fine as is, but I think we can simplify a bit the code so that we only need to call interpolate in a single codepath.

A solution would be to create the size in the int case to be a tuple. Something like

    if isinstance(size, int):
        w, h = img.shape[2], img.shape[1]
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
            size = (oh, ow)
    else:
            oh = size
            ow = int(size * w / h)
            size = (oh, ow)
    out_img = Fn.interpolate(..., size=size, mode=interpolation)


return(out_img.clamp(min=0, max=255).squeeze(0))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we should be clamping the input to 0-255 in the function. Although this makes it match the behavior of the Pillow implementation, this only works if the input is in 0-255, which is rarely the case for floating point values.

My take is that we should remove the clamp in this function.

Also, can you remove the parenthesis and add a space after the return?