-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Adding Mixup and Cutmix #4379
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding Mixup and Cutmix #4379
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Highlighting some "interesting" parts below to provide context. The PR is not yet ready.
torchvision/transforms/transforms.py
Outdated
def __init__(self, num_classes: int, | ||
p: float = 1.0, mixup_alpha: float = 1.0, | ||
cutmix_p: float = 0.5, cutmix_alpha: float = 0.0, | ||
label_smoothing: float = 0.0, inplace: bool = False) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand that the choice of offering inplace
support won't make everyone happy. The reason I decided to support it is:
- It is more performant in terms of memory and speed.
- Most the tensor-only transforms on torchvision typically support
inplace
operations (seeNormalize
andRandomErasing
). - The ClassyVision version of this method supported inplace operations and thus to ensure we are equally performant we should support it as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still a bit skeptical about adding this in-place operation, as most of the operators in here are actually not performing any in-place operation, so the gains of having an inplace flag might not be that large. But we might need to benchmark to be able to know for sure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see why we can't support an in-place operation if one can be implemented correctly. For now I plan to move the entire implementation on the references script so that we have time to think this through before we commit to the API, so we don't have to solve this now. But overall I think inplace is a valid optimization and provided it can be implemented properly, there is no reason not do get the extra speed improvements.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As discussed offline with @fmassa we are going to review the use of inplace
before moving this from references to transforms.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a ton for the PR!
Adding a few high-level comments, let me know what you think
torchvision/transforms/transforms.py
Outdated
def __init__(self, num_classes: int, | ||
p: float = 1.0, mixup_alpha: float = 1.0, | ||
cutmix_p: float = 0.5, cutmix_alpha: float = 0.0, | ||
label_smoothing: float = 0.0, inplace: bool = False) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still a bit skeptical about adding this in-place operation, as most of the operators in here are actually not performing any in-place operation, so the gains of having an inplace flag might not be that large. But we might need to benchmark to be able to know for sure
0d115af
to
c4ca8c9
Compare
c4ca8c9
to
67acd89
Compare
bd72c0a
to
3f19902
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I took care of splitting the methods. Below I highlight other interesting bits that were not there on the previous review.
torchvision/transforms/transforms.py
Outdated
""" | ||
|
||
def __init__(self, num_classes: int, | ||
p: float = 0.5, alpha: float = 1.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed the default to p=0.5
to keep it consistent with other transforms.
torchvision/transforms/transforms.py
Outdated
return s.format(**self.__dict__) | ||
|
||
|
||
class RandomCutmix(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is some code duplication and thus bits that can be shared across the 2 classes, but this will be fixed on the new API with proper class inheritance.
@@ -273,6 +285,8 @@ def get_args_parser(add_help=True): | |||
parser.add_argument('--label-smoothing', default=0.0, type=float, | |||
help='label smoothing (default: 0.0)', | |||
dest='label_smoothing') | |||
parser.add_argument('--mixup-alpha', default=0.0, type=float, help='mixup alpha (default: 0.0)') | |||
parser.add_argument('--cutmix-alpha', default=0.0, type=float, help='cutmix alpha (default: 0.0)') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm exposing here very few options (I'm using hardcoded p
params). Keeping it simple until we use it in models to see what we want to support.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@datumbox nn.CrossEntropyLoss dose not work when you use mixup or cutmix, because the traget shape is (N, K), rather (N,)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was added at pytorch/pytorch#63122
This should be available on the latest stable version of pytorch. See doc.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@datumbox ok, thanks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
I've made a few minor comments, none of which are merge-blocking.
@@ -165,10 +166,21 @@ def main(args): | |||
train_dir = os.path.join(args.data_path, 'train') | |||
val_dir = os.path.join(args.data_path, 'val') | |||
dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) | |||
|
|||
collate_fn = None | |||
num_classes = len(dataset.classes) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not for now, but this exposes a limitation of our current datasets, which is that we don't consistently enforce a way of querying the number of classes in a dataset. The dataset refactoring work from @pmeier will address this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With #4432, you will be able to do
info = torchvision.datasets.info(name)
info.categories
where categories
is a list of strings in which the index corresponds to the label.
4724bfe
to
4a89568
Compare
4a89568
to
eb932b9
Compare
18a2969
to
b548b7a
Compare
Summary: * Add RandomMixupCutmix. * Add test with real data. * Use dataloader and collate in the test. * Making RandomMixupCutmix JIT scriptable. * Move out label_smoothing and try roll instead of flip * Adding mixup/cutmix in references script. * Handle one-hot encoded target in accuracy. * Add support of devices on tests. * Separate Mixup from Cutmix. * Add check for floats. * Adding device on expect value. * Remove hardcoded weights. * One-hot only when necessary. * Fix linter. * Moving mixup and cutmix to references. * Final code clean up. Reviewed By: datumbox Differential Revision: D31268036 fbshipit-source-id: 6a73c079d667443da898e3b175b88978b24d52ad
Partially resolves #3817, resolves #4281
I'm adding Mixup and Cutmix in References instead in the Transforms to allow the investigation of the new Transforms API to conclude. The transforms were tested using unit-tests which can be found in earlier commit. Prior to moving them in the Transforms area, we should consider introducing a base-class and eliminate duplicate code between the two implementations.