Skip to content

Moving normalize, erase to functional_tensor.py #1474

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 10 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions references/segmentation/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torchvision import transforms as T
from torchvision.transforms import functional as F

import torchvision.transforms.functional_tensor as F_t

def pad_if_smaller(img, size, fill=0):
min_size = min(img.size)
Expand Down Expand Up @@ -88,5 +88,5 @@ def __init__(self, mean, std):
self.std = std

def __call__(self, image, target):
image = F.normalize(image, mean=self.mean, std=self.std)
image = F_t.normalize(image, mean=self.mean, std=self.std)
return image, target
5 changes: 3 additions & 2 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch
import torchvision.transforms as transforms
import torchvision.transforms.functional as F
import torchvision.transforms.functional_tensor as F_t
from torch._utils_internal import get_file_path_2
import unittest
import math
Expand Down Expand Up @@ -838,7 +839,7 @@ def test_normalize_different_dtype(self):
mean = torch.tensor([1, 2, 3], dtype=dtype2)
std = torch.tensor([1, 2, 1], dtype=dtype2)
# checks that it doesn't crash
transforms.functional.normalize(img, mean, std)
transforms.functional_tensor.normalize(img, mean, std)

def test_adjust_brightness(self):
x_shape = [2, 2, 3]
Expand Down Expand Up @@ -1363,7 +1364,7 @@ def test_random_erasing(self):
# Test Set 1: Erasing with int value
img_re = transforms.RandomErasing(value=0.2)
i, j, h, w, v = img_re.get_params(img, scale=img_re.scale, ratio=img_re.ratio, value=img_re.value)
img_output = F.erase(img, i, j, h, w, v)
img_output = F_t.erase(img, i, j, h, w, v)
assert img_output.size(0) == 3

# Test Set 2: Check if the unerased region is preserved
Expand Down
55 changes: 0 additions & 55 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,36 +188,6 @@ def to_pil_image(pic, mode=None):
return Image.fromarray(npimg, mode=mode)


def normalize(tensor, mean, std, inplace=False):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a breaking change, because users of torch.nn.functional.normalize won't find it anymore.
At least, you should from .functional_tensor import normalize, erase

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, Got it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given what I mentioned above, I'm not sure if there is a need to move normalize and erase to functional_tensor.py, because we don't have any implementation based on PIL anyway.

Thoughts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes but if we want to do something like dispatch from functional -> functional_tensor / functional -> functional_pil, then this might help in ensuring consistency, what do you think ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fmassa, let me know if I am on same page as you?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking is that we don't need to move normalize nor erase to functional_tensor.py.
Users shouldn't need to worry about the existence of functional_pil nor functional_tensor, just functional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, got it ! Thanks

"""Normalize a tensor image with mean and standard deviation.

.. note::
This transform acts out of place by default, i.e., it does not mutates the input tensor.

See :class:`~torchvision.transforms.Normalize` for more details.

Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation inplace.

Returns:
Tensor: Normalized Tensor image.
"""
if not _is_tensor_image(tensor):
raise TypeError('tensor is not a torch image.')

if not inplace:
tensor = tensor.clone()

dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
return tensor


def resize(img, size, interpolation=Image.BILINEAR):
r"""Resize the input PIL Image to the given size.

Expand Down Expand Up @@ -830,28 +800,3 @@ def to_grayscale(img, num_output_channels=1):
raise ValueError('num_output_channels should be either 1 or 3')

return img


def erase(img, i, j, h, w, v, inplace=False):
""" Erase the input Tensor Image with given value.

Args:
img (Tensor Image): Tensor image of size (C, H, W) to be erased
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the erased region.
w (int): Width of the erased region.
v: Erasing value.
inplace(bool, optional): For in-place operations. By default is set False.

Returns:
Tensor Image: Erased image.
"""
if not isinstance(img, torch.Tensor):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))

if not inplace:
img = img.clone()

img[:, i:i + h, j:j + w] = v
return img
55 changes: 55 additions & 0 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,58 @@ def hflip(img_tensor):
raise TypeError('tensor is not a torch image.')

return img_tensor.flip(-1)


def erase(img, i, j, h, w, v, inplace=False):
""" Erase the input Tensor Image with given value.

Args:
img (Tensor Image): Tensor image of size (C, H, W) to be erased
i (int): i in (i,j) i.e coordinates of the upper left corner.
j (int): j in (i,j) i.e coordinates of the upper left corner.
h (int): Height of the erased region.
w (int): Width of the erased region.
v: Erasing value.
inplace(bool, optional): For in-place operations. By default is set False.

Returns:
Tensor Image: Erased image.
"""
if not isinstance(img, torch.Tensor):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))

if not inplace:
img = img.clone()

img[:, i:i + h, j:j + w] = v
return img


def normalize(tensor, mean, std, inplace=False):
"""Normalize a tensor image with mean and standard deviation.

.. note::
This transform acts out of place by default, i.e., it does not mutates the input tensor.

See :class:`~torchvision.transforms.Normalize` for more details.

Args:
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
inplace(bool,optional): Bool to make this operation inplace.

Returns:
Tensor: Normalized Tensor image.
"""
if not F._is_tensor_image(tensor):
raise TypeError('tensor is not a torch image.')

if not inplace:
tensor = tensor.clone()

dtype = tensor.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device)
std = torch.as_tensor(std, dtype=dtype, device=tensor.device)
tensor.sub_(mean[:, None, None]).div_(std[:, None, None])
return tensor
5 changes: 3 additions & 2 deletions torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import warnings

from . import functional as F
from . import functional_tensor as F_t

if sys.version_info < (3, 3):
Sequence = collections.Sequence
Expand Down Expand Up @@ -172,7 +173,7 @@ def __call__(self, tensor):
Returns:
Tensor: Normalized Tensor image.
"""
return F.normalize(tensor, self.mean, self.std, self.inplace)
return F_t.normalize(tensor, self.mean, self.std, self.inplace)

def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
Expand Down Expand Up @@ -1297,5 +1298,5 @@ def __call__(self, img):
"""
if random.uniform(0, 1) < self.p:
x, y, h, w, v = self.get_params(img, scale=self.scale, ratio=self.ratio, value=self.value)
return F.erase(img, x, y, h, w, v, self.inplace)
return F_t.erase(img, x, y, h, w, v, self.inplace)
return img