Skip to content

Porting reference scripts and updating presets #5629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
Mar 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 1 addition & 53 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -98,58 +98,6 @@ You can construct a model with random weights by calling its constructor:
convnext_large = models.convnext_large()

We provide pre-trained models, using the PyTorch :mod:`torch.utils.model_zoo`.
These can be constructed by passing ``pretrained=True``:

.. code:: python

import torchvision.models as models
resnet18 = models.resnet18(pretrained=True)
alexnet = models.alexnet(pretrained=True)
squeezenet = models.squeezenet1_0(pretrained=True)
vgg16 = models.vgg16(pretrained=True)
densenet = models.densenet161(pretrained=True)
inception = models.inception_v3(pretrained=True)
googlenet = models.googlenet(pretrained=True)
shufflenet = models.shufflenet_v2_x1_0(pretrained=True)
mobilenet_v2 = models.mobilenet_v2(pretrained=True)
mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True)
mobilenet_v3_small = models.mobilenet_v3_small(pretrained=True)
resnext50_32x4d = models.resnext50_32x4d(pretrained=True)
wide_resnet50_2 = models.wide_resnet50_2(pretrained=True)
mnasnet = models.mnasnet1_0(pretrained=True)
efficientnet_b0 = models.efficientnet_b0(pretrained=True)
efficientnet_b1 = models.efficientnet_b1(pretrained=True)
efficientnet_b2 = models.efficientnet_b2(pretrained=True)
efficientnet_b3 = models.efficientnet_b3(pretrained=True)
efficientnet_b4 = models.efficientnet_b4(pretrained=True)
efficientnet_b5 = models.efficientnet_b5(pretrained=True)
efficientnet_b6 = models.efficientnet_b6(pretrained=True)
efficientnet_b7 = models.efficientnet_b7(pretrained=True)
efficientnet_v2_s = models.efficientnet_v2_s(pretrained=True)
efficientnet_v2_m = models.efficientnet_v2_m(pretrained=True)
efficientnet_v2_l = models.efficientnet_v2_l(pretrained=True)
regnet_y_400mf = models.regnet_y_400mf(pretrained=True)
regnet_y_800mf = models.regnet_y_800mf(pretrained=True)
regnet_y_1_6gf = models.regnet_y_1_6gf(pretrained=True)
regnet_y_3_2gf = models.regnet_y_3_2gf(pretrained=True)
regnet_y_8gf = models.regnet_y_8gf(pretrained=True)
regnet_y_16gf = models.regnet_y_16gf(pretrained=True)
regnet_y_32gf = models.regnet_y_32gf(pretrained=True)
regnet_x_400mf = models.regnet_x_400mf(pretrained=True)
regnet_x_800mf = models.regnet_x_800mf(pretrained=True)
regnet_x_1_6gf = models.regnet_x_1_6gf(pretrained=True)
regnet_x_3_2gf = models.regnet_x_3_2gf(pretrained=True)
regnet_x_8gf = models.regnet_x_8gf(pretrained=True)
regnet_x_16gf = models.regnet_x_16gf(pretrainedTrue)
regnet_x_32gf = models.regnet_x_32gf(pretrained=True)
vit_b_16 = models.vit_b_16(pretrained=True)
vit_b_32 = models.vit_b_32(pretrained=True)
vit_l_16 = models.vit_l_16(pretrained=True)
vit_l_32 = models.vit_l_32(pretrained=True)
convnext_tiny = models.convnext_tiny(pretrained=True)
convnext_small = models.convnext_small(pretrained=True)
convnext_base = models.convnext_base(pretrained=True)
convnext_large = models.convnext_large(pretrained=True)

Instancing a pre-trained model will download its weights to a cache directory.
This directory can be set using the `TORCH_HOME` environment variable. See
Expand Down Expand Up @@ -525,7 +473,7 @@ Obtaining a pre-trained quantized model can be done with a few lines of code:
.. code:: python

import torchvision.models as models
model = models.quantization.mobilenet_v2(pretrained=True, quantize=True)
model = models.quantization.mobilenet_v2(weights=MobileNet_V2_QuantizedWeights.IMAGENET1K_QNNPACK_V1, quantize=True)
model.eval()
# run the model with quantized inputs and weights
out = model(torch.rand(1, 3, 224, 224))
Expand Down
2 changes: 1 addition & 1 deletion gallery/plot_optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def plot(imgs, **imshow_kwargs):
def preprocess(img1_batch, img2_batch):
img1_batch = F.resize(img1_batch, size=[520, 960])
img2_batch = F.resize(img2_batch, size=[520, 960])
return transforms(img1_batch, img2_batch)[:2]
return transforms(img1_batch, img2_batch)


img1_batch, img2_batch = preprocess(img1_batch, img2_batch)
Expand Down
2 changes: 1 addition & 1 deletion gallery/plot_repurposing_annotations.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def show(imgs):
print(img.size())

tranforms = weights.transforms()
img, _ = tranforms(img)
img = tranforms(img)
target = {}
target["boxes"] = boxes
target["labels"] = labels = torch.ones((masks.size(0),), dtype=torch.int64)
Expand Down
8 changes: 4 additions & 4 deletions gallery/plot_visualization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def show(imgs):
weights = FasterRCNN_ResNet50_FPN_Weights.DEFAULT
transforms = weights.transforms()

batch, _ = transforms(batch_int)
batch = transforms(batch_int)

model = fasterrcnn_resnet50_fpn(weights=weights, progress=False)
model = model.eval()
Expand Down Expand Up @@ -131,7 +131,7 @@ def show(imgs):
model = fcn_resnet50(weights=weights, progress=False)
model = model.eval()

normalized_batch, _ = transforms(batch)
normalized_batch = transforms(batch)
output = model(normalized_batch)['out']
print(output.shape, output.min().item(), output.max().item())

Expand Down Expand Up @@ -272,7 +272,7 @@ def show(imgs):
weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
transforms = weights.transforms()

batch, _ = transforms(batch_int)
batch = transforms(batch_int)

model = maskrcnn_resnet50_fpn(weights=weights, progress=False)
model = model.eval()
Expand Down Expand Up @@ -397,7 +397,7 @@ def show(imgs):
weights = KeypointRCNN_ResNet50_FPN_Weights.DEFAULT
transforms = weights.transforms()

person_float, _ = transforms(person_int)
person_float = transforms(person_int)

model = keypointrcnn_resnet50_fpn(weights=weights, progress=False)
model = model.eval()
Expand Down
26 changes: 9 additions & 17 deletions references/classification/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Since it expects tensors with a size of N x 3 x 299 x 299, to validate the model

```
torchrun --nproc_per_node=8 train.py --model inception_v3\
--val-resize-size 342 --val-crop-size 299 --train-crop-size 299 --test-only --pretrained
--test-only --weights Inception_V3_Weights.IMAGENET1K_V1
```

### ResNet
Expand Down Expand Up @@ -96,22 +96,14 @@ The weights of the B5-B7 variants are ported from Luke Melas' [EfficientNet-PyTo

All models were trained using Bicubic interpolation and each have custom crop and resize sizes. To validate the models use the following commands:
```
torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --interpolation bicubic\
--val-resize-size 256 --val-crop-size 224 --train-crop-size 224 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --interpolation bicubic\
--val-resize-size 256 --val-crop-size 240 --train-crop-size 240 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --interpolation bicubic\
--val-resize-size 288 --val-crop-size 288 --train-crop-size 288 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --interpolation bicubic\
--val-resize-size 320 --val-crop-size 300 --train-crop-size 300 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --interpolation bicubic\
--val-resize-size 384 --val-crop-size 380 --train-crop-size 380 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --interpolation bicubic\
--val-resize-size 456 --val-crop-size 456 --train-crop-size 456 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --interpolation bicubic\
--val-resize-size 528 --val-crop-size 528 --train-crop-size 528 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --interpolation bicubic\
--val-resize-size 600 --val-crop-size 600 --train-crop-size 600 --test-only --pretrained
torchrun --nproc_per_node=8 train.py --model efficientnet_b0 --test-only --weights EfficientNet_B0_Weights.IMAGENET1K_V1
torchrun --nproc_per_node=8 train.py --model efficientnet_b1 --test-only --weights EfficientNet_B1_Weights.IMAGENET1K_V1
torchrun --nproc_per_node=8 train.py --model efficientnet_b2 --test-only --weights EfficientNet_B2_Weights.IMAGENET1K_V1
torchrun --nproc_per_node=8 train.py --model efficientnet_b3 --test-only --weights EfficientNet_B3_Weights.IMAGENET1K_V1
torchrun --nproc_per_node=8 train.py --model efficientnet_b4 --test-only --weights EfficientNet_B4_Weights.IMAGENET1K_V1
torchrun --nproc_per_node=8 train.py --model efficientnet_b5 --test-only --weights EfficientNet_B5_Weights.IMAGENET1K_V1
torchrun --nproc_per_node=8 train.py --model efficientnet_b6 --test-only --weights EfficientNet_B6_Weights.IMAGENET1K_V1
torchrun --nproc_per_node=8 train.py --model efficientnet_b7 --test-only --weights EfficientNet_B7_Weights.IMAGENET1K_V1
```


Expand Down
42 changes: 5 additions & 37 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@
from torchvision.transforms.functional import InterpolationMode


try:
from torchvision import prototype
except ImportError:
prototype = None


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, args, model_ema=None, scaler=None):
model.train()
metric_logger = utils.MetricLogger(delimiter=" ")
Expand Down Expand Up @@ -154,18 +148,13 @@ def load_data(traindir, valdir, args):
print(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path)
else:
if not args.prototype:
if args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
preprocessing = weights.transforms()
else:
preprocessing = presets.ClassificationPresetEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)
else:
if args.weights:
weights = prototype.models.get_weight(args.weights)
preprocessing = weights.transforms()
else:
preprocessing = prototype.transforms.ImageClassificationEval(
crop_size=val_crop_size, resize_size=val_resize_size, interpolation=interpolation
)

dataset_test = torchvision.datasets.ImageFolder(
valdir,
Expand All @@ -191,10 +180,6 @@ def load_data(traindir, valdir, args):


def main(args):
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir:
utils.mkdir(args.output_dir)

Expand Down Expand Up @@ -236,10 +221,7 @@ def main(args):
)

print("Creating model")
if not args.prototype:
model = torchvision.models.__dict__[args.model](pretrained=args.pretrained, num_classes=num_classes)
else:
model = prototype.models.__dict__[args.model](weights=args.weights, num_classes=num_classes)
model = torchvision.models.__dict__[args.model](weights=args.weights, num_classes=num_classes)
model.to(device)

if args.distributed and args.sync_bn:
Expand Down Expand Up @@ -446,12 +428,6 @@ def get_args_parser(add_help=True):
help="Only test the model",
action="store_true",
)
parser.add_argument(
"--pretrained",
dest="pretrained",
help="Use pre-trained models from the modelzoo",
action="store_true",
)
parser.add_argument("--auto-augment", default=None, type=str, help="auto augment policy (default: None)")
parser.add_argument("--random-erase", default=0.0, type=float, help="random erasing probability (default: 0.0)")

Expand Down Expand Up @@ -496,14 +472,6 @@ def get_args_parser(add_help=True):
parser.add_argument(
"--ra-reps", default=3, type=int, help="number of repetitions for Repeated Augmentation (default: 3)"
)

# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

return parser
Expand Down
23 changes: 1 addition & 22 deletions references/classification/train_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,7 @@
from train import train_one_epoch, evaluate, load_data


try:
from torchvision import prototype
except ImportError:
prototype = None


def main(args):
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir:
utils.mkdir(args.output_dir)

Expand Down Expand Up @@ -56,10 +46,7 @@ def main(args):

print("Creating model", args.model)
# when training quantized models, we always start from a pre-trained fp32 reference model
if not args.prototype:
model = torchvision.models.quantization.__dict__[args.model](pretrained=True, quantize=args.test_only)
else:
model = prototype.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model = torchvision.models.quantization.__dict__[args.model](weights=args.weights, quantize=args.test_only)
model.to(device)

if not (args.test_only or args.post_training_quantize):
Expand Down Expand Up @@ -264,14 +251,6 @@ def get_args_parser(add_help=True):
"--train-crop-size", default=224, type=int, help="the random crop size used for training (default: 224)"
)
parser.add_argument("--clip-grad-norm", default=None, type=float, help="the maximum gradient norm (default None)")

# Prototype models only
parser.add_argument(
"--prototype",
dest="prototype",
help="Use prototype model builders instead those from main area",
action="store_true",
)
parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")

return parser
Expand Down
8 changes: 4 additions & 4 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,22 +330,22 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
from torchvision import models as M

# Classification
model = M.mobilenet_v3_large(pretrained=False)
model = M.mobilenet_v3_large()
print(store_model_weights(model, './class.pth'))

# Quantized Classification
model = M.quantization.mobilenet_v3_large(pretrained=False, quantize=False)
model = M.quantization.mobilenet_v3_large(quantize=False)
model.fuse_model(is_qat=True)
model.qconfig = torch.ao.quantization.get_default_qat_qconfig('qnnpack')
_ = torch.ao.quantization.prepare_qat(model, inplace=True)
print(store_model_weights(model, './qat.pth'))

# Object Detection
model = M.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, pretrained_backbone=False)
model = M.detection.fasterrcnn_mobilenet_v3_large_fpn()
print(store_model_weights(model, './obj.pth'))

# Segmentation
model = M.segmentation.deeplabv3_mobilenet_v3_large(pretrained=False, pretrained_backbone=False, aux_loss=True)
model = M.segmentation.deeplabv3_mobilenet_v3_large(aux_loss=True)
print(store_model_weights(model, './segm.pth', strict=False))

Args:
Expand Down
18 changes: 9 additions & 9 deletions references/detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,65 +24,65 @@ Except otherwise noted, all models have been trained on 8x V100 GPUs.
```
torchrun --nproc_per_node=8 train.py\
--dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
```

### Faster R-CNN MobileNetV3-Large FPN
```
torchrun --nproc_per_node=8 train.py\
--dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
```

### Faster R-CNN MobileNetV3-Large 320 FPN
```
torchrun --nproc_per_node=8 train.py\
--dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
```

### FCOS ResNet-50 FPN
```
torchrun --nproc_per_node=8 train.py\
--dataset coco --model fcos_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --amp --weights-backbone ResNet50_Weights.IMAGENET1K_V1
```

### RetinaNet
```
torchrun --nproc_per_node=8 train.py\
--dataset coco --model retinanet_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
--lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
```

### SSD300 VGG16
```
torchrun --nproc_per_node=8 train.py\
--dataset coco --model ssd300_vgg16 --epochs 120\
--lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\
--weight-decay 0.0005 --data-augmentation ssd
--weight-decay 0.0005 --data-augmentation ssd --weights-backbone VGG16_Weights.IMAGENET1K_FEATURES
```

### SSDlite320 MobileNetV3-Large
```
torchrun --nproc_per_node=8 train.py\
--dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\
--aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\
--weight-decay 0.00004 --data-augmentation ssdlite
--weight-decay 0.00004 --data-augmentation ssdlite --weights-backbone MobileNet_V3_Large_Weights.IMAGENET1K_V1
```


### Mask R-CNN
```
torchrun --nproc_per_node=8 train.py\
--dataset coco --model maskrcnn_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
--lr-steps 16 22 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
```


### Keypoint R-CNN
```
torchrun --nproc_per_node=8 train.py\
--dataset coco_kp --model keypointrcnn_resnet50_fpn --epochs 46\
--lr-steps 36 43 --aspect-ratio-group-factor 3
--lr-steps 36 43 --aspect-ratio-group-factor 3 --weights-backbone ResNet50_Weights.IMAGENET1K_V1
```
Loading