Skip to content

Commit eeacb39

Browse files
authored
Merge pull request #7 from pytorch/pad
adding padding to RandomCrop, as well as transforms.Pad
2 parents 4d247b0 + 685799b commit eeacb39

File tree

4 files changed

+279
-3
lines changed

4 files changed

+279
-3
lines changed

README.md

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,11 @@ Crops the given PIL.Image at the center to have a region of
177177
the given size. size can be a tuple (target_height, target_width)
178178
or an integer, in which case the target will be of a square shape (size, size)
179179

180-
### `RandomCrop(size)`
180+
### `RandomCrop(size, padding=0)`
181181
Crops the given PIL.Image at a random location to have a region of
182182
the given size. size can be a tuple (target_height, target_width)
183183
or an integer, in which case the target will be of a square shape (size, size)
184+
If `padding` is non-zero, then the image is first zero-padded on each side with `padding` pixels.
184185

185186
### `RandomHorizontalFlip()`
186187
Randomly horizontally flips the given PIL.Image with a probability of 0.5
@@ -193,6 +194,12 @@ This is popularly used to train the Inception networks
193194
- size: size of the smaller edge
194195
- interpolation: Default: PIL.Image.BILINEAR
195196

197+
198+
### `Pad(padding, fill=0)`
199+
Pads the given image on each side with `padding` number of pixels, and the padding pixels are filled with
200+
pixel value `fill`.
201+
If a `5x5` image is padded with `padding=1` then it becomes `7x7`
202+
196203
## Transforms on torch.*Tensor
197204

198205
### `Normalize(mean, std)`

test/sanity_checks.ipynb

Lines changed: 234 additions & 0 deletions
Large diffs are not rendered by default.

test/test_transforms.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,29 @@ def test_random_crop(self):
8181
assert result.size(1) == oheight
8282
assert result.size(2) == owidth
8383

84+
padding = random.randint(1, 20)
85+
result = transforms.Compose([
86+
transforms.ToPILImage(),
87+
transforms.RandomCrop((oheight, owidth), padding=padding),
88+
transforms.ToTensor(),
89+
])(img)
90+
assert result.size(1) == oheight
91+
assert result.size(2) == owidth
92+
93+
def test_pad(self):
94+
height = random.randint(10, 32) * 2
95+
width = random.randint(10, 32) * 2
96+
img = torch.ones(3, height, width)
97+
padding = random.randint(1, 20)
98+
result = transforms.Compose([
99+
transforms.ToPILImage(),
100+
transforms.Pad(padding),
101+
transforms.ToTensor(),
102+
])(img)
103+
print(height, width, padding)
104+
print(result.size(1), result.size(2))
105+
assert result.size(1) == height + 2*padding
106+
assert result.size(2) == width + 2*padding
84107

85108

86109
if __name__ == '__main__':

torchvision/transforms.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
import math
44
import random
5-
from PIL import Image
5+
from PIL import Image, ImageOps
66
import numpy as np
77
import numbers
88

@@ -115,6 +115,18 @@ def __call__(self, img):
115115
return img.crop((x1, y1, x1 + tw, y1 + th))
116116

117117

118+
class Pad(object):
119+
"""Pads the given PIL.Image on all sides with the given "pad" value"""
120+
def __init__(self, padding, fill=0):
121+
assert isinstance(padding, numbers.Number)
122+
assert isinstance(fill, numbers.Number)
123+
self.padding = padding
124+
self.fill = fill
125+
126+
def __call__(self, img):
127+
return ImageOps.expand(img, border=self.padding, fill=self.fill)
128+
129+
118130
class RandomCrop(object):
119131
"""Crops the given PIL.Image at a random location to have a region of
120132
the given size. size can be a tuple (target_height, target_width)
@@ -129,7 +141,7 @@ def __init__(self, size, padding=0):
129141

130142
def __call__(self, img):
131143
if self.padding > 0:
132-
raise NotImplementedError()
144+
img = ImageOps.expand(img, border=self.padding, fill=0)
133145

134146
w, h = img.size
135147
th, tw = self.size

0 commit comments

Comments
 (0)