From 21287b516925e1f6262e9284a26029a1655361b1 Mon Sep 17 00:00:00 2001 From: Holger Kohr Date: Mon, 28 May 2018 21:31:44 +0200 Subject: [PATCH] Partially revert #519 due to performance regression & other issues --- test/test_transforms.py | 2 +- torchvision/transforms/functional.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index 467b39459e8..34024acc6dd 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -652,7 +652,7 @@ def test_random_horizontal_flip(self): # Checking if RandomHorizontalFlip can be printed as string transforms.RandomHorizontalFlip().__repr__() - @unittest.skipIf(stats is None, 'scipt.stats is not available') + @unittest.skipIf(stats is None, 'scipy.stats is not available') def test_normalize(self): def samples_from_standard_normal(tensor): p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index b85f28cf92d..7ebbc7d6ab6 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -167,9 +167,10 @@ def normalize(tensor, mean, std): if not _is_tensor_image(tensor): raise TypeError('tensor is not a torch image.') - mean = torch.Tensor(mean).view((tensor.shape[0], 1, 1)) - std = torch.Tensor(std).view((tensor.shape[0], 1, 1)) - return tensor.sub_(mean).div_(std) + # This is faster than using broadcasting, don't change without benchmarking + for t, m, s in zip(tensor, mean, std): + t.sub_(m).div_(s) + return tensor def resize(img, size, interpolation=Image.BILINEAR):