Skip to content

Add multi-type support on get_weight() #4967

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, arg
for i, (image, target) in enumerate(metric_logger.log_every(data_loader, args.print_freq, header)):
start_time = time.time()
image, target = image.to(device), target.to(device)
with torch.cuda.amp.autocast(enabled=args.amp):
with torch.cuda.amp.autocast(enabled=scaler is not None):
output = model(image)
loss = criterion(output, target)

optimizer.zero_grad()
if args.amp:
if scaler is not None:
scaler.scale(loss).backward()
if args.clip_grad_norm is not None:
# we should unscale the gradients of optimizer's assigned params if do gradient clipping
Expand Down
5 changes: 3 additions & 2 deletions references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def main(args):
if args.distributed:
train_sampler.set_epoch(epoch)
print("Starting training for epoch", epoch)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args.print_freq)
train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args)
lr_scheduler.step()
with torch.inference_mode():
if epoch >= args.num_observer_update_epochs:
Expand All @@ -132,7 +132,7 @@ def main(args):
model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)
print("Evaluate QAT model")

evaluate(model, criterion, data_loader_test, device=device)
evaluate(model, criterion, data_loader_test, device=device, log_suffix="QAT")
quantized_eval_model = copy.deepcopy(model_without_ddp)
quantized_eval_model.eval()
quantized_eval_model.to(torch.device("cpu"))
Expand Down Expand Up @@ -261,6 +261,7 @@ def get_args_parser(add_help=True):
parser.add_argument(
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")

# Prototype models only
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
Expand Down
5 changes: 5 additions & 0 deletions torchvision/prototype/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ def get_weight(fn: Callable, weight_name: str) -> Weights:
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
for t in ann.__args__: # type: ignore[union-attr]
if isinstance(t, type) and issubclass(t, Weights):
# ensure the name exists. handles builders with multiple types of weights like in quantization
try:
t.from_str(weight_name)
except ValueError:
continue
weights_class = t
break

Expand Down