1
1
import datetime
2
2
import os
3
+ import random
3
4
import time
4
5
import warnings
5
6
15
16
from torchvision .transforms .functional import InterpolationMode
16
17
17
18
18
- def train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args , model_ema = None , scaler = None ):
19
+ def train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args , model_ema = None , scaler = None , scheduler = None ):
19
20
model .train ()
20
21
metric_logger = utils .MetricLogger (delimiter = " " )
21
22
metric_logger .add_meter ("lr" , utils .SmoothedValue (window_size = 1 , fmt = "{value}" ))
@@ -43,6 +44,9 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
43
44
if args .clip_grad_norm is not None :
44
45
nn .utils .clip_grad_norm_ (model .parameters (), args .clip_grad_norm )
45
46
optimizer .step ()
47
+
48
+ if scheduler is not None and args .lr_step_every_batch :
49
+ scheduler .step ()
46
50
47
51
if model_ema and i % args .model_ema_steps == 0 :
48
52
model_ema .update_parameters (model )
@@ -113,7 +117,7 @@ def _get_cache_path(filepath):
113
117
def load_data (traindir , valdir , args ):
114
118
# Data loading code
115
119
print ("Loading data" )
116
- val_resize_size , val_crop_size , train_crop_size = args .val_resize_size , args .val_crop_size , args .train_crop_size
120
+ val_resize_size , val_crop_size , train_crop_size , center_crop , policy_magnitude = args .val_resize_size , args .val_crop_size , args .train_crop_size , args . train_center_crop , args . policy_magnitude
117
121
interpolation = InterpolationMode (args .interpolation )
118
122
119
123
print ("Loading training data" )
@@ -129,10 +133,12 @@ def load_data(traindir, valdir, args):
129
133
dataset = torchvision .datasets .ImageFolder (
130
134
traindir ,
131
135
presets .ClassificationPresetTrain (
136
+ center_crop = center_crop ,
132
137
crop_size = train_crop_size ,
133
138
interpolation = interpolation ,
134
139
auto_augment_policy = auto_augment_policy ,
135
140
random_erase_prob = random_erase_prob ,
141
+ policy_magnitude = policy_magnitude ,
136
142
),
137
143
)
138
144
if args .cache_dataset :
@@ -182,7 +188,12 @@ def load_data(traindir, valdir, args):
182
188
def main (args ):
183
189
if args .output_dir :
184
190
utils .mkdir (args .output_dir )
185
-
191
+
192
+ if args .seed is None :
193
+ # randomly choose a seed
194
+ args .seed = random .randint (0 , 2 ** 32 )
195
+ utils .set_seed (args .seed )
196
+
186
197
utils .init_distributed_mode (args )
187
198
print (args )
188
199
@@ -261,13 +272,21 @@ def main(args):
261
272
raise RuntimeError (f"Invalid optimizer { args .opt } . Only SGD, RMSprop and AdamW are supported." )
262
273
263
274
scaler = torch .cuda .amp .GradScaler () if args .amp else None
275
+
276
+ batches_per_epoch = len (data_loader )
277
+ warmup_iters = args .lr_warmup_epochs
278
+ total_iters = args .epochs
279
+
280
+ if args .lr_step_every_batch :
281
+ warmup_iters *= batches_per_epoch
282
+ total_iters *= batches_per_epoch
264
283
265
284
args .lr_scheduler = args .lr_scheduler .lower ()
266
285
if args .lr_scheduler == "steplr" :
267
286
main_lr_scheduler = torch .optim .lr_scheduler .StepLR (optimizer , step_size = args .lr_step_size , gamma = args .lr_gamma )
268
287
elif args .lr_scheduler == "cosineannealinglr" :
269
288
main_lr_scheduler = torch .optim .lr_scheduler .CosineAnnealingLR (
270
- optimizer , T_max = args . epochs - args . lr_warmup_epochs , eta_min = args .lr_min
289
+ optimizer , T_max = total_iters - warmup_iters , eta_min = args .lr_min
271
290
)
272
291
elif args .lr_scheduler == "exponentiallr" :
273
292
main_lr_scheduler = torch .optim .lr_scheduler .ExponentialLR (optimizer , gamma = args .lr_gamma )
@@ -280,18 +299,18 @@ def main(args):
280
299
if args .lr_warmup_epochs > 0 :
281
300
if args .lr_warmup_method == "linear" :
282
301
warmup_lr_scheduler = torch .optim .lr_scheduler .LinearLR (
283
- optimizer , start_factor = args .lr_warmup_decay , total_iters = args . lr_warmup_epochs
302
+ optimizer , start_factor = args .lr_warmup_decay , total_iters = warmup_iters
284
303
)
285
304
elif args .lr_warmup_method == "constant" :
286
305
warmup_lr_scheduler = torch .optim .lr_scheduler .ConstantLR (
287
- optimizer , factor = args .lr_warmup_decay , total_iters = args . lr_warmup_epochs
306
+ optimizer , factor = args .lr_warmup_decay , total_iters = warmup_iters
288
307
)
289
308
else :
290
309
raise RuntimeError (
291
310
f"Invalid warmup lr method '{ args .lr_warmup_method } '. Only linear and constant are supported."
292
311
)
293
312
lr_scheduler = torch .optim .lr_scheduler .SequentialLR (
294
- optimizer , schedulers = [warmup_lr_scheduler , main_lr_scheduler ], milestones = [args . lr_warmup_epochs ]
313
+ optimizer , schedulers = [warmup_lr_scheduler , main_lr_scheduler ], milestones = [warmup_iters ]
295
314
)
296
315
else :
297
316
lr_scheduler = main_lr_scheduler
@@ -341,8 +360,9 @@ def main(args):
341
360
for epoch in range (args .start_epoch , args .epochs ):
342
361
if args .distributed :
343
362
train_sampler .set_epoch (epoch )
344
- train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args , model_ema , scaler )
345
- lr_scheduler .step ()
363
+ train_one_epoch (model , criterion , optimizer , data_loader , device , epoch , args , model_ema , scaler , lr_scheduler )
364
+ if not args .lr_step_every_batch :
365
+ lr_scheduler .step ()
346
366
evaluate (model , criterion , data_loader_test , device = device )
347
367
if model_ema :
348
368
evaluate (model_ema , criterion , data_loader_test , device = device , log_suffix = "EMA" )
@@ -371,7 +391,7 @@ def get_args_parser(add_help=True):
371
391
372
392
parser = argparse .ArgumentParser (description = "PyTorch Classification Training" , add_help = add_help )
373
393
374
- parser .add_argument ("--data-path" , default = "/datasets01 /imagenet_full_size/061417/" , type = str , help = "dataset path" )
394
+ parser .add_argument ("--data-path" , default = "/datasets01_ontap /imagenet_full_size/061417/" , type = str , help = "dataset path" )
375
395
parser .add_argument ("--model" , default = "resnet18" , type = str , help = "model name" )
376
396
parser .add_argument ("--device" , default = "cuda" , type = str , help = "device (Use cuda or cpu Default: cuda)" )
377
397
parser .add_argument (
@@ -425,6 +445,7 @@ def get_args_parser(add_help=True):
425
445
parser .add_argument ("--lr-step-size" , default = 30 , type = int , help = "decrease lr every step-size epochs" )
426
446
parser .add_argument ("--lr-gamma" , default = 0.1 , type = float , help = "decrease lr by a factor of lr-gamma" )
427
447
parser .add_argument ("--lr-min" , default = 0.0 , type = float , help = "minimum lr of lr schedule (default: 0.0)" )
448
+ parser .add_argument ("--lr-step-every-batch" , action = "store_true" , help = "decrease lr every step-size batches" , default = False )
428
449
parser .add_argument ("--print-freq" , default = 10 , type = int , help = "print frequency" )
429
450
parser .add_argument ("--output-dir" , default = "." , type = str , help = "path to save outputs" )
430
451
parser .add_argument ("--resume" , default = "" , type = str , help = "path of checkpoint" )
@@ -448,6 +469,7 @@ def get_args_parser(add_help=True):
448
469
action = "store_true" ,
449
470
)
450
471
parser .add_argument ("--auto-augment" , default = None , type = str , help = "auto augment policy (default: None)" )
472
+ parser .add_argument ("--policy-magnitude" , default = 9 , type = int , help = "magnitude of auto augment policy" )
451
473
parser .add_argument ("--random-erase" , default = 0.0 , type = float , help = "random erasing probability (default: 0.0)" )
452
474
453
475
# Mixed precision training parameters
@@ -486,13 +508,16 @@ def get_args_parser(add_help=True):
486
508
parser .add_argument (
487
509
"--train-crop-size" , default = 224 , type = int , help = "the random crop size used for training (default: 224)"
488
510
)
511
+ parser .add_argument (
512
+ "--train-center-crop" , action = "store_true" , help = "use center crop instead of random crop for training (default: False)"
513
+ )
489
514
parser .add_argument ("--clip-grad-norm" , default = None , type = float , help = "the maximum gradient norm (default None)" )
490
515
parser .add_argument ("--ra-sampler" , action = "store_true" , help = "whether to use Repeated Augmentation in training" )
491
516
parser .add_argument (
492
517
"--ra-reps" , default = 3 , type = int , help = "number of repetitions for Repeated Augmentation (default: 3)"
493
518
)
494
519
parser .add_argument ("--weights" , default = None , type = str , help = "the weights enum name to load" )
495
-
520
+ parser . add_argument ( "--seed" , default = None , type = int , help = "the seed for randomness (default: None). A `None` value means a seed will be randomly generated" )
496
521
return parser
497
522
498
523
0 commit comments