From ef1c22021e7db203a0e0d9fb5c733d74de73f221 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 19 Oct 2021 16:06:15 +0100 Subject: [PATCH 1/7] Refactoring resnet_fpn backbone building. --- .../models/detection/backbone_utils.py | 8 ++++---- torchvision/models/detection/faster_rcnn.py | 8 ++++++-- .../models/detection/backbone_utils.py | 20 ------------------- .../prototype/models/detection/faster_rcnn.py | 14 +++++++++---- 4 files changed, 20 insertions(+), 30 deletions(-) delete mode 100644 torchvision/prototype/models/detection/backbone_utils.py diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index e697e06929a..cb6a835c62c 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -100,14 +100,14 @@ def resnet_fpn_backbone( default a ``LastLevelMaxPool`` is used. """ backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) - return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks) + return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks) -def _resnet_backbone_config( +def _resnet_fpn_extractor( backbone: resnet.ResNet, trainable_layers: int, - returned_layers: Optional[List[int]], - extra_blocks: Optional[ExtraFPNBlock], + returned_layers: Optional[List[int]] = None, + extra_blocks: Optional[ExtraFPNBlock] = None, ) -> BackboneWithFPN: # select layers that wont be frozen diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 02da39e8c73..7231aef297a 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -3,9 +3,11 @@ from torchvision.ops import MultiScaleRoIAlign from ..._internally_replaced_utils import load_state_dict_from_url +from ...ops import misc as misc_nn_ops +from ..resnet import resnet50 from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator -from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone +from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, mobilenet_backbone from .generalized_rcnn import GeneralizedRCNN from .roi_heads import RoIHeads from .rpn import RPNHead, RegionProposalNetwork @@ -385,7 +387,9 @@ def fasterrcnn_resnet50_fpn( if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers) + + backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = FasterRCNN(backbone, num_classes, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress) diff --git a/torchvision/prototype/models/detection/backbone_utils.py b/torchvision/prototype/models/detection/backbone_utils.py deleted file mode 100644 index d95b1ab52f0..00000000000 --- a/torchvision/prototype/models/detection/backbone_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Callable, Optional, List - -from torch import nn - -from ....models.detection.backbone_utils import misc_nn_ops, _resnet_backbone_config, BackboneWithFPN, ExtraFPNBlock -from .. import resnet -from .._api import Weights - - -def resnet_fpn_backbone( - backbone_name: str, - weights: Optional[Weights], - norm_layer: Callable[..., nn.Module] = misc_nn_ops.FrozenBatchNorm2d, - trainable_layers: int = 3, - returned_layers: Optional[List[int]] = None, - extra_blocks: Optional[ExtraFPNBlock] = None, -) -> BackboneWithFPN: - - backbone = resnet.__dict__[backbone_name](weights=weights, norm_layer=norm_layer) - return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks) diff --git a/torchvision/prototype/models/detection/faster_rcnn.py b/torchvision/prototype/models/detection/faster_rcnn.py index 0b27eb50a37..bb3817c6b45 100644 --- a/torchvision/prototype/models/detection/faster_rcnn.py +++ b/torchvision/prototype/models/detection/faster_rcnn.py @@ -1,12 +1,17 @@ import warnings from typing import Any, Optional -from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers +from ....models.detection.faster_rcnn import ( + _validate_trainable_layers, + _resnet_fpn_extractor, + FasterRCNN, + misc_nn_ops, + overwrite_eps, +) from ...transforms.presets import CocoEval from .._api import Weights, WeightEntry from .._meta import _COCO_CATEGORIES -from ..resnet import ResNet50Weights -from .backbone_utils import resnet_fpn_backbone +from ..resnet import ResNet50Weights, resnet50 __all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"] @@ -49,7 +54,8 @@ def fasterrcnn_resnet50_fpn( weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3 ) - backbone = resnet_fpn_backbone("resnet50", weights_backbone, trainable_layers=trainable_backbone_layers) + backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = FasterRCNN(backbone, num_classes=num_classes, **kwargs) if weights is not None: From ed319d7ed1182f70c074edee7e0e85acf2c89852 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 19 Oct 2021 16:15:12 +0100 Subject: [PATCH 2/7] Passing the change to *_rcnn and retinanet. --- torchvision/models/detection/keypoint_rcnn.py | 8 ++++++-- torchvision/models/detection/mask_rcnn.py | 8 ++++++-- torchvision/models/detection/retinanet.py | 14 +++++++------- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 7cd975ea6a0..d16d1600c77 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -3,8 +3,10 @@ from torchvision.ops import MultiScaleRoIAlign from ..._internally_replaced_utils import load_state_dict_from_url +from ...ops import misc as misc_nn_ops +from ..resnet import resnet50 from ._utils import overwrite_eps -from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers +from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .faster_rcnn import FasterRCNN @@ -367,7 +369,9 @@ def keypointrcnn_resnet50_fpn( if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers) + + backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs) if pretrained: key = "keypointrcnn_resnet50_fpn_coco" diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 6b8208b19d8..ca2cbc114f0 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -4,8 +4,10 @@ from torchvision.ops import MultiScaleRoIAlign from ..._internally_replaced_utils import load_state_dict_from_url +from ...ops import misc as misc_nn_ops +from ..resnet import resnet50 from ._utils import overwrite_eps -from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers +from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .faster_rcnn import FasterRCNN __all__ = [ @@ -364,7 +366,9 @@ def maskrcnn_resnet50_fpn( if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers) + + backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) + backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers) model = MaskRCNN(backbone, num_classes, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls["maskrcnn_resnet50_fpn_coco"], progress=progress) diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index eb05144cb0c..4b16d7edc7f 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -9,11 +9,13 @@ from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import sigmoid_focal_loss from ...ops import boxes as box_ops +from ...ops import misc as misc_nn_ops from ...ops.feature_pyramid_network import LastLevelP6P7 +from ..resnet import resnet50 from . import _utils as det_utils from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator -from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers +from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers from .transform import GeneralizedRCNNTransform @@ -630,13 +632,11 @@ def retinanet_resnet50_fpn( if pretrained: # no need to download the backbone if pretrained is set pretrained_backbone = False + + backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d) # skip P2 because it generates too many anchors (according to their paper) - backbone = resnet_fpn_backbone( - "resnet50", - pretrained_backbone, - returned_layers=[2, 3, 4], - extra_blocks=LastLevelP6P7(256, 256), - trainable_layers=trainable_backbone_layers, + backbone = _resnet_fpn_extractor( + backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256) ) model = RetinaNet(backbone, num_classes, **kwargs) if pretrained: From 9ff8df54d05a7d9aeaf1a5194b35c816997740cf Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 19 Oct 2021 16:30:42 +0100 Subject: [PATCH 3/7] Applying for faster_rcnn + mobilenetv3 --- torchvision/models/detection/backbone_utils.py | 13 +++++++++++-- torchvision/models/detection/faster_rcnn.py | 9 ++++++--- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index cb6a835c62c..1c2bceacda0 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -1,5 +1,5 @@ import warnings -from typing import Callable, Dict, Optional, List +from typing import Callable, Dict, Optional, List, Union from torch import nn, Tensor from torchvision.ops import misc as misc_nn_ops @@ -165,9 +165,18 @@ def mobilenet_backbone( returned_layers: Optional[List[int]] = None, extra_blocks: Optional[ExtraFPNBlock] = None, ) -> nn.Module: + backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer) + return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks) - backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features +def _mobilenet_extractor( + backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3], + fpn: bool, + trainable_layers, + returned_layers: Optional[List[int]] = None, + extra_blocks: Optional[ExtraFPNBlock] = None, +) -> nn.Module: + backbone = backbone.features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # The first and last blocks are always included because they are the C0 (conv1) and Cn. stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 7231aef297a..30a22529a00 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -4,10 +4,11 @@ from ..._internally_replaced_utils import load_state_dict_from_url from ...ops import misc as misc_nn_ops +from ..mobilenetv3 import mobilenet_v3_large from ..resnet import resnet50 from ._utils import overwrite_eps from .anchor_utils import AnchorGenerator -from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, mobilenet_backbone +from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, _mobilenet_extractor from .generalized_rcnn import GeneralizedRCNN from .roi_heads import RoIHeads from .rpn import RPNHead, RegionProposalNetwork @@ -413,9 +414,11 @@ def _fasterrcnn_mobilenet_v3_large_fpn( if pretrained: pretrained_backbone = False - backbone = mobilenet_backbone( - "mobilenet_v3_large", pretrained_backbone, True, trainable_layers=trainable_backbone_layers + + backbone = mobilenet_v3_large( + pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d ) + backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers) anchor_sizes = ( ( From 33f6db1e50e5d4aee413f0c1e5845dcf032e7c5e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 19 Oct 2021 16:46:31 +0100 Subject: [PATCH 4/7] Applying for ssdlite + mobilenetv3 --- torchvision/models/detection/ssdlite.py | 30 +++++++++---------------- 1 file changed, 11 insertions(+), 19 deletions(-) diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 503b69d7380..e32aebcf839 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -1,7 +1,7 @@ import warnings from collections import OrderedDict from functools import partial -from typing import Any, Callable, Dict, List, Optional +from typing import Any, Callable, Dict, List, Optional, Union import torch from torch import nn, Tensor @@ -117,7 +117,6 @@ def __init__( norm_layer: Callable[..., nn.Module], width_mult: float = 1.0, min_depth: int = 16, - **kwargs: Any, ): super().__init__() @@ -156,20 +155,11 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: def _mobilenet_extractor( - backbone_name: str, - progress: bool, - pretrained: bool, + backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3], trainable_layers: int, norm_layer: Callable[..., nn.Module], - **kwargs: Any, ): - backbone = mobilenet.__dict__[backbone_name]( - pretrained=pretrained, progress=progress, norm_layer=norm_layer, **kwargs - ).features - if not pretrained: - # Change the default initialization scheme if not pretrained - _normal_init(backbone) - + backbone = backbone.features # Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. # The first and last blocks are always included because they are the C0 (conv1) and Cn. stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1] @@ -183,7 +173,7 @@ def _mobilenet_extractor( for parameter in b.parameters(): parameter.requires_grad_(False) - return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs) + return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer) def ssdlite320_mobilenet_v3_large( @@ -235,14 +225,16 @@ def ssdlite320_mobilenet_v3_large( if norm_layer is None: norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) + backbone = mobilenet.mobilenet_v3_large( + pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs + ) + if not pretrained_backbone: + # Change the default initialization scheme if not pretrained + _normal_init(backbone) backbone = _mobilenet_extractor( - "mobilenet_v3_large", - progress, - pretrained_backbone, + backbone, trainable_backbone_layers, norm_layer, - reduced_tail=reduce_tail, - **kwargs, ) size = (320, 320) From 893ff410c332a0dcc2bd3ea9b9f9f22e6f8ce37c Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 19 Oct 2021 16:54:54 +0100 Subject: [PATCH 5/7] Applying for ssd + vgg16 --- torchvision/models/detection/ssd.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index c7d74b4e1af..1af8c45936b 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -519,18 +519,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return OrderedDict([(str(i), v) for i, v in enumerate(output)]) -def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int): - if backbone_name in backbone_urls: - # Use custom backbones more appropriate for SSD - arch = backbone_name.split("_")[0] - backbone = vgg.__dict__[arch](pretrained=False, progress=progress).features - if pretrained: - state_dict = load_state_dict_from_url(backbone_urls[backbone_name], progress=progress) - backbone.load_state_dict(state_dict) - else: - # Use standard backbones from TorchVision - backbone = vgg.__dict__[backbone_name](pretrained=pretrained, progress=progress).features - +def _vgg_extractor(backbone: nn.Module, highres: bool, trainable_layers: int): # Gather the indices of maxpools. These are the locations of output blocks. stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1] num_stages = len(stage_indices) @@ -609,7 +598,13 @@ def ssd300_vgg16( # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers) + # Use custom backbones more appropriate for SSD + backbone = vgg.vgg16(pretrained=False, progress=progress).features # TODO: pass a full vgg instead + if pretrained_backbone: + state_dict = load_state_dict_from_url(backbone_urls["vgg16_features"], progress=progress) + backbone.load_state_dict(state_dict) + + backbone = _vgg_extractor(backbone, False, trainable_backbone_layers) anchor_generator = DefaultBoxGenerator( [[2], [2, 3], [2, 3], [2, 3], [2], [2]], scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], From 6d60c5fc8ee14f37b64c117ecb85e5f64093ee17 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 19 Oct 2021 18:46:09 +0100 Subject: [PATCH 6/7] Update the expected file of retinanet_resnet50_fpn to fix order of initialization. --- ...ter.test_retinanet_resnet50_fpn_expect.pkl | Bin 9571 -> 9571 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl b/test/expect/ModelTester.test_retinanet_resnet50_fpn_expect.pkl index b8fd97cbfe148da67a93ab51e1583b4b14efa252..7fb8d66b080dfdcebb4bed386cd752b99398b779 100644 GIT binary patch literal 9571 zcmb_i2{={T+dqe6I_5DnAVWl^kk||ReI-IklO!RP?p2(SCS6IQdAP}KE*fu1X)aB- zQY!b7R5z7dw-gP?5E;I6w$}aM&VBmd@A;l@J|=*YxK+xYKeeL#&D`c+AlnJnU^i1?Nva$*lrar34nh-9WF z#)*lxQ<4&A#wG}-%$XLIG&44OO5(KH32}3#%ykrsWGyX26vTuCF*0(-G?Ac(BJ_!k zi%cBTV-$+yLil2gy+}SpRxGE+Cy5jS3j~QG#nJw*{yP4G43UyQ51VW z(uh$AwBaX;G)McZ`RgnbY5A)!6KQA0Na#9=BHdt-9!HOjjG7jkAnMga|B`k<-0T^D zWv$O;?I<$%;?eMn$KDc;eX=-@eG^6ff<;D2k+wnp3W}EQpF&x2S0y#*n-ikO!>=a9 z&fW(+rKaPnx%uBH<*ij99u#qjz?^nvOwaN{stat5#J?j@F+-1)$O7IVET|9ut9^-O z6S%BbN4KKWtW1hOrq}!mQ5W06>)bX^e{W9=yKxy>S`Tpk_eL3^o9AIz`27=zTDYB! zeLWa=op=qUw}T-loMmT)cLP)N5Nz7+K-JHOpg?~J?r`_TQn!uZ+CtUy3Mi@VL$P~@$K9_0t2uqzPX z$C;sXjR%)ci0W?uetSV{_7V2=#IumyUIx$SkHD}3MOf{`gW9yGP+!>sxx=c&edqB8 zTjj&0f!jgBeH?hDrh?lP9(Jo%eMVr^i3un_w1~q8@vlSZhSTt>Y!fT| za|z7Tx&tXi{c-x0`&@jyW`U~~r9x0ZwwJ_zpAnJ3i#`aeYIm>;l(QkRN|qb@K~D#u z#_CRXV|O?@l|JUU>z3{JYTn=sEk?$8(YXzxWrVmt(3!)tKg{GXrJ27&o{1IOy-Vlv zv3QjRvhEG=Tit!Q=gGrA`zY55<)Lii0E+inx=aesu%P}(rJd(`dQGOG8uKIXyLYrYJkelQ$k2~Hgw;83s(1~_zqP1ywWaLB^% z2H+H5KU92eC5{<^vi0Uy$%m$shgri0Bb>bHC4`$D14pZs5Pkh6 z2qtS{TEPP@-)Co>v1WS?xNAIPFMXE7?9I6_=7-5xnjVZ+gBNiAn=16tt9=m64flcS z_x3&g)sjWM?A2`8UKSJ0Y9K%UCFd`9wH;PobAtYz8#&#~NDa3%WkBf>RWuwj4?^nw z&_YfP$JnHD?Jq(;Li`pxSk-lz(?dU-;QZz5{&D@B>D%g6;jY3qYGlKXj%JWv&sCLv z*Z?_V9}(EG&k`e@b2)!jh3XhGfr0GERM>aZ2-tbrC?Cs1=Ftm?GB?MZsv$Ia2`tW1L$gvLD(*MKj0z?6h;>CK(;wCR z#^Ft)Lmc1ug)QbB8-nRRrs6q`K*#z$TrP_iYs0DaXE_Y-xeD^i%h`yxo+!xdfYU!U zz%|i0h+eLPzJuGrro0Fit?P%ny;QmJe(=T~tJAVL?&fPMxZrLOcrQN$eU^yl;VNQU zV<#75s0+al7H47qT^=s5Pltt({cudK67s@&0q@RYNOP})G_L^|@F5Q}+owQ2RKlza z0`%0afGORIkfI`o_Et|oIKu?}FY92&XUV=8fwfzQbL&^~``3F*xLA_w+w|>C_Vn1} zu=?(`FSZC=J|GjcSVa!M=70}|!Fp(E^uYOsd@lc(ojq+5e)K3u5GF>#$Y{ehcFq+DjZcIpBiq1spbXlwWw3jF0n``CpqYUgK6`Kp?C+d_afd&`zN^LgT-uKE3N{o67qRy+y8W^dR>k6*(` z_uFvFR2@fLNrUMIw%A8U8D0ExLC118yVpeqg9;fuVAUHPG*$89HhVn3%8%>Y>$*O` zpEwmtW-6h}5>t%(q>2y65oFr~utU#?eKb7)VtntyIZIXC661oeM;oB;&@rg@Y$Rqx z$l?LzK`<}Z2{rcU;E_BzyitD-j`cOePIG&_Hc5QnDYxLpKI-5^^zT0y_tc$&(J7fW6H<65Z@lD`3dupWdaV-+|XD+B1^XM_ih&2g(=Iq>6#qNS1z z>OAR(lON{5^maFNeWi&zYxhIRjAO8O_CvV!uICz)#FMNakX`d7R~}D3Eg8oD)WJ?o zvc(7&0a{N7T=38XRr^1N6&gaE=O~X($^j^MTA$;{Gga_hv<9Bt8ISKrjS$-+aP~?b z&KX<+{$nRVv4SmrDAVC`l|)N?CW6nVWpIXXiyDbHIecuF9p2Z~#Os=>pgz?ezZ+fw ze6H6AZ^#wqU)gT;U&;F4s4e&CPZb;hW^%@RWnxQcU&K%wDK*e+{; zN`ezyPC@YzUR$3Q!>Ri@uZ;qG3C(9a`-Y@_cKDiCA?yBObCGGH|{RoWp z*T%GkY8dM}5Sy|Bal>Xs?238`$FG~=y`A=0N>ei!PdCDA59^@uk{gCU zwMUuGKG+VHa53K=Ki)2Zwwo$g``|nbf8v3x2_g~Cb-Ymqd%w~g`^`H60~Ne*LE|vw zsgLEx%Xp^}PTn&H^zK`rj7JNo6n4YRlo|+cH^iHVZLs8WDIE9s9+G6uam6VYl)fLO zUL1_GUL=D2`q3ze{tQ_nUA(YqFJQX^mOR$O;M5l|BKJN!u2LBzCf|j<^WTGe!%!@p zDBh>V^EH1<9SFxN;pPH0j4Bv`)>S*fvTU>0lHM<1?2^Zjn@|M@%ZB2JK>`l%o2`OF z3RN+5mpKNW2*5RUzd+?-dB{-t1WO_xKzo}K`X1N{ccRSk+&F!_I3y6Q$5wJ{VbwYo zKY7?;$%Gp4E)GD`gn?M~qXzm6bm7{)O$bK0zJXOwWifJ|G8$h*+)~j2`43%D?x{8x z|GD}cZvN~I*1*GBH8!-MFZS7y1k7DW^s@a3fp#vqtull2C+T0M>pKn~99$;0@pz|x zt>ttX_B2%QAHljWAAxVS`r@S#lR5uZE3|RU8cTR1KJU$zy$3C}zknd@56I7MhrES@ zaNqAvu=@Q*n00CsY>QDwTfK2&o}@%z|JZnF*q#O*Csx6gLq)J^j{=(1Y=-ES#V}O0 z1j1acadDUsRfnix%Lo(kIV%9?#W};wYMsfJi98YvB z?S}Y|6H&du1v8%6bNn}h?mD?{h+fj zsMZ)$#;T)>?`>$fWrKA;+G2|)i$QjheO|oZuKWnSyUsz0wh&DQyn>MneDSBlra1o3 z1CX5WiuSAZ(f;RS+WW3PF(*z5zXyehLOMzHg z;i*XioEK?_?q*(S@8ivlXTn()jsLjeDTyC#R15)QXF#~;M6gwBfam2wXtUuqlq=?r!^1GARt0}_ z^}za(Z(+;xV7xil0hf9hq5RdMm@?myi*FIo2E0E6cyEyoZfOCWuwwwar6a(D&D^b+-GG?0R!BBhv(gnsP(EAu6FvN)dm}Uc13b_;PJ@$u~_hS`1~m-U7ymc>a1y*;hIkfltpUZR9iX~N6J3(D(Js>g_w)`%|9cf2zNtwA^U`jD z;JO>e4(K_{+8iOc!!8^R3sH=Tz|bfjCiIfUpWi2dl9@g_eUQU5dy2fee1Cy7TnmNe zqcJsj6$G_6K!)mh$UE!Cjp>ZUIF5hHMje&%+Q91%Z}#xjUKsQI0@OJq!`lPhpr&00 z<=X>sdC)U1pB1xRuza}%*4j6~##zCzwp#-iEcyg{ij**jIRoj*{n2r!3GUw_hg}Dk zLD@=kB!{%|ypJpXa&ZfIP8x|7@d!A@<*PqLM{{uxhF?1|c>@Nh(7Jppod;uQt<#FrSI_T@` zi5qWeV7+s=SFao!4DH+n>PE|8db1jKhnS+FM>DE^Vk#C@%IfE^QMz&&>){xAXM6Ui%#sN9MBGR;IYevJ#p!lfdS3E95Eeh4&kK&QMF&{>7SY9{h%DRUFt?s}kYZZE7@lEt>%bH1wRk*9{cn%Y1YbWsbMO&+Il6BfIX!$6 z*%Y5iR<2AZq$8EwSdc;{3{D^kDk3s$$ZR72W-{4!X)-A?nM96XnMjVl3L*Rd3?y19 zfkZLDpA=siPP`usBQD+E5YkQJh3qt0{4Izeaok(7| z6H#k&Ag5LANu|J^{A6fN_}8sSoQ}9owIo-CmLzNc0OD3-P7HUO5l>=9thLR^5s@iL zAJL!8h%zBeMn5v_Mqe@|xG&LD?nCTb4T)lrA&EO?NPfvSBs-TGl4C0jh)Jpe`L$S| z{J`pyrm|jSP;xI)Tc$@c?DR;Ok{-Fyp-UWI=n$_Q9g?7-P14?K5#0$|yNfG^&$6picH#sT01bI?-%VBX~)Tq^(mUazCq*Pr0hZe5)#XxL%cL zE>k6|7panxBvq1PtV;A7R7gmT3OTe6BB)9l3pex!I?s$J5NZ=qJ?C!uaL}Uh2*-akYwg6k+nHW zL?NZMnw`5s7Us?E0W8XxUJ{S&AHyS6qj*Hso=145JaS2sM~=zxh+|zhWAv(Xm>HgYdV?rmE!tSC*!%HlR4?s$?yhtGK)1j znF@tY=1p4%qjtH2F)i$1))aIw=jU`Vk0*67@e?|j&|B;KkPH(IQBE+ zW%rrU)BVgW`lFp$y}O+ePH$(f3}|N*jN6%CTickTn{CYKD{V~Vt~Lgf+8B$`ZA`p> z8}mZ9jZts@#Q077#KaE&#Pl8ZiOH>QWm2xUGRM!gGW^4>%-zsd#y+H#`C!?~=#;cD zGaXwPQB^aO^01jPC~9Vw9BgLt_BJ!=6PuasgPWP9y3NeCr%lY@OHE95XcN<0Topo^ zm}?GA;!CfIS$p*(Q)Tv%Ij-}O*~{|1BkNVkAY{ue6=!^(`F5EA_bh%OAB<9@VM*-`XioT2eps#Tjj)BezQ zil;bjseVvA?VmKB>QqYmrFQC{>eR0^p5kab)v1)`Qyk5Y>J(2UwNpNgOWUcP>e7<( zDUa%uMW})Taj8FQr#fw?JZhJg6i4Tq;;CI4Pur#AQ5@xc3rBerM|B#X z+NnQjDeaf?rKL1aTBmqvyR=SmG{0}fr#y=Pw|-E(v_ES9H%WQ_hLg^N>NHQvr~as& z;%L9AqI+c`1bt=Dwr*<0uTljCqmySnu>W}iMr15ARilaJ}6erE2b}FUu|K?8` zPyI+sX+Od*`!U5yw^P4VO5-S=j*m2s+9{9ro8qXXc8Zhc$w=i%+bN&sA?=6yqwQ3u zyl=^G;idV~@u)7%r@D0B)E{l9lKPR(kK(DM{=TJCJC)QQ^+^jYz-nk!Jt>zqQ(bXyRY%`MkgKm9XvXrMZ7Ur+&@-O-G9BAkF>z zn0?K?Y#_yTl;-|@G}OP!Io|X?+%IDs&UxLZzL!yOL5ab|EQZ Z9yjNF^B!H&9ON&@tqVCZOM>sg{tpEr-C_U$ literal 9571 zcma)C30RG3`#+XK;-|ya@=Y8JibdHOYmbOf$uP7ZhC4$pc?5p!$+lEM===UAK)1@rf$$c zrQ@cd#`IF@x@oE9w7K(Ddc$M2gH-yXoa~)UowV1e8aaWJ#u}Bux@pq_{QZ5V153tz1C)3p`6DPyfs%GnaMY?H_%51c%Iive~dinSVs#?_1zozXRFlWX;S+`_a z+p1c9_1OBW$2OwJw(FTk^B|SQXjQxUo;EH{divG_YbOfoQpMZlUnfK^qi-gpXRjeR zSTY^o%+3Fz#N`;1x+&_*gSLIXfe#{vW1QAu9I!--?uK`v8r?25G}{?PyXB)+6xQ1S zlV`|i>)rw^e>RcEjJ}N?F&mZd?U7bQWaHwiH*oNhSGc31J1PGjOQDfLf^6m;>^d+J z=iFA%wU81#zGFDbLVA*+voWoj5-fBsh`>0=!5q;_q<$T zh<$U?(0`1@FhOMUU6`2f`Dq)(A^bV%MuN!I##g4nB$=8Z{e z_u8{notsQ{(%z7E+z!F_cWz?b!zf|%d^>s&{Rr1?48?M|gP{Y4v)ZKV<8kwNwQrd$ z&Myf|m(Iel!-eRuU8#?#ij(X9SivqsTTZ7Z}AqJ)%%f8eH$#dx95 zaBBT3iP;+kkD@NM(b%={3MOe<(hPkfpT4!|buCh`+~7z%a~|S`ttW-a3uRRO^mjb- zY!ao(;<0-1H0&DNnLhV5qJ!6KSUr_)LrAOE86SQxqxis!IC0-l(hjkw>F*Wj;pc~t z)`!*8_rek3t@b!f_t9s%!N(Jf->tt5B`sTzWvx?`Cj&HS@}3Z^cc;Enuo)S-_Nc`MQxzaYnjP@fhdY6Xz}{P_{poh50(&skrq8#&3PdnSy_NF3de$subCM zq9>u<_A9t%`WMXlDFiRfa3fbkbE=&!(05<%u=rvO$E6{H+~>0J^5J$YlRrf1{t}Xs zj>8CHTGAM_xb%6LTkJYlT0HHD3R%3@ayBV4QFi_Pxq((+H*Bs+N#?QXbGroJP& zE(*hxvHMWBQ46X_Xs7lelNHSs=iGxXY_p>E{dQyeZVR&gL%n;e#{z7NaPHnpG~F?f z&V)LWms2&1;inu%Bfo6Ilu>a^Pl|A6_@lEn36*n?Vo>rfG;QHZ)y{V@_;5$6wr$Vs z(*0A6QF#DZ{PgN(^z@hg%GwQ$ggdcg(Ee;Q+TFAp%TdHx?3_$%LrrkWuLCh5tqd1! z)Fv;Bc2r#Z2Nq^cq8EQoAnEg5wl!EO*7!MDg-X4tE zcR#ZI=3rF?p4=#-l!!Jo;`g2`--?leXtBLLt&aGE^@I1?- zHOci-IIfuIMk{BD&uiT!T-C>6eoxToT^bgSJ0tWgZAGVx>h^!P$Ko^VK`niUuo&g1 zXEA-yVk;WH(}X@w7WXI*>aI7o!F)9N#T!qA6bt1Q+Eliz92>bFM5oJtq1Rw{=If&m zp#SfY!oEsdrc3t=(;6FAk3pL(>=v^QrR(3{bqiLX-Jv|>+Ke{#JcZtWOr)eC_gFq+ zPVKiCGrn(lEZUkrMq5b2w_8G(eY9^~&+0$lC)SDkrc8#-C(}OfJ-Gf!OYHTcqY~TS z!8E@ZbT-(DH;YV!gj!9~|GP_r_ko{EkL;D_gi>nFXJ6-no3 zL$wv!t!+=^_ovk5JRwIZ#&6c^A&$)pz&-tRXhEeBv&|jYhe}2@p@-)zNbhOgT!=Nc z-JH$rE|-VVzE)Gzf2;HFlZ%J%xw2Zqp>F?*?`uI8j<4Bz+M+y&!Kw+Y=ho}DF@ADC zaSqhi(xXgIOPl!r?gt{oW3n6GY`dz#`%d$TxX&^K>7E&dnU}f__E%mP+Nj@MCsR66 zp+UBAeS{f#h1KoRUaM!|mLJzE$GkX?lMY9sbiZnFUI@O*>Kl5*o=%KCfX{YMX1d2_ zeOhZDiGl8$G5qr$<;sXhI4}2jgZHY2GwpYD$g!N{M(m<|S2IQr-)*4wuamb3NLfgx!Lp9XZkcNelq z8&B6G)(O|%m1EG@lgy{|^U1(z6EP#PIc_^#fxdUl>2=0aJh|9c*lymIhW(nzat-}O zlaj;tV`Q5VR2dM<#*P)YAl`NO}5u zKcrP-uVM|l7u}u~WmwYws>bx)_K|Gv8urhFr2CiBD__{t<_G$_D3#Sz|M{)BbP{Vd zrU>0%G$zX_y;%Rn`PHADj$VJ7sLwS~g7;e|)b&`-xOqBu_Z_P+O>pD8`i)=?%_XqQ4tPtC+vkMO~Ty72^&vT`zC|9g8Tz+r2~4 zz(_%jjETBr)tw)9?zw}O!;Y$J^`QEA)6gPxIB7_wi!^A&pn+`e>hGWHY-`Y?Dq(PC zOQeS{lhM4Tc(#zqT5qzUMbA6q%fk;?|0A3eQ9o{wFx2y=a5v8ktL-ePW#1imWX~Fg zw>b3*5B5(%-_u9%=AbRYtg?5w=SeD#bKHfIm$c}7yqt_OCeptCQFzzC?kv!@f42tf zp~ELLY&yIYS6cjy`Yrm?ktg;vY)doNhqte7NP6C0VNr_DZv#3Fv8KE(b!SHJB|ou# zO}}VG*GF7q`Ppr4PO0YMSXiXO0f)9@h01|48<*qV*w(Bkn!`p=%JlcbumfZ8fCFGo z2Rk~}ITH8S{Xp=zfcfh_SC{3{bj@X4?6N<=eyou7+$um>4DGiJHO%{Yf2 z!afUMuBrFbYcEk}fHCRqGp62S22y`B_5WhjS7Goyd@)y!wm%7^H=9^Kzn1L9n<0)$ z(ca;aCuUoRp;4_7z8G&Tal$M2}<}a%_=+O9Fe-AN#&zc%MmP zFFmOKeso>_7ry@M1wIQgrtM%!;=j-O=(~Gln#g40{onq5rv2@|&w3A0M`=2>y~4bx zrJ~*PmWl^Gnk(X6O%)4AnkuG#F;SGtjTAPyhKk66hKj6}a>a;?O%=I&n<`xOn<}na zH&Jxy)mX95(LiC?%0Q9Ot&!q|rJkaPTu<@tq^=^xR9B%;-%~1e6br9vE4rp>E1u8R zR`finrHFZ_-}ZBycYZp)IzrvwJ@cBEmRl3g?*>qLi=Cd z!iQ;Z!IR#=^iQwhhQll9t9S)-uGD~`ehq9JPz_{%geBK=Vhqekq7bf^1$094@Pgg1VKYD zfsyeg_)wV(PSLsGemEDV&d-G)zg#%?{34_dxd=+ji=fiF2v=WSfbD4);M$rC5H|P% z94a{vc}3^JbNP7)v_20vJD-PsH_kzF);Uqgns=K;gMA$EN_+wE=?2Rvt}Y3D@cHt!wK;7+63q~ zHvu}jB*6P-31DZC01CYXp!f06#^!@ZN&S(?RTc%?sv9B(5S)J1x>}xU44mP#)uc!`4>V` zo#VKj>r%<>od3VtrFf|yZs+4tzg*{bspNc){>DW1oZ zj*H(^i5KT_$#GKs-{NyTmmJ6a^Kq#^KF;x6mr8Dz`V+qm6ffyGkHdLfm-_uTo)pjh zNTn3V{d2z@&&Q>Dv%2*ESqB`){cKgD?MPyhe` From ab71f1d8df8ed699788957b747638a90531c7259 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 19 Oct 2021 19:41:10 +0100 Subject: [PATCH 7/7] Adding full model weights for the VGG16 features. --- torchvision/models/detection/ssd.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index 1af8c45936b..5a068a0f0cc 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -23,7 +23,8 @@ backbone_urls = { # We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the # same input standardization method as the paper. Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth - "vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot.pth" + # Only the `features` weights have proper values, those on the `classifier` module are filled with nans. + "vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth" } @@ -519,7 +520,8 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return OrderedDict([(str(i), v) for i, v in enumerate(output)]) -def _vgg_extractor(backbone: nn.Module, highres: bool, trainable_layers: int): +def _vgg_extractor(backbone: vgg.VGG, highres: bool, trainable_layers: int): + backbone = backbone.features # Gather the indices of maxpools. These are the locations of output blocks. stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1] num_stages = len(stage_indices) @@ -599,7 +601,7 @@ def ssd300_vgg16( pretrained_backbone = False # Use custom backbones more appropriate for SSD - backbone = vgg.vgg16(pretrained=False, progress=progress).features # TODO: pass a full vgg instead + backbone = vgg.vgg16(pretrained=False, progress=progress) if pretrained_backbone: state_dict = load_state_dict_from_url(backbone_urls["vgg16_features"], progress=progress) backbone.load_state_dict(state_dict)