diff --git a/README.rst b/README.rst index c6de582fb08..40f5abc78da 100644 --- a/README.rst +++ b/README.rst @@ -332,6 +332,16 @@ integer, in which case the target will be of a square shape (size, size) If ``padding`` is non-zero, then the image is first zero-padded on each side with ``padding`` pixels. +``PairRandomCrop(size, padding=0)`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Crops the given PIL.Image at a random location to have a region of the +given size for both input image and its target image. size can be a +tuple (target\_height, target\_width) or an integer, in which case the +target will be of a square shape (size, size) +If ``padding`` is non-zero, then the image is first zero-padded on each +side with ``padding`` pixels. + ``RandomHorizontalFlip()`` ^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 6d649ab18fa..d948bae6175 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -311,6 +311,51 @@ def __call__(self, img): return img.crop((x1, y1, x1 + tw, y1 + th)) +class PairRandomCrop(object): + """Crop the given PIL.Image at a random location for both input and target. + + Args: + size (sequence or int): Desired output size of the crop. If size is an + int instead of sequence like (h, w), a square crop (size, size) is + made. + padding (int or sequence, optional): Optional padding on each border + of the image. Default is 0, i.e no padding. If a sequence of length + 4 is provided, it is used to pad left, top, right, bottom borders + respectively. + """ + last_position = None + + def __init__(self, size, padding=0): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + self.padding = padding + + def __call__(self, img): + """ + Args: + img (PIL.Image): Image to be cropped. + Returns: + PIL.Image: Cropped image. + """ + if self.padding > 0: + img = ImageOps.expand(img, border=self.padding, fill=0) + + w, h = img.size + th, tw = self.size + if w == tw and h == th: + return img + + if self.last_position is not None: + (x1, y1), self.last_position = self.last_position, None + else: + x1 = random.randint(0, w - tw) + y1 = random.randint(0, h - th) + self.last_position = (x1, y1) + return img.crop((x1, y1, x1 + tw, y1 + th)) + + class RandomHorizontalFlip(object): """Horizontally flip the given PIL.Image randomly with a probability of 0.5."""