Skip to content

Commit ba299e8

Browse files
kbozasKonstantinos Bozas
and
Konstantinos Bozas
authored
support amp training for video classification models (#5023)
* support amp training for video classification models * Removed extra empty line and used scaler instead of args.amp as function argument * apply formating to pass lint tests Co-authored-by: Konstantinos Bozas <[email protected]>
1 parent fe4ba30 commit ba299e8

File tree

1 file changed

+19
-35
lines changed

1 file changed

+19
-35
lines changed

references/video_classification/train.py

Lines changed: 19 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -12,19 +12,13 @@
1212
from torch.utils.data.dataloader import default_collate
1313
from torchvision.datasets.samplers import DistributedSampler, UniformClipSampler, RandomClipSampler
1414

15-
try:
16-
from apex import amp
17-
except ImportError:
18-
amp = None
19-
20-
2115
try:
2216
from torchvision.prototype import models as PM
2317
except ImportError:
2418
PM = None
2519

2620

27-
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False):
21+
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, scaler=None):
2822
model.train()
2923
metric_logger = utils.MetricLogger(delimiter=" ")
3024
metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value}"))
@@ -34,16 +28,19 @@ def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, devi
3428
for video, target in metric_logger.log_every(data_loader, print_freq, header):
3529
start_time = time.time()
3630
video, target = video.to(device), target.to(device)
37-
output = model(video)
38-
loss = criterion(output, target)
31+
with torch.cuda.amp.autocast(enabled=scaler is not None):
32+
output = model(video)
33+
loss = criterion(output, target)
3934

4035
optimizer.zero_grad()
41-
if apex:
42-
with amp.scale_loss(loss, optimizer) as scaled_loss:
43-
scaled_loss.backward()
36+
37+
if scaler is not None:
38+
scaler.scale(loss).backward()
39+
scaler.step(optimizer)
40+
scaler.update()
4441
else:
4542
loss.backward()
46-
optimizer.step()
43+
optimizer.step()
4744

4845
acc1, acc5 = utils.accuracy(output, target, topk=(1, 5))
4946
batch_size = video.shape[0]
@@ -101,11 +98,6 @@ def collate_fn(batch):
10198
def main(args):
10299
if args.weights and PM is None:
103100
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
104-
if args.apex and amp is None:
105-
raise RuntimeError(
106-
"Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
107-
"to enable mixed-precision training."
108-
)
109101

110102
if args.output_dir:
111103
utils.mkdir(args.output_dir)
@@ -224,9 +216,7 @@ def main(args):
224216

225217
lr = args.lr * args.world_size
226218
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)
227-
228-
if args.apex:
229-
model, optimizer = amp.initialize(model, optimizer, opt_level=args.apex_opt_level)
219+
scaler = torch.cuda.amp.GradScaler() if args.amp else None
230220

231221
# convert scheduler to be per iteration, not per epoch, for warmup that lasts
232222
# between different epochs
@@ -267,6 +257,8 @@ def main(args):
267257
optimizer.load_state_dict(checkpoint["optimizer"])
268258
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
269259
args.start_epoch = checkpoint["epoch"] + 1
260+
if args.amp:
261+
scaler.load_state_dict(checkpoint["scaler"])
270262

271263
if args.test_only:
272264
evaluate(model, criterion, data_loader_test, device=device)
@@ -277,9 +269,7 @@ def main(args):
277269
for epoch in range(args.start_epoch, args.epochs):
278270
if args.distributed:
279271
train_sampler.set_epoch(epoch)
280-
train_one_epoch(
281-
model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, args.apex
282-
)
272+
train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler)
283273
evaluate(model, criterion, data_loader_test, device=device)
284274
if args.output_dir:
285275
checkpoint = {
@@ -289,6 +279,8 @@ def main(args):
289279
"epoch": epoch,
290280
"args": args,
291281
}
282+
if args.amp:
283+
checkpoint["scaler"] = scaler.state_dict()
292284
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
293285
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
294286

@@ -363,24 +355,16 @@ def parse_args():
363355
action="store_true",
364356
)
365357

366-
# Mixed precision training parameters
367-
parser.add_argument("--apex", action="store_true", help="Use apex for mixed precision training")
368-
parser.add_argument(
369-
"--apex-opt-level",
370-
default="O1",
371-
type=str,
372-
help="For apex mixed precision training"
373-
"O0 for FP32 training, O1 for mixed precision training."
374-
"For further detail, see https://github.com/NVIDIA/apex/tree/master/examples/imagenet",
375-
)
376-
377358
# distributed training parameters
378359
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
379360
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
380361

381362
# Prototype models only
382363
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
383364

365+
# Mixed precision training parameters
366+
parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
367+
384368
args = parser.parse_args()
385369

386370
return args

0 commit comments

Comments
 (0)