Skip to content

Commit b8e69d8

Browse files
committed
updated transforms.ToPILImage, see #105
1 parent e65925d commit b8e69d8

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

test/test_transforms.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ def test_ndarray_to_pil_image(self):
169169
l, = img.split()
170170
assert np.allclose(l, img_data[:, :, 0])
171171

172+
def test_ndarray16_to_pil_image(self):
173+
trans = transforms.ToPILImage()
174+
img_data = np.random.randint(0, 65535, [4, 4, 1], np.uint16)
175+
img = trans(img_data)
176+
assert img.mode == 'I;16'
177+
assert np.allclose(img, img_data[:, :, 0])
172178

173179
if __name__ == '__main__':
174180
unittest.main()

torchvision/transforms.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,32 @@ def __call__(self, pic):
5555

5656

5757
class ToPILImage(object):
58-
"""Converts a torch.*Tensor of range [0, 1] and shape C x H x W
59-
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
60-
to a PIL.Image of range [0, 255]
58+
"""Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
59+
H x W x C to a PIL.Image while preserving value range.
6160
"""
6261

6362
def __call__(self, pic):
6463
npimg = pic
6564
mode = None
66-
if not isinstance(npimg, np.ndarray):
67-
npimg = pic.mul(255).byte().numpy()
68-
npimg = np.transpose(npimg, (1, 2, 0))
65+
if isinstance(pic, torch.FloatTensor):
66+
pic = pic.mul(255).byte()
67+
if torch.is_tensor(pic):
68+
npimg = np.transpose(pic.numpy(), (1, 2, 0))
69+
assert isinstance(npimg, np.ndarray), 'pic should be Tensor or ndarray'
6970

7071
if npimg.shape[2] == 1:
7172
npimg = npimg[:, :, 0]
72-
mode = "L"
73+
74+
if npimg.dtype == np.uint8:
75+
mode = 'L'
76+
if npimg.dtype == np.uint16:
77+
mode = 'I;16'
78+
elif npimg.dtype == np.float32:
79+
mode = 'F'
80+
else:
81+
if npimg.dtype == np.uint8:
82+
mode = 'RGB'
83+
assert mode is not None, '{} is not supported'.format(npimg.dtype)
7384

7485
return Image.fromarray(npimg, mode=mode)
7586

0 commit comments

Comments
 (0)