-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Adding Mixup and Cutmix #4379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Mixup and Cutmix #4379
Changes from all commits
2de7cfc
55aedd3
33d9575
c15dce0
6f2ebea
c1bc525
67acd89
544967e
0e128f3
3f19902
78e8605
de9fa07
3f212fe
33c2973
9abb18b
b5bf8fc
eb932b9
b548b7a
e3be92b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,11 +4,13 @@ | |
|
||
import torch | ||
import torch.utils.data | ||
from torch.utils.data.dataloader import default_collate | ||
from torch import nn | ||
import torchvision | ||
from torchvision.transforms.functional import InterpolationMode | ||
|
||
import presets | ||
import transforms | ||
import utils | ||
|
||
try: | ||
|
@@ -164,10 +166,21 @@ def main(args): | |
train_dir = os.path.join(args.data_path, 'train') | ||
val_dir = os.path.join(args.data_path, 'val') | ||
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) | ||
|
||
collate_fn = None | ||
num_classes = len(dataset.classes) | ||
mixup_transforms = [] | ||
if args.mixup_alpha > 0.0: | ||
mixup_transforms.append(transforms.RandomMixup(num_classes, p=1.0, alpha=args.mixup_alpha)) | ||
if args.cutmix_alpha > 0.0: | ||
mixup_transforms.append(transforms.RandomCutmix(num_classes, p=1.0, alpha=args.cutmix_alpha)) | ||
if mixup_transforms: | ||
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) | ||
collate_fn = lambda batch: mixupcutmix(*default_collate(batch)) # noqa: E731 | ||
data_loader = torch.utils.data.DataLoader( | ||
dataset, batch_size=args.batch_size, | ||
sampler=train_sampler, num_workers=args.workers, pin_memory=True) | ||
|
||
sampler=train_sampler, num_workers=args.workers, pin_memory=True, | ||
collate_fn=collate_fn) | ||
data_loader_test = torch.utils.data.DataLoader( | ||
dataset_test, batch_size=args.batch_size, | ||
sampler=test_sampler, num_workers=args.workers, pin_memory=True) | ||
|
@@ -272,6 +285,8 @@ def get_args_parser(add_help=True): | |
parser.add_argument('--label-smoothing', default=0.0, type=float, | ||
help='label smoothing (default: 0.0)', | ||
dest='label_smoothing') | ||
parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)') | ||
parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm exposing here very few options (I'm using hardcoded There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @datumbox nn.CrossEntropyLoss dose not work when you use mixup or cutmix, because the traget shape is (N, K), rather (N,) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this was added at pytorch/pytorch#63122 This should be available on the latest stable version of pytorch. See doc. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @datumbox ok, thanks. |
||
parser.add_argument('--lr-step-size', default=30, type=int, help='decrease lr every step-size epochs') | ||
parser.add_argument('--lr-gamma', default=0.1, type=float, help='decrease lr by a factor of lr-gamma') | ||
parser.add_argument('--print-freq', default=10, type=int, help='print frequency') | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,175 @@ | ||
import math | ||
import torch | ||
|
||
from typing import Tuple | ||
from torch import Tensor | ||
from torchvision.transforms import functional as F | ||
|
||
|
||
class RandomMixup(torch.nn.Module): | ||
"""Randomly apply Mixup to the provided batch and targets. | ||
The class implements the data augmentations as described in the paper | ||
`"mixup: Beyond Empirical Risk Minimization" <https://arxiv.org/abs/1710.09412>`_. | ||
|
||
Args: | ||
num_classes (int): number of classes used for one-hot encoding. | ||
p (float): probability of the batch being transformed. Default value is 0.5. | ||
alpha (float): hyperparameter of the Beta distribution used for mixup. | ||
Default value is 1.0. | ||
inplace (bool): boolean to make this transform inplace. Default set to False. | ||
""" | ||
|
||
def __init__(self, num_classes: int, | ||
p: float = 0.5, alpha: float = 1.0, | ||
inplace: bool = False) -> None: | ||
super().__init__() | ||
assert num_classes > 0, "Please provide a valid positive value for the num_classes." | ||
assert alpha > 0, "Alpha param can't be zero." | ||
|
||
self.num_classes = num_classes | ||
self.p = p | ||
self.alpha = alpha | ||
self.inplace = inplace | ||
|
||
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: | ||
""" | ||
Args: | ||
batch (Tensor): Float tensor of size (B, C, H, W) | ||
target (Tensor): Integer tensor of size (B, ) | ||
|
||
Returns: | ||
Tensor: Randomly transformed batch. | ||
""" | ||
if batch.ndim != 4: | ||
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) | ||
elif target.ndim != 1: | ||
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) | ||
elif not batch.is_floating_point(): | ||
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) | ||
elif target.dtype != torch.int64: | ||
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) | ||
|
||
if not self.inplace: | ||
batch = batch.clone() | ||
target = target.clone() | ||
|
||
if target.ndim == 1: | ||
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) | ||
|
||
if torch.rand(1).item() >= self.p: | ||
return batch, target | ||
|
||
# It's faster to roll the batch by one instead of shuffling it to create image pairs | ||
batch_rolled = batch.roll(1, 0) | ||
target_rolled = target.roll(1, 0) | ||
|
||
# Implemented as on mixup paper, page 3. | ||
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) | ||
batch_rolled.mul_(1.0 - lambda_param) | ||
batch.mul_(lambda_param).add_(batch_rolled) | ||
|
||
target_rolled.mul_(1.0 - lambda_param) | ||
target.mul_(lambda_param).add_(target_rolled) | ||
|
||
return batch, target | ||
|
||
def __repr__(self) -> str: | ||
s = self.__class__.__name__ + '(' | ||
s += 'num_classes={num_classes}' | ||
s += ', p={p}' | ||
s += ', alpha={alpha}' | ||
s += ', inplace={inplace}' | ||
s += ')' | ||
return s.format(**self.__dict__) | ||
|
||
|
||
class RandomCutmix(torch.nn.Module): | ||
"""Randomly apply Cutmix to the provided batch and targets. | ||
The class implements the data augmentations as described in the paper | ||
`"CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features" | ||
<https://arxiv.org/abs/1905.04899>`_. | ||
|
||
Args: | ||
num_classes (int): number of classes used for one-hot encoding. | ||
p (float): probability of the batch being transformed. Default value is 0.5. | ||
alpha (float): hyperparameter of the Beta distribution used for cutmix. | ||
Default value is 1.0. | ||
inplace (bool): boolean to make this transform inplace. Default set to False. | ||
""" | ||
|
||
def __init__(self, num_classes: int, | ||
p: float = 0.5, alpha: float = 1.0, | ||
inplace: bool = False) -> None: | ||
super().__init__() | ||
assert num_classes > 0, "Please provide a valid positive value for the num_classes." | ||
assert alpha > 0, "Alpha param can't be zero." | ||
|
||
self.num_classes = num_classes | ||
self.p = p | ||
self.alpha = alpha | ||
self.inplace = inplace | ||
|
||
def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]: | ||
""" | ||
Args: | ||
batch (Tensor): Float tensor of size (B, C, H, W) | ||
target (Tensor): Integer tensor of size (B, ) | ||
|
||
Returns: | ||
Tensor: Randomly transformed batch. | ||
""" | ||
if batch.ndim != 4: | ||
raise ValueError("Batch ndim should be 4. Got {}".format(batch.ndim)) | ||
elif target.ndim != 1: | ||
raise ValueError("Target ndim should be 1. Got {}".format(target.ndim)) | ||
elif not batch.is_floating_point(): | ||
raise TypeError('Batch dtype should be a float tensor. Got {}.'.format(batch.dtype)) | ||
elif target.dtype != torch.int64: | ||
raise TypeError("Target dtype should be torch.int64. Got {}".format(target.dtype)) | ||
|
||
if not self.inplace: | ||
batch = batch.clone() | ||
target = target.clone() | ||
|
||
if target.ndim == 1: | ||
target = torch.nn.functional.one_hot(target, num_classes=self.num_classes).to(dtype=batch.dtype) | ||
|
||
if torch.rand(1).item() >= self.p: | ||
return batch, target | ||
|
||
# It's faster to roll the batch by one instead of shuffling it to create image pairs | ||
batch_rolled = batch.roll(1, 0) | ||
target_rolled = target.roll(1, 0) | ||
|
||
# Implemented as on cutmix paper, page 12 (with minor corrections on typos). | ||
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0]) | ||
W, H = F.get_image_size(batch) | ||
|
||
r_x = torch.randint(W, (1,)) | ||
r_y = torch.randint(H, (1,)) | ||
|
||
r = 0.5 * math.sqrt(1.0 - lambda_param) | ||
r_w_half = int(r * W) | ||
r_h_half = int(r * H) | ||
|
||
x1 = int(torch.clamp(r_x - r_w_half, min=0)) | ||
y1 = int(torch.clamp(r_y - r_h_half, min=0)) | ||
x2 = int(torch.clamp(r_x + r_w_half, max=W)) | ||
y2 = int(torch.clamp(r_y + r_h_half, max=H)) | ||
|
||
batch[:, :, y1:y2, x1:x2] = batch_rolled[:, :, y1:y2, x1:x2] | ||
lambda_param = float(1.0 - (x2 - x1) * (y2 - y1) / (W * H)) | ||
|
||
target_rolled.mul_(1.0 - lambda_param) | ||
target.mul_(lambda_param).add_(target_rolled) | ||
|
||
return batch, target | ||
|
||
def __repr__(self) -> str: | ||
s = self.__class__.__name__ + '(' | ||
s += 'num_classes={num_classes}' | ||
s += ', p={p}' | ||
s += ', alpha={alpha}' | ||
s += ', inplace={inplace}' | ||
s += ')' | ||
return s.format(**self.__dict__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not for now, but this exposes a limitation of our current datasets, which is that we don't consistently enforce a way of querying the number of classes in a dataset. The dataset refactoring work from @pmeier will address this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With #4432, you will be able to do
where
categories
is a list of strings in which the index corresponds to the label.