1
1
import torch
2
2
import math
3
+ irange = range
3
4
4
5
5
- def make_grid (tensor , nrow = 8 , padding = 2 ):
6
+ def make_grid (tensor , nrow = 8 , padding = 2 ,
7
+ normalize = False , range = None , scale_each = False ):
6
8
"""
7
9
Given a 4D mini-batch Tensor of shape (B x C x H x W),
8
10
or a list of images all of the same size,
9
11
makes a grid of images
12
+
13
+ normalize=True will shift the image to the range (0, 1),
14
+ by subtracting the minimum and dividing by the maximum pixel value.
15
+
16
+ if range=(min, max) where min and max are numbers, then these numbers are used to
17
+ normalize the image.
18
+
19
+ scale_each=True will scale each image in the batch of images separately rather than
20
+ computing the (min, max) over all images.
21
+
22
+ [Example usage is given in this notebook](https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91)
10
23
"""
11
- tensorlist = None
24
+ # if list of tensors, convert to a 4D mini-batch Tensor
12
25
if isinstance (tensor , list ):
13
26
tensorlist = tensor
14
27
numImages = len (tensorlist )
15
28
size = torch .Size (torch .Size ([long (numImages )]) + tensorlist [0 ].size ())
16
29
tensor = tensorlist [0 ].new (size )
17
- for i in range (numImages ):
30
+ for i in irange (numImages ):
18
31
tensor [i ].copy_ (tensorlist [i ])
32
+
19
33
if tensor .dim () == 2 : # single image H x W
20
34
tensor = tensor .view (1 , tensor .size (0 ), tensor .size (1 ))
21
35
if tensor .dim () == 3 : # single image
22
- if tensor .size (0 ) == 1 :
36
+ if tensor .size (0 ) == 1 : # if single-channel, convert to 3-channel
23
37
tensor = torch .cat ((tensor , tensor , tensor ), 0 )
24
38
return tensor
25
39
if tensor .dim () == 4 and tensor .size (1 ) == 1 : # single-channel images
26
40
tensor = torch .cat ((tensor , tensor , tensor ), 1 )
41
+
42
+ if normalize is True :
43
+ if range is not None :
44
+ assert isinstance (range , tuple ), \
45
+ "range has to be a tuple (min, max) if specified. min and max are numbers"
46
+
47
+ def norm_ip (img , min , max ):
48
+ img .clamp_ (min = min , max = max )
49
+ img .add_ (- min ).div_ (max - min )
50
+
51
+ def norm_range (t , range ):
52
+ if range is not None :
53
+ norm_ip (t , range [0 ], range [1 ])
54
+ else :
55
+ norm_ip (t , t .min (), t .max ())
56
+
57
+ if scale_each is True :
58
+ for t in tensor : # loop over mini-batch dimension
59
+ norm_range (t , range )
60
+ else :
61
+ norm_range (tensor , range )
62
+
27
63
# make the mini-batch of images into a grid
28
64
nmaps = tensor .size (0 )
29
65
xmaps = min (nrow , nmaps )
30
66
ymaps = int (math .ceil (float (nmaps ) / xmaps ))
31
67
height , width = int (tensor .size (2 ) + padding ), int (tensor .size (3 ) + padding )
32
- grid = tensor .new (3 , height * ymaps , width * xmaps ).fill_ (tensor . max () )
68
+ grid = tensor .new (3 , height * ymaps , width * xmaps ).fill_ (0 )
33
69
k = 0
34
- for y in range (ymaps ):
35
- for x in range (xmaps ):
70
+ for y in irange (ymaps ):
71
+ for x in irange (xmaps ):
36
72
if k >= nmaps :
37
73
break
38
74
grid .narrow (1 , y * height + 1 + padding // 2 , height - padding )\
@@ -42,14 +78,18 @@ def make_grid(tensor, nrow=8, padding=2):
42
78
return grid
43
79
44
80
45
- def save_image (tensor , filename , nrow = 8 , padding = 2 ):
81
+ def save_image (tensor , filename , nrow = 8 , padding = 2 ,
82
+ normalize = False , range = None , scale_each = False ):
46
83
"""
47
84
Saves a given Tensor into an image file.
48
- If given a mini-batch tensor, will save the tensor as a grid of images.
85
+ If given a mini-batch tensor, will save the tensor as a grid of images by calling `make_grid`.
86
+ All options after `filename` are passed through to `make_grid`. Refer to it's documentation for
87
+ more details
49
88
"""
50
89
from PIL import Image
51
90
tensor = tensor .cpu ()
52
- grid = make_grid (tensor , nrow = nrow , padding = padding )
91
+ grid = make_grid (tensor , nrow = nrow , padding = padding ,
92
+ normalize = normalize , range = range , scale_each = scale_each )
53
93
ndarr = grid .mul (255 ).byte ().transpose (0 , 2 ).transpose (0 , 1 ).numpy ()
54
94
im = Image .fromarray (ndarr )
55
95
im .save (filename )
0 commit comments