Skip to content

Commit 0a612cb

Browse files
authored
Porting reference scripts and updating presets (#5629)
* Making _preset.py classes * Remove support of targets on presets. * Rewriting the video preset * Adding tests to check that the bundled transforms are JIT scriptable * Rename all presets from *Eval to *Inference * Minor refactoring * Remove --prototype and --pretrained from reference scripts * remove pretained_backbone refs * Corrections and simplifications * Fixing bug * Fixing linter * Fix flake8 * restore documentation example * minor fixes * fix optical flow missing param * Fixing commands * Adding weights_backbone support in detection and segmentation * Updating the commands for InceptionV3
1 parent 96b99dc commit 0a612cb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

53 files changed

+343
-522
lines changed

docs/source/models.rst

Lines changed: 1 addition & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -98,58 +98,6 @@ You can construct a model with random weights by calling its constructor:
9898
convnext_large = models.convnext_large()
9999
100100
We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
101-
These can be constructed by passing ``pretrained=True``:
102-
103-
.. code:: python
104-
105-
import torchvision.models as models
106-
resnet18 = models.resnet18(pretrained=True)
107-
alexnet = models.alexnet(pretrained=True)
108-
squeezenet = models.squeezenet1_0(pretrained=True)
109-
vgg16 = models.vgg16(pretrained=True)
110-
densenet = models.densenet161(pretrained=True)
111-
inception = models.inception_v3(pretrained=True)
112-
googlenet = models.googlenet(pretrained=True)
113-
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
114-
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
115-
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
116-
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
117-
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
118-
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
119-
mnasnet = models.mnasnet1_0(pretrained=True)
120-
efficientnet_b0 = models.efficientnet_b0(pretrained=True)
121-
efficientnet_b1 = models.efficientnet_b1(pretrained=True)
122-
efficientnet_b2 = models.efficientnet_b2(pretrained=True)
123-
efficientnet_b3 = models.efficientnet_b3(pretrained=True)
124-
efficientnet_b4 = models.efficientnet_b4(pretrained=True)
125-
efficientnet_b5 = models.efficientnet_b5(pretrained=True)
126-
efficientnet_b6 = models.efficientnet_b6(pretrained=True)
127-
efficientnet_b7 = models.efficientnet_b7(pretrained=True)
128-
efficientnet_v2_s = models.efficientnet_v2_s(pretrained=True)
129-
efficientnet_v2_m = models.efficientnet_v2_m(pretrained=True)
130-
efficientnet_v2_l = models.efficientnet_v2_l(pretrained=True)
131-
regnet_y_400mf = models.regnet_y_400mf(pretrained=True)
132-
regnet_y_800mf = models.regnet_y_800mf(pretrained=True)
133-
regnet_y_1_6gf = models.regnet_y_1_6gf(pretrained=True)
134-
regnet_y_3_2gf = models.regnet_y_3_2gf(pretrained=True)
135-
regnet_y_8gf = models.regnet_y_8gf(pretrained=True)
136-
regnet_y_16gf = models.regnet_y_16gf(pretrained=True)
137-
regnet_y_32gf = models.regnet_y_32gf(pretrained=True)
138-
regnet_x_400mf = models.regnet_x_400mf(pretrained=True)
139-
regnet_x_800mf = models.regnet_x_800mf(pretrained=True)
140-
regnet_x_1_6gf = models.regnet_x_1_6gf(pretrained=True)
141-
regnet_x_3_2gf = models.regnet_x_3_2gf(pretrained=True)
142-
regnet_x_8gf = models.regnet_x_8gf(pretrained=True)
143-
regnet_x_16gf = models.regnet_x_16gf(pretrainedTrue)
144-
regnet_x_32gf = models.regnet_x_32gf(pretrained=True)
145-
vit_b_16 = models.vit_b_16(pretrained=True)
146-
vit_b_32 = models.vit_b_32(pretrained=True)
147-
vit_l_16 = models.vit_l_16(pretrained=True)
148-
vit_l_32 = models.vit_l_32(pretrained=True)
149-
convnext_tiny = models.convnext_tiny(pretrained=True)
150-
convnext_small = models.convnext_small(pretrained=True)
151-
convnext_base = models.convnext_base(pretrained=True)
152-
convnext_large = models.convnext_large(pretrained=True)
153101

154102
Instancing a pre-trained model will download its weights to a cache directory.
155103
This directory can be set using the `TORCH_HOME` environment variable. See
@@ -525,7 +473,7 @@ Obtaining a pre-trained quantized model can be done with a few lines of code:
525473
.. code:: python
526474
527475
import torchvision.models as models
528-
model = models.quantization.mobilenet_v2(pretrained=True, quantize=True)
476+
model = models.quantization.mobilenet_v2(weights=MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1, quantize=True)
529477
model.eval()
530478
# run the model with quantized inputs and weights
531479
out = model(torch.rand(1, 3, 224, 224))

gallery/plot_optical_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def plot(imgs, **imshow_kwargs):
9696
def preprocess(img1_batch, img2_batch):
9797
img1_batch = F.resize(img1_batch, size=[520, 960])
9898
img2_batch = F.resize(img2_batch, size=[520, 960])
99-
return transforms(img1_batch, img2_batch)[:2]
99+
return transforms(img1_batch, img2_batch)
100100

101101

102102
img1_batch, img2_batch = preprocess(img1_batch, img2_batch)

gallery/plot_repurposing_annotations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def show(imgs):
146146
print(img.size())
147147

148148
tranforms = weights.transforms()
149-
img, _ = tranforms(img)
149+
img = tranforms(img)
150150
target = {}
151151
target["boxes"] = boxes
152152
target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64)

gallery/plot_visualization_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def show(imgs):
8181
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
8282
transforms = weights.transforms()
8383

84-
batch, _ = transforms(batch_int)
84+
batch = transforms(batch_int)
8585

8686
model = fasterrcnn_resnet50_fpn(weights=weights, progress=False)
8787
model = model.eval()
@@ -131,7 +131,7 @@ def show(imgs):
131131
model = fcn_resnet50(weights=weights, progress=False)
132132
model = model.eval()
133133

134-
normalized_batch, _ = transforms(batch)
134+
normalized_batch = transforms(batch)
135135
output = model(normalized_batch)['out']
136136
print(output.shape, output.min().item(), output.max().item())
137137

@@ -272,7 +272,7 @@ def show(imgs):
272272
weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
273273
transforms = weights.transforms()
274274

275-
batch, _ = transforms(batch_int)
275+
batch = transforms(batch_int)
276276

277277
model = maskrcnn_resnet50_fpn(weights=weights, progress=False)
278278
model = model.eval()
@@ -397,7 +397,7 @@ def show(imgs):
397397
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
398398
transforms = weights.transforms()
399399

400-
person_float, _ = transforms(person_int)
400+
person_float = transforms(person_int)
401401

402402
model = keypointrcnn_resnet50_fpn(weights=weights, progress=False)
403403
model = model.eval()

references/classification/README.md

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model
4343

4444
```
4545
torchrun --nproc_per_node=8 train.py --model inception_v3\
46-
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained
46+
--test-only --weights Inception_V3_Weights.IMAGENET1K_V1
4747
```
4848

4949
### ResNet
@@ -96,22 +96,14 @@ The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTo
9696

9797
All models were trained using Bicubic interpolation and each have custom crop and resize sizes. To validate the models use the following commands:
9898
```
99-
torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --interpolation bicubic\
100-
--val-resize-size 256 --val-crop-size 224 --train-crop-size 224 --test-only --pretrained
101-
torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --interpolation bicubic\
102-
--val-resize-size 256 --val-crop-size 240 --train-crop-size 240 --test-only --pretrained
103-
torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --interpolation bicubic\
104-
--val-resize-size 288 --val-crop-size 288 --train-crop-size 288 --test-only --pretrained
105-
torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --interpolation bicubic\
106-
--val-resize-size 320 --val-crop-size 300 --train-crop-size 300 --test-only --pretrained
107-
torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --interpolation bicubic\
108-
--val-resize-size 384 --val-crop-size 380 --train-crop-size 380 --test-only --pretrained
109-
torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --interpolation bicubic\
110-
--val-resize-size 456 --val-crop-size 456 --train-crop-size 456 --test-only --pretrained
111-
torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --interpolation bicubic\
112-
--val-resize-size 528 --val-crop-size 528 --train-crop-size 528 --test-only --pretrained
113-
torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bicubic\
114-
--val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained
99+
torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --test-only --weights EfficientNet_B0_Weights.IMAGENET1K_V1
100+
torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --test-only --weights EfficientNet_B1_Weights.IMAGENET1K_V1
101+
torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --test-only --weights EfficientNet_B2_Weights.IMAGENET1K_V1
102+
torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --test-only --weights EfficientNet_B3_Weights.IMAGENET1K_V1
103+
torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --test-only --weights EfficientNet_B4_Weights.IMAGENET1K_V1
104+
torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --test-only --weights EfficientNet_B5_Weights.IMAGENET1K_V1
105+
torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --test-only --weights EfficientNet_B6_Weights.IMAGENET1K_V1
106+
torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --test-only --weights EfficientNet_B7_Weights.IMAGENET1K_V1
115107
```
116108

117109

references/classification/train.py

Lines changed: 5 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@
1515
from torchvision.transforms.functional import InterpolationMode
1616

1717

18-
try:
19-
from torchvision import prototype
20-
except ImportError:
21-
prototype = None
22-
23-
2418
def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
2519
model.train()
2620
metric_logger = utils.MetricLogger(delimiter=" ")
@@ -154,18 +148,13 @@ def load_data(traindir, valdir, args):
154148
print(f"Loading dataset_test from {cache_path}")
155149
dataset_test, _ = torch.load(cache_path)
156150
else:
157-
if not args.prototype:
151+
if args.weights and args.test_only:
152+
weights = torchvision.models.get_weight(args.weights)
153+
preprocessing = weights.transforms()
154+
else:
158155
preprocessing = presets.ClassificationPresetEval(
159156
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
160157
)
161-
else:
162-
if args.weights:
163-
weights = prototype.models.get_weight(args.weights)
164-
preprocessing = weights.transforms()
165-
else:
166-
preprocessing = prototype.transforms.ImageClassificationEval(
167-
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
168-
)
169158

170159
dataset_test = torchvision.datasets.ImageFolder(
171160
valdir,
@@ -191,10 +180,6 @@ def load_data(traindir, valdir, args):
191180

192181

193182
def main(args):
194-
if args.prototype and prototype is None:
195-
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
196-
if not args.prototype and args.weights:
197-
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
198183
if args.output_dir:
199184
utils.mkdir(args.output_dir)
200185

@@ -236,10 +221,7 @@ def main(args):
236221
)
237222

238223
print("Creating model")
239-
if not args.prototype:
240-
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
241-
else:
242-
model = prototype.models.__dict__[args.model](weights=args.weights, num_classes=num_classes)
224+
model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes)
243225
model.to(device)
244226

245227
if args.distributed and args.sync_bn:
@@ -446,12 +428,6 @@ def get_args_parser(add_help=True):
446428
help="Only test the model",
447429
action="store_true",
448430
)
449-
parser.add_argument(
450-
"--pretrained",
451-
dest="pretrained",
452-
help="Use pre-trained models from the modelzoo",
453-
action="store_true",
454-
)
455431
parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
456432
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")
457433

@@ -496,14 +472,6 @@ def get_args_parser(add_help=True):
496472
parser.add_argument(
497473
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
498474
)
499-
500-
# Prototype models only
501-
parser.add_argument(
502-
"--prototype",
503-
dest="prototype",
504-
help="Use prototype model builders instead those from main area",
505-
action="store_true",
506-
)
507475
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
508476

509477
return parser

references/classification/train_quantization.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,7 @@
1212
from train import train_one_epoch, evaluate, load_data
1313

1414

15-
try:
16-
from torchvision import prototype
17-
except ImportError:
18-
prototype = None
19-
20-
2115
def main(args):
22-
if args.prototype and prototype is None:
23-
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
24-
if not args.prototype and args.weights:
25-
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
2616
if args.output_dir:
2717
utils.mkdir(args.output_dir)
2818

@@ -56,10 +46,7 @@ def main(args):
5646

5747
print("Creating model", args.model)
5848
# when training quantized models, we always start from a pre-trained fp32 reference model
59-
if not args.prototype:
60-
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
61-
else:
62-
model = prototype.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
49+
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
6350
model.to(device)
6451

6552
if not (args.test_only or args.post_training_quantize):
@@ -264,14 +251,6 @@ def get_args_parser(add_help=True):
264251
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
265252
)
266253
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")
267-
268-
# Prototype models only
269-
parser.add_argument(
270-
"--prototype",
271-
dest="prototype",
272-
help="Use prototype model builders instead those from main area",
273-
action="store_true",
274-
)
275254
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
276255

277256
return parser

references/classification/utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -330,22 +330,22 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
330330
from torchvision import models as M
331331
332332
# Classification
333-
model = M.mobilenet_v3_large(pretrained=False)
333+
model = M.mobilenet_v3_large()
334334
print(store_model_weights(model, './class.pth'))
335335
336336
# Quantized Classification
337-
model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False)
337+
model = M.quantization.mobilenet_v3_large(quantize=False)
338338
model.fuse_model(is_qat=True)
339339
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
340340
_ = torch.ao.quantization.prepare_qat(model, inplace=True)
341341
print(store_model_weights(model, './qat.pth'))
342342
343343
# Object Detection
344-
model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, pretrained_backbone=False)
344+
model = M.detection.fasterrcnn_mobilenet_v3_large_fpn()
345345
print(store_model_weights(model, './obj.pth'))
346346
347347
# Segmentation
348-
model = M.segmentation.deeplabv3_mobilenet_v3_large(pretrained=False, pretrained_backbone=False, aux_loss=True)
348+
model = M.segmentation.deeplabv3_mobilenet_v3_large(aux_loss=True)
349349
print(store_model_weights(model, './segm.pth', strict=False))
350350
351351
Args:

references/detection/README.md

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,65 +24,65 @@ Except otherwise noted, all models have been trained on 8x V100 GPUs.
2424
```
2525
torchrun --nproc_per_node=8 train.py\
2626
--dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\
27-
--lr-steps 16 22 --aspect-ratio-group-factor 3
27+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
2828
```
2929

3030
### Faster R-CNN MobileNetV3-Large FPN
3131
```
3232
torchrun --nproc_per_node=8 train.py\
3333
--dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\
34-
--lr-steps 16 22 --aspect-ratio-group-factor 3
34+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
3535
```
3636

3737
### Faster R-CNN MobileNetV3-Large 320 FPN
3838
```
3939
torchrun --nproc_per_node=8 train.py\
4040
--dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\
41-
--lr-steps 16 22 --aspect-ratio-group-factor 3
41+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
4242
```
4343

4444
### FCOS ResNet-50 FPN
4545
```
4646
torchrun --nproc_per_node=8 train.py\
4747
--dataset coco --model fcos_resnet50_fpn --epochs 26\
48-
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp
48+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp --weights-backbone ResNet50_Weights.IMAGENET1K_V1
4949
```
5050

5151
### RetinaNet
5252
```
5353
torchrun --nproc_per_node=8 train.py\
5454
--dataset coco --model retinanet_resnet50_fpn --epochs 26\
55-
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
55+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
5656
```
5757

5858
### SSD300 VGG16
5959
```
6060
torchrun --nproc_per_node=8 train.py\
6161
--dataset coco --model ssd300_vgg16 --epochs 120\
6262
--lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\
63-
--weight-decay 0.0005 --data-augmentation ssd
63+
--weight-decay 0.0005 --data-augmentation ssd --weights-backbone VGG16_Weights.IMAGENET1K_FEATURES
6464
```
6565

6666
### SSDlite320 MobileNetV3-Large
6767
```
6868
torchrun --nproc_per_node=8 train.py\
6969
--dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\
7070
--aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\
71-
--weight-decay 0.00004 --data-augmentation ssdlite
71+
--weight-decay 0.00004 --data-augmentation ssdlite --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
7272
```
7373

7474

7575
### Mask R-CNN
7676
```
7777
torchrun --nproc_per_node=8 train.py\
7878
--dataset coco --model maskrcnn_resnet50_fpn --epochs 26\
79-
--lr-steps 16 22 --aspect-ratio-group-factor 3
79+
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
8080
```
8181

8282

8383
### Keypoint R-CNN
8484
```
8585
torchrun --nproc_per_node=8 train.py\
8686
--dataset coco_kp --model keypointrcnn_resnet50_fpn --epochs 46\
87-
--lr-steps 36 43 --aspect-ratio-group-factor 3
87+
--lr-steps 36 43 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
8888
```

0 commit comments

Comments
 (0)