From dbf61c65ae2899f6b3febbc5ffcc44e8200211cc Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Fri, 5 Nov 2021 18:02:32 +0000 Subject: [PATCH] Adding interpolation in meta for all models and cleaning up unnecessary vars. --- .../prototype/models/detection/faster_rcnn.py | 7 ++++++- .../prototype/models/detection/mask_rcnn.py | 3 +++ .../prototype/models/detection/retinanet.py | 3 +++ .../prototype/models/segmentation/deeplabv3.py | 7 ++++++- .../prototype/models/segmentation/fcn.py | 7 ++++++- .../prototype/models/segmentation/lraspp.py | 3 +++ torchvision/prototype/models/vgg.py | 18 +++++++++--------- 7 files changed, 36 insertions(+), 12 deletions(-) diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index 4f8ec08edc3..27e30df69cd 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -1,6 +1,8 @@ import warnings from typing import Any, Optional, Union +from torchvision.transforms.functional import InterpolationMode + from ....models.detection.faster_rcnn import ( _mobilenet_extractor, _resnet_fpn_extractor, @@ -28,7 +30,10 @@ ] -_common_meta = {"categories": _COCO_CATEGORIES} +_common_meta = { + "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} class FasterRCNNResNet50FPNWeights(Weights): diff --git a/torchvision/prototype/models/detection/mask_rcnn.py b/torchvision/prototype/models/detection/mask_rcnn.py index 5a27db455a7..efce203f1bb 100644 --- a/torchvision/prototype/models/detection/mask_rcnn.py +++ b/torchvision/prototype/models/detection/mask_rcnn.py @@ -1,6 +1,8 @@ import warnings from typing import Any, Optional +from torchvision.transforms.functional import InterpolationMode + from ....models.detection.mask_rcnn import ( _resnet_fpn_extractor, _validate_trainable_layers, @@ -27,6 +29,7 @@ class MaskRCNNResNet50FPNWeights(Weights): transforms=CocoEval, meta={ "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn", "box_map": 37.9, "mask_map": 34.6, diff --git a/torchvision/prototype/models/detection/retinanet.py b/torchvision/prototype/models/detection/retinanet.py index b8d1f8e956f..e44e4fe9285 100644 --- a/torchvision/prototype/models/detection/retinanet.py +++ b/torchvision/prototype/models/detection/retinanet.py @@ -1,6 +1,8 @@ import warnings from typing import Any, Optional +from torchvision.transforms.functional import InterpolationMode + from ....models.detection.retinanet import ( _resnet_fpn_extractor, _validate_trainable_layers, @@ -28,6 +30,7 @@ class RetinaNetResNet50FPNWeights(Weights): transforms=CocoEval, meta={ "categories": _COCO_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, "recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet", "map": 36.4, }, diff --git a/torchvision/prototype/models/segmentation/deeplabv3.py b/torchvision/prototype/models/segmentation/deeplabv3.py index 016e6cc507d..1e7225ab0f2 100644 --- a/torchvision/prototype/models/segmentation/deeplabv3.py +++ b/torchvision/prototype/models/segmentation/deeplabv3.py @@ -2,6 +2,8 @@ from functools import partial from typing import Any, Optional +from torchvision.transforms.functional import InterpolationMode + from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet from ...transforms.presets import VocEval from .._api import Weights, WeightEntry @@ -22,7 +24,10 @@ ] -_common_meta = {"categories": _VOC_CATEGORIES} +_common_meta = { + "categories": _VOC_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} class DeepLabV3ResNet50Weights(Weights): diff --git a/torchvision/prototype/models/segmentation/fcn.py b/torchvision/prototype/models/segmentation/fcn.py index 0c671053bf5..3aa9e533d62 100644 --- a/torchvision/prototype/models/segmentation/fcn.py +++ b/torchvision/prototype/models/segmentation/fcn.py @@ -2,6 +2,8 @@ from functools import partial from typing import Any, Optional +from torchvision.transforms.functional import InterpolationMode + from ....models.segmentation.fcn import FCN, _fcn_resnet from ...transforms.presets import VocEval from .._api import Weights, WeightEntry @@ -12,7 +14,10 @@ __all__ = ["FCN", "FCNResNet50Weights", "FCNResNet101Weights", "fcn_resnet50", "fcn_resnet101"] -_common_meta = {"categories": _VOC_CATEGORIES} +_common_meta = { + "categories": _VOC_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, +} class FCNResNet50Weights(Weights): diff --git a/torchvision/prototype/models/segmentation/lraspp.py b/torchvision/prototype/models/segmentation/lraspp.py index 2a696c24bd4..2cd1a7209e3 100644 --- a/torchvision/prototype/models/segmentation/lraspp.py +++ b/torchvision/prototype/models/segmentation/lraspp.py @@ -2,6 +2,8 @@ from functools import partial from typing import Any, Optional +from torchvision.transforms.functional import InterpolationMode + from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3 from ...transforms.presets import VocEval from .._api import Weights, WeightEntry @@ -18,6 +20,7 @@ class LRASPPMobileNetV3LargeWeights(Weights): transforms=partial(VocEval, resize_size=520), meta={ "categories": _VOC_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, "recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large", "mIoU": 57.9, "acc": 91.2, diff --git a/torchvision/prototype/models/vgg.py b/torchvision/prototype/models/vgg.py index 54afd3b10fc..d031eece194 100644 --- a/torchvision/prototype/models/vgg.py +++ b/torchvision/prototype/models/vgg.py @@ -31,7 +31,7 @@ ] -def _vgg(arch: str, cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG: +def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG: if weights is not None: kwargs["num_classes"] = len(weights.meta["categories"]) model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs) @@ -150,7 +150,7 @@ def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwarg weights = VGG11Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG11Weights.verify(weights) - return _vgg("vgg11", "A", False, weights, progress, **kwargs) + return _vgg("A", False, weights, progress, **kwargs) def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: @@ -159,7 +159,7 @@ def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, ** weights = VGG11BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG11BNWeights.verify(weights) - return _vgg("vgg11_bn", "A", True, weights, progress, **kwargs) + return _vgg("A", True, weights, progress, **kwargs) def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: @@ -168,7 +168,7 @@ def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwarg weights = VGG13Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG13Weights.verify(weights) - return _vgg("vgg13", "B", False, weights, progress, **kwargs) + return _vgg("B", False, weights, progress, **kwargs) def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: @@ -177,7 +177,7 @@ def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, ** weights = VGG13BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG13BNWeights.verify(weights) - return _vgg("vgg13_bn", "B", True, weights, progress, **kwargs) + return _vgg("B", True, weights, progress, **kwargs) def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: @@ -186,7 +186,7 @@ def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwarg weights = VGG16Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG16Weights.verify(weights) - return _vgg("vgg16", "D", False, weights, progress, **kwargs) + return _vgg("D", False, weights, progress, **kwargs) def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: @@ -195,7 +195,7 @@ def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, ** weights = VGG16BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG16BNWeights.verify(weights) - return _vgg("vgg16_bn", "D", True, weights, progress, **kwargs) + return _vgg("D", True, weights, progress, **kwargs) def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG: @@ -204,7 +204,7 @@ def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwarg weights = VGG19Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG19Weights.verify(weights) - return _vgg("vgg19", "E", False, weights, progress, **kwargs) + return _vgg("E", False, weights, progress, **kwargs) def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG: @@ -213,4 +213,4 @@ def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, ** weights = VGG19BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None weights = VGG19BNWeights.verify(weights) - return _vgg("vgg19_bn", "E", True, weights, progress, **kwargs) + return _vgg("E", True, weights, progress, **kwargs)