Skip to content

Adding interpolation in meta for all models and cleaning up unused vars #4876

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 6, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion torchvision/prototype/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
from typing import Any, Optional, Union

from torchvision.transforms.functional import InterpolationMode

from ....models.detection.faster_rcnn import (
_mobilenet_extractor,
_resnet_fpn_extractor,
Expand Down Expand Up @@ -28,7 +30,10 @@
]


_common_meta = {"categories": _COCO_CATEGORIES}
_common_meta = {
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}


class FasterRCNNResNet50FPNWeights(Weights):
Expand Down
3 changes: 3 additions & 0 deletions torchvision/prototype/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
from typing import Any, Optional

from torchvision.transforms.functional import InterpolationMode

from ....models.detection.mask_rcnn import (
_resnet_fpn_extractor,
_validate_trainable_layers,
Expand All @@ -27,6 +29,7 @@ class MaskRCNNResNet50FPNWeights(Weights):
transforms=CocoEval,
meta={
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn",
"box_map": 37.9,
"mask_map": 34.6,
Expand Down
3 changes: 3 additions & 0 deletions torchvision/prototype/models/detection/retinanet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import warnings
from typing import Any, Optional

from torchvision.transforms.functional import InterpolationMode

from ....models.detection.retinanet import (
_resnet_fpn_extractor,
_validate_trainable_layers,
Expand Down Expand Up @@ -28,6 +30,7 @@ class RetinaNetResNet50FPNWeights(Weights):
transforms=CocoEval,
meta={
"categories": _COCO_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
"map": 36.4,
},
Expand Down
7 changes: 6 additions & 1 deletion torchvision/prototype/models/segmentation/deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from functools import partial
from typing import Any, Optional

from torchvision.transforms.functional import InterpolationMode

from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
from ...transforms.presets import VocEval
from .._api import Weights, WeightEntry
Expand All @@ -22,7 +24,10 @@
]


_common_meta = {"categories": _VOC_CATEGORIES}
_common_meta = {
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}


class DeepLabV3ResNet50Weights(Weights):
Expand Down
7 changes: 6 additions & 1 deletion torchvision/prototype/models/segmentation/fcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from functools import partial
from typing import Any, Optional

from torchvision.transforms.functional import InterpolationMode

from ....models.segmentation.fcn import FCN, _fcn_resnet
from ...transforms.presets import VocEval
from .._api import Weights, WeightEntry
Expand All @@ -12,7 +14,10 @@
__all__ = ["FCN", "FCNResNet50Weights", "FCNResNet101Weights", "fcn_resnet50", "fcn_resnet101"]


_common_meta = {"categories": _VOC_CATEGORIES}
_common_meta = {
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
}


class FCNResNet50Weights(Weights):
Expand Down
3 changes: 3 additions & 0 deletions torchvision/prototype/models/segmentation/lraspp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from functools import partial
from typing import Any, Optional

from torchvision.transforms.functional import InterpolationMode

from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3
from ...transforms.presets import VocEval
from .._api import Weights, WeightEntry
Expand All @@ -18,6 +20,7 @@ class LRASPPMobileNetV3LargeWeights(Weights):
transforms=partial(VocEval, resize_size=520),
meta={
"categories": _VOC_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large",
"mIoU": 57.9,
"acc": 91.2,
Expand Down
18 changes: 9 additions & 9 deletions torchvision/prototype/models/vgg.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
]


def _vgg(arch: str, cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG:
def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
Expand Down Expand Up @@ -150,7 +150,7 @@ def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwarg
weights = VGG11Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG11Weights.verify(weights)

return _vgg("vgg11", "A", False, weights, progress, **kwargs)
return _vgg("A", False, weights, progress, **kwargs)


def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
Expand All @@ -159,7 +159,7 @@ def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **
weights = VGG11BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG11BNWeights.verify(weights)

return _vgg("vgg11_bn", "A", True, weights, progress, **kwargs)
return _vgg("A", True, weights, progress, **kwargs)


def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
Expand All @@ -168,7 +168,7 @@ def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwarg
weights = VGG13Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG13Weights.verify(weights)

return _vgg("vgg13", "B", False, weights, progress, **kwargs)
return _vgg("B", False, weights, progress, **kwargs)


def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
Expand All @@ -177,7 +177,7 @@ def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **
weights = VGG13BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG13BNWeights.verify(weights)

return _vgg("vgg13_bn", "B", True, weights, progress, **kwargs)
return _vgg("B", True, weights, progress, **kwargs)


def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
Expand All @@ -186,7 +186,7 @@ def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwarg
weights = VGG16Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG16Weights.verify(weights)

return _vgg("vgg16", "D", False, weights, progress, **kwargs)
return _vgg("D", False, weights, progress, **kwargs)


def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
Expand All @@ -195,7 +195,7 @@ def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **
weights = VGG16BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG16BNWeights.verify(weights)

return _vgg("vgg16_bn", "D", True, weights, progress, **kwargs)
return _vgg("D", True, weights, progress, **kwargs)


def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
Expand All @@ -204,7 +204,7 @@ def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwarg
weights = VGG19Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG19Weights.verify(weights)

return _vgg("vgg19", "E", False, weights, progress, **kwargs)
return _vgg("E", False, weights, progress, **kwargs)


def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
Expand All @@ -213,4 +213,4 @@ def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **
weights = VGG19BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
weights = VGG19BNWeights.verify(weights)

return _vgg("vgg19_bn", "E", True, weights, progress, **kwargs)
return _vgg("E", True, weights, progress, **kwargs)