Skip to content

Commit cfb0f39

Browse files
committed
add range to make_grid
1 parent 429dbeb commit cfb0f39

File tree

2 files changed

+71
-16
lines changed

2 files changed

+71
-16
lines changed

README.rst

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -349,15 +349,30 @@ For example:
349349
Utils
350350
=====
351351

352-
make\_grid(tensor, nrow=8, padding=2)
352+
make\_grid(tensor, nrow=8, padding=2, normalize=False, range=None, scale\_each=False)
353353
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
354354

355-
Given a 4D mini-batch Tensor of shape (B x C x H x W), makes a grid of
356-
images
355+
Given a 4D mini-batch Tensor of shape (B x C x H x W),
356+
or a list of images all of the same size,
357+
makes a grid of images
357358

358-
save\_image(tensor, filename, nrow=8, padding=2)
359-
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
359+
normalize=True will shift the image to the range (0, 1),
360+
by subtracting the minimum and dividing by the maximum pixel value.
361+
362+
if range=(min, max) where min and max are numbers, then these numbers are used to
363+
normalize the image.
364+
365+
scale_each=True will scale each image in the batch of images separately rather than
366+
computing the (min, max) over all images.
367+
368+
[Example usage is given in this notebook](https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91)
369+
370+
save\_image(tensor, filename, nrow=8, padding=2, normalize=False, range=None, scale\_each=False)
371+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
360372

361373
Saves a given Tensor into an image file.
362374

363375
If given a mini-batch tensor, will save the tensor as a grid of images.
376+
377+
All options after `filename` are passed through to `make_grid`. Refer to it's documentation for
378+
more details

torchvision/utils.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,74 @@
11
import torch
22
import math
3+
irange = range
34

45

5-
def make_grid(tensor, nrow=8, padding=2):
6+
def make_grid(tensor, nrow=8, padding=2,
7+
normalize=False, range=None, scale_each=False):
68
"""
79
Given a 4D mini-batch Tensor of shape (B x C x H x W),
810
or a list of images all of the same size,
911
makes a grid of images
12+
13+
normalize=True will shift the image to the range (0, 1),
14+
by subtracting the minimum and dividing by the maximum pixel value.
15+
16+
if range=(min, max) where min and max are numbers, then these numbers are used to
17+
normalize the image.
18+
19+
scale_each=True will scale each image in the batch of images separately rather than
20+
computing the (min, max) over all images.
21+
22+
[Example usage is given in this notebook](https://gist.github.com/anonymous/bf16430f7750c023141c562f3e9f2a91)
1023
"""
11-
tensorlist = None
24+
# if list of tensors, convert to a 4D mini-batch Tensor
1225
if isinstance(tensor, list):
1326
tensorlist = tensor
1427
numImages = len(tensorlist)
1528
size = torch.Size(torch.Size([long(numImages)]) + tensorlist[0].size())
1629
tensor = tensorlist[0].new(size)
17-
for i in range(numImages):
30+
for i in irange(numImages):
1831
tensor[i].copy_(tensorlist[i])
32+
1933
if tensor.dim() == 2: # single image H x W
2034
tensor = tensor.view(1, tensor.size(0), tensor.size(1))
2135
if tensor.dim() == 3: # single image
22-
if tensor.size(0) == 1:
36+
if tensor.size(0) == 1: # if single-channel, convert to 3-channel
2337
tensor = torch.cat((tensor, tensor, tensor), 0)
2438
return tensor
2539
if tensor.dim() == 4 and tensor.size(1) == 1: # single-channel images
2640
tensor = torch.cat((tensor, tensor, tensor), 1)
41+
42+
if normalize is True:
43+
if range is not None:
44+
assert isinstance(range, tuple), \
45+
"range has to be a tuple (min, max) if specified. min and max are numbers"
46+
47+
def norm_ip(img, min, max):
48+
img.clamp_(min=min, max=max)
49+
img.add_(-min).div_(max - min)
50+
51+
def norm_range(t, range):
52+
if range is not None:
53+
norm_ip(t, range[0], range[1])
54+
else:
55+
norm_ip(t, t.min(), t.max())
56+
57+
if scale_each is True:
58+
for t in tensor: # loop over mini-batch dimension
59+
norm_range(t, range)
60+
else:
61+
norm_range(tensor, range)
62+
2763
# make the mini-batch of images into a grid
2864
nmaps = tensor.size(0)
2965
xmaps = min(nrow, nmaps)
3066
ymaps = int(math.ceil(float(nmaps) / xmaps))
3167
height, width = int(tensor.size(2) + padding), int(tensor.size(3) + padding)
32-
grid = tensor.new(3, height * ymaps, width * xmaps).fill_(tensor.max())
68+
grid = tensor.new(3, height * ymaps, width * xmaps).fill_(0)
3369
k = 0
34-
for y in range(ymaps):
35-
for x in range(xmaps):
70+
for y in irange(ymaps):
71+
for x in irange(xmaps):
3672
if k >= nmaps:
3773
break
3874
grid.narrow(1, y * height + 1 + padding // 2, height - padding)\
@@ -42,14 +78,18 @@ def make_grid(tensor, nrow=8, padding=2):
4278
return grid
4379

4480

45-
def save_image(tensor, filename, nrow=8, padding=2):
81+
def save_image(tensor, filename, nrow=8, padding=2,
82+
normalize=False, range=None, scale_each=False):
4683
"""
4784
Saves a given Tensor into an image file.
48-
If given a mini-batch tensor, will save the tensor as a grid of images.
85+
If given a mini-batch tensor, will save the tensor as a grid of images by calling `make_grid`.
86+
All options after `filename` are passed through to `make_grid`. Refer to it's documentation for
87+
more details
4988
"""
5089
from PIL import Image
5190
tensor = tensor.cpu()
52-
grid = make_grid(tensor, nrow=nrow, padding=padding)
53-
ndarr = grid.mul(255).byte().transpose(0,2).transpose(0,1).numpy()
91+
grid = make_grid(tensor, nrow=nrow, padding=padding,
92+
normalize=normalize, range=range, scale_each=scale_each)
93+
ndarr = grid.mul(255).byte().transpose(0, 2).transpose(0, 1).numpy()
5494
im = Image.fromarray(ndarr)
5595
im.save(filename)

0 commit comments

Comments
 (0)