diff --git a/torchvision/transforms/autoaugment.py b/torchvision/transforms/autoaugment.py index e241b821871..9b2e1ac212e 100644 --- a/torchvision/transforms/autoaugment.py +++ b/torchvision/transforms/autoaugment.py @@ -10,6 +10,45 @@ __all__ = ["AutoAugmentPolicy", "AutoAugment"] +def _apply_op(img: Tensor, op_name: str, magnitude: float, + interpolation: InterpolationMode, fill: Optional[List[float]]): + if op_name == "ShearX": + img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], + interpolation=interpolation, fill=fill) + elif op_name == "ShearY": + img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], + interpolation=interpolation, fill=fill) + elif op_name == "TranslateX": + img = F.affine(img, angle=0.0, translate=[int(magnitude), 0], scale=1.0, + interpolation=interpolation, shear=[0.0, 0.0], fill=fill) + elif op_name == "TranslateY": + img = F.affine(img, angle=0.0, translate=[0, int(magnitude)], scale=1.0, + interpolation=interpolation, shear=[0.0, 0.0], fill=fill) + elif op_name == "Rotate": + img = F.rotate(img, magnitude, interpolation=interpolation, fill=fill) + elif op_name == "Brightness": + img = F.adjust_brightness(img, 1.0 + magnitude) + elif op_name == "Color": + img = F.adjust_saturation(img, 1.0 + magnitude) + elif op_name == "Contrast": + img = F.adjust_contrast(img, 1.0 + magnitude) + elif op_name == "Sharpness": + img = F.adjust_sharpness(img, 1.0 + magnitude) + elif op_name == "Posterize": + img = F.posterize(img, int(magnitude)) + elif op_name == "Solarize": + img = F.solarize(img, magnitude) + elif op_name == "AutoContrast": + img = F.autocontrast(img) + elif op_name == "Equalize": + img = F.equalize(img) + elif op_name == "Invert": + img = F.invert(img) + else: + raise ValueError("The provided operator {} is not recognized.".format(op_name)) + return img + + class AutoAugmentPolicy(Enum): """AutoAugment policies learned on different datasets. Available policies are IMAGENET, CIFAR10 and SVHN. @@ -19,116 +58,6 @@ class AutoAugmentPolicy(Enum): SVHN = "svhn" -def _get_transforms( # type: ignore[return] - policy: AutoAugmentPolicy -) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: - if policy == AutoAugmentPolicy.IMAGENET: - return [ - (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), - (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), - (("Equalize", 0.8, None), ("Equalize", 0.6, None)), - (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), - (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), - (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), - (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), - (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), - (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), - (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), - (("Rotate", 0.8, 8), ("Color", 0.4, 0)), - (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), - (("Equalize", 0.0, None), ("Equalize", 0.8, None)), - (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Color", 0.6, 4), ("Contrast", 1.0, 8)), - (("Rotate", 0.8, 8), ("Color", 1.0, 2)), - (("Color", 0.8, 8), ("Solarize", 0.8, 7)), - (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), - (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), - (("Color", 0.4, 0), ("Equalize", 0.6, None)), - (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), - (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), - (("Invert", 0.6, None), ("Equalize", 1.0, None)), - (("Color", 0.6, 4), ("Contrast", 1.0, 8)), - (("Equalize", 0.8, None), ("Equalize", 0.6, None)), - ] - elif policy == AutoAugmentPolicy.CIFAR10: - return [ - (("Invert", 0.1, None), ("Contrast", 0.2, 6)), - (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), - (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), - (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), - (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), - (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), - (("Color", 0.4, 3), ("Brightness", 0.6, 7)), - (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), - (("Equalize", 0.6, None), ("Equalize", 0.5, None)), - (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), - (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), - (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), - (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), - (("Brightness", 0.9, 6), ("Color", 0.2, 8)), - (("Solarize", 0.5, 2), ("Invert", 0.0, None)), - (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), - (("Equalize", 0.2, None), ("Equalize", 0.6, None)), - (("Color", 0.9, 9), ("Equalize", 0.6, None)), - (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), - (("Brightness", 0.1, 3), ("Color", 0.7, 0)), - (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), - (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), - (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), - (("Equalize", 0.8, None), ("Invert", 0.1, None)), - (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), - ] - elif policy == AutoAugmentPolicy.SVHN: - return [ - (("ShearX", 0.9, 4), ("Invert", 0.2, None)), - (("ShearY", 0.9, 8), ("Invert", 0.7, None)), - (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), - (("Invert", 0.9, None), ("Equalize", 0.6, None)), - (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), - (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), - (("ShearY", 0.9, 8), ("Invert", 0.4, None)), - (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), - (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), - (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), - (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), - (("ShearY", 0.8, 8), ("Invert", 0.7, None)), - (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), - (("Invert", 0.9, None), ("Equalize", 0.6, None)), - (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), - (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), - (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), - (("Invert", 0.6, None), ("Rotate", 0.8, 4)), - (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), - (("ShearX", 0.1, 6), ("Invert", 0.6, None)), - (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), - (("ShearY", 0.8, 4), ("Invert", 0.8, None)), - (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), - (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), - (("ShearX", 0.7, 2), ("Invert", 0.1, None)), - ] - - -def _get_magnitudes() -> Dict[str, Tuple[Optional[Tensor], Optional[bool]]]: - _BINS = 10 - return { - # name: (magnitudes, signed) - "ShearX": (torch.linspace(0.0, 0.3, _BINS), True), - "ShearY": (torch.linspace(0.0, 0.3, _BINS), True), - "TranslateX": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True), - "TranslateY": (torch.linspace(0.0, 150.0 / 331.0, _BINS), True), - "Rotate": (torch.linspace(0.0, 30.0, _BINS), True), - "Brightness": (torch.linspace(0.0, 0.9, _BINS), True), - "Color": (torch.linspace(0.0, 0.9, _BINS), True), - "Contrast": (torch.linspace(0.0, 0.9, _BINS), True), - "Sharpness": (torch.linspace(0.0, 0.9, _BINS), True), - "Posterize": (torch.tensor([8, 8, 7, 7, 6, 6, 5, 5, 4, 4]), False), - "Solarize": (torch.linspace(256.0, 0.0, _BINS), False), - "AutoContrast": (None, None), - "Equalize": (None, None), - "Invert": (None, None), - } - - class AutoAugment(torch.nn.Module): r"""AutoAugment data augmentation method based on `"AutoAugment: Learning Augmentation Strategies from Data" `_. @@ -156,11 +85,117 @@ def __init__( self.policy = policy self.interpolation = interpolation self.fill = fill + self.transforms = self._get_transforms(policy) - self.transforms = _get_transforms(policy) - if self.transforms is None: + def _get_transforms( + self, + policy: AutoAugmentPolicy + ) -> List[Tuple[Tuple[str, float, Optional[int]], Tuple[str, float, Optional[int]]]]: + if policy == AutoAugmentPolicy.IMAGENET: + return [ + (("Posterize", 0.4, 8), ("Rotate", 0.6, 9)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + (("Posterize", 0.6, 7), ("Posterize", 0.6, 6)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Equalize", 0.4, None), ("Rotate", 0.8, 8)), + (("Solarize", 0.6, 3), ("Equalize", 0.6, None)), + (("Posterize", 0.8, 5), ("Equalize", 1.0, None)), + (("Rotate", 0.2, 3), ("Solarize", 0.6, 8)), + (("Equalize", 0.6, None), ("Posterize", 0.4, 6)), + (("Rotate", 0.8, 8), ("Color", 0.4, 0)), + (("Rotate", 0.4, 9), ("Equalize", 0.6, None)), + (("Equalize", 0.0, None), ("Equalize", 0.8, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Rotate", 0.8, 8), ("Color", 1.0, 2)), + (("Color", 0.8, 8), ("Solarize", 0.8, 7)), + (("Sharpness", 0.4, 7), ("Invert", 0.6, None)), + (("ShearX", 0.6, 5), ("Equalize", 1.0, None)), + (("Color", 0.4, 0), ("Equalize", 0.6, None)), + (("Equalize", 0.4, None), ("Solarize", 0.2, 4)), + (("Solarize", 0.6, 5), ("AutoContrast", 0.6, None)), + (("Invert", 0.6, None), ("Equalize", 1.0, None)), + (("Color", 0.6, 4), ("Contrast", 1.0, 8)), + (("Equalize", 0.8, None), ("Equalize", 0.6, None)), + ] + elif policy == AutoAugmentPolicy.CIFAR10: + return [ + (("Invert", 0.1, None), ("Contrast", 0.2, 6)), + (("Rotate", 0.7, 2), ("TranslateX", 0.3, 9)), + (("Sharpness", 0.8, 1), ("Sharpness", 0.9, 3)), + (("ShearY", 0.5, 8), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.5, None), ("Equalize", 0.9, None)), + (("ShearY", 0.2, 7), ("Posterize", 0.3, 7)), + (("Color", 0.4, 3), ("Brightness", 0.6, 7)), + (("Sharpness", 0.3, 9), ("Brightness", 0.7, 9)), + (("Equalize", 0.6, None), ("Equalize", 0.5, None)), + (("Contrast", 0.6, 7), ("Sharpness", 0.6, 5)), + (("Color", 0.7, 7), ("TranslateX", 0.5, 8)), + (("Equalize", 0.3, None), ("AutoContrast", 0.4, None)), + (("TranslateY", 0.4, 3), ("Sharpness", 0.2, 6)), + (("Brightness", 0.9, 6), ("Color", 0.2, 8)), + (("Solarize", 0.5, 2), ("Invert", 0.0, None)), + (("Equalize", 0.2, None), ("AutoContrast", 0.6, None)), + (("Equalize", 0.2, None), ("Equalize", 0.6, None)), + (("Color", 0.9, 9), ("Equalize", 0.6, None)), + (("AutoContrast", 0.8, None), ("Solarize", 0.2, 8)), + (("Brightness", 0.1, 3), ("Color", 0.7, 0)), + (("Solarize", 0.4, 5), ("AutoContrast", 0.9, None)), + (("TranslateY", 0.9, 9), ("TranslateY", 0.7, 9)), + (("AutoContrast", 0.9, None), ("Solarize", 0.8, 3)), + (("Equalize", 0.8, None), ("Invert", 0.1, None)), + (("TranslateY", 0.7, 9), ("AutoContrast", 0.9, None)), + ] + elif policy == AutoAugmentPolicy.SVHN: + return [ + (("ShearX", 0.9, 4), ("Invert", 0.2, None)), + (("ShearY", 0.9, 8), ("Invert", 0.7, None)), + (("Equalize", 0.6, None), ("Solarize", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("AutoContrast", 0.8, None)), + (("ShearY", 0.9, 8), ("Invert", 0.4, None)), + (("ShearY", 0.9, 5), ("Solarize", 0.2, 6)), + (("Invert", 0.9, None), ("AutoContrast", 0.8, None)), + (("Equalize", 0.6, None), ("Rotate", 0.9, 3)), + (("ShearX", 0.9, 4), ("Solarize", 0.3, 3)), + (("ShearY", 0.8, 8), ("Invert", 0.7, None)), + (("Equalize", 0.9, None), ("TranslateY", 0.6, 6)), + (("Invert", 0.9, None), ("Equalize", 0.6, None)), + (("Contrast", 0.3, 3), ("Rotate", 0.8, 4)), + (("Invert", 0.8, None), ("TranslateY", 0.0, 2)), + (("ShearY", 0.7, 6), ("Solarize", 0.4, 8)), + (("Invert", 0.6, None), ("Rotate", 0.8, 4)), + (("ShearY", 0.3, 7), ("TranslateX", 0.9, 3)), + (("ShearX", 0.1, 6), ("Invert", 0.6, None)), + (("Solarize", 0.7, 2), ("TranslateY", 0.6, 7)), + (("ShearY", 0.8, 4), ("Invert", 0.8, None)), + (("ShearX", 0.7, 9), ("TranslateY", 0.8, 3)), + (("ShearY", 0.8, 5), ("AutoContrast", 0.7, None)), + (("ShearX", 0.7, 2), ("Invert", 0.1, None)), + ] + else: raise ValueError("The provided policy {} is not recognized.".format(policy)) - self._op_meta = _get_magnitudes() + + def _get_magnitudes(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: + return { + # name: (magnitudes, signed) + "ShearX": (torch.linspace(0.0, 0.3, num_bins), True), + "ShearY": (torch.linspace(0.0, 0.3, num_bins), True), + "TranslateX": (torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True), + "TranslateY": (torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True), + "Rotate": (torch.linspace(0.0, 30.0, num_bins), True), + "Brightness": (torch.linspace(0.0, 0.9, num_bins), True), + "Color": (torch.linspace(0.0, 0.9, num_bins), True), + "Contrast": (torch.linspace(0.0, 0.9, num_bins), True), + "Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), + "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), + "Solarize": (torch.linspace(256.0, 0.0, num_bins), False), + "AutoContrast": (torch.tensor(0.0), False), + "Equalize": (torch.tensor(0.0), False), + "Invert": (torch.tensor(0.0), False), + } @staticmethod def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: @@ -175,9 +210,6 @@ def get_params(transform_num: int) -> Tuple[int, Tensor, Tensor]: return policy_id, probs, signs - def _get_op_meta(self, name: str) -> Tuple[Optional[Tensor], Optional[bool]]: - return self._op_meta[name] - def forward(self, img: Tensor) -> Tensor: """ img (PIL Image or Tensor): Image to be transformed. @@ -196,46 +228,12 @@ def forward(self, img: Tensor) -> Tensor: for i, (op_name, p, magnitude_id) in enumerate(self.transforms[transform_id]): if probs[i] <= p: - magnitudes, signed = self._get_op_meta(op_name) - magnitude = float(magnitudes[magnitude_id].item()) \ - if magnitudes is not None and magnitude_id is not None else 0.0 - if signed is not None and signed and signs[i] == 0: + op_meta = self._get_magnitudes(10, F.get_image_size(img)) + magnitudes, signed = op_meta[op_name] + magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0 + if signed and signs[i] == 0: magnitude *= -1.0 - - if op_name == "ShearX": - img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[math.degrees(magnitude), 0.0], - interpolation=self.interpolation, fill=fill) - elif op_name == "ShearY": - img = F.affine(img, angle=0.0, translate=[0, 0], scale=1.0, shear=[0.0, math.degrees(magnitude)], - interpolation=self.interpolation, fill=fill) - elif op_name == "TranslateX": - img = F.affine(img, angle=0.0, translate=[int(F.get_image_size(img)[0] * magnitude), 0], scale=1.0, - interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill) - elif op_name == "TranslateY": - img = F.affine(img, angle=0.0, translate=[0, int(F.get_image_size(img)[1] * magnitude)], scale=1.0, - interpolation=self.interpolation, shear=[0.0, 0.0], fill=fill) - elif op_name == "Rotate": - img = F.rotate(img, magnitude, interpolation=self.interpolation, fill=fill) - elif op_name == "Brightness": - img = F.adjust_brightness(img, 1.0 + magnitude) - elif op_name == "Color": - img = F.adjust_saturation(img, 1.0 + magnitude) - elif op_name == "Contrast": - img = F.adjust_contrast(img, 1.0 + magnitude) - elif op_name == "Sharpness": - img = F.adjust_sharpness(img, 1.0 + magnitude) - elif op_name == "Posterize": - img = F.posterize(img, int(magnitude)) - elif op_name == "Solarize": - img = F.solarize(img, magnitude) - elif op_name == "AutoContrast": - img = F.autocontrast(img) - elif op_name == "Equalize": - img = F.equalize(img) - elif op_name == "Invert": - img = F.invert(img) - else: - raise ValueError("The provided operator {} is not recognized.".format(op_name)) + img = _apply_op(img, op_name, magnitude, interpolation=self.interpolation, fill=fill) return img