diff --git a/references/segmentation/train.py b/references/segmentation/train.py index 35f198c707a..866b733a180 100644 --- a/references/segmentation/train.py +++ b/references/segmentation/train.py @@ -11,6 +11,12 @@ from torch import nn +try: + from torchvision.prototype import models as PM +except ImportError: + PM = None + + def get_dataset(dir_path, name, image_set, transform): def sbd(*args, **kwargs): return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs) @@ -26,11 +32,15 @@ def sbd(*args, **kwargs): return ds, num_classes -def get_transform(train): - base_size = 520 - crop_size = 480 - - return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(base_size) +def get_transform(train, args): + if train: + return presets.SegmentationPresetTrain(base_size=520, crop_size=480) + elif not args.weights: + return presets.SegmentationPresetEval(base_size=520) + else: + fn = PM.segmentation.__dict__[args.model] + weights = PM._api.get_weight(fn, args.weights) + return weights.transforms() def criterion(inputs, target): @@ -90,8 +100,8 @@ def main(args): device = torch.device(args.device) - dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(train=True)) - dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(train=False)) + dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args)) + dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args)) if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) @@ -113,9 +123,18 @@ def main(args): dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn ) - model = torchvision.models.segmentation.__dict__[args.model]( - num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained - ) + if not args.weights: + model = torchvision.models.segmentation.__dict__[args.model]( + pretrained=args.pretrained, + num_classes=num_classes, + aux_loss=args.aux_loss, + ) + else: + if PM is None: + raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") + model = PM.segmentation.__dict__[args.model]( + weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss + ) model.to(device) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -247,6 +266,9 @@ def get_args_parser(add_help=True): parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") + # Prototype models only + parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + return parser diff --git a/references/video_classification/train.py b/references/video_classification/train.py index 968ea552dc7..44254acdd63 100644 --- a/references/video_classification/train.py +++ b/references/video_classification/train.py @@ -18,6 +18,12 @@ amp = None +try: + from torchvision.prototype import models as PM +except ImportError: + PM = None + + def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False): model.train() metric_logger = utils.MetricLogger(delimiter=" ") @@ -149,7 +155,12 @@ def main(args): print("Loading validation data") cache_path = _get_cache_path(valdir) - transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) + if not args.weights: + transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112)) + else: + fn = PM.video.__dict__[args.model] + weights = PM._api.get_weight(fn, args.weights) + transform_test = weights.transforms() if args.cache_dataset and os.path.exists(cache_path): print(f"Loading dataset_test from {cache_path}") @@ -200,7 +211,12 @@ def main(args): ) print("Creating model") - model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) + if not args.weights: + model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained) + else: + if PM is None: + raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") + model = PM.video.__dict__[args.model](weights=args.weights) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) @@ -363,6 +379,9 @@ def parse_args(): parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") + # Prototype models only + parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + args = parser.parse_args() return args