@@ -55,21 +55,32 @@ def __call__(self, pic):
55
55
56
56
57
57
class ToPILImage (object ):
58
- """Converts a torch.*Tensor of range [0, 1] and shape C x H x W
59
- or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
60
- to a PIL.Image of range [0, 255]
58
+ """Converts a torch.*Tensor of shape C x H x W or a numpy ndarray of shape
59
+ H x W x C to a PIL.Image while preserving value range.
61
60
"""
62
61
63
62
def __call__ (self , pic ):
64
63
npimg = pic
65
64
mode = None
66
- if not isinstance (npimg , np .ndarray ):
67
- npimg = pic .mul (255 ).byte ().numpy ()
68
- npimg = np .transpose (npimg , (1 , 2 , 0 ))
65
+ if isinstance (pic , torch .FloatTensor ):
66
+ pic = pic .mul (255 ).byte ()
67
+ if torch .is_tensor (pic ):
68
+ npimg = np .transpose (pic .numpy (), (1 , 2 , 0 ))
69
+ assert isinstance (npimg , np .ndarray ), 'pic should be Tensor or ndarray'
69
70
70
71
if npimg .shape [2 ] == 1 :
71
72
npimg = npimg [:, :, 0 ]
72
- mode = "L"
73
+
74
+ if npimg .dtype == np .uint8 :
75
+ mode = 'L'
76
+ if npimg .dtype == np .uint16 :
77
+ mode = 'I;16'
78
+ elif npimg .dtype == np .float32 :
79
+ mode = 'F'
80
+ else :
81
+ if npimg .dtype == np .uint8 :
82
+ mode = 'RGB'
83
+ assert mode is not None , '{} is not supported' .format (npimg .dtype )
73
84
74
85
return Image .fromarray (npimg , mode = mode )
75
86
0 commit comments