12
12
from torch .utils .data .dataloader import default_collate
13
13
from torchvision .datasets .samplers import DistributedSampler , UniformClipSampler , RandomClipSampler
14
14
15
- try :
16
- from apex import amp
17
- except ImportError :
18
- amp = None
19
-
20
-
21
15
try :
22
16
from torchvision .prototype import models as PM
23
17
except ImportError :
24
18
PM = None
25
19
26
20
27
- def train_one_epoch (model , criterion , optimizer , lr_scheduler , data_loader , device , epoch , print_freq , apex = False ):
21
+ def train_one_epoch (model , criterion , optimizer , lr_scheduler , data_loader , device , epoch , print_freq , scaler = None ):
28
22
model .train ()
29
23
metric_logger = utils .MetricLogger (delimiter = " " )
30
24
metric_logger .add_meter ("lr" , utils .SmoothedValue (window_size = 1 , fmt = "{value}" ))
@@ -34,16 +28,19 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi
34
28
for video , target in metric_logger .log_every (data_loader , print_freq , header ):
35
29
start_time = time .time ()
36
30
video , target = video .to (device ), target .to (device )
37
- output = model (video )
38
- loss = criterion (output , target )
31
+ with torch .cuda .amp .autocast (enabled = scaler is not None ):
32
+ output = model (video )
33
+ loss = criterion (output , target )
39
34
40
35
optimizer .zero_grad ()
41
- if apex :
42
- with amp .scale_loss (loss , optimizer ) as scaled_loss :
43
- scaled_loss .backward ()
36
+
37
+ if scaler is not None :
38
+ scaler .scale (loss ).backward ()
39
+ scaler .step (optimizer )
40
+ scaler .update ()
44
41
else :
45
42
loss .backward ()
46
- optimizer .step ()
43
+ optimizer .step ()
47
44
48
45
acc1 , acc5 = utils .accuracy (output , target , topk = (1 , 5 ))
49
46
batch_size = video .shape [0 ]
@@ -101,11 +98,6 @@ def collate_fn(batch):
101
98
def main (args ):
102
99
if args .weights and PM is None :
103
100
raise ImportError ("The prototype module couldn't be found. Please install the latest torchvision nightly." )
104
- if args .apex and amp is None :
105
- raise RuntimeError (
106
- "Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
107
- "to enable mixed-precision training."
108
- )
109
101
110
102
if args .output_dir :
111
103
utils .mkdir (args .output_dir )
@@ -224,9 +216,7 @@ def main(args):
224
216
225
217
lr = args .lr * args .world_size
226
218
optimizer = torch .optim .SGD (model .parameters (), lr = lr , momentum = args .momentum , weight_decay = args .weight_decay )
227
-
228
- if args .apex :
229
- model , optimizer = amp .initialize (model , optimizer , opt_level = args .apex_opt_level )
219
+ scaler = torch .cuda .amp .GradScaler () if args .amp else None
230
220
231
221
# convert scheduler to be per iteration, not per epoch, for warmup that lasts
232
222
# between different epochs
@@ -267,6 +257,8 @@ def main(args):
267
257
optimizer .load_state_dict (checkpoint ["optimizer" ])
268
258
lr_scheduler .load_state_dict (checkpoint ["lr_scheduler" ])
269
259
args .start_epoch = checkpoint ["epoch" ] + 1
260
+ if args .amp :
261
+ scaler .load_state_dict (checkpoint ["scaler" ])
270
262
271
263
if args .test_only :
272
264
evaluate (model , criterion , data_loader_test , device = device )
@@ -277,9 +269,7 @@ def main(args):
277
269
for epoch in range (args .start_epoch , args .epochs ):
278
270
if args .distributed :
279
271
train_sampler .set_epoch (epoch )
280
- train_one_epoch (
281
- model , criterion , optimizer , lr_scheduler , data_loader , device , epoch , args .print_freq , args .apex
282
- )
272
+ train_one_epoch (model , criterion , optimizer , lr_scheduler , data_loader , device , epoch , args .print_freq , scaler )
283
273
evaluate (model , criterion , data_loader_test , device = device )
284
274
if args .output_dir :
285
275
checkpoint = {
@@ -289,6 +279,8 @@ def main(args):
289
279
"epoch" : epoch ,
290
280
"args" : args ,
291
281
}
282
+ if args .amp :
283
+ checkpoint ["scaler" ] = scaler .state_dict ()
292
284
utils .save_on_master (checkpoint , os .path .join (args .output_dir , f"model_{ epoch } .pth" ))
293
285
utils .save_on_master (checkpoint , os .path .join (args .output_dir , "checkpoint.pth" ))
294
286
@@ -363,24 +355,16 @@ def parse_args():
363
355
action = "store_true" ,
364
356
)
365
357
366
- # Mixed precision training parameters
367
- parser .add_argument ("--apex" , action = "store_true" , help = "Use apex for mixed precision training" )
368
- parser .add_argument (
369
- "--apex-opt-level" ,
370
- default = "O1" ,
371
- type = str ,
372
- help = "For apex mixed precision training"
373
- "O0 for FP32 training, O1 for mixed precision training."
374
- "For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet" ,
375
- )
376
-
377
358
# distributed training parameters
378
359
parser .add_argument ("--world-size" , default = 1 , type = int , help = "number of distributed processes" )
379
360
parser .add_argument ("--dist-url" , default = "env://" , type = str , help = "url used to set up distributed training" )
380
361
381
362
# Prototype models only
382
363
parser .add_argument ("--weights" , default = None , type = str , help = "the weights enum name to load" )
383
364
365
+ # Mixed precision training parameters
366
+ parser .add_argument ("--amp" , action = "store_true" , help = "Use torch.cuda.amp for mixed precision training" )
367
+
384
368
args = parser .parse_args ()
385
369
386
370
return args
0 commit comments