Skip to content

Commit c7eb44d

Browse files
prabhat00155facebook-github-bot
authored andcommitted
[fbsync] Update Reference scripts to support the prototype models (#4837)
Summary: * Adding prototype preprocessing on segmentation references. * Adding prototype preprocessing on video references. Reviewed By: kazhang Differential Revision: D32216688 fbshipit-source-id: 219f9d8e3b34ecc5a30f9b93f9da0698631fd84e
1 parent bf2fa0a commit c7eb44d

File tree

2 files changed

+53
-12
lines changed

2 files changed

+53
-12
lines changed

references/segmentation/train.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@
1111
from torch import nn
1212

1313

14+
try:
15+
from torchvision.prototype import models as PM
16+
except ImportError:
17+
PM = None
18+
19+
1420
def get_dataset(dir_path, name, image_set, transform):
1521
def sbd(*args, **kwargs):
1622
return torchvision.datasets.SBDataset(*args, mode="segmentation", **kwargs)
@@ -26,11 +32,15 @@ def sbd(*args, **kwargs):
2632
return ds, num_classes
2733

2834

29-
def get_transform(train):
30-
base_size = 520
31-
crop_size = 480
32-
33-
return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(base_size)
35+
def get_transform(train, args):
36+
if train:
37+
return presets.SegmentationPresetTrain(base_size=520, crop_size=480)
38+
elif not args.weights:
39+
return presets.SegmentationPresetEval(base_size=520)
40+
else:
41+
fn = PM.segmentation.__dict__[args.model]
42+
weights = PM._api.get_weight(fn, args.weights)
43+
return weights.transforms()
3444

3545

3646
def criterion(inputs, target):
@@ -90,8 +100,8 @@ def main(args):
90100

91101
device = torch.device(args.device)
92102

93-
dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(train=True))
94-
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(train=False))
103+
dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", get_transform(True, args))
104+
dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", get_transform(False, args))
95105

96106
if args.distributed:
97107
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
@@ -113,9 +123,18 @@ def main(args):
113123
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
114124
)
115125

116-
model = torchvision.models.segmentation.__dict__[args.model](
117-
num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained
118-
)
126+
if not args.weights:
127+
model = torchvision.models.segmentation.__dict__[args.model](
128+
pretrained=args.pretrained,
129+
num_classes=num_classes,
130+
aux_loss=args.aux_loss,
131+
)
132+
else:
133+
if PM is None:
134+
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
135+
model = PM.segmentation.__dict__[args.model](
136+
weights=args.weights, num_classes=num_classes, aux_loss=args.aux_loss
137+
)
119138
model.to(device)
120139
if args.distributed:
121140
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@@ -247,6 +266,9 @@ def get_args_parser(add_help=True):
247266
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
248267
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
249268

269+
# Prototype models only
270+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
271+
250272
return parser
251273

252274

references/video_classification/train.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,12 @@
1818
amp = None
1919

2020

21+
try:
22+
from torchvision.prototype import models as PM
23+
except ImportError:
24+
PM = None
25+
26+
2127
def train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, print_freq, apex=False):
2228
model.train()
2329
metric_logger = utils.MetricLogger(delimiter=" ")
@@ -149,7 +155,12 @@ def main(args):
149155
print("Loading validation data")
150156
cache_path = _get_cache_path(valdir)
151157

152-
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
158+
if not args.weights:
159+
transform_test = presets.VideoClassificationPresetEval((128, 171), (112, 112))
160+
else:
161+
fn = PM.video.__dict__[args.model]
162+
weights = PM._api.get_weight(fn, args.weights)
163+
transform_test = weights.transforms()
153164

154165
if args.cache_dataset and os.path.exists(cache_path):
155166
print(f"Loading dataset_test from {cache_path}")
@@ -200,7 +211,12 @@ def main(args):
200211
)
201212

202213
print("Creating model")
203-
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
214+
if not args.weights:
215+
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
216+
else:
217+
if PM is None:
218+
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
219+
model = PM.video.__dict__[args.model](weights=args.weights)
204220
model.to(device)
205221
if args.distributed and args.sync_bn:
206222
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
@@ -363,6 +379,9 @@ def parse_args():
363379
parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
364380
parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
365381

382+
# Prototype models only
383+
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
384+
366385
args = parser.parse_args()
367386

368387
return args

0 commit comments

Comments
 (0)