Skip to content

Adding AugMix implementation #5411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 18, 2022
1 change: 1 addition & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ The new transform can be used standalone or mixed-and-matched with existing tran
AutoAugment
RandAugment
TrivialAugmentWide
AugMix

.. _functional_transforms:

Expand Down
8 changes: 8 additions & 0 deletions gallery/plot_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,14 @@ def plot(imgs, with_orig=True, row_title=None, **imshow_kwargs):
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

####################################
# AugMix
# ~~~~~~
# The :class:`~torchvision.transforms.AugMix` transform automatically augments the data.
augmenter = T.AugMix()
imgs = [augmenter(orig_img) for _ in range(4)]
plot(imgs)

####################################
# Randomly-applied transforms
# ---------------------------
Expand Down
2 changes: 2 additions & 0 deletions references/classification/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ def __init__(
trans.append(autoaugment.RandAugment(interpolation=interpolation))
elif auto_augment_policy == "ta_wide":
trans.append(autoaugment.TrivialAugmentWide(interpolation=interpolation))
elif auto_augment_policy == "augmix":
trans.append(autoaugment.AugMix(interpolation=interpolation))
else:
aa_policy = autoaugment.AutoAugmentPolicy(auto_augment_policy)
trans.append(autoaugment.AutoAugment(policy=aa_policy, interpolation=interpolation))
Expand Down
19 changes: 19 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1601,6 +1601,25 @@ def test_trivialaugmentwide(fill, num_magnitude_bins, grayscale):
transform.__repr__()


@pytest.mark.parametrize("fill", [None, 85, (128, 128, 128)])
@pytest.mark.parametrize("severity", [1, 10])
@pytest.mark.parametrize("mixture_width", [1, 2])
@pytest.mark.parametrize("chain_depth", [-1, 2])
@pytest.mark.parametrize("all_ops", [True, False])
@pytest.mark.parametrize("grayscale", [True, False])
def test_augmix(fill, severity, mixture_width, chain_depth, all_ops, grayscale):
random.seed(42)
img = Image.open(GRACE_HOPPER)
if grayscale:
img, fill = _get_grayscale_test_image(img, fill)
transform = transforms.AugMix(
fill=fill, severity=severity, mixture_width=mixture_width, chain_depth=chain_depth, all_ops=all_ops
)
for _ in range(100):
img = transform(img)
transform.__repr__()


def test_random_crop():
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
Expand Down
33 changes: 32 additions & 1 deletion test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,38 @@ def test_trivialaugmentwide(device, fill):
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide])
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"fill",
[
None,
85,
(10, -10, 10),
0.7,
[0.0, 0.0, 0.0],
[
1,
],
1,
],
)
def test_augmix(device, fill):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=device)

class DeterministicAugMix(T.AugMix):
def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
# patch the method to ensure that the order of rand calls doesn't affect the outcome
return params.softmax(dim=-1)

transform = DeterministicAugMix(fill=fill)
s_transform = torch.jit.script(transform)
for _ in range(25):
_test_transform_vs_scripted(transform, s_transform, tensor)
_test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)


@pytest.mark.parametrize("augmentation", [T.AutoAugment, T.RandAugment, T.TrivialAugmentWide, T.AugMix])
def test_autoaugment_save(augmentation, tmpdir):
transform = augmentation()
s_transform = torch.jit.script(transform)
Expand Down
153 changes: 152 additions & 1 deletion torchvision/transforms/autoaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from . import functional as F, InterpolationMode

__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"]
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"]


def _apply_op(
Expand Down Expand Up @@ -458,3 +458,154 @@ def __repr__(self) -> str:
f")"
)
return s


class AugMix(torch.nn.Module):
r"""AugMix data augmentation method based on
`"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_.
If the image is torch Tensor, it should be of type torch.uint8, and it is expected
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions.
If img is PIL Image, it is expected to be in mode "L" or "RGB".

Args:
severity (int): The severity of base augmentation operators. Default is ``3``.
mixture_width (int): The number of augmentation chains. Default is ``3``.
chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3].
Default is ``-1``.
alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``.
all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.
"""

def __init__(
self,
severity: int = 3,
mixture_width: int = 3,
chain_depth: int = -1,
alpha: float = 1.0,
all_ops: bool = True,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[List[float]] = None,
) -> None:
super().__init__()
self._PARAMETER_MAX = 10
if not (1 <= severity <= self._PARAMETER_MAX):
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.")
self.severity = severity
self.mixture_width = mixture_width
self.chain_depth = chain_depth
self.alpha = alpha
self.all_ops = all_ops
self.interpolation = interpolation
self.fill = fill

def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]:
s = {
# op_name: (magnitudes, signed)
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True),
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True),
"TranslateX": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True),
"TranslateY": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True),
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True),
"Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False),
"AutoContrast": (torch.tensor(0.0), False),
"Equalize": (torch.tensor(0.0), False),
}
if self.all_ops:
s.update(
{
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True),
"Color": (torch.linspace(0.0, 0.9, num_bins), True),
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True),
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True),
}
)
return s

@torch.jit.unused
def _pil_to_tensor(self, img) -> Tensor:
return F.pil_to_tensor(img)

@torch.jit.unused
def _tensor_to_pil(self, img: Tensor):
return F.to_pil_image(img)

def _sample_dirichlet(self, params: Tensor) -> Tensor:
# Must be on a separate method so that we can overwrite it in tests.
return torch._sample_dirichlet(params)

def forward(self, orig_img: Tensor) -> Tensor:
"""
img (PIL Image or Tensor): Image to be transformed.

Returns:
PIL Image or Tensor: Transformed image.
"""
fill = self.fill
if isinstance(orig_img, Tensor):
img = orig_img
if isinstance(fill, (int, float)):
fill = [float(fill)] * F.get_image_num_channels(img)
elif fill is not None:
fill = [float(f) for f in fill]
else:
img = self._pil_to_tensor(orig_img)
Comment on lines +549 to +557
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Later we may want to refactor this part of code that could be applicable to all other aug strategies...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Certainly. We would need to consider how Videos can be handled as well here too.


op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img))

orig_dims = list(img.shape)
batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)

# Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet
# with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image.
m = self._sample_dirichlet(
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1)
)

# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images.
combined_weights = self._sample_dirichlet(
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1)
) * m[:, 1].view([batch_dims[0], -1])

mix = m[:, 0].view(batch_dims) * batch
for i in range(self.mixture_width):
aug = batch
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item())
for _ in range(depth):
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
magnitude = (
float(magnitudes[torch.randint(self.severity, (1,), dtype=torch.long)].item())
if magnitudes.ndim > 0
else 0.0
)
if signed and torch.randint(2, (1,)):
magnitude *= -1.0
aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill)
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
mix = mix.view(orig_dims).to(dtype=img.dtype)

if not isinstance(orig_img, Tensor):
return self._tensor_to_pil(mix)
return mix

def __repr__(self) -> str:
s = (
f"{self.__class__.__name__}("
f"severity={self.severity}"
f", mixture_width={self.mixture_width}"
f", chain_depth={self.chain_depth}"
f", alpha={self.alpha}"
f", all_ops={self.all_ops}"
f", interpolation={self.interpolation}"
f", fill={self.fill}"
f")"
)
return s