diff --git a/test/test_transforms.py b/test/test_transforms.py index e67d8f0d6c5..2732e3d313a 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -851,6 +851,64 @@ def test_random_grayscale(self): assert gray_np_3.shape == tuple(x_shape[0:2]), 'should be 1 channel' np.testing.assert_equal(gray_np, gray_np_3) + def test_translate(self): + x = np.zeros((100, 100, 1), dtype=np.uint8) + x[40, 40] = 255 + + with self.assertRaises(TypeError): + F.translate(x, horizontal=10) + + img = F.to_pil_image(x) + + result = F.translate(img, horizontal=10) + assert result.size == (100, 100) + r, c = np.where(result) + np.testing.assert_equal(r, 40) + np.testing.assert_equal(c, 30) + + result = F.translate(img, horizontal=-10) + assert result.size == (100, 100) + r, c = np.where(result) + np.testing.assert_equal(r, 40) + np.testing.assert_equal(c, 50) + + result = F.translate(img, vertical=10) + assert result.size == (100, 100) + r, c = np.where(result) + np.testing.assert_equal(r, 30) + np.testing.assert_equal(c, 40) + + result = F.translate(img, vertical=-10) + assert result.size == (100, 100) + r, c = np.where(result) + np.testing.assert_equal(r, 50) + np.testing.assert_equal(c, 40) + + result = F.translate(img, horizontal=10, vertical=-10) + assert result.size == (100, 100) + r, c = np.where(result) + np.testing.assert_equal(r, 50) + np.testing.assert_equal(c, 30) + + def test_random_translation(self): + + with self.assertRaises(ValueError): + transforms.RandomTranslation(horizontal=-10) + transforms.RandomTranslation(horizontal=[-10]) + transforms.RandomTranslation(horizontal=[-10, 0, 10]) + transforms.RandomTranslation(vertical=-10) + transforms.RandomTranslation(vertical=[-10]) + transforms.RandomTranslation(vertical=[-10, 0, 10]) + + t = transforms.RandomTranslation(horizontal=10, vertical=10) + h, v = t.get_params(t.horizontal, t.vertical) + assert h > -10 and h < 10 + assert v > -10 and v < 10 + + t = transforms.RandomTranslation(horizontal=(-10, 10), vertical=(-10, 10)) + h, v = t.get_params(t.horizontal, t.vertical) + assert h > -10 and h < 10 + assert v > -10 and v < 10 if __name__ == '__main__': unittest.main() diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 85beb38b934..6c9882ec3f4 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -577,3 +577,21 @@ def to_grayscale(img, num_output_channels=1): raise ValueError('num_output_channels should be either 1 or 3') return img + + +def translate(img, horizontal=0, vertical=0): + """Translate the img by horizontal and vertical pixels. + + Args: + img (PIL Image): PIL Image to be rotated. + horizontal (int): Number of horizontal pixels to translate. + If horizontal > 0, img will be translated LEFT. + If horizontal < 0, img will be translated RIGHT. + vertical (int): Number of vertical pixels to translate. + If vertical > 0, img will be translated UP. + If vertical < 0, img will be translated DOWN. + """ + if not _is_pil_image(img): + raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + + return img.transform(img.size, Image.AFFINE, (1, 0, horizontal, 0, 1, vertical)) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index fd2baa32487..054775086ae 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -18,7 +18,7 @@ __all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", "Lambda", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", - "Grayscale", "RandomGrayscale"] + "Grayscale", "RandomGrayscale", "RandomTranslation"] class Compose(object): @@ -693,3 +693,63 @@ def __call__(self, img): if random.random() < self.p: return F.to_grayscale(img, num_output_channels=num_output_channels) return img + + +class RandomTranslation(object): + """Tanslate the image horizontally and vertically. + + Args: + horizontal (sequence or int): Range of horizontal pixels to select from. + If horizontal is a number instead of sequence like (min, max), the range of pixels + will be (-horizontal, +horizontal). + If horizontal > 0, img will be translated LEFT. + If horizontal < 0, img will be translated RIGHT. + vertical (sequence or int): Range of vertical pixels to select from. + If vertical is a number instead of sequence like (min, max), the range of pixels + will be (-vertical, +vertical). + If vertical > 0, img will be translated UP. + If vertical < 0, img will be translated DOWN. + """ + + def __init__(self, horizontal=0, vertical=0): + if isinstance(horizontal, numbers.Number): + if horizontal < 0: + raise ValueError("If horizontal is a single number, it must be positive.") + self.horizontal = (-horizontal, horizontal) + else: + if len(horizontal) != 2: + raise ValueError("If horizontal is a sequence, it must be of len 2.") + self.horizontal = horizontal + + if isinstance(vertical, numbers.Number): + if vertical < 0: + raise ValueError("If vertical is a single number, it must be positive.") + self.vertical = (-vertical, vertical) + else: + if len(vertical) != 2: + raise ValueError("If vertical is a sequence, it must be of len 2.") + self.vertical = vertical + + @staticmethod + def get_params(horizontal, vertical): + """Get parameters for ``translate`` for a random translation. + + Returns: + h, v: params to be passed to ``translate`` for random translation. + """ + h = np.random.uniform(horizontal[0], horizontal[1]) + v = np.random.uniform(vertical[0], vertical[1]) + + return h, v + + def __call__(self, img): + """ + img (PIL Image): Image to be rotated. + + Returns: + PIL Image: Rotated image. + """ + + h, v = self.get_params(self.horizontal, self.vertical) + + return F.translate(img, h, v)