Skip to content

Commit 650eb32

Browse files
authored
Merge pull request #5 from pytorch/tests
adding unit tests for image transforms
2 parents 44da562 + bd62df6 commit 650eb32

File tree

3 files changed

+209
-28
lines changed

3 files changed

+209
-28
lines changed

README.md

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,6 @@ The data is preprocessed [as described here](https://github.com/facebook/fb.resn
144144
Transforms are common image transforms.
145145
They can be chained together using `transforms.Compose`
146146

147-
- `ToTensor()` - converts PIL Image to Tensor
148-
- `Normalize(mean, std)` - normalizes the image given mean, std (for example: mean = [0.3, 1.2, 2.1])
149-
- `Scale(size, interpolation=Image.BILINEAR)` - Scales the smaller image edge to the given size. Interpolation modes are options from PIL
150-
- `CenterCrop(size)` - center-crops the image to the given size
151-
- `RandomCrop(size)` - Random crops the image to the given size.
152-
- `RandomHorizontalFlip()` - hflip the image with probability 0.5
153-
- `RandomSizedCrop(size, interpolation=Image.BILINEAR)` - Random crop with size 0.08-1 and aspect ratio 3/4 - 4/3 (Inception-style)
154-
155147
### `transforms.Compose`
156148

157149
One can compose several transforms together.
@@ -166,3 +158,45 @@ transform = transforms.Compose([
166158
std = [ 0.229, 0.224, 0.225 ]),
167159
])
168160
```
161+
162+
## Transforms on PIL.Image
163+
164+
### `Scale(size, interpolation=Image.BILINEAR)`
165+
Rescales the input PIL.Image to the given 'size'.
166+
'size' will be the size of the smaller edge.
167+
168+
For example, if height > width, then image will be
169+
rescaled to (size * height / width, size)
170+
- size: size of the smaller edge
171+
- interpolation: Default: PIL.Image.BILINEAR
172+
173+
### `CenterCrop(size)` - center-crops the image to the given size
174+
Crops the given PIL.Image at the center to have a region of
175+
the given size. size can be a tuple (target_height, target_width)
176+
or an integer, in which case the target will be of a square shape (size, size)
177+
178+
### `RandomCrop(size)`
179+
Crops the given PIL.Image at a random location to have a region of
180+
the given size. size can be a tuple (target_height, target_width)
181+
or an integer, in which case the target will be of a square shape (size, size)
182+
183+
### `RandomHorizontalFlip()`
184+
Randomly horizontally flips the given PIL.Image with a probability of 0.5
185+
186+
### `RandomSizedCrop(size, interpolation=Image.BILINEAR)`
187+
Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
188+
and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
189+
190+
This is popularly used to train the Inception networks
191+
- size: size of the smaller edge
192+
- interpolation: Default: PIL.Image.BILINEAR
193+
194+
## Transforms on torch.*Tensor
195+
196+
### `Normalize(mean, std)`
197+
Given mean: (R, G, B) and std: (R, G, B), will normalize each channel of the torch.*Tensor, i.e. channel = (channel - mean) / std
198+
199+
## Conversion Transforms
200+
- `ToTensor()` - Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
201+
- `ToPILImage()` - Converts a torch.*Tensor of range [0, 1] and shape C x H x W or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C to a PIL.Image of range [0, 255]
202+

test/test_transforms.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
import torch
2+
import torchvision.transforms as transforms
3+
import torchvision.datasets as datasets
4+
import numpy as np
5+
import unittest
6+
import random
7+
8+
class Tester(unittest.TestCase):
9+
def test_crop(self):
10+
height = random.randint(10, 32) * 2
11+
width = random.randint(10, 32) * 2
12+
oheight = random.randint(5, (height - 2) / 2) * 2
13+
owidth = random.randint(5, (width - 2) / 2) * 2
14+
15+
img = torch.ones(3, height, width)
16+
oh1 = (height - oheight) / 2
17+
ow1 = (width - owidth) / 2
18+
imgnarrow = img[:, oh1 :oh1 + oheight, ow1 :ow1 + owidth]
19+
imgnarrow.fill_(0)
20+
result = transforms.Compose([
21+
transforms.ToPILImage(),
22+
transforms.CenterCrop((oheight, owidth)),
23+
transforms.ToTensor(),
24+
])(img)
25+
assert result.sum() == 0, "height: " + str(height) + " width: " \
26+
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
27+
oheight += 1
28+
owidth += 1
29+
result = transforms.Compose([
30+
transforms.ToPILImage(),
31+
transforms.CenterCrop((oheight, owidth)),
32+
transforms.ToTensor(),
33+
])(img)
34+
sum1 = result.sum()
35+
assert sum1 > 1, "height: " + str(height) + " width: " \
36+
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
37+
oheight += 1
38+
owidth += 1
39+
result = transforms.Compose([
40+
transforms.ToPILImage(),
41+
transforms.CenterCrop((oheight, owidth)),
42+
transforms.ToTensor(),
43+
])(img)
44+
sum2 = result.sum()
45+
assert sum2 > 0, "height: " + str(height) + " width: " \
46+
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
47+
assert sum2 > sum1, "height: " + str(height) + " width: " \
48+
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
49+
50+
def test_scale(self):
51+
height = random.randint(24, 32) * 2
52+
width = random.randint(24, 32) * 2
53+
osize = random.randint(5, 12) * 2
54+
55+
img = torch.ones(3, height, width)
56+
result = transforms.Compose([
57+
transforms.ToPILImage(),
58+
transforms.Scale(osize),
59+
transforms.ToTensor(),
60+
])(img)
61+
# print img.size()
62+
# print 'output size:', osize
63+
# print result.size()
64+
assert osize in result.size()
65+
if height < width:
66+
assert result.size(1) <= result.size(2)
67+
elif width < height:
68+
assert result.size(1) >= result.size(2)
69+
70+
def test_random_crop(self):
71+
height = random.randint(10, 32) * 2
72+
width = random.randint(10, 32) * 2
73+
oheight = random.randint(5, (height - 2) / 2) * 2
74+
owidth = random.randint(5, (width - 2) / 2) * 2
75+
img = torch.ones(3, height, width)
76+
result = transforms.Compose([
77+
transforms.ToPILImage(),
78+
transforms.RandomCrop((oheight, owidth)),
79+
transforms.ToTensor(),
80+
])(img)
81+
assert result.size(1) == oheight
82+
assert result.size(2) == owidth
83+
84+
85+
86+
if __name__ == '__main__':
87+
unittest.main()

torchvision/transforms.py

Lines changed: 80 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,19 @@
1+
from __future__ import division
12
import torch
23
import math
34
import random
45
from PIL import Image
56
import numpy as np
6-
7+
import numbers
78

89
class Compose(object):
10+
""" Composes several transforms together.
11+
For example:
12+
>>> transforms.Compose([
13+
>>> transforms.CenterCrop(10),
14+
>>> transforms.ToTensor(),
15+
>>> ])
16+
"""
917
def __init__(self, transforms):
1018
self.transforms = transforms
1119

@@ -16,6 +24,8 @@ def __call__(self, img):
1624

1725

1826
class ToTensor(object):
27+
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
28+
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
1929
def __call__(self, pic):
2030
if isinstance(pic, np.ndarray):
2131
# handle numpy array
@@ -24,24 +34,50 @@ def __call__(self, pic):
2434
# handle PIL Image
2535
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
2636
img = img.view(pic.size[0], pic.size[1], 3)
27-
# put it in CHW format
37+
# put it from WHC to CHW format
2838
# yikes, this transpose takes 80% of the loading time/CPU
29-
img = img.transpose(0, 2).transpose(1, 2).contiguous()
30-
return img.float()
39+
img = img.transpose(0, 2).contiguous()
40+
return img.float().div(255)
41+
42+
class ToPILImage(object):
43+
""" Converts a torch.*Tensor of range [0, 1] and shape C x H x W
44+
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
45+
to a PIL.Image of range [0, 255]
46+
"""
47+
def __call__(self, pic):
48+
if isinstance(pic, np.ndarray):
49+
# handle numpy array
50+
img = Image.fromarray(pic)
51+
else:
52+
npimg = pic.mul(255).byte().numpy()
53+
npimg = np.transpose(npimg, (1,2,0))
54+
img = Image.fromarray(npimg)
55+
return img
3156

3257
class Normalize(object):
58+
""" Given mean: (R, G, B) and std: (R, G, B),
59+
will normalize each channel of the torch.*Tensor, i.e.
60+
channel = (channel - mean) / std
61+
"""
3362
def __init__(self, mean, std):
3463
self.mean = mean
3564
self.std = std
3665

3766
def __call__(self, tensor):
67+
# TODO: make efficient
3868
for t, m, s in zip(tensor, self.mean, self.std):
3969
t.sub_(m).div_(s)
4070
return tensor
4171

4272

4373
class Scale(object):
44-
"Scales the smaller edge to size"
74+
""" Rescales the input PIL.Image to the given 'size'.
75+
'size' will be the size of the smaller edge.
76+
For example, if height > width, then image will be
77+
rescaled to (size * height / width, size)
78+
size: size of the smaller edge
79+
interpolation: Default: PIL.Image.BILINEAR
80+
"""
4581
def __init__(self, size, interpolation=Image.BILINEAR):
4682
self.size = size
4783
self.interpolation = interpolation
@@ -51,52 +87,76 @@ def __call__(self, img):
5187
if (w <= h and w == self.size) or (h <= w and h == self.size):
5288
return img
5389
if w < h:
54-
return img.resize((w, int(round(h / w * self.size))), self.interpolation)
90+
ow = self.size
91+
oh = int(self.size * h / w)
92+
return img.resize((ow, oh), self.interpolation)
5593
else:
56-
return img.resize((int(round(w / h * self.size)), h), self.interpolation)
94+
oh = self.size
95+
ow = int(self.size * w / h)
96+
return img.resize((ow, oh), self.interpolation)
5797

5898

5999
class CenterCrop(object):
60-
"Crop to centered rectangle"
100+
"""Crops the given PIL.Image at the center to have a region of
101+
the given size. size can be a tuple (target_height, target_width)
102+
or an integer, in which case the target will be of a square shape (size, size)
103+
"""
61104
def __init__(self, size):
62-
self.size = size
105+
if isinstance(size, numbers.Number):
106+
self.size = (int(size), int(size))
107+
else:
108+
self.size = size
63109

64110
def __call__(self, img):
65111
w, h = img.size
66-
x1 = int(round((w - self.size) / 2))
67-
y1 = int(round((h - self.size) / 2))
68-
return img.crop((x1, y1, x1 + self.size, y1 + self.size))
112+
th, tw = self.size
113+
x1 = int(round((w - tw) / 2))
114+
y1 = int(round((h - th) / 2))
115+
return img.crop((x1, y1, x1 + tw, y1 + th))
69116

70117

71118
class RandomCrop(object):
72-
"Random crop form larger image with optional zero padding"
119+
"""Crops the given PIL.Image at a random location to have a region of
120+
the given size. size can be a tuple (target_height, target_width)
121+
or an integer, in which case the target will be of a square shape (size, size)
122+
"""
73123
def __init__(self, size, padding=0):
74-
self.size = size
124+
if isinstance(size, numbers.Number):
125+
self.size = (int(size), int(size))
126+
else:
127+
self.size = size
75128
self.padding = padding
76129

77130
def __call__(self, img):
78131
if self.padding > 0:
79132
raise NotImplementedError()
80133

81134
w, h = img.size
82-
if w == self.size and h == self.size:
135+
th, tw = self.size
136+
if w == tw and h == th:
83137
return img
84138

85-
x1 = random.randint(0, w - self.size)
86-
y1 = random.randint(0, h - self.size)
87-
return img.crop((x1, y1, x1 + self.size, y1 + self.size))
139+
x1 = random.randint(0, w - tw)
140+
y1 = random.randint(0, h - th)
141+
return img.crop((x1, y1, x1 + tw, y1 + th))
88142

89143

90144
class RandomHorizontalFlip(object):
91-
"Horizontal flip with 0.5 probability"
145+
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
146+
"""
92147
def __call__(self, img):
93148
if random.random() < 0.5:
94149
return img.transpose(Image.FLIP_LEFT_RIGHT)
95150
return img
96151

97152

98153
class RandomSizedCrop(object):
99-
"Random crop with size 0.08-1 and aspect ratio 3/4 - 4/3 (Inception-style)"
154+
"""Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
155+
and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
156+
This is popularly used to train the Inception networks
157+
size: size of the smaller edge
158+
interpolation: Default: PIL.Image.BILINEAR
159+
"""
100160
def __init__(self, size, interpolation=Image.BILINEAR):
101161
self.size = size
102162
self.interpolation = interpolation

0 commit comments

Comments
 (0)