Skip to content

Commit a23778c

Browse files
authored
Adding interpolation in meta for all models and cleaning up unnecessary vars. (#4876)
1 parent ec6f12d commit a23778c

File tree

7 files changed

+36
-12
lines changed

7 files changed

+36
-12
lines changed

torchvision/prototype/models/detection/faster_rcnn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import warnings
22
from typing import Any, Optional, Union
33

4+
from torchvision.transforms.functional import InterpolationMode
5+
46
from ....models.detection.faster_rcnn import (
57
_mobilenet_extractor,
68
_resnet_fpn_extractor,
@@ -28,7 +30,10 @@
2830
]
2931

3032

31-
_common_meta = {"categories": _COCO_CATEGORIES}
33+
_common_meta = {
34+
"categories": _COCO_CATEGORIES,
35+
"interpolation": InterpolationMode.BILINEAR,
36+
}
3237

3338

3439
class FasterRCNNResNet50FPNWeights(Weights):

torchvision/prototype/models/detection/mask_rcnn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import warnings
22
from typing import Any, Optional
33

4+
from torchvision.transforms.functional import InterpolationMode
5+
46
from ....models.detection.mask_rcnn import (
57
_resnet_fpn_extractor,
68
_validate_trainable_layers,
@@ -27,6 +29,7 @@ class MaskRCNNResNet50FPNWeights(Weights):
2729
transforms=CocoEval,
2830
meta={
2931
"categories": _COCO_CATEGORIES,
32+
"interpolation": InterpolationMode.BILINEAR,
3033
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#mask-r-cnn",
3134
"box_map": 37.9,
3235
"mask_map": 34.6,

torchvision/prototype/models/detection/retinanet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import warnings
22
from typing import Any, Optional
33

4+
from torchvision.transforms.functional import InterpolationMode
5+
46
from ....models.detection.retinanet import (
57
_resnet_fpn_extractor,
68
_validate_trainable_layers,
@@ -28,6 +30,7 @@ class RetinaNetResNet50FPNWeights(Weights):
2830
transforms=CocoEval,
2931
meta={
3032
"categories": _COCO_CATEGORIES,
33+
"interpolation": InterpolationMode.BILINEAR,
3134
"recipe": "https://github.com/pytorch/vision/tree/main/references/detection#retinanet",
3235
"map": 36.4,
3336
},

torchvision/prototype/models/segmentation/deeplabv3.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from functools import partial
33
from typing import Any, Optional
44

5+
from torchvision.transforms.functional import InterpolationMode
6+
57
from ....models.segmentation.deeplabv3 import DeepLabV3, _deeplabv3_mobilenetv3, _deeplabv3_resnet
68
from ...transforms.presets import VocEval
79
from .._api import Weights, WeightEntry
@@ -22,7 +24,10 @@
2224
]
2325

2426

25-
_common_meta = {"categories": _VOC_CATEGORIES}
27+
_common_meta = {
28+
"categories": _VOC_CATEGORIES,
29+
"interpolation": InterpolationMode.BILINEAR,
30+
}
2631

2732

2833
class DeepLabV3ResNet50Weights(Weights):

torchvision/prototype/models/segmentation/fcn.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from functools import partial
33
from typing import Any, Optional
44

5+
from torchvision.transforms.functional import InterpolationMode
6+
57
from ....models.segmentation.fcn import FCN, _fcn_resnet
68
from ...transforms.presets import VocEval
79
from .._api import Weights, WeightEntry
@@ -12,7 +14,10 @@
1214
__all__ = ["FCN", "FCNResNet50Weights", "FCNResNet101Weights", "fcn_resnet50", "fcn_resnet101"]
1315

1416

15-
_common_meta = {"categories": _VOC_CATEGORIES}
17+
_common_meta = {
18+
"categories": _VOC_CATEGORIES,
19+
"interpolation": InterpolationMode.BILINEAR,
20+
}
1621

1722

1823
class FCNResNet50Weights(Weights):

torchvision/prototype/models/segmentation/lraspp.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from functools import partial
33
from typing import Any, Optional
44

5+
from torchvision.transforms.functional import InterpolationMode
6+
57
from ....models.segmentation.lraspp import LRASPP, _lraspp_mobilenetv3
68
from ...transforms.presets import VocEval
79
from .._api import Weights, WeightEntry
@@ -18,6 +20,7 @@ class LRASPPMobileNetV3LargeWeights(Weights):
1820
transforms=partial(VocEval, resize_size=520),
1921
meta={
2022
"categories": _VOC_CATEGORIES,
23+
"interpolation": InterpolationMode.BILINEAR,
2124
"recipe": "https://github.com/pytorch/vision/tree/main/references/segmentation#lraspp_mobilenet_v3_large",
2225
"mIoU": 57.9,
2326
"acc": 91.2,

torchvision/prototype/models/vgg.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
]
3232

3333

34-
def _vgg(arch: str, cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG:
34+
def _vgg(cfg: str, batch_norm: bool, weights: Optional[Weights], progress: bool, **kwargs: Any) -> VGG:
3535
if weights is not None:
3636
kwargs["num_classes"] = len(weights.meta["categories"])
3737
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
@@ -150,7 +150,7 @@ def vgg11(weights: Optional[VGG11Weights] = None, progress: bool = True, **kwarg
150150
weights = VGG11Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
151151
weights = VGG11Weights.verify(weights)
152152

153-
return _vgg("vgg11", "A", False, weights, progress, **kwargs)
153+
return _vgg("A", False, weights, progress, **kwargs)
154154

155155

156156
def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
@@ -159,7 +159,7 @@ def vgg11_bn(weights: Optional[VGG11BNWeights] = None, progress: bool = True, **
159159
weights = VGG11BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
160160
weights = VGG11BNWeights.verify(weights)
161161

162-
return _vgg("vgg11_bn", "A", True, weights, progress, **kwargs)
162+
return _vgg("A", True, weights, progress, **kwargs)
163163

164164

165165
def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
@@ -168,7 +168,7 @@ def vgg13(weights: Optional[VGG13Weights] = None, progress: bool = True, **kwarg
168168
weights = VGG13Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
169169
weights = VGG13Weights.verify(weights)
170170

171-
return _vgg("vgg13", "B", False, weights, progress, **kwargs)
171+
return _vgg("B", False, weights, progress, **kwargs)
172172

173173

174174
def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
@@ -177,7 +177,7 @@ def vgg13_bn(weights: Optional[VGG13BNWeights] = None, progress: bool = True, **
177177
weights = VGG13BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
178178
weights = VGG13BNWeights.verify(weights)
179179

180-
return _vgg("vgg13_bn", "B", True, weights, progress, **kwargs)
180+
return _vgg("B", True, weights, progress, **kwargs)
181181

182182

183183
def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
@@ -186,7 +186,7 @@ def vgg16(weights: Optional[VGG16Weights] = None, progress: bool = True, **kwarg
186186
weights = VGG16Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
187187
weights = VGG16Weights.verify(weights)
188188

189-
return _vgg("vgg16", "D", False, weights, progress, **kwargs)
189+
return _vgg("D", False, weights, progress, **kwargs)
190190

191191

192192
def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
@@ -195,7 +195,7 @@ def vgg16_bn(weights: Optional[VGG16BNWeights] = None, progress: bool = True, **
195195
weights = VGG16BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
196196
weights = VGG16BNWeights.verify(weights)
197197

198-
return _vgg("vgg16_bn", "D", True, weights, progress, **kwargs)
198+
return _vgg("D", True, weights, progress, **kwargs)
199199

200200

201201
def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwargs: Any) -> VGG:
@@ -204,7 +204,7 @@ def vgg19(weights: Optional[VGG19Weights] = None, progress: bool = True, **kwarg
204204
weights = VGG19Weights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
205205
weights = VGG19Weights.verify(weights)
206206

207-
return _vgg("vgg19", "E", False, weights, progress, **kwargs)
207+
return _vgg("E", False, weights, progress, **kwargs)
208208

209209

210210
def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **kwargs: Any) -> VGG:
@@ -213,4 +213,4 @@ def vgg19_bn(weights: Optional[VGG19BNWeights] = None, progress: bool = True, **
213213
weights = VGG19BNWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None
214214
weights = VGG19BNWeights.verify(weights)
215215

216-
return _vgg("vgg19_bn", "E", True, weights, progress, **kwargs)
216+
return _vgg("E", True, weights, progress, **kwargs)

0 commit comments

Comments
 (0)