Skip to content

Denormalize option in torchvision.utils.save_image() #848

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
SamSweere opened this issue Apr 11, 2019 · 4 comments
Open

Denormalize option in torchvision.utils.save_image() #848

SamSweere opened this issue Apr 11, 2019 · 4 comments

Comments

@SamSweere
Copy link

SamSweere commented Apr 11, 2019

🚀 Feature

A function to denormalize an image based on mean and standard deviation.

Motivation

When working with images on NN's trained on a specific dataset (for example ImageNet), an image is first normalized to the mean and standard deviation of that dataset. When we want to save such an image later in the process we can use the function torchvision.utils.save_image(). However the image is still normalized and will have a different mean and standard deviation compared to the original image. There is no option to denormalize such an image such that the initial normalization is undone and the saved image has the same mean and std.

Pitch

A extra parameter to the torchvision.utils.save_image() function to denormalize an image based on a mean and standardization array.

Alternatives

One way to tackle the problem currently is to use the transforms.Normalize() function. My current implementation is shown below. One flaw of this implementation is that the image has to be clipped to keep the values between 0 and 1. Thus some information is lost. I am not sure how to do this operation lossless.

def img_denorm(img, mean, std):
    #for ImageNet the mean and std are:
    #mean = np.asarray([ 0.485, 0.456, 0.406 ])
    #std = np.asarray([ 0.229, 0.224, 0.225 ])

    denormalize = transforms.Normalize((-1 * mean / std), (1.0 / std))

    res = img.squeeze(0)
    res = denormalize(res)

    #Image needs to be clipped since the denormalize function will map some
    #values below 0 and above 1
    res = torch.clamp(res, 0, 1)
    
return(res)

cc @vfdev-5

@pmeier
Copy link
Collaborator

pmeier commented Apr 12, 2019

  1. Why would you lose information just by normalising and subsequently denormalising? If you do not alter the image before denormalising, you should be able to reconstruct the image perfectly. There might be some numerical errors near 0 or 1, but this hardly counts as 'loss of information'.

  2. We could add something like transforms.Denormalize:

class Denormalize(object):
    def __init__(self, mean, std, inplace=False):
        self.mean = mean
        self.demean = [-m/s for m, s in zip(mean, std)]
        self.std = std
        self.destd = [1/s for s in std]
        self.inplace = inplace

    def __call__(self, tensor):
        tensor F.normalize(tensor, self.demean, self.destd, self.inplace)
        # clamp to get rid of numerical errors
        return torch.clamp(tensor, 0.0, 1.0)

With this you could do:

import torchvision
from torchvision import transforms

tensors = [tensor1, tensor2, ...]

transform = transforms.Denormalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))

torchvision.utils.save_image([transform(tensor) for tensor in tensors])

@fmassa
Copy link
Member

fmassa commented Apr 12, 2019

I think another solution, which has been proposed recently, is to make normalization part of the model.
This way, the normalization parameters are built-in inside the model parameters, and we can just visualize the images without any "denormalization"

@pmeier
Copy link
Collaborator

pmeier commented Apr 15, 2019

@SamSweere See #782 for the discussion @fmassa mentioned.

@ustczwq
Copy link

ustczwq commented May 23, 2021

Another elegant solution:

IMG_MEAN = [0.485, 0.456, 0.406]
IMG_STD = [0.229, 0.224, 0.225]

def denormalize(x, mean=IMG_MEAN, std=IMG_STD):
    # 3, H, W, B
    ten = x.clone().permute(1, 2, 3, 0)
    for t, m, s in zip(ten, mean, std):
        t.mul_(s).add_(m)
    # B, 3, H, W
    return torch.clamp(ten, 0, 1).permute(3, 0, 1, 2)

# tensor_x: size [B, 3, H, W]
torchvison.utils.save_image(denormalize(tensor_x))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants