-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Update Reference scripts to support the prototype models #4837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5b72af4
77c80f5
a0654dd
55ddb93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,12 @@ | |
from torch import nn | ||
|
||
|
||
try: | ||
from torchvision.prototype import models as PM | ||
except ImportError: | ||
PM = None | ||
|
||
|
||
def get_dataset(dir_path, name, image_set, transform): | ||
def sbd(*args, **kwargs): | ||
return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) | ||
|
@@ -26,11 +32,15 @@ def sbd(*args, **kwargs): | |
return ds, num_classes | ||
|
||
|
||
def get_transform(train): | ||
base_size = 520 | ||
crop_size = 480 | ||
|
||
return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(base_size) | ||
def get_transform(train, args): | ||
if train: | ||
return presets.SegmentationPresetTrain(base_size=520, crop_size=480) | ||
elif not args.weights: | ||
return presets.SegmentationPresetEval(base_size=520) | ||
else: | ||
fn = PM.segmentation.__dict__[args.model] | ||
weights = PM._api.get_weight(fn, args.weights) | ||
return weights.transforms() | ||
|
||
|
||
def criterion(inputs, target): | ||
|
@@ -90,8 +100,8 @@ def main(args): | |
|
||
device = torch.device(args.device) | ||
|
||
dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(train=True)) | ||
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(train=False)) | ||
dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args)) | ||
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args)) | ||
|
||
if args.distributed: | ||
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) | ||
|
@@ -113,9 +123,18 @@ def main(args): | |
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn | ||
) | ||
|
||
model = torchvision.models.segmentation.__dict__[args.model]( | ||
num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained | ||
) | ||
if not args.weights: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the weights are not defined, we use the standard way. Else it's a prototype run which means we will use the prototype model mechanism. |
||
model = torchvision.models.segmentation.__dict__[args.model]( | ||
pretrained=args.pretrained, | ||
num_classes=num_classes, | ||
aux_loss=args.aux_loss, | ||
) | ||
else: | ||
if PM is None: | ||
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") | ||
model = PM.segmentation.__dict__[args.model]( | ||
weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss | ||
) | ||
model.to(device) | ||
if args.distributed: | ||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | ||
|
@@ -247,6 +266,9 @@ def get_args_parser(add_help=True): | |
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") | ||
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") | ||
|
||
# Prototype models only | ||
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") | ||
|
||
return parser | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,6 +18,12 @@ | |
amp = None | ||
|
||
|
||
try: | ||
from torchvision.prototype import models as PM | ||
except ImportError: | ||
PM = None | ||
|
||
|
||
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False): | ||
model.train() | ||
metric_logger = utils.MetricLogger(delimiter=" ") | ||
|
@@ -149,7 +155,12 @@ def main(args): | |
print("Loading validation data") | ||
cache_path = _get_cache_path(valdir) | ||
|
||
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) | ||
if not args.weights: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That's exactly the plan. |
||
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) | ||
else: | ||
fn = PM.video.__dict__[args.model] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are we providing some sort of registration API to get the models without having to resort to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeap, that's the plan. There will be a proper registration mechanism, possibly something similar to what was discussed here. There are still pending discussions with other domains, so I didn't want to adopt something before those discussions take place. |
||
weights = PM._api.get_weight(fn, args.weights) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: using a private API here. We probably don't want to advertise private APIs in the references There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
transform_test = weights.transforms() | ||
|
||
if args.cache_dataset and os.path.exists(cache_path): | ||
print(f"Loading dataset_test from {cache_path}") | ||
|
@@ -200,7 +211,12 @@ def main(args): | |
) | ||
|
||
print("Creating model") | ||
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) | ||
if not args.weights: | ||
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) | ||
else: | ||
if PM is None: | ||
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") | ||
model = PM.video.__dict__[args.model](weights=args.weights) | ||
model.to(device) | ||
if args.distributed and args.sync_bn: | ||
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) | ||
|
@@ -363,6 +379,9 @@ def parse_args(): | |
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") | ||
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") | ||
|
||
# Prototype models only | ||
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") | ||
|
||
args = parser.parse_args() | ||
|
||
return args | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we are in train mode, we always initialize the SegmentationPresetTrain. For validation if the weights are not defined (aka not a prototype model) then use the old preprocessing method for evaluation. Else use the one attached to the weights.