diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst index 7f835267200..cae53728b96 100644 --- a/docs/source/transforms.rst +++ b/docs/source/transforms.rst @@ -198,6 +198,7 @@ The new transform can be used standalone or mixed-and-matched with existing tran AutoAugment RandAugment TrivialAugmentWide + AugMix .. _functional_transforms: diff --git a/gallery/plot_transforms.py b/gallery/plot_transforms.py index ab0cb892b16..d781f8f35ed 100644 --- a/gallery/plot_transforms.py +++ b/gallery/plot_transforms.py @@ -263,6 +263,14 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs): imgs = [augmenter(orig_img) for _ in range(4)] plot(imgs) +#################################### +# AugMix +# ~~~~~~ +# The :class:`~torchvision.transforms.AugMix` transform automatically augments the data. +augmenter = T.AugMix() +imgs = [augmenter(orig_img) for _ in range(4)] +plot(imgs) + #################################### # Randomly-applied transforms # --------------------------- diff --git a/references/classification/presets.py b/references/classification/presets.py index 6e1000174ab..418ef3e2e07 100644 --- a/references/classification/presets.py +++ b/references/classification/presets.py @@ -22,6 +22,8 @@ def __init__( trans.append(autoaugment.RandAugment(interpolation=interpolation)) elif auto_augment_policy == "ta_wide": trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation)) + elif auto_augment_policy == "augmix": + trans.append(autoaugment.AugMix(interpolation=interpolation)) else: aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy) trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation)) diff --git a/test/test_transforms.py b/test/test_transforms.py index 4219f4d1645..246afa84802 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1601,6 +1601,25 @@ def test_trivialaugmentwide(fill, num_magnitude_bins, grayscale): transform.__repr__() +@pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)]) +@pytest.mark.parametrize("severity", [1, 10]) +@pytest.mark.parametrize("mixture_width", [1, 2]) +@pytest.mark.parametrize("chain_depth", [-1, 2]) +@pytest.mark.parametrize("all_ops", [True, False]) +@pytest.mark.parametrize("grayscale", [True, False]) +def test_augmix(fill, severity, mixture_width, chain_depth, all_ops, grayscale): + random.seed(42) + img = Image.open(GRACE_HOPPER) + if grayscale: + img, fill = _get_grayscale_test_image(img, fill) + transform = transforms.AugMix( + fill=fill, severity=severity, mixture_width=mixture_width, chain_depth=chain_depth, all_ops=all_ops + ) + for _ in range(100): + img = transform(img) + transform.__repr__() + + def test_random_crop(): height = random.randint(10, 32) * 2 width = random.randint(10, 32) * 2 diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index 9bc499467b7..fb6bec5bb9b 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -720,7 +720,38 @@ def test_trivialaugmentwide(device, fill): _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) -@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide]) +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "fill", + [ + None, + 85, + (10, -10, 10), + 0.7, + [0.0, 0.0, 0.0], + [ + 1, + ], + 1, + ], +) +def test_augmix(device, fill): + tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device) + batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device) + + class DeterministicAugMix(T.AugMix): + def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor: + # patch the method to ensure that the order of rand calls doesn't affect the outcome + return params.softmax(dim=-1) + + transform = DeterministicAugMix(fill=fill) + s_transform = torch.jit.script(transform) + for _ in range(25): + _test_transform_vs_scripted(transform, s_transform, tensor) + _test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors) + + +@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide, T.AugMix]) def test_autoaugment_save(augmentation, tmpdir): transform = augmentation() s_transform = torch.jit.script(transform) diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index a6109ad7030..d820e5126a1 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -7,7 +7,7 @@ from . import functional as F, InterpolationMode -__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"] +__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"] def _apply_op( @@ -458,3 +458,154 @@ def __repr__(self) -> str: f")" ) return s + + +class AugMix(torch.nn.Module): + r"""AugMix data augmentation method based on + `"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" `_. + If the image is torch Tensor, it should be of type torch.uint8, and it is expected + to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. + If img is PIL Image, it is expected to be in mode "L" or "RGB". + + Args: + severity (int): The severity of base augmentation operators. Default is ``3``. + mixture_width (int): The number of augmentation chains. Default is ``3``. + chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3]. + Default is ``-1``. + alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``. + all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``. + interpolation (InterpolationMode): Desired interpolation enum defined by + :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. + If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + fill (sequence or number, optional): Pixel fill value for the area outside the transformed + image. If given a number, the value is used for all bands respectively. + """ + + def __init__( + self, + severity: int = 3, + mixture_width: int = 3, + chain_depth: int = -1, + alpha: float = 1.0, + all_ops: bool = True, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, + ) -> None: + super().__init__() + self._PARAMETER_MAX = 10 + if not (1 <= severity <= self._PARAMETER_MAX): + raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") + self.severity = severity + self.mixture_width = mixture_width + self.chain_depth = chain_depth + self.alpha = alpha + self.all_ops = all_ops + self.interpolation = interpolation + self.fill = fill + + def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + s = { + # op_name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True), + "TranslateY": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True), + "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), + "Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), + "Solarize": (torch.linspace(255.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + } + if self.all_ops: + s.update( + { + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), + "Color": (torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), + } + ) + return s + + @torch.jit.unused + def _pil_to_tensor(self, img) -> Tensor: + return F.pil_to_tensor(img) + + @torch.jit.unused + def _tensor_to_pil(self, img: Tensor): + return F.to_pil_image(img) + + def _sample_dirichlet(self, params: Tensor) -> Tensor: + # Must be on a separate method so that we can overwrite it in tests. + return torch._sample_dirichlet(params) + + def forward(self, orig_img: Tensor) -> Tensor: + """ + img (PIL Image or Tensor): Image to be transformed. + + Returns: + PIL Image or Tensor: Transformed image. + """ + fill = self.fill + if isinstance(orig_img, Tensor): + img = orig_img + if isinstance(fill, (int, float)): + fill = [float(fill)] * F.get_image_num_channels(img) + elif fill is not None: + fill = [float(f) for f in fill] + else: + img = self._pil_to_tensor(orig_img) + + op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img)) + + orig_dims = list(img.shape) + batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims) + batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) + + # Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet + # with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image. + m = self._sample_dirichlet( + torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1) + ) + + # Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images. + combined_weights = self._sample_dirichlet( + torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) + ) * m[:, 1].view([batch_dims[0], -1]) + + mix = m[:, 0].view(batch_dims) * batch + for i in range(self.mixture_width): + aug = batch + depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) + for _ in range(depth): + op_index = int(torch.randint(len(op_meta), (1,)).item()) + op_name = list(op_meta.keys())[op_index] + magnitudes, signed = op_meta[op_name] + magnitude = ( + float(magnitudes[torch.randint(self.severity, (1,), dtype=torch.long)].item()) + if magnitudes.ndim > 0 + else 0.0 + ) + if signed and torch.randint(2, (1,)): + magnitude *= -1.0 + aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill) + mix.add_(combined_weights[:, i].view(batch_dims) * aug) + mix = mix.view(orig_dims).to(dtype=img.dtype) + + if not isinstance(orig_img, Tensor): + return self._tensor_to_pil(mix) + return mix + + def __repr__(self) -> str: + s = ( + f"{self.__class__.__name__}(" + f"severity={self.severity}" + f", mixture_width={self.mixture_width}" + f", chain_depth={self.chain_depth}" + f", alpha={self.alpha}" + f", all_ops={self.all_ops}" + f", interpolation={self.interpolation}" + f", fill={self.fill}" + f")" + ) + return s