Skip to content

Error in mean and std computation in torchvision.models #3657

Closed
@krishnap25

Description

@krishnap25

📚 Documentation

Hello all,

Thanks for the fantastic library. I would like to point out an error in the mean and std computation in the torchvision.models page. In particular, I'm referring to the code following the line
"The process for obtaining the values of mean and std is roughly equivalent to:"

import torch
from torchvision import datasets, transforms as T

transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
dataset = datasets.ImageNet(".", split="train", transform=transform)

means = []
stds = []
for img in subset(dataset):
    means.append(torch.mean(img))
    stds.append(torch.std(img)) # Bug here 

mean = torch.mean(torch.tensor(means)) # Error here
std = torch.mean(torch.tensor(stds)) # Error here

There are two issues here:

  • torch.tensor(means) throws an error: ValueError: only one element tensors can be converted to Python scalars. It should be torch.stack(means)
  • the mean of the standard deviations should rather be a mean of the variances, followed by square rooting it at the end.
    Here is a version which fixes these:
import torch
from torchvision import datasets, transforms as T

transform = T.Compose([T.Resize(256), T.CenterCrop(224), T.ToTensor()])
dataset = datasets.ImageNet(".", split="train", transform=transform)

means = []
variances = []
for img in subset(dataset):
    means.append(torch.mean(img))
    variances.append(torch.std(img)**2)

mean = torch.mean(torch.stack(means), axis=0)
std = torch.sqrt(torch.mean(torch.stack(variances), axis=0))

I would argue further that since we are interested in the per-channel mean and std, we should first compute the mean across all images and then compute the std using this "batch mean" (whereas the current version uses the per-channel mean of the image). I have not run these on a large dataset, so I do not know how much of a difference this would make.

Thank you!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions