From 144c6cf06b67e85573cfb43633a97de88be01993 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 3 Nov 2021 12:14:27 +0000 Subject: [PATCH 1/3] Aligning exception with all other models. --- torchvision/models/shufflenetv2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/shufflenetv2.py b/torchvision/models/shufflenetv2.py index c2af51d8ecf..f3758c54aaf 100644 --- a/torchvision/models/shufflenetv2.py +++ b/torchvision/models/shufflenetv2.py @@ -162,7 +162,7 @@ def _shufflenetv2(arch: str, pretrained: bool, progress: bool, *args: Any, **kwa if pretrained: model_url = model_urls[arch] if model_url is None: - raise NotImplementedError(f"pretrained {arch} is not supported as of now") + raise ValueError(f"No checkpoint is available for model type {arch}") else: state_dict = load_state_dict_from_url(model_url, progress=progress) model.load_state_dict(state_dict) From b1215fbb88cbaa27f591e869340c5e3baf76e5c7 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 3 Nov 2021 12:31:57 +0000 Subject: [PATCH 2/3] Adding prototype preprocessing on video references. --- references/detection/train.py | 37 ++++++++++++++++++++++++++--------- test/test_prototype_models.py | 14 +++++++++++++ 2 files changed, 42 insertions(+), 9 deletions(-) diff --git a/references/detection/train.py b/references/detection/train.py index ce74ff22b30..c05f80f8740 100644 --- a/references/detection/train.py +++ b/references/detection/train.py @@ -33,6 +33,12 @@ from group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups +try: + from torchvision.prototype import models as PM +except ImportError: + PM = None + + def get_dataset(name, image_set, transform, data_path): paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)} p, ds_fn, num_classes = paths[name] @@ -41,8 +47,15 @@ def get_dataset(name, image_set, transform, data_path): return ds, num_classes -def get_transform(train, data_augmentation): - return presets.DetectionPresetTrain(data_augmentation) if train else presets.DetectionPresetEval() +def get_transform(train, args): + if train: + return presets.DetectionPresetTrain(args.data_augmentation) + elif not args.weights: + return presets.DetectionPresetEval() + else: + fn = PM.detection.__dict__[args.model] + weights = PM._api.get_weight(fn, args.weights) + return weights.transforms() def get_args_parser(add_help=True): @@ -128,6 +141,9 @@ def get_args_parser(add_help=True): parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes") parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training") + # Prototype models only + parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load") + return parser @@ -143,10 +159,8 @@ def main(args): # Data loading code print("Loading data") - dataset, num_classes = get_dataset( - args.dataset, "train", get_transform(True, args.data_augmentation), args.data_path - ) - dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args.data_augmentation), args.data_path) + dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path) + dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path) print("Creating data loaders") if args.distributed: @@ -175,9 +189,14 @@ def main(args): if "rcnn" in args.model: if args.rpn_score_thresh is not None: kwargs["rpn_score_thresh"] = args.rpn_score_thresh - model = torchvision.models.detection.__dict__[args.model]( - num_classes=num_classes, pretrained=args.pretrained, **kwargs - ) + if not args.weights: + model = torchvision.models.detection.__dict__[args.model]( + pretrained=args.pretrained, num_classes=num_classes, **kwargs + ) + else: + if PM is None: + raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.") + model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs) model.to(device) if args.distributed and args.sync_bn: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) diff --git a/test/test_prototype_models.py b/test/test_prototype_models.py index 7a39b1c9a7f..e781653f073 100644 --- a/test/test_prototype_models.py +++ b/test/test_prototype_models.py @@ -48,6 +48,13 @@ def test_classification_model(model_fn, dev): TM.test_classification_model(model_fn, dev) +@pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.detection)) +@pytest.mark.parametrize("dev", cpu_and_gpu()) +@pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled") +def test_detection_model(model_fn, dev): + TM.test_detection_model(model_fn, dev) + + @pytest.mark.parametrize("model_fn", TM.get_models_from_module(models.quantization)) @pytest.mark.skipif(os.getenv("PYTORCH_TEST_WITH_PROTOTYPE", "0") == "0", reason="Prototype code tests are disabled") def test_quantized_classification_model(model_fn): @@ -71,6 +78,7 @@ def test_video_model(model_fn, dev): @pytest.mark.parametrize( "model_fn, module_name", get_models_with_module_names(models) + + get_models_with_module_names(models.detection) + get_models_with_module_names(models.quantization) + get_models_with_module_names(models.segmentation) + get_models_with_module_names(models.video), @@ -82,6 +90,9 @@ def test_old_vs_new_factory(model_fn, module_name, dev): "models": { "input_shape": (1, 3, 224, 224), }, + "detection": { + "input_shape": (3, 300, 300), + }, "quantization": { "input_shape": (1, 3, 224, 224), }, @@ -95,7 +106,10 @@ def test_old_vs_new_factory(model_fn, module_name, dev): model_name = model_fn.__name__ kwargs = {"pretrained": True, **defaults[module_name], **TM._model_params.get(model_name, {})} input_shape = kwargs.pop("input_shape") + kwargs.pop("num_classes", None) # ignore this as it's an incompatible speed optimization for pre-trained models x = torch.rand(input_shape).to(device=dev) + if module_name == "detection": + x = [x] # compare with new model builder parameterized in the old fashion way model_old = _build_model(_get_original_model(model_fn), **kwargs).to(device=dev) From 2829f98fae1c70397a20ab12f3b5944279c498be Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Wed, 3 Nov 2021 12:32:14 +0000 Subject: [PATCH 3/3] Adding the rest of model builders on faster_rcnn. --- .../prototype/models/detection/faster_rcnn.py | 152 +++++++++++++++++- 1 file changed, 148 insertions(+), 4 deletions(-) diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index bb3817c6b45..4f8ec08edc3 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -1,9 +1,11 @@ import warnings -from typing import Any, Optional +from typing import Any, Optional, Union from ....models.detection.faster_rcnn import ( - _validate_trainable_layers, + _mobilenet_extractor, _resnet_fpn_extractor, + _validate_trainable_layers, + AnchorGenerator, FasterRCNN, misc_nn_ops, overwrite_eps, @@ -11,10 +13,22 @@ from ...transforms.presets import CocoEval from .._api import Weights, WeightEntry from .._meta import _COCO_CATEGORIES +from ..mobilenetv3 import MobileNetV3LargeWeights, mobilenet_v3_large from ..resnet import ResNet50Weights, resnet50 -__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"] +__all__ = [ + "FasterRCNN", + "FasterRCNNResNet50FPNWeights", + "FasterRCNNMobileNetV3LargeFPNWeights", + "FasterRCNNMobileNetV3Large320FPNWeights", + "fasterrcnn_resnet50_fpn", + "fasterrcnn_mobilenet_v3_large_fpn", + "fasterrcnn_mobilenet_v3_large_320_fpn", +] + + +_common_meta = {"categories": _COCO_CATEGORIES} class FasterRCNNResNet50FPNWeights(Weights): @@ -22,13 +36,37 @@ class FasterRCNNResNet50FPNWeights(Weights): url="https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth", transforms=CocoEval, meta={ - "categories": _COCO_CATEGORIES, + **_common_meta, "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-resnet-50-fpn", "map": 37.0, }, ) +class FasterRCNNMobileNetV3LargeFPNWeights(Weights): + Coco_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-fb6a3cc7.pth", + transforms=CocoEval, + meta={ + **_common_meta, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-fpn", + "map": 32.8, + }, + ) + + +class FasterRCNNMobileNetV3Large320FPNWeights(Weights): + Coco_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_320_fpn-907ea3f9.pth", + transforms=CocoEval, + meta={ + **_common_meta, + "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#faster-r-cnn-mobilenetv3-large-320-fpn", + "map": 22.8, + }, + ) + + def fasterrcnn_resnet50_fpn( weights: Optional[FasterRCNNResNet50FPNWeights] = None, weights_backbone: Optional[ResNet50Weights] = None, @@ -64,3 +102,109 @@ def fasterrcnn_resnet50_fpn( overwrite_eps(model, 0.0) return model + + +def _fasterrcnn_mobilenet_v3_large_fpn( + weights: Optional[Union[FasterRCNNMobileNetV3LargeFPNWeights, FasterRCNNMobileNetV3Large320FPNWeights]] = None, + weights_backbone: Optional[MobileNetV3LargeWeights] = None, + progress: bool = True, + num_classes: int = 91, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: + if weights is not None: + weights_backbone = None + num_classes = len(weights.meta["categories"]) + + trainable_backbone_layers = _validate_trainable_layers( + weights is not None or weights_backbone is not None, trainable_backbone_layers, 6, 3 + ) + + backbone = mobilenet_v3_large(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) + anchor_sizes = ( + ( + 32, + 64, + 128, + 256, + 512, + ), + ) * 3 + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) + model = FasterRCNN( + backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs + ) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + + return model + + +def fasterrcnn_mobilenet_v3_large_fpn( + weights: Optional[FasterRCNNMobileNetV3LargeFPNWeights] = None, + weights_backbone: Optional[MobileNetV3LargeWeights] = None, + progress: bool = True, + num_classes: int = 91, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = FasterRCNNMobileNetV3LargeFPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None + weights = FasterRCNNMobileNetV3LargeFPNWeights.verify(weights) + if "pretrained_backbone" in kwargs: + warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") + weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None + weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) + + defaults = { + "rpn_score_thresh": 0.05, + } + + kwargs = {**defaults, **kwargs} + return _fasterrcnn_mobilenet_v3_large_fpn( + weights, + weights_backbone, + progress, + num_classes, + trainable_backbone_layers, + **kwargs, + ) + + +def fasterrcnn_mobilenet_v3_large_320_fpn( + weights: Optional[FasterRCNNMobileNetV3Large320FPNWeights] = None, + weights_backbone: Optional[MobileNetV3LargeWeights] = None, + progress: bool = True, + num_classes: int = 91, + trainable_backbone_layers: Optional[int] = None, + **kwargs: Any, +) -> FasterRCNN: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = FasterRCNNMobileNetV3Large320FPNWeights.Coco_RefV1 if kwargs.pop("pretrained") else None + weights = FasterRCNNMobileNetV3Large320FPNWeights.verify(weights) + if "pretrained_backbone" in kwargs: + warnings.warn("The argument pretrained_backbone is deprecated, please use weights_backbone instead.") + weights_backbone = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained_backbone") else None + weights_backbone = MobileNetV3LargeWeights.verify(weights_backbone) + + defaults = { + "min_size": 320, + "max_size": 640, + "rpn_pre_nms_top_n_test": 150, + "rpn_post_nms_top_n_test": 150, + "rpn_score_thresh": 0.05, + } + + kwargs = {**defaults, **kwargs} + return _fasterrcnn_mobilenet_v3_large_fpn( + weights, + weights_backbone, + progress, + num_classes, + trainable_backbone_layers, + **kwargs, + )