@@ -113,7 +113,12 @@ def _get_cache_path(filepath):
113
113
def load_data (traindir , valdir , args ):
114
114
# Data loading code
115
115
print ("Loading data" )
116
- val_resize_size , val_crop_size , train_crop_size = args .val_resize_size , args .val_crop_size , args .train_crop_size
116
+ val_resize_size , val_crop_size , train_crop_size , center_crop = (
117
+ args .val_resize_size ,
118
+ args .val_crop_size ,
119
+ args .train_crop_size ,
120
+ args .train_center_crop ,
121
+ )
117
122
interpolation = InterpolationMode (args .interpolation )
118
123
119
124
print ("Loading training data" )
@@ -126,13 +131,18 @@ def load_data(traindir, valdir, args):
126
131
else :
127
132
auto_augment_policy = getattr (args , "auto_augment" , None )
128
133
random_erase_prob = getattr (args , "random_erase" , 0.0 )
134
+ ra_magnitude = args .ra_magnitude
135
+ augmix_severity = args .augmix_severity
129
136
dataset = torchvision .datasets .ImageFolder (
130
137
traindir ,
131
138
presets .ClassificationPresetTrain (
139
+ center_crop = center_crop ,
132
140
crop_size = train_crop_size ,
133
141
interpolation = interpolation ,
134
142
auto_augment_policy = auto_augment_policy ,
135
143
random_erase_prob = random_erase_prob ,
144
+ ra_magnitude = ra_magnitude ,
145
+ augmix_severity = augmix_severity ,
136
146
),
137
147
)
138
148
if args .cache_dataset :
@@ -207,7 +217,10 @@ def main(args):
207
217
mixup_transforms .append (transforms .RandomCutmix (num_classes , p = 1.0 , alpha = args .cutmix_alpha ))
208
218
if mixup_transforms :
209
219
mixupcutmix = torchvision .transforms .RandomChoice (mixup_transforms )
210
- collate_fn = lambda batch : mixupcutmix (* default_collate (batch )) # noqa: E731
220
+
221
+ def collate_fn (batch ):
222
+ return mixupcutmix (* default_collate (batch ))
223
+
211
224
data_loader = torch .utils .data .DataLoader (
212
225
dataset ,
213
226
batch_size = args .batch_size ,
@@ -448,6 +461,8 @@ def get_args_parser(add_help=True):
448
461
action = "store_true" ,
449
462
)
450
463
parser .add_argument ("--auto-augment" , default = None , type = str , help = "auto augment policy (default: None)" )
464
+ parser .add_argument ("--ra-magnitude" , default = 9 , type = int , help = "magnitude of auto augment policy" )
465
+ parser .add_argument ("--augmix-severity" , default = 3 , type = int , help = "severity of augmix policy" )
451
466
parser .add_argument ("--random-erase" , default = 0.0 , type = float , help = "random erasing probability (default: 0.0)" )
452
467
453
468
# Mixed precision training parameters
@@ -486,13 +501,17 @@ def get_args_parser(add_help=True):
486
501
parser .add_argument (
487
502
"--train-crop-size" , default = 224 , type = int , help = "the random crop size used for training (default: 224)"
488
503
)
504
+ parser .add_argument (
505
+ "--train-center-crop" ,
506
+ action = "store_true" ,
507
+ help = "use center crop instead of random crop for training (default: False)" ,
508
+ )
489
509
parser .add_argument ("--clip-grad-norm" , default = None , type = float , help = "the maximum gradient norm (default None)" )
490
510
parser .add_argument ("--ra-sampler" , action = "store_true" , help = "whether to use Repeated Augmentation in training" )
491
511
parser .add_argument (
492
512
"--ra-reps" , default = 3 , type = int , help = "number of repetitions for Repeated Augmentation (default: 3)"
493
513
)
494
514
parser .add_argument ("--weights" , default = None , type = str , help = "the weights enum name to load" )
495
-
496
515
return parser
497
516
498
517
0 commit comments