From 807f98731c5c0927b12a668dfcf4cd265494cbcd Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Mon, 18 Apr 2022 12:46:39 +0530 Subject: [PATCH 01/22] added simple POC --- references/segmentation/transforms.py | 56 +++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 518048db2fa..d6c36a0e5ec 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -1,4 +1,5 @@ import random +from typing import Tuple import numpy as np import torch @@ -98,3 +99,58 @@ def __init__(self, mean, std): def __call__(self, image, target): image = F.normalize(image, mean=self.mean, std=self.std) return image, target + + +class SimpleCopyPaste(torch.nn.Module): + def __init__(self, p: float = 0.5, inplace: bool = False): + super().__init__() + self.p = p + self.inplace = inplace + + def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + # validate inputs + if batch.ndim != 4: + raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}.") + if target.ndim != 3: + raise ValueError(f"Target ndim should be 3. Got {target.ndim}.") + if not batch.is_floating_point(): + raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") + if target.dtype != torch.int64: + raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") + + # check inplace + if not self.inplace: + batch = batch.clone() + target = target.clone() + + batch_rolled = batch.roll(1, 0) + target_rolled = target.roll(1, 0) + + # TODO: Apply random scale jittering and random horizontal flipping + + # TODO: Pad images smaller than their original size with gray pixel values + + # TODO: select a random subset of objects from one of the images and paste them onto the other image + + # TODO: Smooth out the edges of the pasted objects using a Gaussian filter on the mask + + # get binary paste mask + paste_binary_mask = (target_rolled != 0).to(target_rolled.dtype) + # delete pixels from source mask using paste mask + target.mul_(1 - paste_binary_mask) + # Combine paste mask with source mask + target.add_(target_rolled) + + # get paste image using paste image mask + paste_image = batch_rolled * torch.unsqueeze(paste_binary_mask, 1) + # delete pixels from source image using paste binary mask + batch.mul_(torch.unsqueeze(1 - paste_binary_mask, 1)) + # Combine paste image with source image + batch.add_(paste_image) + + return batch, target + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}(" f", p={self.p}" f", inplace={self.inplace}" f")" + return s From 2fe16e83aa5ff7359f86c93dde87a5bf8545cfb6 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Tue, 19 Apr 2022 14:44:54 +0530 Subject: [PATCH 02/22] added jitter and crop options --- references/segmentation/transforms.py | 140 ++++++++++++++++++++++++-- 1 file changed, 131 insertions(+), 9 deletions(-) diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index d6c36a0e5ec..49de4ce941a 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -4,7 +4,7 @@ import numpy as np import torch from torchvision import transforms as T -from torchvision.transforms import functional as F +from torchvision.transforms import functional as F, InterpolationMode def pad_if_smaller(img, size, fill=0): @@ -101,12 +101,135 @@ def __call__(self, image, target): return image, target +class ScaleJitter: + """Randomly resizes the image and its mask within the specified scale range. + The class implements the Scale Jitter augmentation as described in the paper + `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" `_. + + Args: + target_size (tuple of ints): The target size for the transform provided in (height, weight) format. + scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the + range a <= scale <= b. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. + """ + + def __init__( + self, + target_size: Tuple[int, int], + scale_range: Tuple[float, float] = (0.1, 2.0), + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ): + super().__init__() + self.target_size = target_size + self.scale_range = scale_range + self.interpolation = interpolation + + def __call__(self, image: torch.Tensor, target: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + + if isinstance(image, torch.Tensor): + if image.ndimension() not in {2, 3}: + raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") + elif image.ndimension() == 2: + image = image.unsqueeze(0) + + _, orig_height, orig_width = F.get_dimensions(image) + + scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) + r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale + new_width = int(orig_width * r) + new_height = int(orig_height * r) + + image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) + target = F.resize(torch.unsqueeze(target, 0), [new_height, new_width], interpolation=InterpolationMode.NEAREST) + + return image, target + + +class FixedSizeCrop: + def __init__(self, size, fill=0, padding_mode="constant"): + super().__init__() + size = tuple(T.transforms._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) + self.crop_height = size[0] + self.crop_width = size[1] + self.fill = fill + self.padding_mode = padding_mode + + def _pad(self, image, target, padding): + # Taken from the functional_tensor.py pad + if isinstance(padding, int): + pad_left = pad_right = pad_top = pad_bottom = padding + elif len(padding) == 1: + pad_left = pad_right = pad_top = pad_bottom = padding[0] + elif len(padding) == 2: + pad_left = pad_right = padding[0] + pad_top = pad_bottom = padding[1] + elif len(padding) == 4: + pad_left = padding[0] + pad_top = padding[1] + pad_right = padding[2] + pad_bottom = padding[3] + else: + # TODO: fix this error + raise ValueError("padding ndim should be int, (int, int) or (int, int, int, int)") + + padding = [pad_left, pad_top, pad_right, pad_bottom] + image = F.pad(image, padding, self.fill, self.padding_mode) + target = F.pad(target, padding, 0, self.padding_mode) + + return image, target + + def _crop(self, image, target, top, left, height, width): + image = F.crop(image, top, left, height, width) + target = F.crop(target, top, left, height, width) + return image, target + + def __call__(self, img, target=None): + _, height, width = F.get_dimensions(img) + new_height = min(height, self.crop_height) + new_width = min(width, self.crop_width) + + if new_height != height or new_width != width: + offset_height = max(height - self.crop_height, 0) + offset_width = max(width - self.crop_width, 0) + + r = torch.rand(1) + top = int(offset_height * r) + left = int(offset_width * r) + + img, target = self._crop(img, target, top, left, new_height, new_width) + + pad_bottom = max(self.crop_height - new_height, 0) + pad_right = max(self.crop_width - new_width, 0) + if pad_bottom != 0 or pad_right != 0: + img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom]) + + return img, target + + class SimpleCopyPaste(torch.nn.Module): - def __init__(self, p: float = 0.5, inplace: bool = False): + def __init__(self, p: float = 0.5, jittering_type="LSJ", inplace: bool = False): super().__init__() self.p = p self.inplace = inplace + # TODO: Apply random scale jittering ( resize and crop ) + if jittering_type == "LSJ": + scale_range = (0.1, 2.0) + elif jittering_type == "SSJ": + scale_range = (0.8, 1.25) + else: + # TODO: add invalid option error + raise ValueError("Invalid jittering type") + + self.transforms = Compose( + [ + ScaleJitter(target_size=(1024, 1024), scale_range=scale_range), + FixedSizeCrop(size=(1024, 1024), fill=105), + RandomHorizontalFlip(0.5), + ] + ) + def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # validate inputs @@ -120,17 +243,16 @@ def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") # check inplace - if not self.inplace: - batch = batch.clone() - target = target.clone() + for i, (image, mask) in enumerate(zip(batch, target)): + batch[i], target[i] = self.transforms(image, mask) + + # if not self.inplace: + # batch = batch.clone() + # target = target.clone() batch_rolled = batch.roll(1, 0) target_rolled = target.roll(1, 0) - # TODO: Apply random scale jittering and random horizontal flipping - - # TODO: Pad images smaller than their original size with gray pixel values - # TODO: select a random subset of objects from one of the images and paste them onto the other image # TODO: Smooth out the edges of the pasted objects using a Gaussian filter on the mask From 690f03fad828fc323446c8e53d955f9f628df590 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Tue, 19 Apr 2022 14:50:46 +0530 Subject: [PATCH 03/22] added references --- references/segmentation/transforms.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index 49de4ce941a..eff203b4d31 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -102,6 +102,8 @@ def __call__(self, image, target): class ScaleJitter: + # Referenced from references/detection/transforms.py + """Randomly resizes the image and its mask within the specified scale range. The class implements the Scale Jitter augmentation as described in the paper `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" `_. @@ -125,7 +127,7 @@ def __init__( self.scale_range = scale_range self.interpolation = interpolation - def __call__(self, image: torch.Tensor, target: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + def __call__(self, image: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if isinstance(image, torch.Tensor): if image.ndimension() not in {2, 3}: @@ -147,6 +149,8 @@ def __call__(self, image: torch.Tensor, target: torch.Tensor = None) -> Tuple[to class FixedSizeCrop: + # Referenced from references/detection/transforms.py + def __init__(self, size, fill=0, padding_mode="constant"): super().__init__() size = tuple(T.transforms._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) @@ -156,7 +160,6 @@ def __init__(self, size, fill=0, padding_mode="constant"): self.padding_mode = padding_mode def _pad(self, image, target, padding): - # Taken from the functional_tensor.py pad if isinstance(padding, int): pad_left = pad_right = pad_top = pad_bottom = padding elif len(padding) == 1: From 0055d83ceab52ae151e69154196fcb7cbcaf9f84 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Thu, 21 Apr 2022 12:23:26 +0530 Subject: [PATCH 04/22] moved simplecopypaste to detection module --- references/detection/transforms.py | 69 ++++++++++ references/segmentation/transforms.py | 183 +------------------------- 2 files changed, 70 insertions(+), 182 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index bb4540d1aae..1238431470f 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -437,3 +437,72 @@ def forward( ) return image, target + + +class SimpleCopyPaste(torch.nn.Module): + def __init__(self, p: float = 0.5, jittering_type="LSJ", inplace: bool = False): + super().__init__() + self.p = p + self.inplace = inplace + + # TODO: Apply random scale jittering ( resize and crop ) + if jittering_type == "LSJ": + scale_range = (0.1, 2.0) + elif jittering_type == "SSJ": + scale_range = (0.8, 1.25) + else: + # TODO: add invalid option error + raise ValueError("Invalid jittering type") + + self.transforms = Compose( + [ + ScaleJitter(target_size=(1024, 1024), scale_range=scale_range), + FixedSizeCrop(size=(1024, 1024), fill=105), + RandomHorizontalFlip(0.5), + ] + ) + + def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + + # validate inputs + # if batch.ndim != 4: + # raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}.") + # if target.ndim != 3: + # raise ValueError(f"Target ndim should be 3. Got {target.ndim}.") + # if not batch.is_floating_point(): + # raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") + # if target.dtype != torch.int64: + # raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") + + for i, (image, mask) in enumerate(zip(batch, target)): + batch[i], target[i] = self.transforms(image, mask) + + batch_rolled = batch.roll(1, 0) + target_rolled = target[-1:] + target[:-1] # noqa: F841 + + # TODO: select a random subset of objects from one of the images and paste them onto the other image + + # TODO: Smooth out the edges of the pasted objects using a Gaussian filter on the mask + + # TODO: update masks and boxes and labels + # - uodate mask + # - update boxes + # - concat labels + + # get paste binary mask from rolled + # subtract form image + # get pixels from paste image + # paste onto new image + + # get paste image using paste image mask + paste_image = batch_rolled * torch.unsqueeze(paste_binary_mask, 1) + # delete pixels from source image using paste binary mask + batch.mul_(torch.unsqueeze(1 - paste_binary_mask, 1)) + # Combine paste image with source image + batch.add_(paste_image) + + return batch, target + + def __repr__(self) -> str: + s = f"{self.__class__.__name__}(" f", p={self.p}" f", inplace={self.inplace}" f")" + return s diff --git a/references/segmentation/transforms.py b/references/segmentation/transforms.py index eff203b4d31..518048db2fa 100644 --- a/references/segmentation/transforms.py +++ b/references/segmentation/transforms.py @@ -1,10 +1,9 @@ import random -from typing import Tuple import numpy as np import torch from torchvision import transforms as T -from torchvision.transforms import functional as F, InterpolationMode +from torchvision.transforms import functional as F def pad_if_smaller(img, size, fill=0): @@ -99,183 +98,3 @@ def __init__(self, mean, std): def __call__(self, image, target): image = F.normalize(image, mean=self.mean, std=self.std) return image, target - - -class ScaleJitter: - # Referenced from references/detection/transforms.py - - """Randomly resizes the image and its mask within the specified scale range. - The class implements the Scale Jitter augmentation as described in the paper - `"Simple Copy-Paste is a Strong Data Augmentation Method for Instance Segmentation" `_. - - Args: - target_size (tuple of ints): The target size for the transform provided in (height, weight) format. - scale_range (tuple of ints): scaling factor interval, e.g (a, b), then scale is randomly sampled from the - range a <= scale <= b. - interpolation (InterpolationMode): Desired interpolation enum defined by - :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. - """ - - def __init__( - self, - target_size: Tuple[int, int], - scale_range: Tuple[float, float] = (0.1, 2.0), - interpolation: InterpolationMode = InterpolationMode.BILINEAR, - ): - super().__init__() - self.target_size = target_size - self.scale_range = scale_range - self.interpolation = interpolation - - def __call__(self, image: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - - if isinstance(image, torch.Tensor): - if image.ndimension() not in {2, 3}: - raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.") - elif image.ndimension() == 2: - image = image.unsqueeze(0) - - _, orig_height, orig_width = F.get_dimensions(image) - - scale = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0]) - r = min(self.target_size[1] / orig_height, self.target_size[0] / orig_width) * scale - new_width = int(orig_width * r) - new_height = int(orig_height * r) - - image = F.resize(image, [new_height, new_width], interpolation=self.interpolation) - target = F.resize(torch.unsqueeze(target, 0), [new_height, new_width], interpolation=InterpolationMode.NEAREST) - - return image, target - - -class FixedSizeCrop: - # Referenced from references/detection/transforms.py - - def __init__(self, size, fill=0, padding_mode="constant"): - super().__init__() - size = tuple(T.transforms._setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")) - self.crop_height = size[0] - self.crop_width = size[1] - self.fill = fill - self.padding_mode = padding_mode - - def _pad(self, image, target, padding): - if isinstance(padding, int): - pad_left = pad_right = pad_top = pad_bottom = padding - elif len(padding) == 1: - pad_left = pad_right = pad_top = pad_bottom = padding[0] - elif len(padding) == 2: - pad_left = pad_right = padding[0] - pad_top = pad_bottom = padding[1] - elif len(padding) == 4: - pad_left = padding[0] - pad_top = padding[1] - pad_right = padding[2] - pad_bottom = padding[3] - else: - # TODO: fix this error - raise ValueError("padding ndim should be int, (int, int) or (int, int, int, int)") - - padding = [pad_left, pad_top, pad_right, pad_bottom] - image = F.pad(image, padding, self.fill, self.padding_mode) - target = F.pad(target, padding, 0, self.padding_mode) - - return image, target - - def _crop(self, image, target, top, left, height, width): - image = F.crop(image, top, left, height, width) - target = F.crop(target, top, left, height, width) - return image, target - - def __call__(self, img, target=None): - _, height, width = F.get_dimensions(img) - new_height = min(height, self.crop_height) - new_width = min(width, self.crop_width) - - if new_height != height or new_width != width: - offset_height = max(height - self.crop_height, 0) - offset_width = max(width - self.crop_width, 0) - - r = torch.rand(1) - top = int(offset_height * r) - left = int(offset_width * r) - - img, target = self._crop(img, target, top, left, new_height, new_width) - - pad_bottom = max(self.crop_height - new_height, 0) - pad_right = max(self.crop_width - new_width, 0) - if pad_bottom != 0 or pad_right != 0: - img, target = self._pad(img, target, [0, 0, pad_right, pad_bottom]) - - return img, target - - -class SimpleCopyPaste(torch.nn.Module): - def __init__(self, p: float = 0.5, jittering_type="LSJ", inplace: bool = False): - super().__init__() - self.p = p - self.inplace = inplace - - # TODO: Apply random scale jittering ( resize and crop ) - if jittering_type == "LSJ": - scale_range = (0.1, 2.0) - elif jittering_type == "SSJ": - scale_range = (0.8, 1.25) - else: - # TODO: add invalid option error - raise ValueError("Invalid jittering type") - - self.transforms = Compose( - [ - ScaleJitter(target_size=(1024, 1024), scale_range=scale_range), - FixedSizeCrop(size=(1024, 1024), fill=105), - RandomHorizontalFlip(0.5), - ] - ) - - def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - - # validate inputs - if batch.ndim != 4: - raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}.") - if target.ndim != 3: - raise ValueError(f"Target ndim should be 3. Got {target.ndim}.") - if not batch.is_floating_point(): - raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") - if target.dtype != torch.int64: - raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") - - # check inplace - for i, (image, mask) in enumerate(zip(batch, target)): - batch[i], target[i] = self.transforms(image, mask) - - # if not self.inplace: - # batch = batch.clone() - # target = target.clone() - - batch_rolled = batch.roll(1, 0) - target_rolled = target.roll(1, 0) - - # TODO: select a random subset of objects from one of the images and paste them onto the other image - - # TODO: Smooth out the edges of the pasted objects using a Gaussian filter on the mask - - # get binary paste mask - paste_binary_mask = (target_rolled != 0).to(target_rolled.dtype) - # delete pixels from source mask using paste mask - target.mul_(1 - paste_binary_mask) - # Combine paste mask with source mask - target.add_(target_rolled) - - # get paste image using paste image mask - paste_image = batch_rolled * torch.unsqueeze(paste_binary_mask, 1) - # delete pixels from source image using paste binary mask - batch.mul_(torch.unsqueeze(1 - paste_binary_mask, 1)) - # Combine paste image with source image - batch.add_(paste_image) - - return batch, target - - def __repr__(self) -> str: - s = f"{self.__class__.__name__}(" f", p={self.p}" f", inplace={self.inplace}" f")" - return s From 5a6c2636b2ca3bbb473a8146b76427715320a3bb Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Fri, 22 Apr 2022 17:30:57 +0530 Subject: [PATCH 05/22] working POC for simple copy paste in detection --- references/detection/transforms.py | 66 +++++++++++++++++------------- 1 file changed, 37 insertions(+), 29 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 1238431470f..7693c482366 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -1,8 +1,10 @@ +import copy from typing import List, Tuple, Dict, Optional, Union import torch import torchvision from torch import nn, Tensor +from torchvision import ops from torchvision.transforms import functional as F from torchvision.transforms import transforms as T, InterpolationMode @@ -440,10 +442,8 @@ def forward( class SimpleCopyPaste(torch.nn.Module): - def __init__(self, p: float = 0.5, jittering_type="LSJ", inplace: bool = False): + def __init__(self, jittering_type: str = "LSJ"): super().__init__() - self.p = p - self.inplace = inplace # TODO: Apply random scale jittering ( resize and crop ) if jittering_type == "LSJ": @@ -462,44 +462,52 @@ def __init__(self, p: float = 0.5, jittering_type="LSJ", inplace: bool = False): ] ) + def combine_masks(self, masks): + return masks.sum(dim=0).greater(0) + def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # validate inputs - # if batch.ndim != 4: - # raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}.") - # if target.ndim != 3: - # raise ValueError(f"Target ndim should be 3. Got {target.ndim}.") - # if not batch.is_floating_point(): - # raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") - # if target.dtype != torch.int64: - # raise TypeError(f"Target dtype should be torch.int64. Got {target.dtype}") + if batch.ndim != 4: + raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}.") + if not batch.is_floating_point(): + raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") for i, (image, mask) in enumerate(zip(batch, target)): batch[i], target[i] = self.transforms(image, mask) - batch_rolled = batch.roll(1, 0) - target_rolled = target[-1:] + target[:-1] # noqa: F841 + batch_rolled = batch.roll(1, 0).detach().clone() + target_rolled = copy.deepcopy(target[-1:] + target[:-1]) # TODO: select a random subset of objects from one of the images and paste them onto the other image # TODO: Smooth out the edges of the pasted objects using a Gaussian filter on the mask - # TODO: update masks and boxes and labels - # - uodate mask - # - update boxes - # - concat labels - - # get paste binary mask from rolled - # subtract form image - # get pixels from paste image - # paste onto new image - - # get paste image using paste image mask - paste_image = batch_rolled * torch.unsqueeze(paste_binary_mask, 1) - # delete pixels from source image using paste binary mask - batch.mul_(torch.unsqueeze(1 - paste_binary_mask, 1)) - # Combine paste image with source image - batch.add_(paste_image) + paste_masks = [] + + for source_image, paste_image, source_data, paste_data in zip(batch, batch_rolled, target, target_rolled): + paste_alpha_mask = self.combine_masks(paste_data["masks"]) + paste_masks.append(paste_alpha_mask) + + for i, mask in enumerate(source_data["masks"]): + source_data["masks"][i] = mask ^ paste_alpha_mask & mask + + mask_filter = source_data["masks"].sum((2, 1)).not_equal(0) + filtered_masks = source_data["masks"][mask_filter] + source_data["boxes"] = ops.masks_to_boxes(filtered_masks) + # TODO: update area + + source_data["masks"] = torch.cat((source_data["masks"], paste_data["masks"])) + source_data["boxes"] = torch.cat((source_data["boxes"], paste_data["boxes"])) + source_data["labels"] = torch.cat((source_data["labels"], paste_data["labels"])) + source_data["area"] = torch.cat((source_data["area"], paste_data["area"])) + source_data["iscrowd"] = torch.cat((source_data["iscrowd"], paste_data["iscrowd"])) + + paste_masks = torch.stack(paste_masks) + batch.mul_(torch.unsqueeze(torch.logical_not(paste_masks), 1)) + + paste_images = batch_rolled * torch.unsqueeze(paste_masks, 1) + batch.add_(paste_images) return batch, target From 7eefe7d3ee27dbba154a21d5441a675d4a729b3a Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Fri, 22 Apr 2022 17:44:52 +0530 Subject: [PATCH 06/22] added comments --- references/detection/transforms.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 7693c482366..3c6ac978ec4 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -445,7 +445,6 @@ class SimpleCopyPaste(torch.nn.Module): def __init__(self, jittering_type: str = "LSJ"): super().__init__() - # TODO: Apply random scale jittering ( resize and crop ) if jittering_type == "LSJ": scale_range = (0.1, 2.0) elif jittering_type == "SSJ": @@ -476,6 +475,7 @@ def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens for i, (image, mask) in enumerate(zip(batch, target)): batch[i], target[i] = self.transforms(image, mask) + # create copy of batch and target as the original will be modified batch_rolled = batch.roll(1, 0).detach().clone() target_rolled = copy.deepcopy(target[-1:] + target[:-1]) @@ -483,26 +483,33 @@ def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens # TODO: Smooth out the edges of the pasted objects using a Gaussian filter on the mask + # collect binary paste masks for all images paste_masks = [] for source_image, paste_image, source_data, paste_data in zip(batch, batch_rolled, target, target_rolled): paste_alpha_mask = self.combine_masks(paste_data["masks"]) paste_masks.append(paste_alpha_mask) + # update original masks for i, mask in enumerate(source_data["masks"]): source_data["masks"][i] = mask ^ paste_alpha_mask & mask + # remove masks where no annotations are present (all values are 0) mask_filter = source_data["masks"].sum((2, 1)).not_equal(0) filtered_masks = source_data["masks"][mask_filter] + + # update bboxes based on new masks source_data["boxes"] = ops.masks_to_boxes(filtered_masks) # TODO: update area + # concatenate paste data with original data source_data["masks"] = torch.cat((source_data["masks"], paste_data["masks"])) source_data["boxes"] = torch.cat((source_data["boxes"], paste_data["boxes"])) source_data["labels"] = torch.cat((source_data["labels"], paste_data["labels"])) source_data["area"] = torch.cat((source_data["area"], paste_data["area"])) source_data["iscrowd"] = torch.cat((source_data["iscrowd"], paste_data["iscrowd"])) + # update the original images with paste images paste_masks = torch.stack(paste_masks) batch.mul_(torch.unsqueeze(torch.logical_not(paste_masks), 1)) From bdf20a04be24f33b49d615a660e904d2c88df561 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Wed, 4 May 2022 21:19:47 +0530 Subject: [PATCH 07/22] remove transforms from class updated the labels added gaussian blur --- references/detection/transforms.py | 51 ++++++++++-------------------- 1 file changed, 17 insertions(+), 34 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 3c6ac978ec4..f4b959402da 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -442,25 +442,9 @@ def forward( class SimpleCopyPaste(torch.nn.Module): - def __init__(self, jittering_type: str = "LSJ"): + def __init__(self): super().__init__() - if jittering_type == "LSJ": - scale_range = (0.1, 2.0) - elif jittering_type == "SSJ": - scale_range = (0.8, 1.25) - else: - # TODO: add invalid option error - raise ValueError("Invalid jittering type") - - self.transforms = Compose( - [ - ScaleJitter(target_size=(1024, 1024), scale_range=scale_range), - FixedSizeCrop(size=(1024, 1024), fill=105), - RandomHorizontalFlip(0.5), - ] - ) - def combine_masks(self, masks): return masks.sum(dim=0).greater(0) @@ -472,22 +456,18 @@ def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens if not batch.is_floating_point(): raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") - for i, (image, mask) in enumerate(zip(batch, target)): - batch[i], target[i] = self.transforms(image, mask) - # create copy of batch and target as the original will be modified batch_rolled = batch.roll(1, 0).detach().clone() target_rolled = copy.deepcopy(target[-1:] + target[:-1]) - # TODO: select a random subset of objects from one of the images and paste them onto the other image - - # TODO: Smooth out the edges of the pasted objects using a Gaussian filter on the mask - # collect binary paste masks for all images paste_masks = [] for source_image, paste_image, source_data, paste_data in zip(batch, batch_rolled, target, target_rolled): - paste_alpha_mask = self.combine_masks(paste_data["masks"]) + number_of_masks = len(paste_data["masks"]) + random_selection = torch.randint(0, number_of_masks, (number_of_masks,)).unique() + + paste_alpha_mask = self.combine_masks(paste_data["masks"][random_selection]) paste_masks.append(paste_alpha_mask) # update original masks @@ -496,21 +476,24 @@ def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens # remove masks where no annotations are present (all values are 0) mask_filter = source_data["masks"].sum((2, 1)).not_equal(0) - filtered_masks = source_data["masks"][mask_filter] + source_data["masks"] = source_data["masks"][mask_filter] + source_data["boxes"] = ops.masks_to_boxes(source_data["masks"]) + source_data["labels"] = source_data["labels"][mask_filter] + source_data["area"] = source_data["area"][mask_filter] + source_data["iscrowd"] = source_data["iscrowd"][mask_filter] - # update bboxes based on new masks - source_data["boxes"] = ops.masks_to_boxes(filtered_masks) # TODO: update area # concatenate paste data with original data - source_data["masks"] = torch.cat((source_data["masks"], paste_data["masks"])) - source_data["boxes"] = torch.cat((source_data["boxes"], paste_data["boxes"])) - source_data["labels"] = torch.cat((source_data["labels"], paste_data["labels"])) - source_data["area"] = torch.cat((source_data["area"], paste_data["area"])) - source_data["iscrowd"] = torch.cat((source_data["iscrowd"], paste_data["iscrowd"])) + source_data["masks"] = torch.cat((source_data["masks"], paste_data["masks"][random_selection])) + source_data["boxes"] = torch.cat((source_data["boxes"], paste_data["boxes"][random_selection])) + source_data["labels"] = torch.cat((source_data["labels"], paste_data["labels"][random_selection])) + source_data["area"] = torch.cat((source_data["area"], paste_data["area"][random_selection])) + source_data["iscrowd"] = torch.cat((source_data["iscrowd"], paste_data["iscrowd"][random_selection])) # update the original images with paste images - paste_masks = torch.stack(paste_masks) + paste_masks = torch.stack(paste_masks).to(torch.uint8) + paste_masks = T.GaussianBlur((5, 5), sigma=2)(paste_masks) # Adds Gaussian Filter batch.mul_(torch.unsqueeze(torch.logical_not(paste_masks), 1)) paste_images = batch_rolled * torch.unsqueeze(paste_masks, 1) From f1ba6cf2044bd0824ff4a53d7f4502aed91a0dc3 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Wed, 4 May 2022 21:50:14 +0530 Subject: [PATCH 08/22] removed loop for mask calculation --- references/detection/transforms.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index f4b959402da..399b42f0810 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -471,8 +471,7 @@ def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens paste_masks.append(paste_alpha_mask) # update original masks - for i, mask in enumerate(source_data["masks"]): - source_data["masks"][i] = mask ^ paste_alpha_mask & mask + source_data["masks"] = source_data["masks"] ^ paste_alpha_mask & source_data["masks"] # remove masks where no annotations are present (all values are 0) mask_filter = source_data["masks"].sum((2, 1)).not_equal(0) From 5b238cf1a552cfd717f59717a99e25d46b3cd994 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Tue, 10 May 2022 21:40:58 +0530 Subject: [PATCH 09/22] replaced Gaussian blur with functional api --- references/detection/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 399b42f0810..8e45cc7d6d5 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -492,7 +492,7 @@ def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens # update the original images with paste images paste_masks = torch.stack(paste_masks).to(torch.uint8) - paste_masks = T.GaussianBlur((5, 5), sigma=2)(paste_masks) # Adds Gaussian Filter + paste_masks = F.gaussian_blur(paste_masks, kernel_size=(5, 5), sigma=2.0) # Adds Gaussian Filter batch.mul_(torch.unsqueeze(torch.logical_not(paste_masks), 1)) paste_images = batch_rolled * torch.unsqueeze(paste_masks, 1) From 7468480e7f83b3ba725819e3e96eb56c3fdac341 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Fri, 20 May 2022 16:12:07 +0530 Subject: [PATCH 10/22] added inplace operations --- references/detection/transforms.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 8e45cc7d6d5..a3b65e38704 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -442,28 +442,40 @@ def forward( class SimpleCopyPaste(torch.nn.Module): - def __init__(self): + def __init__(self, inplace=True): super().__init__() + self.inplace = inplace def combine_masks(self, masks): return masks.sum(dim=0).greater(0) - def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def forward( + self, batch: torch.Tensor, target: List[Dict[str, Tensor]] + ) -> Tuple[torch.Tensor, List[Dict[str, Tensor]]]: # validate inputs if batch.ndim != 4: raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}.") if not batch.is_floating_point(): raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") + if not isinstance(target, tuple): + raise TypeError(f"Target type should be a tuple of dictionaries. Got {type(target)}.") + if not len(target) == len(batch): + raise ValueError(f"batch and target lengths do not match. Got {len(batch)} and {len(target)}.") + + if not self.inplace: + batch = batch.clone().detach() + target = copy.deepcopy(target) + + shift = 1 - # create copy of batch and target as the original will be modified - batch_rolled = batch.roll(1, 0).detach().clone() - target_rolled = copy.deepcopy(target[-1:] + target[:-1]) + # create shifted copy of target as the original will be modified + target_rolled = copy.deepcopy(target[-shift:] + target[:-shift]) # collect binary paste masks for all images paste_masks = [] - for source_image, paste_image, source_data, paste_data in zip(batch, batch_rolled, target, target_rolled): + for source_data, paste_data in zip(target, target_rolled): number_of_masks = len(paste_data["masks"]) random_selection = torch.randint(0, number_of_masks, (number_of_masks,)).unique() @@ -493,8 +505,14 @@ def forward(self, batch: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tens # update the original images with paste images paste_masks = torch.stack(paste_masks).to(torch.uint8) paste_masks = F.gaussian_blur(paste_masks, kernel_size=(5, 5), sigma=2.0) # Adds Gaussian Filter + + # Clone batch as it will be modified + batch_rolled = batch.roll(shift, 0).clone().detach() + + # Black out areas which will be replaced batch.mul_(torch.unsqueeze(torch.logical_not(paste_masks), 1)) + # Copy and paste areas from rolled batch and paste on source batch paste_images = batch_rolled * torch.unsqueeze(paste_masks, 1) batch.add_(paste_images) From eb3446556ee7054b7f4411f387a5066b0d6bf429 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Fri, 20 May 2022 19:18:32 +0530 Subject: [PATCH 11/22] added changes to accept tuples instead of tensors --- references/detection/transforms.py | 138 +++++++++++++++++++---------- 1 file changed, 89 insertions(+), 49 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index a3b65e38704..701f6d73c11 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -446,77 +446,117 @@ def __init__(self, inplace=True): super().__init__() self.inplace = inplace - def combine_masks(self, masks): - return masks.sum(dim=0).greater(0) - def forward( - self, batch: torch.Tensor, target: List[Dict[str, Tensor]] - ) -> Tuple[torch.Tensor, List[Dict[str, Tensor]]]: + self, batch: Tuple[torch.Tensor], target: Tuple[Dict[str, Tensor]] + ) -> Tuple[Tuple[torch.Tensor], Tuple[Dict[str, Tensor]]]: # validate inputs - if batch.ndim != 4: - raise ValueError(f"Batch ndim should be 4. Got {batch.ndim}.") - if not batch.is_floating_point(): - raise TypeError(f"Batch dtype should be a float tensor. Got {batch.dtype}.") + if not isinstance(batch, tuple): + raise TypeError(f"Batch type should be a tuple of tensors. Got {type(target)}.") if not isinstance(target, tuple): raise TypeError(f"Target type should be a tuple of dictionaries. Got {type(target)}.") + if len(batch) == 0: + raise ValueError("'batch' tuple cannot be empty.") + if len(target) == 0: + raise ValueError("'target' tuple cannot be empty.") if not len(target) == len(batch): - raise ValueError(f"batch and target lengths do not match. Got {len(batch)} and {len(target)}.") + raise ValueError(f"'batch' and 'target' lengths do not match. Got {len(batch)} and {len(target)}.") + + # TODO: Validate that batch contains tensors + # TODO: Validate that target contains dictionaries with tensor values if not self.inplace: - batch = batch.clone().detach() + batch = copy.deepcopy(batch) target = copy.deepcopy(target) shift = 1 - # create shifted copy of target as the original will be modified + # create shifted copy of batch and target as the original will be modified + batch_rolled = copy.deepcopy(batch[-shift:] + batch[:-shift]) target_rolled = copy.deepcopy(target[-shift:] + target[:-shift]) - # collect binary paste masks for all images - paste_masks = [] - - for source_data, paste_data in zip(target, target_rolled): - number_of_masks = len(paste_data["masks"]) - random_selection = torch.randint(0, number_of_masks, (number_of_masks,)).unique() - - paste_alpha_mask = self.combine_masks(paste_data["masks"][random_selection]) - paste_masks.append(paste_alpha_mask) + for image, source_data, paste_image, paste_data in zip(batch, target, batch_rolled, target_rolled): + self.copy_paste(image, source_data, paste_image, paste_data) - # update original masks - source_data["masks"] = source_data["masks"] ^ paste_alpha_mask & source_data["masks"] - - # remove masks where no annotations are present (all values are 0) - mask_filter = source_data["masks"].sum((2, 1)).not_equal(0) - source_data["masks"] = source_data["masks"][mask_filter] - source_data["boxes"] = ops.masks_to_boxes(source_data["masks"]) - source_data["labels"] = source_data["labels"][mask_filter] - source_data["area"] = source_data["area"][mask_filter] - source_data["iscrowd"] = source_data["iscrowd"][mask_filter] + return batch, target - # TODO: update area + def copy_paste(self, image, source_data, paste_image, paste_data): + number_of_masks = len(paste_data["masks"]) + random_selection = torch.randint(0, number_of_masks, (number_of_masks,)).unique() - # concatenate paste data with original data - source_data["masks"] = torch.cat((source_data["masks"], paste_data["masks"][random_selection])) - source_data["boxes"] = torch.cat((source_data["boxes"], paste_data["boxes"][random_selection])) - source_data["labels"] = torch.cat((source_data["labels"], paste_data["labels"][random_selection])) - source_data["area"] = torch.cat((source_data["area"], paste_data["area"][random_selection])) - source_data["iscrowd"] = torch.cat((source_data["iscrowd"], paste_data["iscrowd"][random_selection])) + # Update image + paste_alpha_mask = self.combine_masks(paste_data["masks"][random_selection]) + self.update_source_image(image, paste_image, paste_alpha_mask) - # update the original images with paste images - paste_masks = torch.stack(paste_masks).to(torch.uint8) - paste_masks = F.gaussian_blur(paste_masks, kernel_size=(5, 5), sigma=2.0) # Adds Gaussian Filter + # update original masks + source_data["masks"] = source_data["masks"] ^ paste_alpha_mask & source_data["masks"] - # Clone batch as it will be modified - batch_rolled = batch.roll(shift, 0).clone().detach() + # remove masks where no annotations are present (all values are 0) + mask_filter = source_data["masks"].sum((2, 1)).not_equal(0) + # TODO: update area - # Black out areas which will be replaced - batch.mul_(torch.unsqueeze(torch.logical_not(paste_masks), 1)) + self.filter_source_data(source_data, mask_filter) + self.concat_paste_data(source_data, paste_data, random_selection) - # Copy and paste areas from rolled batch and paste on source batch - paste_images = batch_rolled * torch.unsqueeze(paste_masks, 1) - batch.add_(paste_images) + def combine_masks(self, masks): + """ + Combines multiple masks into one alpha mask + Args: + masks: tensor of masks that need to be combined + + Returns: + Tensor: Boolean tensor of combined alpha mask + """ + return masks.sum(dim=0).greater(0) - return batch, target + def concat_paste_data(self, source_data, paste_data, random_selection): + """ + Concatenates the masks, boxes, labels, area and iscrowd info + from the paste image data to the source image data. + Args: + source_data: masks, boxes, labels, area, iscrowd info related to source image + paste_data: masks, boxes, labels, area, iscrowd info related to paste image + random_selection: indices of random masks selected from paste data + + Returns: + """ + source_data["masks"] = torch.cat((source_data["masks"], paste_data["masks"][random_selection])) + source_data["boxes"] = torch.cat((source_data["boxes"], paste_data["boxes"][random_selection])) + source_data["labels"] = torch.cat((source_data["labels"], paste_data["labels"][random_selection])) + source_data["area"] = torch.cat((source_data["area"], paste_data["area"][random_selection])) + source_data["iscrowd"] = torch.cat((source_data["iscrowd"], paste_data["iscrowd"][random_selection])) + + def filter_source_data(self, source_data, mask_filter): + """ + Removes masks and related masks data which are no longer visible due to overwriting of paste image. + Args: + mask_filter: boolean tensor of visible and non-visible masks + source_data: masks, boxes, labels, area, iscrowd info related to source image + + Returns: + + """ + source_data["masks"] = source_data["masks"][mask_filter] + source_data["boxes"] = ops.masks_to_boxes(source_data["masks"]) + source_data["labels"] = source_data["labels"][mask_filter] + source_data["area"] = source_data["area"][mask_filter] + source_data["iscrowd"] = source_data["iscrowd"][mask_filter] + + def update_source_image(self, image, paste_image, paste_alpha_mask): + """ + Copies the pixels from the paste_image to image using the paste_alpha_mask + Args: + image: Source image which has to be updated + paste_image: Image from which pixels will be copied + paste_alpha_mask: Mask of pixels which need to be copied + + Returns: + + """ + paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=(5, 5), sigma=2.0) # add blur + image.mul_(torch.logical_not(paste_alpha_mask)) # Delete pixels from source image + paste_pixels = paste_image * paste_alpha_mask # Copy pixels from paste image + image.add_(paste_pixels) # paste it on source image def __repr__(self) -> str: s = f"{self.__class__.__name__}(" f", p={self.p}" f", inplace={self.inplace}" f")" From 7676203a498953609892045b2998696b6eedf918 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Fri, 3 Jun 2022 15:50:10 +0530 Subject: [PATCH 12/22] - make copy paste functional - make only one copy of batch and target --- references/detection/transforms.py | 143 ++++++++++++----------------- 1 file changed, 57 insertions(+), 86 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 701f6d73c11..5abe9c6912a 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -441,6 +441,50 @@ def forward( return image, target +def copy_paste(image, source_data, paste_image, paste_data): + number_of_masks = len(paste_data["masks"]) + random_selection = torch.randint(0, number_of_masks, (number_of_masks,)).unique() + + # Combine masks + paste_alpha_mask = paste_data["masks"][random_selection].sum(dim=0).greater(0) + + paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=(5, 5), sigma=2.0) # add blur + + # Paste pixels from paste_image to image + image.mul_(torch.logical_not(paste_alpha_mask)) # Delete pixels from source image + paste_pixels = paste_image * paste_alpha_mask # Copy pixels from paste image + image.add_(paste_pixels) # paste it on source image + + # update original masks + source_data["masks"] = source_data["masks"] ^ paste_alpha_mask & source_data["masks"] + + # remove masks where no annotations are present (all values are 0) + mask_filter = source_data["masks"].sum((2, 1)).not_equal(0) + # TODO: update area + + # Update other attributes + source_data["masks"] = source_data["masks"][mask_filter] + source_data["boxes"] = ops.masks_to_boxes(source_data["masks"]) + source_data["labels"] = source_data["labels"][mask_filter] + + # TODO: Fix IndexError: The shape of the mask [5] at index 0 does not match the shape of the indexed tensor [6] at index 0 + # if "area" in source_data: + # source_data["area"] = source_data["area"][mask_filter] + # if "iscrowd" in source_data: + # source_data["iscrowd"] = source_data["iscrowd"][mask_filter] + + source_data["masks"] = torch.cat((source_data["masks"], paste_data["masks"][random_selection])) + source_data["boxes"] = torch.cat((source_data["boxes"], paste_data["boxes"][random_selection])) + source_data["labels"] = torch.cat((source_data["labels"], paste_data["labels"][random_selection])) + + if "area" in source_data: + source_data["area"] = torch.cat((source_data["area"], paste_data["area"][random_selection])) + if "iscrowd" in source_data: + source_data["iscrowd"] = torch.cat((source_data["iscrowd"], paste_data["iscrowd"][random_selection])) + + return image, source_data + + class SimpleCopyPaste(torch.nn.Module): def __init__(self, inplace=True): super().__init__() @@ -465,99 +509,26 @@ def forward( # TODO: Validate that batch contains tensors # TODO: Validate that target contains dictionaries with tensor values - if not self.inplace: - batch = copy.deepcopy(batch) - target = copy.deepcopy(target) - shift = 1 - # create shifted copy of batch and target as the original will be modified - batch_rolled = copy.deepcopy(batch[-shift:] + batch[:-shift]) - target_rolled = copy.deepcopy(target[-shift:] + target[:-shift]) + if self.inplace: + # create shifted copy of batch and target as the original will be modified + batch_rolled = copy.deepcopy(batch[-shift:] + batch[:-shift]) + target_rolled = copy.deepcopy(target[-shift:] + target[:-shift]) + else: + # use the original batch as the shifted copy. Make a copy of original which will be modified + batch_copy = copy.deepcopy(batch) + target_copy = copy.deepcopy(target) + batch_rolled = batch[-shift:] + batch[:-shift] + target_rolled = target[-shift:] + target[:-shift] + batch = batch_copy + target = target_copy for image, source_data, paste_image, paste_data in zip(batch, target, batch_rolled, target_rolled): - self.copy_paste(image, source_data, paste_image, paste_data) + copy_paste(image, source_data, paste_image, paste_data) return batch, target - def copy_paste(self, image, source_data, paste_image, paste_data): - number_of_masks = len(paste_data["masks"]) - random_selection = torch.randint(0, number_of_masks, (number_of_masks,)).unique() - - # Update image - paste_alpha_mask = self.combine_masks(paste_data["masks"][random_selection]) - self.update_source_image(image, paste_image, paste_alpha_mask) - - # update original masks - source_data["masks"] = source_data["masks"] ^ paste_alpha_mask & source_data["masks"] - - # remove masks where no annotations are present (all values are 0) - mask_filter = source_data["masks"].sum((2, 1)).not_equal(0) - # TODO: update area - - self.filter_source_data(source_data, mask_filter) - self.concat_paste_data(source_data, paste_data, random_selection) - - def combine_masks(self, masks): - """ - Combines multiple masks into one alpha mask - Args: - masks: tensor of masks that need to be combined - - Returns: - Tensor: Boolean tensor of combined alpha mask - """ - return masks.sum(dim=0).greater(0) - - def concat_paste_data(self, source_data, paste_data, random_selection): - """ - Concatenates the masks, boxes, labels, area and iscrowd info - from the paste image data to the source image data. - Args: - source_data: masks, boxes, labels, area, iscrowd info related to source image - paste_data: masks, boxes, labels, area, iscrowd info related to paste image - random_selection: indices of random masks selected from paste data - - Returns: - """ - source_data["masks"] = torch.cat((source_data["masks"], paste_data["masks"][random_selection])) - source_data["boxes"] = torch.cat((source_data["boxes"], paste_data["boxes"][random_selection])) - source_data["labels"] = torch.cat((source_data["labels"], paste_data["labels"][random_selection])) - source_data["area"] = torch.cat((source_data["area"], paste_data["area"][random_selection])) - source_data["iscrowd"] = torch.cat((source_data["iscrowd"], paste_data["iscrowd"][random_selection])) - - def filter_source_data(self, source_data, mask_filter): - """ - Removes masks and related masks data which are no longer visible due to overwriting of paste image. - Args: - mask_filter: boolean tensor of visible and non-visible masks - source_data: masks, boxes, labels, area, iscrowd info related to source image - - Returns: - - """ - source_data["masks"] = source_data["masks"][mask_filter] - source_data["boxes"] = ops.masks_to_boxes(source_data["masks"]) - source_data["labels"] = source_data["labels"][mask_filter] - source_data["area"] = source_data["area"][mask_filter] - source_data["iscrowd"] = source_data["iscrowd"][mask_filter] - - def update_source_image(self, image, paste_image, paste_alpha_mask): - """ - Copies the pixels from the paste_image to image using the paste_alpha_mask - Args: - image: Source image which has to be updated - paste_image: Image from which pixels will be copied - paste_alpha_mask: Mask of pixels which need to be copied - - Returns: - - """ - paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=(5, 5), sigma=2.0) # add blur - image.mul_(torch.logical_not(paste_alpha_mask)) # Delete pixels from source image - paste_pixels = paste_image * paste_alpha_mask # Copy pixels from paste image - image.add_(paste_pixels) # paste it on source image - def __repr__(self) -> str: s = f"{self.__class__.__name__}(" f", p={self.p}" f", inplace={self.inplace}" f")" return s From 15bc8db9191ef47d361572ab0d719cf2fd3b7b45 Mon Sep 17 00:00:00 2001 From: Lezwon Castelino Date: Fri, 3 Jun 2022 17:05:41 +0530 Subject: [PATCH 13/22] add inplace support within copy paste functional --- references/detection/transforms.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 5abe9c6912a..cbcfd1ced85 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -1,5 +1,5 @@ import copy -from typing import List, Tuple, Dict, Optional, Union +from typing import List, Tuple, Dict, Optional, Union, cast import torch import torchvision @@ -441,7 +441,11 @@ def forward( return image, target -def copy_paste(image, source_data, paste_image, paste_data): +def copy_paste(image, source_data, paste_image, paste_data, inplace=True): + + image = image.clone() if not inplace else image + source_data = copy.deepcopy(source_data) if not inplace else source_data + number_of_masks = len(paste_data["masks"]) random_selection = torch.randint(0, number_of_masks, (number_of_masks,)).unique() @@ -512,22 +516,21 @@ def forward( shift = 1 if self.inplace: - # create shifted copy of batch and target as the original will be modified batch_rolled = copy.deepcopy(batch[-shift:] + batch[:-shift]) target_rolled = copy.deepcopy(target[-shift:] + target[:-shift]) else: - # use the original batch as the shifted copy. Make a copy of original which will be modified - batch_copy = copy.deepcopy(batch) - target_copy = copy.deepcopy(target) batch_rolled = batch[-shift:] + batch[:-shift] target_rolled = target[-shift:] + target[:-shift] - batch = batch_copy - target = target_copy + + output_batch = [] + output_target = [] for image, source_data, paste_image, paste_data in zip(batch, target, batch_rolled, target_rolled): - copy_paste(image, source_data, paste_image, paste_data) + output_image, output_data = copy_paste(image, source_data, paste_image, paste_data, self.inplace) + output_batch.append(output_image) + output_target.append(output_data) - return batch, target + return cast(Tuple[Tuple[torch.Tensor], Tuple[Dict[str, Tensor]]], (tuple(output_batch), tuple(output_target))) def __repr__(self) -> str: s = f"{self.__class__.__name__}(" f", p={self.p}" f", inplace={self.inplace}" f")" From c2a10a460247d44198eb3c66279bca376f51d9ee Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 10 Jun 2022 12:33:48 +0000 Subject: [PATCH 14/22] Updated code for copy-paste transform --- references/detection/train.py | 16 ++- references/detection/transforms.py | 207 ++++++++++++++++++----------- 2 files changed, 144 insertions(+), 79 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index 229278eb9b4..5d98300ab14 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -31,6 +31,7 @@ from coco_utils import get_coco, get_coco_kp from engine import train_one_epoch, evaluate from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups +from transforms import SimpleCopyPaste def get_dataset(name, image_set, transform, data_path): @@ -145,6 +146,9 @@ def get_args_parser(add_help=True): # Mixed precision training parameters parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") + # Use CopyPaste augmentation training parameter + parser.add_argument("--use-copypaste", action="store_true", help="Use CopyPaste data augmentation") + return parser @@ -180,8 +184,18 @@ def main(args): else: train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True) + train_collate_fn = utils.collate_fn + if args.use_copypaste: + print("Use SimpleCopyPaste data aug") + copypaste = SimpleCopyPaste(inplace=False) + + def copypaste_collate_fn(batch): + return copypaste(*utils.collate_fn(batch)) + + train_collate_fn = copypaste_collate_fn + data_loader = torch.utils.data.DataLoader( - dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn + dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=train_collate_fn ) data_loader_test = torch.utils.data.DataLoader( diff --git a/references/detection/transforms.py b/references/detection/transforms.py index cbcfd1ced85..61e8516067f 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -441,52 +441,108 @@ def forward( return image, target -def copy_paste(image, source_data, paste_image, paste_data, inplace=True): - - image = image.clone() if not inplace else image - source_data = copy.deepcopy(source_data) if not inplace else source_data - - number_of_masks = len(paste_data["masks"]) - random_selection = torch.randint(0, number_of_masks, (number_of_masks,)).unique() - - # Combine masks - paste_alpha_mask = paste_data["masks"][random_selection].sum(dim=0).greater(0) - - paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=(5, 5), sigma=2.0) # add blur - - # Paste pixels from paste_image to image - image.mul_(torch.logical_not(paste_alpha_mask)) # Delete pixels from source image - paste_pixels = paste_image * paste_alpha_mask # Copy pixels from paste image - image.add_(paste_pixels) # paste it on source image - - # update original masks - source_data["masks"] = source_data["masks"] ^ paste_alpha_mask & source_data["masks"] - - # remove masks where no annotations are present (all values are 0) - mask_filter = source_data["masks"].sum((2, 1)).not_equal(0) - # TODO: update area - - # Update other attributes - source_data["masks"] = source_data["masks"][mask_filter] - source_data["boxes"] = ops.masks_to_boxes(source_data["masks"]) - source_data["labels"] = source_data["labels"][mask_filter] - - # TODO: Fix IndexError: The shape of the mask [5] at index 0 does not match the shape of the indexed tensor [6] at index 0 - # if "area" in source_data: - # source_data["area"] = source_data["area"][mask_filter] - # if "iscrowd" in source_data: - # source_data["iscrowd"] = source_data["iscrowd"][mask_filter] - - source_data["masks"] = torch.cat((source_data["masks"], paste_data["masks"][random_selection])) - source_data["boxes"] = torch.cat((source_data["boxes"], paste_data["boxes"][random_selection])) - source_data["labels"] = torch.cat((source_data["labels"], paste_data["labels"][random_selection])) - - if "area" in source_data: - source_data["area"] = torch.cat((source_data["area"], paste_data["area"][random_selection])) - if "iscrowd" in source_data: - source_data["iscrowd"] = torch.cat((source_data["iscrowd"], paste_data["iscrowd"][random_selection])) - - return image, source_data +def _copy_paste(image, target, paste_image, paste_target, inplace=True): + + # Random paste targets selection: + num_masks = len(paste_target["masks"]) + random_selection = torch.randint(0, num_masks, (num_masks,)).unique() + + paste_masks = paste_target["masks"][random_selection] + paste_boxes = paste_target["boxes"][random_selection] + paste_labels = paste_target["labels"][random_selection] + + paste_alpha_mask = paste_masks.sum(dim=0) > 0 + paste_alpha_mask = F.gaussian_blur( + paste_alpha_mask.unsqueeze(0), kernel_size=(5, 5), sigma=2.0 + ) + + masks = target["masks"] + # Align images keeping top-left corner if source and paste data + # of different sizes + shape1 = image.shape[-2:] + shape2 = paste_image.shape[-2:] + reduced_paste_mask = False + if shape1 != shape2: + + h1 = None if shape1[0] < shape2[0] else shape2[0] + w1 = None if shape1[1] < shape2[1] else shape2[1] + h2 = shape1[0] if shape1[0] < shape2[0] else None + w2 = shape1[1] if shape1[1] < shape2[1] else None + + image = image[..., :h1, :w1] + masks = masks[..., :h1, :w1] + paste_image = paste_image[..., :h2, :w2] + paste_masks = paste_masks[..., :h2, :w2] + paste_alpha_mask = paste_alpha_mask[..., :h2, :w2] + + if h2 is not None or w2 is not None: + reduced_paste_mask = True + + # Copy-paste images: + image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask) + + # Copy-paste masks: + masks = masks * (~paste_alpha_mask) + # to avoid degenerated bboxes (with width or height of 1) + # let's use a threshold of 10 to filter out small masks + non_all_zero_masks = masks.sum((-1, -2)) > 10 + masks = masks[non_all_zero_masks] + + # As paste_masks was aligned with masks, we can remove small masks + # thus we need to keep only non-zero masks + non_all_zero_pmasks = None + if reduced_paste_mask: + # let's use a threshold of 4 to filter out small masks + non_all_zero_pmasks = paste_masks.sum((-1, -2)) > 10 + paste_masks = paste_masks[non_all_zero_pmasks] + paste_boxes = paste_boxes[non_all_zero_pmasks] + paste_labels = paste_labels[non_all_zero_pmasks] + + if inplace: + out_target = target + else: + # Do a shallow copy of the target dict + out_target = copy.copy(target) + + out_target["masks"] = torch.cat([masks, paste_masks]) + + # Copy-paste boxes and labels + boxes = ops.masks_to_boxes(masks) + out_target["boxes"] = torch.cat([boxes, paste_boxes]) + + labels = target["labels"][non_all_zero_masks] + out_target["labels"] = torch.cat([labels, paste_labels]) + + # Update additional optional keys: area and iscrowd if exist + if "area" in target: + out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32) + + if "iscrowd" in target and "iscrowd" in paste_target: + iscrowd = target["iscrowd"][non_all_zero_masks] + paste_iscrowd = paste_target["iscrowd"][random_selection] + if reduced_paste_mask: + paste_iscrowd = paste_iscrowd[non_all_zero_pmasks] + out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd]) + + # Check for degenerated boxes and remove them + boxes = out_target["boxes"] + degenerate_boxes = boxes[:, 2:] <= boxes[:, :2] + if degenerate_boxes.any(): + valid_targets = ~degenerate_boxes.any(dim=1) + + out_target["boxes"] = boxes[valid_targets] + out_target["masks"] = out_target["masks"][valid_targets] + out_target["labels"] = out_target["labels"][valid_targets] + + if "area" in out_target: + out_target["area"] = out_target["area"][valid_targets] + if "iscrowd" in out_target: + out_target["iscrowd"] = out_target["iscrowd"][valid_targets] + + assert len(out_target["boxes"]) == len(out_target["masks"]), f"{len(out_target['boxes'])}, {len(out_target['masks'])}" + assert len(out_target["labels"]) == len(out_target["masks"]), f"{len(out_target['labels'])}, {len(out_target['masks'])}" + + return image, out_target class SimpleCopyPaste(torch.nn.Module): @@ -494,44 +550,39 @@ def __init__(self, inplace=True): super().__init__() self.inplace = inplace - def forward( - self, batch: Tuple[torch.Tensor], target: Tuple[Dict[str, Tensor]] - ) -> Tuple[Tuple[torch.Tensor], Tuple[Dict[str, Tensor]]]: - - # validate inputs - if not isinstance(batch, tuple): - raise TypeError(f"Batch type should be a tuple of tensors. Got {type(target)}.") - if not isinstance(target, tuple): - raise TypeError(f"Target type should be a tuple of dictionaries. Got {type(target)}.") - if len(batch) == 0: - raise ValueError("'batch' tuple cannot be empty.") - if len(target) == 0: - raise ValueError("'target' tuple cannot be empty.") - if not len(target) == len(batch): - raise ValueError(f"'batch' and 'target' lengths do not match. Got {len(batch)} and {len(target)}.") - - # TODO: Validate that batch contains tensors - # TODO: Validate that target contains dictionaries with tensor values - + def forward(self, images, targets=None): + assert targets is not None + assert isinstance(images, tuple) and all([isinstance(v, torch.Tensor) for v in images]) + assert isinstance(targets, tuple) and len(images) == len(targets) + for target in targets: + assert isinstance(target, dict) + for k in ["masks", "boxes", "labels"]: + assert k in target, f"Key {k} should be present in targets" + assert isinstance(target[k], torch.Tensor) + + # images = [t1, t2, ..., tN] + # Let's define paste_images as shifted list of input images + # paste_images = [t2, t3, ..., tN, t1] + # FYI: in TF they mix data on the dataset level shift = 1 + images_rolled = images[-shift:] + images[:-shift] + targets_rolled = targets[-shift:] + targets[:-shift] if self.inplace: - batch_rolled = copy.deepcopy(batch[-shift:] + batch[:-shift]) - target_rolled = copy.deepcopy(target[-shift:] + target[:-shift]) - else: - batch_rolled = batch[-shift:] + batch[:-shift] - target_rolled = target[-shift:] + target[:-shift] + images_rolled = copy.deepcopy(images_rolled) + targets_rolled = copy.deepcopy(targets_rolled) - output_batch = [] - output_target = [] + output_images = [] + output_targets = [] - for image, source_data, paste_image, paste_data in zip(batch, target, batch_rolled, target_rolled): - output_image, output_data = copy_paste(image, source_data, paste_image, paste_data, self.inplace) - output_batch.append(output_image) - output_target.append(output_data) + data = [images, targets, images_rolled, targets_rolled] + for image, target, paste_image, paste_target in zip(*data): + output_image, output_data = _copy_paste(image, target, paste_image, paste_target, self.inplace) + output_images.append(output_image) + output_targets.append(output_data) - return cast(Tuple[Tuple[torch.Tensor], Tuple[Dict[str, Tensor]]], (tuple(output_batch), tuple(output_target))) + return tuple(output_images), tuple(output_targets) def __repr__(self) -> str: - s = f"{self.__class__.__name__}(" f", p={self.p}" f", inplace={self.inplace}" f")" + s = f"{self.__class__.__name__}(inplace={self.inplace})" return s From 117c6da318e82fb073f659326fdbe7338827ba01 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 10 Jun 2022 12:40:13 +0000 Subject: [PATCH 15/22] Fixed code formatting --- references/detection/transforms.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 61e8516067f..84947e3637f 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -1,5 +1,5 @@ import copy -from typing import List, Tuple, Dict, Optional, Union, cast +from typing import List, Tuple, Dict, Optional, Union import torch import torchvision @@ -452,9 +452,7 @@ def _copy_paste(image, target, paste_image, paste_target, inplace=True): paste_labels = paste_target["labels"][random_selection] paste_alpha_mask = paste_masks.sum(dim=0) > 0 - paste_alpha_mask = F.gaussian_blur( - paste_alpha_mask.unsqueeze(0), kernel_size=(5, 5), sigma=2.0 - ) + paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=(5, 5), sigma=2.0) masks = target["masks"] # Align images keeping top-left corner if source and paste data @@ -539,8 +537,12 @@ def _copy_paste(image, target, paste_image, paste_target, inplace=True): if "iscrowd" in out_target: out_target["iscrowd"] = out_target["iscrowd"][valid_targets] - assert len(out_target["boxes"]) == len(out_target["masks"]), f"{len(out_target['boxes'])}, {len(out_target['masks'])}" - assert len(out_target["labels"]) == len(out_target["masks"]), f"{len(out_target['labels'])}, {len(out_target['masks'])}" + assert len(out_target["boxes"]) == len( + out_target["masks"] + ), f"{len(out_target['boxes'])}, {len(out_target['masks'])}" + assert len(out_target["labels"]) == len( + out_target["masks"] + ), f"{len(out_target['labels'])}, {len(out_target['masks'])}" return image, out_target From 21ac77542a78f4d113f9949f95241d86206bc513 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 10 Jun 2022 13:19:19 +0000 Subject: [PATCH 16/22] [skip ci] removed manual thresholding --- references/detection/transforms.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 84947e3637f..fb3246707f4 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -481,17 +481,14 @@ def _copy_paste(image, target, paste_image, paste_target, inplace=True): # Copy-paste masks: masks = masks * (~paste_alpha_mask) - # to avoid degenerated bboxes (with width or height of 1) - # let's use a threshold of 10 to filter out small masks - non_all_zero_masks = masks.sum((-1, -2)) > 10 + non_all_zero_masks = masks.sum((-1, -2)) > 0 masks = masks[non_all_zero_masks] # As paste_masks was aligned with masks, we can remove small masks # thus we need to keep only non-zero masks non_all_zero_pmasks = None if reduced_paste_mask: - # let's use a threshold of 4 to filter out small masks - non_all_zero_pmasks = paste_masks.sum((-1, -2)) > 10 + non_all_zero_pmasks = paste_masks.sum((-1, -2)) > 0 paste_masks = paste_masks[non_all_zero_pmasks] paste_boxes = paste_boxes[non_all_zero_pmasks] paste_labels = paste_labels[non_all_zero_pmasks] @@ -502,6 +499,8 @@ def _copy_paste(image, target, paste_image, paste_target, inplace=True): # Do a shallow copy of the target dict out_target = copy.copy(target) + # TODO: what if paste_masks, paste_boxes and paste_labels are empty now ? + out_target["masks"] = torch.cat([masks, paste_masks]) # Copy-paste boxes and labels From ad546faa94be196942b66262342dbd733769eb7c Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 10 Jun 2022 13:54:33 +0000 Subject: [PATCH 17/22] Replaced cropping by resizing data to paste --- references/detection/transforms.py | 54 +++++++----------------------- 1 file changed, 12 insertions(+), 42 deletions(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index fb3246707f4..721b5084128 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -451,31 +451,21 @@ def _copy_paste(image, target, paste_image, paste_target, inplace=True): paste_boxes = paste_target["boxes"][random_selection] paste_labels = paste_target["labels"][random_selection] + masks = target["masks"] + # If source and paste data have different sizes + # Let's resize paste data + size1 = image.shape[-2:] + size2 = paste_image.shape[-2:] + if size1 != size2: + paste_image = F.resize(paste_image, size1, interpolation=F.InterpolationMode.BICUBIC) + paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST) + # resize bboxes: + ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device) + paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape) + paste_alpha_mask = paste_masks.sum(dim=0) > 0 paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=(5, 5), sigma=2.0) - masks = target["masks"] - # Align images keeping top-left corner if source and paste data - # of different sizes - shape1 = image.shape[-2:] - shape2 = paste_image.shape[-2:] - reduced_paste_mask = False - if shape1 != shape2: - - h1 = None if shape1[0] < shape2[0] else shape2[0] - w1 = None if shape1[1] < shape2[1] else shape2[1] - h2 = shape1[0] if shape1[0] < shape2[0] else None - w2 = shape1[1] if shape1[1] < shape2[1] else None - - image = image[..., :h1, :w1] - masks = masks[..., :h1, :w1] - paste_image = paste_image[..., :h2, :w2] - paste_masks = paste_masks[..., :h2, :w2] - paste_alpha_mask = paste_alpha_mask[..., :h2, :w2] - - if h2 is not None or w2 is not None: - reduced_paste_mask = True - # Copy-paste images: image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask) @@ -484,23 +474,12 @@ def _copy_paste(image, target, paste_image, paste_target, inplace=True): non_all_zero_masks = masks.sum((-1, -2)) > 0 masks = masks[non_all_zero_masks] - # As paste_masks was aligned with masks, we can remove small masks - # thus we need to keep only non-zero masks - non_all_zero_pmasks = None - if reduced_paste_mask: - non_all_zero_pmasks = paste_masks.sum((-1, -2)) > 0 - paste_masks = paste_masks[non_all_zero_pmasks] - paste_boxes = paste_boxes[non_all_zero_pmasks] - paste_labels = paste_labels[non_all_zero_pmasks] - if inplace: out_target = target else: # Do a shallow copy of the target dict out_target = copy.copy(target) - # TODO: what if paste_masks, paste_boxes and paste_labels are empty now ? - out_target["masks"] = torch.cat([masks, paste_masks]) # Copy-paste boxes and labels @@ -517,8 +496,6 @@ def _copy_paste(image, target, paste_image, paste_target, inplace=True): if "iscrowd" in target and "iscrowd" in paste_target: iscrowd = target["iscrowd"][non_all_zero_masks] paste_iscrowd = paste_target["iscrowd"][random_selection] - if reduced_paste_mask: - paste_iscrowd = paste_iscrowd[non_all_zero_pmasks] out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd]) # Check for degenerated boxes and remove them @@ -536,13 +513,6 @@ def _copy_paste(image, target, paste_image, paste_target, inplace=True): if "iscrowd" in out_target: out_target["iscrowd"] = out_target["iscrowd"][valid_targets] - assert len(out_target["boxes"]) == len( - out_target["masks"] - ), f"{len(out_target['boxes'])}, {len(out_target['masks'])}" - assert len(out_target["labels"]) == len( - out_target["masks"] - ), f"{len(out_target['labels'])}, {len(out_target['masks'])}" - return image, out_target From 064c44d8695e3c95f10333192f9c447e55fe01a6 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 13 Jun 2022 08:29:49 +0000 Subject: [PATCH 18/22] Removed inplace arg (as useless) and put a check on iscrowd target --- references/detection/train.py | 2 +- references/detection/transforms.py | 28 +++++++++++----------------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index 5d98300ab14..a98dfd0deae 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -187,7 +187,7 @@ def main(args): train_collate_fn = utils.collate_fn if args.use_copypaste: print("Use SimpleCopyPaste data aug") - copypaste = SimpleCopyPaste(inplace=False) + copypaste = SimpleCopyPaste() def copypaste_collate_fn(batch): return copypaste(*utils.collate_fn(batch)) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 721b5084128..5d634d8938e 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -474,11 +474,8 @@ def _copy_paste(image, target, paste_image, paste_target, inplace=True): non_all_zero_masks = masks.sum((-1, -2)) > 0 masks = masks[non_all_zero_masks] - if inplace: - out_target = target - else: - # Do a shallow copy of the target dict - out_target = copy.copy(target) + # Do a shallow copy of the target dict + out_target = copy.copy(target) out_target["masks"] = torch.cat([masks, paste_masks]) @@ -494,9 +491,13 @@ def _copy_paste(image, target, paste_image, paste_target, inplace=True): out_target["area"] = out_target["masks"].sum((-1, -2)).to(torch.float32) if "iscrowd" in target and "iscrowd" in paste_target: - iscrowd = target["iscrowd"][non_all_zero_masks] - paste_iscrowd = paste_target["iscrowd"][random_selection] - out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd]) + # target['iscrowd'] size can be differ from mask size (non_all_zero_masks) + if len(target["iscrowd"]) == len(non_all_zero_masks): + iscrowd = target["iscrowd"][non_all_zero_masks] + paste_iscrowd = paste_target["iscrowd"][random_selection] + out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd]) + else: + out_target["iscrowd"] = target["iscrowd"] # Check for degenerated boxes and remove them boxes = out_target["boxes"] @@ -517,9 +518,6 @@ def _copy_paste(image, target, paste_image, paste_target, inplace=True): class SimpleCopyPaste(torch.nn.Module): - def __init__(self, inplace=True): - super().__init__() - self.inplace = inplace def forward(self, images, targets=None): assert targets is not None @@ -539,21 +537,17 @@ def forward(self, images, targets=None): images_rolled = images[-shift:] + images[:-shift] targets_rolled = targets[-shift:] + targets[:-shift] - if self.inplace: - images_rolled = copy.deepcopy(images_rolled) - targets_rolled = copy.deepcopy(targets_rolled) - output_images = [] output_targets = [] data = [images, targets, images_rolled, targets_rolled] for image, target, paste_image, paste_target in zip(*data): - output_image, output_data = _copy_paste(image, target, paste_image, paste_target, self.inplace) + output_image, output_data = _copy_paste(image, target, paste_image, paste_target) output_images.append(output_image) output_targets.append(output_data) return tuple(output_images), tuple(output_targets) def __repr__(self) -> str: - s = f"{self.__class__.__name__}(inplace={self.inplace})" + s = f"{self.__class__.__name__}()" return s From 09b4db07715a0182629a3803cab495e48929155a Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 13 Jun 2022 11:48:51 +0000 Subject: [PATCH 19/22] code-formatting --- references/detection/transforms.py | 1 - 1 file changed, 1 deletion(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 5d634d8938e..b43654131be 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -518,7 +518,6 @@ def _copy_paste(image, target, paste_image, paste_target, inplace=True): class SimpleCopyPaste(torch.nn.Module): - def forward(self, images, targets=None): assert targets is not None assert isinstance(images, tuple) and all([isinstance(v, torch.Tensor) for v in images]) From f1cc84bb7942e4be28702bcb7a70cf066debc0bf Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 14 Jun 2022 13:53:03 +0000 Subject: [PATCH 20/22] Updated copypaste op to make it torch scriptable Added fallbacks to support LSJ --- references/detection/train.py | 18 ++++- references/detection/transforms.py | 101 ++++++++++++++++++++--------- 2 files changed, 87 insertions(+), 32 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index a98dfd0deae..6fd63b9c2ee 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -31,6 +31,7 @@ from coco_utils import get_coco, get_coco_kp from engine import train_one_epoch, evaluate from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups +from torchvision.transforms import InterpolationMode from transforms import SimpleCopyPaste @@ -147,7 +148,11 @@ def get_args_parser(add_help=True): parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training") # Use CopyPaste augmentation training parameter - parser.add_argument("--use-copypaste", action="store_true", help="Use CopyPaste data augmentation") + parser.add_argument( + "--use-copypaste", + action="store_true", + help="Use CopyPaste data augmentation. It is intended to work together with data-augmentation='lsj'.", + ) return parser @@ -186,8 +191,17 @@ def main(args): train_collate_fn = utils.collate_fn if args.use_copypaste: + if args.data_augmentation in ["ssd", "ssdlite"]: + raise RuntimeError("SimpleCopyPaste algorithm does support 'ssd', 'ssdlite' data augmentation policies") + print("Use SimpleCopyPaste data aug") - copypaste = SimpleCopyPaste() + if args.data_augmentation not in ["lsj"]: + print( + "INFO: SimpleCopyPaste is intended to work together with data-augmentation='lsj'. " + "Currently, the algorithm can work with any data-augmentation policy " + "but an additional resize is applied to the pasted data" + ) + copypaste = SimpleCopyPaste(resize_interpolation=InterpolationMode.BILINEAR, blending=True) def copypaste_collate_fn(batch): return copypaste(*utils.collate_fn(batch)) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index b43654131be..eabbbbc8569 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -1,5 +1,4 @@ -import copy -from typing import List, Tuple, Dict, Optional, Union +from typing import List, Tuple, Dict, Optional, Union, cast import torch import torchvision @@ -441,30 +440,54 @@ def forward( return image, target -def _copy_paste(image, target, paste_image, paste_target, inplace=True): +def _copy_paste( + image: torch.Tensor, + target: Dict[str, Tensor], + paste_image: torch.Tensor, + paste_target: Dict[str, Tensor], + blending: bool = True, + resize_interpolation: F.InterpolationMode = F.InterpolationMode.BILINEAR, +) -> Tuple[torch.Tensor, Dict[str, Tensor]]: # Random paste targets selection: num_masks = len(paste_target["masks"]) - random_selection = torch.randint(0, num_masks, (num_masks,)).unique() + + if num_masks < 1: + # Such degerante case with num_masks=0 can happen with LSJ + # Let's just return (image, target) + return image, target + + # We have to please torch script by explicitly specifying dtype as torch.long + random_selection = torch.unique(torch.randint(0, num_masks, (num_masks,))).to(torch.long) paste_masks = paste_target["masks"][random_selection] paste_boxes = paste_target["boxes"][random_selection] paste_labels = paste_target["labels"][random_selection] masks = target["masks"] - # If source and paste data have different sizes - # Let's resize paste data + + # We resize source and paste data if they have different sizes + # This is something we introduced here as originally the algorithm works + # on equal-sized data (for example, coming from LSJ data augmentations) size1 = image.shape[-2:] size2 = paste_image.shape[-2:] if size1 != size2: - paste_image = F.resize(paste_image, size1, interpolation=F.InterpolationMode.BICUBIC) + paste_image = F.resize(paste_image, size1, interpolation=resize_interpolation) paste_masks = F.resize(paste_masks, size1, interpolation=F.InterpolationMode.NEAREST) # resize bboxes: ratios = torch.tensor((size1[1] / size2[1], size1[0] / size2[0]), device=paste_boxes.device) paste_boxes = paste_boxes.view(-1, 2, 2).mul(ratios).view(paste_boxes.shape) paste_alpha_mask = paste_masks.sum(dim=0) > 0 - paste_alpha_mask = F.gaussian_blur(paste_alpha_mask.unsqueeze(0), kernel_size=(5, 5), sigma=2.0) + + if blending: + paste_alpha_mask = F.gaussian_blur( + paste_alpha_mask.unsqueeze(0), + kernel_size=(5, 5), + sigma=[ + 2.0, + ], + ) # Copy-paste images: image = (image * (~paste_alpha_mask)) + (paste_image * paste_alpha_mask) @@ -475,7 +498,7 @@ def _copy_paste(image, target, paste_image, paste_target, inplace=True): masks = masks[non_all_zero_masks] # Do a shallow copy of the target dict - out_target = copy.copy(target) + out_target = {k: v for k, v in target.items()} out_target["masks"] = torch.cat([masks, paste_masks]) @@ -492,12 +515,12 @@ def _copy_paste(image, target, paste_image, paste_target, inplace=True): if "iscrowd" in target and "iscrowd" in paste_target: # target['iscrowd'] size can be differ from mask size (non_all_zero_masks) + # For example, if previous transforms geometrically modifies masks/boxes/labels but + # does not update "iscrowd" if len(target["iscrowd"]) == len(non_all_zero_masks): iscrowd = target["iscrowd"][non_all_zero_masks] paste_iscrowd = paste_target["iscrowd"][random_selection] out_target["iscrowd"] = torch.cat([iscrowd, paste_iscrowd]) - else: - out_target["iscrowd"] = target["iscrowd"] # Check for degenerated boxes and remove them boxes = out_target["boxes"] @@ -511,41 +534,59 @@ def _copy_paste(image, target, paste_image, paste_target, inplace=True): if "area" in out_target: out_target["area"] = out_target["area"][valid_targets] - if "iscrowd" in out_target: + if "iscrowd" in out_target and len(out_target["iscrowd"]) == len(valid_targets): out_target["iscrowd"] = out_target["iscrowd"][valid_targets] return image, out_target class SimpleCopyPaste(torch.nn.Module): - def forward(self, images, targets=None): - assert targets is not None - assert isinstance(images, tuple) and all([isinstance(v, torch.Tensor) for v in images]) - assert isinstance(targets, tuple) and len(images) == len(targets) + def __init__(self, blending=True, resize_interpolation=F.InterpolationMode.BILINEAR): + super().__init__() + self.resize_interpolation = resize_interpolation + self.blending = blending + + def forward( + self, images: List[torch.Tensor], targets: List[Dict[str, Tensor]] + ) -> Tuple[List[torch.Tensor], List[Dict[str, Tensor]]]: + torch._assert( + isinstance(images, (list, tuple)) and all([isinstance(v, torch.Tensor) for v in images]), + "images should be a list of tensors", + ) + torch._assert( + isinstance(targets, (list, tuple)) and len(images) == len(targets), + "targets should be a list of the same size as images", + ) for target in targets: - assert isinstance(target, dict) + # Can not check for instance type dict with inside torch.jit.script + # torch._assert(isinstance(target, dict), "targets item should be a dict") for k in ["masks", "boxes", "labels"]: - assert k in target, f"Key {k} should be present in targets" - assert isinstance(target[k], torch.Tensor) + torch._assert(k in target, f"Key {k} should be present in targets") + torch._assert(isinstance(target[k], torch.Tensor), f"Value for the key {k} should be a tensor") # images = [t1, t2, ..., tN] # Let's define paste_images as shifted list of input images # paste_images = [t2, t3, ..., tN, t1] # FYI: in TF they mix data on the dataset level - shift = 1 - images_rolled = images[-shift:] + images[:-shift] - targets_rolled = targets[-shift:] + targets[:-shift] - - output_images = [] - output_targets = [] - - data = [images, targets, images_rolled, targets_rolled] - for image, target, paste_image, paste_target in zip(*data): - output_image, output_data = _copy_paste(image, target, paste_image, paste_target) + images_rolled = images[-1:] + images[:-1] + targets_rolled = targets[-1:] + targets[:-1] + + output_images: List[torch.Tensor] = [] + output_targets: List[Dict[str, Tensor]] = [] + + for image, target, paste_image, paste_target in zip(images, targets, images_rolled, targets_rolled): + output_image, output_data = _copy_paste( + image, + target, + paste_image, + paste_target, + blending=self.blending, + resize_interpolation=self.resize_interpolation, + ) output_images.append(output_image) output_targets.append(output_data) - return tuple(output_images), tuple(output_targets) + return output_images, output_targets def __repr__(self) -> str: s = f"{self.__class__.__name__}()" From 956aa7783956433d63403bbf5db65672d08341d4 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Tue, 14 Jun 2022 14:01:08 +0000 Subject: [PATCH 21/22] Fixed flake8 --- references/detection/transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/references/detection/transforms.py b/references/detection/transforms.py index eabbbbc8569..162ea4e8983 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -1,4 +1,4 @@ -from typing import List, Tuple, Dict, Optional, Union, cast +from typing import List, Tuple, Dict, Optional, Union import torch import torchvision From 7020eb8287665e49ce28f3f1b8fa98321cfe0b75 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Wed, 15 Jun 2022 10:17:07 +0000 Subject: [PATCH 22/22] Updates according to the review --- references/detection/train.py | 15 ++++----------- references/detection/transforms.py | 5 +++-- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index 6fd63b9c2ee..4cad35ccb0a 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -151,7 +151,7 @@ def get_args_parser(add_help=True): parser.add_argument( "--use-copypaste", action="store_true", - help="Use CopyPaste data augmentation. It is intended to work together with data-augmentation='lsj'.", + help="Use CopyPaste data augmentation. Works only with data-augmentation='lsj'.", ) return parser @@ -191,16 +191,9 @@ def main(args): train_collate_fn = utils.collate_fn if args.use_copypaste: - if args.data_augmentation in ["ssd", "ssdlite"]: - raise RuntimeError("SimpleCopyPaste algorithm does support 'ssd', 'ssdlite' data augmentation policies") - - print("Use SimpleCopyPaste data aug") - if args.data_augmentation not in ["lsj"]: - print( - "INFO: SimpleCopyPaste is intended to work together with data-augmentation='lsj'. " - "Currently, the algorithm can work with any data-augmentation policy " - "but an additional resize is applied to the pasted data" - ) + if args.data_augmentation != "lsj": + raise RuntimeError("SimpleCopyPaste algorithm currently only supports the 'lsj' data augmentation policies") + copypaste = SimpleCopyPaste(resize_interpolation=InterpolationMode.BILINEAR, blending=True) def copypaste_collate_fn(batch): diff --git a/references/detection/transforms.py b/references/detection/transforms.py index 162ea4e8983..35ae34bd56a 100644 --- a/references/detection/transforms.py +++ b/references/detection/transforms.py @@ -458,7 +458,8 @@ def _copy_paste( return image, target # We have to please torch script by explicitly specifying dtype as torch.long - random_selection = torch.unique(torch.randint(0, num_masks, (num_masks,))).to(torch.long) + random_selection = torch.randint(0, num_masks, (num_masks,), device=paste_image.device) + random_selection = torch.unique(random_selection).to(torch.long) paste_masks = paste_target["masks"][random_selection] paste_boxes = paste_target["boxes"][random_selection] @@ -589,5 +590,5 @@ def forward( return output_images, output_targets def __repr__(self) -> str: - s = f"{self.__class__.__name__}()" + s = f"{self.__class__.__name__}(blending={self.blending}, resize_interpolation={self.resize_interpolation})" return s