diff --git a/references/classification/train.py b/references/classification/train.py index 2fbe61dd65f..b16ed3d2a42 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -30,12 +30,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)): start_time = time.time() image, target = image.to(device), target.to(device) - with torch.cuda.amp.autocast(enabled=args.amp): + with torch.cuda.amp.autocast(enabled=scaler is not None): output = model(image) loss = criterion(output, target) optimizer.zero_grad() - if args.amp: + if scaler is not None: scaler.scale(loss).backward() if args.clip_grad_norm is not None: # we should unscale the gradients of optimizer's assigned params if do gradient clipping diff --git a/references/classification/train_quantization.py b/references/classification/train_quantization.py index 7c8a2df91ff..96ab4fcac97 100644 --- a/references/classification/train_quantization.py +++ b/references/classification/train_quantization.py @@ -121,7 +121,7 @@ def main(args): if args.distributed: train_sampler.set_epoch(epoch) print("Starting training for epoch", epoch) - train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq) + train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args) lr_scheduler.step() with torch.inference_mode(): if epoch >= args.num_observer_update_epochs: @@ -132,7 +132,7 @@ def main(args): model.apply(torch.nn.intrinsic.qat.freeze_bn_stats) print("Evaluate QAT model") - evaluate(model, criterion, data_loader_test, device=device) + evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT") quantized_eval_model = copy.deepcopy(model_without_ddp) quantized_eval_model.eval() quantized_eval_model.to(torch.device("cpu")) @@ -261,6 +261,7 @@ def get_args_parser(add_help=True): parser.add_argument( "--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)" ) + parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)") # Prototype models only parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") diff --git a/torchvision/prototype/models/_api.py b/torchvision/prototype/models/_api.py index 98a96f80934..be26b10b5a0 100644 --- a/torchvision/prototype/models/_api.py +++ b/torchvision/prototype/models/_api.py @@ -101,6 +101,11 @@ def get_weight(fn: Callable, weight_name: str) -> Weights: # TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8 for t in ann.__args__: # type: ignore[union-attr] if isinstance(t, type) and issubclass(t, Weights): + # ensure the name exists. handles builders with multiple types of weights like in quantization + try: + t.from_str(weight_name) + except ValueError: + continue weights_class = t break