-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Adding AugMix implementation #5411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d3e8b08
6c1a388
bc5667c
bf1a17e
649305e
b05d712
547bcf9
03176f2
9c284cc
e4b62be
c722c02
ecc598e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
|
||
from . import functional as F, InterpolationMode | ||
|
||
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide"] | ||
__all__ = ["AutoAugmentPolicy", "AutoAugment", "RandAugment", "TrivialAugmentWide", "AugMix"] | ||
|
||
|
||
def _apply_op( | ||
|
@@ -458,3 +458,154 @@ def __repr__(self) -> str: | |
f")" | ||
) | ||
return s | ||
|
||
|
||
class AugMix(torch.nn.Module): | ||
r"""AugMix data augmentation method based on | ||
`"AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty" <https://arxiv.org/abs/1912.02781>`_. | ||
If the image is torch Tensor, it should be of type torch.uint8, and it is expected | ||
to have [..., 1 or 3, H, W] shape, where ... means an arbitrary number of leading dimensions. | ||
If img is PIL Image, it is expected to be in mode "L" or "RGB". | ||
|
||
Args: | ||
severity (int): The severity of base augmentation operators. Default is ``3``. | ||
mixture_width (int): The number of augmentation chains. Default is ``3``. | ||
chain_depth (int): The depth of augmentation chains. A negative value denotes stochastic depth sampled from the interval [1, 3]. | ||
Default is ``-1``. | ||
alpha (float): The hyperparameter for the probability distributions. Default is ``1.0``. | ||
all_ops (bool): Use all operations (including brightness, contrast, color and sharpness). Default is ``True``. | ||
interpolation (InterpolationMode): Desired interpolation enum defined by | ||
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. | ||
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. | ||
fill (sequence or number, optional): Pixel fill value for the area outside the transformed | ||
image. If given a number, the value is used for all bands respectively. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
severity: int = 3, | ||
mixture_width: int = 3, | ||
chain_depth: int = -1, | ||
alpha: float = 1.0, | ||
all_ops: bool = True, | ||
interpolation: InterpolationMode = InterpolationMode.BILINEAR, | ||
fill: Optional[List[float]] = None, | ||
) -> None: | ||
super().__init__() | ||
self._PARAMETER_MAX = 10 | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if not (1 <= severity <= self._PARAMETER_MAX): | ||
raise ValueError(f"The severity must be between [1, {self._PARAMETER_MAX}]. Got {severity} instead.") | ||
self.severity = severity | ||
self.mixture_width = mixture_width | ||
self.chain_depth = chain_depth | ||
self.alpha = alpha | ||
self.all_ops = all_ops | ||
self.interpolation = interpolation | ||
self.fill = fill | ||
|
||
def _augmentation_space(self, num_bins: int, image_size: List[int]) -> Dict[str, Tuple[Tensor, bool]]: | ||
s = { | ||
# op_name: (magnitudes, signed) | ||
"ShearX": (torch.linspace(0.0, 0.3, num_bins), True), | ||
"ShearY": (torch.linspace(0.0, 0.3, num_bins), True), | ||
"TranslateX": (torch.linspace(0.0, image_size[0] / 3.0, num_bins), True), | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"TranslateY": (torch.linspace(0.0, image_size[1] / 3.0, num_bins), True), | ||
"Rotate": (torch.linspace(0.0, 30.0, num_bins), True), | ||
"Posterize": (4 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False), | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"Solarize": (torch.linspace(255.0, 0.0, num_bins), False), | ||
"AutoContrast": (torch.tensor(0.0), False), | ||
"Equalize": (torch.tensor(0.0), False), | ||
} | ||
if self.all_ops: | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
s.update( | ||
{ | ||
"Brightness": (torch.linspace(0.0, 0.9, num_bins), True), | ||
"Color": (torch.linspace(0.0, 0.9, num_bins), True), | ||
"Contrast": (torch.linspace(0.0, 0.9, num_bins), True), | ||
"Sharpness": (torch.linspace(0.0, 0.9, num_bins), True), | ||
} | ||
) | ||
return s | ||
|
||
@torch.jit.unused | ||
def _pil_to_tensor(self, img) -> Tensor: | ||
return F.pil_to_tensor(img) | ||
|
||
@torch.jit.unused | ||
def _tensor_to_pil(self, img: Tensor): | ||
return F.to_pil_image(img) | ||
|
||
def _sample_dirichlet(self, params: Tensor) -> Tensor: | ||
# Must be on a separate method so that we can overwrite it in tests. | ||
return torch._sample_dirichlet(params) | ||
|
||
def forward(self, orig_img: Tensor) -> Tensor: | ||
""" | ||
img (PIL Image or Tensor): Image to be transformed. | ||
|
||
Returns: | ||
PIL Image or Tensor: Transformed image. | ||
""" | ||
fill = self.fill | ||
if isinstance(orig_img, Tensor): | ||
img = orig_img | ||
if isinstance(fill, (int, float)): | ||
fill = [float(fill)] * F.get_image_num_channels(img) | ||
elif fill is not None: | ||
fill = [float(f) for f in fill] | ||
else: | ||
img = self._pil_to_tensor(orig_img) | ||
Comment on lines
+549
to
+557
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Later we may want to refactor this part of code that could be applicable to all other aug strategies... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Certainly. We would need to consider how Videos can be handled as well here too. |
||
|
||
op_meta = self._augmentation_space(self._PARAMETER_MAX, F.get_image_size(img)) | ||
|
||
orig_dims = list(img.shape) | ||
batch = img.view([1] * max(4 - img.ndim, 0) + orig_dims) | ||
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) | ||
|
||
# Sample the beta weights for combining the original and augmented image. To get Beta, we use a Dirichlet | ||
# with 2 parameters. The 1st column stores the weights of the original and the 2nd the ones of augmented image. | ||
m = self._sample_dirichlet( | ||
torch.tensor([self.alpha, self.alpha], device=batch.device).expand(batch_dims[0], -1) | ||
) | ||
|
||
# Sample the mixing weights and combine them with the ones sampled from Beta for the augmented images. | ||
combined_weights = self._sample_dirichlet( | ||
torch.tensor([self.alpha] * self.mixture_width, device=batch.device).expand(batch_dims[0], -1) | ||
) * m[:, 1].view([batch_dims[0], -1]) | ||
|
||
mix = m[:, 0].view(batch_dims) * batch | ||
for i in range(self.mixture_width): | ||
aug = batch | ||
depth = self.chain_depth if self.chain_depth > 0 else int(torch.randint(low=1, high=4, size=(1,)).item()) | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for _ in range(depth): | ||
op_index = int(torch.randint(len(op_meta), (1,)).item()) | ||
op_name = list(op_meta.keys())[op_index] | ||
magnitudes, signed = op_meta[op_name] | ||
magnitude = ( | ||
float(magnitudes[torch.randint(self.severity, (1,), dtype=torch.long)].item()) | ||
if magnitudes.ndim > 0 | ||
else 0.0 | ||
) | ||
if signed and torch.randint(2, (1,)): | ||
magnitude *= -1.0 | ||
aug = _apply_op(aug, op_name, magnitude, interpolation=self.interpolation, fill=fill) | ||
mix.add_(combined_weights[:, i].view(batch_dims) * aug) | ||
mix = mix.view(orig_dims).to(dtype=img.dtype) | ||
|
||
if not isinstance(orig_img, Tensor): | ||
return self._tensor_to_pil(mix) | ||
return mix | ||
|
||
def __repr__(self) -> str: | ||
s = ( | ||
f"{self.__class__.__name__}(" | ||
f"severity={self.severity}" | ||
f", mixture_width={self.mixture_width}" | ||
f", chain_depth={self.chain_depth}" | ||
f", alpha={self.alpha}" | ||
f", all_ops={self.all_ops}" | ||
f", interpolation={self.interpolation}" | ||
f", fill={self.fill}" | ||
f")" | ||
) | ||
return s |
Uh oh!
There was an error while loading. Please reload this page.