From 2f1f57868d042fc1a66e244f8dc95675967c5515 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 3 May 2021 09:56:31 +0100 Subject: [PATCH 01/14] Add experimental resnet50 backbone. --- ...odelTester.test_ssd512_resnet50_expect.pkl | Bin 0 -> 6925 bytes test/test_models.py | 1 + torchvision/models/detection/ssd.py | 116 +++++++++++++++++- 3 files changed, 115 insertions(+), 2 deletions(-) create mode 100644 test/expect/ModelTester.test_ssd512_resnet50_expect.pkl diff --git a/test/expect/ModelTester.test_ssd512_resnet50_expect.pkl b/test/expect/ModelTester.test_ssd512_resnet50_expect.pkl new file mode 100644 index 0000000000000000000000000000000000000000..da8cf57765245377de2061102abb0b99d3f530a0 GIT binary patch literal 6925 zcmeI1dsI_L8o)1vM+~nN5P2zril87*At(s5k-8PAAgG8a zqzHFClu1+q#>h2-F;+1q5sYc10X9IA zqF_xnOJprd)Jq&NqQAG7NF-u}iH72474t^8c$nBSmoXO$av6*KM6AnF#SD*NtcWgo z+!Co=$&BE1_0DKkN^-*U30V)AkiW0Lr!Ql}skiM_Z->>}7ZCL$Rm`Xe#vv_kQkd9~ zzwf})g)$wA!a3}v3ujN6^jkQOJ_nNVLSb0nBHFV)z-bbP*M_5d?ftcU-RAeI_WoLb z49Rf;^n39Gemmo&;qPVMuj@QD@8$m%@n($EZzbNA0uP0}`WETS9mPrY;c8v$$)G~PYzCXldZiCZd3@p#9ggrB}A-XIV_+ou#1w5ypBq)pR>J3MRf(MzH)g-n8SyBs>`(*WS#%nn)2e+#|}ZVG8Do(Xtn z&LtBGbGFcgQR58=k9&@Ui3WcUIUV3h|EuL*NNA=mL|CptxKHe?e3ohybay$D73V?J z`yZf|cd%;WNyMZp`$Zd~Q5*H$vQg-G3u#*YML#Pf0&a$$N1`?vu7e z!hLT|CluaTL^!`I0v#D!2?^!%kl1z~Sa!}v_j=aD$F+0O{<9?zT}30O*?ACbh>-Qo zb&$5%8!1*6g2LYi@fMYVj;)zDdgKD-3>O5Hi5A)iw$#I3NV z=|`yRQo`!|4zTGchVqMjL?_HV#Fei;Z{%V2TMTBy9e z1-_6ML1OW8h>NZQtI|wJNUw+7eH&o?;~jA7d=Yqlegry?`oU+H^T2eEKRlXJ06rJy zLGZ%Oa4&fwG_+R0*)M~Dw{#6S9iyRpbTJ6|9rT>bYoW`zo4z);7#?rWq@6#m1S836 zdgF{;5VXRGmZxuk{W(sw$@Wqx__sT4dAR~kioNM5;~ikT`uz~tS_JL46(NYJfQ%nT z&|~B@eAqpPj;PB5+sYeo!Z89)WwpZjb|pAGD<$iM*BNIk3oG=`vyQMU$qJn)^Mbq4 zmWW@lnJ(m6A)^U9X&j3hS)1GCnb6~mn%{f~fBMb`6|d1nHh2C;o}IRl zIEE=|S>poa_uV51k6Mb(|Da3yX~xF9;FyivBdig%Cl9rhI}U0O?I=MDvU*5-C(bh( z)XvR}^A*a5P-!*w}o-VPsj!=#EnFJ7@eo3fCriF#1 z{ESeI?zp*!;4!0oFtHx(G*s61#{a@Um%bQD*5s!GE|wYEK zS#MDnrW!>E!{7vSh@K)+0V8pR5%cGh_9=v1ob0^8#d|f&2Or8vdpDsP<+^5(az3FN zjr15s$}WUzlyTaYlvfd|QGU0Q#8U+!+rCpk%IgT#XnW&Fq})uXMmxWH$o5wwHDY>W zR@=iv@n-kual|i%bMU@A*dLGfF6gc2)>7MhFOv7$mi3$DU#w3v){A*+_+hf9&E@Av zY;bnGST?C76FFxpq1Ofxt*B%5wQo8aS9-EW4fXVrycS`HP()FPevf} z%Uw$+2xCe81VI9N5VJ0*HE|=lSFMJMcS@00eKoPSF)V}VYhqb+WjQ+bSOVXy%0a6| zIi$Ti-75&sb-+tex;U+UftP{PMZ*eH-BNrEmz0#lL4MN_S5)*dMK> zryOX8`tU7u$x0_!7pxD)xlW}0?K48+zjKd-w8yv%i1C>Z^JmWi>~d8ZKrXgHwbJN46h1)vB@O%D6d9tdDPh*yB)r3d?h{{{hd% B#lrvq literal 0 HcmV?d00001 diff --git a/test/test_models.py b/test/test_models.py index c23087a05e9..ab62720f8f8 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -45,6 +45,7 @@ def get_available_video_models(): "keypointrcnn_resnet50_fpn": lambda x: x[1], "retinanet_resnet50_fpn": lambda x: x[1], "ssd300_vgg16": lambda x: x[1], + "ssd512_resnet50": lambda x: x[1], } diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index fcb79d6e651..b0cc9fe4705 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -10,14 +10,15 @@ from .anchor_utils import DefaultBoxGenerator from .backbone_utils import _validate_trainable_layers from .transform import GeneralizedRCNNTransform -from .. import vgg +from .. import vgg, resnet from ..utils import load_state_dict_from_url from ...ops import boxes as box_ops -__all__ = ['SSD', 'ssd300_vgg16'] +__all__ = ['SSD', 'ssd300_vgg16', 'ssd512_resnet50'] model_urls = { 'ssd300_vgg16_coco': 'https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth', + 'ssd512_resnet50_coco': None, # TODO: add weights } backbone_urls = { @@ -562,3 +563,114 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) model.load_state_dict(state_dict) return model + + +class SSDFeatureExtractorResNet(nn.Module): + def __init__(self, backbone: resnet.ResNet): + super().__init__() + + self.features = nn.Sequential( + backbone.conv1, + backbone.bn1, + backbone.relu, + backbone.maxpool, + backbone.layer1, + backbone.layer2, + backbone.layer3, + backbone.layer4, + ) + + # Patch last block's strides to get valid output sizes + for m in self.features[-1][0].modules(): + if hasattr(m, 'stride'): + m.stride = 1 + + backbone_out_channels = self.features[-1][-1].bn3.num_features + extra = nn.ModuleList([ + nn.Sequential( + nn.Conv2d(backbone_out_channels, 256, kernel_size=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + ), + nn.Sequential( + nn.Conv2d(512, 256, kernel_size=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + ), + nn.Sequential( + nn.Conv2d(512, 128, kernel_size=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + ), + nn.Sequential( + nn.Conv2d(256, 128, kernel_size=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + ), + nn.Sequential( + nn.Conv2d(256, 128, kernel_size=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=2, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + ) + ]) + _xavier_init(extra) + self.extra = extra + + def forward(self, x: Tensor) -> Dict[str, Tensor]: + x = self.features(x) + output = [x] + + for block in self.extra: + x = block(x) + output.append(x) + + return OrderedDict([(str(i), v) for i, v in enumerate(output)]) + + +def _resnet_extractor(backbone_name: str, pretrained: bool, trainable_layers: int): + backbone = resnet.__dict__[backbone_name](pretrained=pretrained) + + assert 0 <= trainable_layers <= 5 + layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers] + if trainable_layers == 5: + layers_to_train.append('bn1') + for name, parameter in backbone.named_parameters(): + if all([not name.startswith(layer) for layer in layers_to_train]): + parameter.requires_grad_(False) + + return SSDFeatureExtractorResNet(backbone) + + +def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes: int = 91, + pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any): + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5) + + if pretrained: + pretrained_backbone = False + + backbone = _resnet_extractor("resnet50", pretrained_backbone, trainable_backbone_layers) + anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]]) + model = SSD(backbone, anchor_generator, (512, 512), num_classes, **kwargs) + if pretrained: + weights_name = 'ssd512_resnet50_coco' + if model_urls.get(weights_name, None) is None: + raise ValueError("No checkpoint is available for model {}".format(weights_name)) + state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) + model.load_state_dict(state_dict) + return model From 0c17b0a89cd81d64d2bfbce8b854683b8272ad40 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 3 May 2021 19:57:22 +0100 Subject: [PATCH 02/14] Passing custom scales (necessary after master merge). --- torchvision/models/detection/ssd.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index 95110d47bfd..d8fd1cc1fd6 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -667,7 +667,8 @@ def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes pretrained_backbone = False backbone = _resnet_extractor("resnet50", pretrained_backbone, trainable_backbone_layers) - anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]]) + anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]], + scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05]) model = SSD(backbone, anchor_generator, (512, 512), num_classes, **kwargs) if pretrained: weights_name = 'ssd512_resnet50_coco' From b6406807fbd23b7d457f04db176a0545740c4eb5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 3 May 2021 20:57:04 +0100 Subject: [PATCH 03/14] Add experimental FPN-style resnet50 backbone. --- torchvision/models/detection/ssd.py | 62 +++++++---------------------- 1 file changed, 15 insertions(+), 47 deletions(-) diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index d8fd1cc1fd6..5080f4b16a1 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -572,70 +572,39 @@ def __init__(self, backbone: resnet.ResNet): super().__init__() self.features = nn.Sequential( - backbone.conv1, - backbone.bn1, - backbone.relu, - backbone.maxpool, - backbone.layer1, - backbone.layer2, + nn.Sequential( + backbone.conv1, + backbone.bn1, + backbone.relu, + backbone.maxpool, + backbone.layer1, + backbone.layer2, + ), backbone.layer3, backbone.layer4, ) - # Patch last block's strides to get valid output sizes - for m in self.features[-1][0].modules(): - if hasattr(m, 'stride'): - m.stride = 1 - backbone_out_channels = self.features[-1][-1].bn3.num_features extra = nn.ModuleList([ nn.Sequential( - nn.Conv2d(backbone_out_channels, 256, kernel_size=1, bias=False), - nn.BatchNorm2d(256), - nn.ReLU(inplace=True), - nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False), - nn.BatchNorm2d(512), - nn.ReLU(inplace=True), - ), - nn.Sequential( - nn.Conv2d(512, 256, kernel_size=1, bias=False), - nn.BatchNorm2d(256), - nn.ReLU(inplace=True), - nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False), - nn.BatchNorm2d(512), - nn.ReLU(inplace=True), - ), - nn.Sequential( - nn.Conv2d(512, 128, kernel_size=1, bias=False), - nn.BatchNorm2d(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2, bias=False), + nn.Conv2d(backbone_out_channels, 256, kernel_size=3, padding=1, stride=2, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ), nn.Sequential( - nn.Conv2d(256, 128, kernel_size=1, bias=False), - nn.BatchNorm2d(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 256, kernel_size=3, bias=False), + nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ), - nn.Sequential( - nn.Conv2d(256, 128, kernel_size=1, bias=False), - nn.BatchNorm2d(128), - nn.ReLU(inplace=True), - nn.Conv2d(128, 256, kernel_size=2, bias=False), - nn.BatchNorm2d(256), - nn.ReLU(inplace=True), - ) ]) _xavier_init(extra) self.extra = extra def forward(self, x: Tensor) -> Dict[str, Tensor]: - x = self.features(x) - output = [x] + output = [] + for block in self.features: + x = block(x) + output.append(x) for block in self.extra: x = block(x) @@ -667,8 +636,7 @@ def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes pretrained_backbone = False backbone = _resnet_extractor("resnet50", pretrained_backbone, trainable_backbone_layers) - anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]], - scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05]) + anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2]], min_ratio=0.04) model = SSD(backbone, anchor_generator, (512, 512), num_classes, **kwargs) if pretrained: weights_name = 'ssd512_resnet50_coco' From 36163dc93957d3aa908e69253ce18e40d7b064bf Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Mon, 3 May 2021 21:40:28 +0100 Subject: [PATCH 04/14] Add experimental VGG-style resnet50 backbone. --- torchvision/models/detection/ssd.py | 30 ++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index 5080f4b16a1..d1aec2fa39d 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -587,12 +587,34 @@ def __init__(self, backbone: resnet.ResNet): backbone_out_channels = self.features[-1][-1].bn3.num_features extra = nn.ModuleList([ nn.Sequential( - nn.Conv2d(backbone_out_channels, 256, kernel_size=3, padding=1, stride=2, bias=False), + nn.Conv2d(backbone_out_channels, 256, kernel_size=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + ), + nn.Sequential( + nn.Conv2d(512, 128, kernel_size=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3, padding=1, stride=2, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ), nn.Sequential( - nn.Conv2d(256, 256, kernel_size=3, padding=1, stride=2, bias=False), + nn.Conv2d(256, 128, kernel_size=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=3, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + ), + nn.Sequential( + nn.Conv2d(256, 128, kernel_size=1, bias=False), + nn.BatchNorm2d(128), + nn.ReLU(inplace=True), + nn.Conv2d(128, 256, kernel_size=2, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), ), @@ -636,7 +658,9 @@ def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes pretrained_backbone = False backbone = _resnet_extractor("resnet50", pretrained_backbone, trainable_backbone_layers) - anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2]], min_ratio=0.04) + anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2], [2]], + scales=[0.04, 0.1, 0.26, 0.42, 0.58, 0.74, 0.9, 1.06], + steps=[8, 16, 32, 64, 128, 256, 512]) model = SSD(backbone, anchor_generator, (512, 512), num_classes, **kwargs) if pretrained: weights_name = 'ssd512_resnet50_coco' From 2c0f46d86488082fe0f892e4d5298063697e6de5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 4 May 2021 11:28:07 +0100 Subject: [PATCH 05/14] Add a highres option to support both the 300 and 512 versions. --- torchvision/models/detection/ssd.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index d1aec2fa39d..86d044b6091 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -568,7 +568,7 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i class SSDFeatureExtractorResNet(nn.Module): - def __init__(self, backbone: resnet.ResNet): + def __init__(self, backbone: resnet.ResNet, highres: bool): super().__init__() self.features = nn.Sequential( @@ -610,15 +610,16 @@ def __init__(self, backbone: resnet.ResNet): nn.BatchNorm2d(256), nn.ReLU(inplace=True), ), - nn.Sequential( + ]) + if highres: + extra.append(nn.Sequential( nn.Conv2d(256, 128, kernel_size=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 256, kernel_size=2, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), - ), - ]) + )) _xavier_init(extra) self.extra = extra @@ -635,8 +636,8 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return OrderedDict([(str(i), v) for i, v in enumerate(output)]) -def _resnet_extractor(backbone_name: str, pretrained: bool, trainable_layers: int): - backbone = resnet.__dict__[backbone_name](pretrained=pretrained) +def _resnet_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int): + backbone = resnet.__dict__[backbone_name](pretrained=pretrained, progress=progress) assert 0 <= trainable_layers <= 5 layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers] @@ -646,7 +647,7 @@ def _resnet_extractor(backbone_name: str, pretrained: bool, trainable_layers: in if all([not name.startswith(layer) for layer in layers_to_train]): parameter.requires_grad_(False) - return SSDFeatureExtractorResNet(backbone) + return SSDFeatureExtractorResNet(backbone, highres) def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes: int = 91, @@ -657,7 +658,7 @@ def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes if pretrained: pretrained_backbone = False - backbone = _resnet_extractor("resnet50", pretrained_backbone, trainable_backbone_layers) + backbone = _resnet_extractor("resnet50", True, progress, pretrained_backbone, trainable_backbone_layers) anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2], [2]], scales=[0.04, 0.1, 0.26, 0.42, 0.58, 0.74, 0.9, 1.06], steps=[8, 16, 32, 64, 128, 256, 512]) From eef01bc190a7224555c297a17f5350d80cb55564 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 6 May 2021 10:18:56 +0100 Subject: [PATCH 06/14] Select best performing prototype. This reverts commits b6406807fbd23b7d457f04db176a0545740c4eb5, 36163dc93957d3aa908e69253ce18e40d7b064bf and 2c0f46d86488082fe0f892e4d5298063697e6de5. --- torchvision/models/detection/ssd.py | 55 ++++++++++++++++------------- 1 file changed, 31 insertions(+), 24 deletions(-) diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index 86d044b6091..d8fd1cc1fd6 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -568,22 +568,25 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i class SSDFeatureExtractorResNet(nn.Module): - def __init__(self, backbone: resnet.ResNet, highres: bool): + def __init__(self, backbone: resnet.ResNet): super().__init__() self.features = nn.Sequential( - nn.Sequential( - backbone.conv1, - backbone.bn1, - backbone.relu, - backbone.maxpool, - backbone.layer1, - backbone.layer2, - ), + backbone.conv1, + backbone.bn1, + backbone.relu, + backbone.maxpool, + backbone.layer1, + backbone.layer2, backbone.layer3, backbone.layer4, ) + # Patch last block's strides to get valid output sizes + for m in self.features[-1][0].modules(): + if hasattr(m, 'stride'): + m.stride = 1 + backbone_out_channels = self.features[-1][-1].bn3.num_features extra = nn.ModuleList([ nn.Sequential( @@ -594,6 +597,14 @@ def __init__(self, backbone: resnet.ResNet, highres: bool): nn.BatchNorm2d(512), nn.ReLU(inplace=True), ), + nn.Sequential( + nn.Conv2d(512, 256, kernel_size=1, bias=False), + nn.BatchNorm2d(256), + nn.ReLU(inplace=True), + nn.Conv2d(256, 512, kernel_size=3, padding=1, stride=2, bias=False), + nn.BatchNorm2d(512), + nn.ReLU(inplace=True), + ), nn.Sequential( nn.Conv2d(512, 128, kernel_size=1, bias=False), nn.BatchNorm2d(128), @@ -610,24 +621,21 @@ def __init__(self, backbone: resnet.ResNet, highres: bool): nn.BatchNorm2d(256), nn.ReLU(inplace=True), ), - ]) - if highres: - extra.append(nn.Sequential( + nn.Sequential( nn.Conv2d(256, 128, kernel_size=1, bias=False), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 256, kernel_size=2, bias=False), nn.BatchNorm2d(256), nn.ReLU(inplace=True), - )) + ) + ]) _xavier_init(extra) self.extra = extra def forward(self, x: Tensor) -> Dict[str, Tensor]: - output = [] - for block in self.features: - x = block(x) - output.append(x) + x = self.features(x) + output = [x] for block in self.extra: x = block(x) @@ -636,8 +644,8 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return OrderedDict([(str(i), v) for i, v in enumerate(output)]) -def _resnet_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int): - backbone = resnet.__dict__[backbone_name](pretrained=pretrained, progress=progress) +def _resnet_extractor(backbone_name: str, pretrained: bool, trainable_layers: int): + backbone = resnet.__dict__[backbone_name](pretrained=pretrained) assert 0 <= trainable_layers <= 5 layers_to_train = ['layer4', 'layer3', 'layer2', 'layer1', 'conv1'][:trainable_layers] @@ -647,7 +655,7 @@ def _resnet_extractor(backbone_name: str, highres: bool, progress: bool, pretrai if all([not name.startswith(layer) for layer in layers_to_train]): parameter.requires_grad_(False) - return SSDFeatureExtractorResNet(backbone, highres) + return SSDFeatureExtractorResNet(backbone) def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes: int = 91, @@ -658,10 +666,9 @@ def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes if pretrained: pretrained_backbone = False - backbone = _resnet_extractor("resnet50", True, progress, pretrained_backbone, trainable_backbone_layers) - anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2], [2]], - scales=[0.04, 0.1, 0.26, 0.42, 0.58, 0.74, 0.9, 1.06], - steps=[8, 16, 32, 64, 128, 256, 512]) + backbone = _resnet_extractor("resnet50", pretrained_backbone, trainable_backbone_layers) + anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]], + scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05]) model = SSD(backbone, anchor_generator, (512, 512), num_classes, **kwargs) if pretrained: weights_name = 'ssd512_resnet50_coco' From 9cf7c5d9389f2987f94e948e6742f54b09d106ec Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 6 May 2021 10:40:50 +0100 Subject: [PATCH 07/14] Adding documentation. --- torchvision/models/detection/ssd.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index d8fd1cc1fd6..ce305f05396 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -528,13 +528,13 @@ def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: int = 91, pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any): """ - Constructs an SSD model with a VGG16 backbone. See `SSD` for more details. + Constructs an SSD model with input size 300x300 and a VGG16 backbone. See `SSD` for more details. Example: >>> model = torchvision.models.detection.ssd300_vgg16(pretrained=True) >>> model.eval() - >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> x = [torch.rand(3, 300, 300), torch.rand(3, 500, 400)] >>> predictions = model(x) Args: @@ -660,6 +660,24 @@ def _resnet_extractor(backbone_name: str, pretrained: bool, trainable_layers: in def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes: int = 91, pretrained_backbone: bool = True, trainable_backbone_layers: Optional[int] = None, **kwargs: Any): + """ + Constructs an SSD model with input size 512x512 and a ResNet50 backbone. See `SSD` for more details. + + Example: + + >>> model = torchvision.models.detection.ssd300_vgg16(pretrained=True) + >>> model.eval() + >>> x = [torch.rand(3, 512, 512), torch.rand(3, 750, 600)] + >>> predictions = model(x) + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. + """ trainable_backbone_layers = _validate_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5) From d419eeabf63c80c25fdc03d36630bcdf90a93a6e Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Thu, 6 May 2021 11:52:08 +0100 Subject: [PATCH 08/14] Adding weights. --- docs/source/models.rst | 3 +++ references/detection/README.md | 8 ++++++++ torchvision/models/detection/ssd.py | 4 ++-- 3 files changed, 13 insertions(+), 2 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index c70cd07979f..cad727236f2 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -427,6 +427,7 @@ Faster R-CNN MobileNetV3-Large FPN 32.8 - - Faster R-CNN MobileNetV3-Large 320 FPN 22.8 - - RetinaNet ResNet-50 FPN 36.4 - - SSD VGG16 25.1 - - +SSD ResNet-50 30.2 - - Mask R-CNN ResNet-50 FPN 37.9 34.6 - ====================================== ======= ======== =========== @@ -486,6 +487,7 @@ Faster R-CNN MobileNetV3-Large FPN 0.1020 0.0415 Faster R-CNN MobileNetV3-Large 320 FPN 0.0978 0.0376 0.6 RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1 SSD VGG16 0.2093 0.0744 1.5 +SSD ResNet-50 0.2316 0.0772 3.0 Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4 Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8 ====================================== =================== ================== =========== @@ -509,6 +511,7 @@ SSD ------------ .. autofunction:: torchvision.models.detection.ssd300_vgg16 +.. autofunction:: torchvision.models.detection.ssd512_resnet50 Mask R-CNN diff --git a/references/detection/README.md b/references/detection/README.md index e4d52869d35..064fb91287d 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -56,6 +56,14 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ --weight-decay 0.0005 --data-augmentation ssd ``` +### SSD ResNet-50 +``` +python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ + --dataset coco --model ssd512_resnet50 --epochs 120\ + --lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\ + --weight-decay 0.0005 --data-augmentation ssd +``` + ### Mask R-CNN ``` diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index ce305f05396..36da06a8bc5 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -18,7 +18,7 @@ model_urls = { 'ssd300_vgg16_coco': 'https://download.pytorch.org/models/ssd300_vgg16_coco-b556d3b4.pth', - 'ssd512_resnet50_coco': None, # TODO: add weights + 'ssd512_resnet50_coco': 'https://download.pytorch.org/models/ssd512_resnet50_coco-d6d7edbb.pth', } backbone_urls = { @@ -665,7 +665,7 @@ def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes Example: - >>> model = torchvision.models.detection.ssd300_vgg16(pretrained=True) + >>> model = torchvision.models.detection.ssd512_resnet50(pretrained=True) >>> model.eval() >>> x = [torch.rand(3, 512, 512), torch.rand(3, 750, 600)] >>> predictions = model(x) From ea1e2c4b6f4f2b0e61fbd6035ebdb5b00e6634f9 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 11 May 2021 11:36:38 +0100 Subject: [PATCH 09/14] Fix not implemented for half exception --- torchvision/models/detection/anchor_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchvision/models/detection/anchor_utils.py b/torchvision/models/detection/anchor_utils.py index 3e0740036c8..06ecc551442 100644 --- a/torchvision/models/detection/anchor_utils.py +++ b/torchvision/models/detection/anchor_utils.py @@ -206,8 +206,8 @@ def _grid_default_boxes(self, grid_sizes: List[List[int]], image_size: List[int] else: y_f_k, x_f_k = f_k - shifts_x = (torch.arange(0, f_k[1], dtype=dtype) + 0.5) / x_f_k - shifts_y = (torch.arange(0, f_k[0], dtype=dtype) + 0.5) / y_f_k + shifts_x = ((torch.arange(0, f_k[1]) + 0.5) / x_f_k).to(dtype=dtype) + shifts_y = ((torch.arange(0, f_k[0]) + 0.5) / y_f_k).to(dtype=dtype) shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) From 777126d56154ccfd87090ab0b92b195977f0fbb5 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 11 May 2021 16:04:34 +0100 Subject: [PATCH 10/14] Apply recommendations from code review. --- docs/source/models.rst | 6 +++--- torchvision/models/detection/ssd.py | 15 +++++++++++++-- torchvision/models/detection/ssdlite.py | 4 ++++ 3 files changed, 20 insertions(+), 5 deletions(-) diff --git a/docs/source/models.rst b/docs/source/models.rst index 0cfa96e53c0..2ca354ff298 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -504,20 +504,20 @@ Faster R-CNN RetinaNet ------------- +--------- .. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn SSD ------------- +--- .. autofunction:: torchvision.models.detection.ssd300_vgg16 .. autofunction:: torchvision.models.detection.ssd512_resnet50 SSDlite ------------- +------- .. autofunction:: torchvision.models.detection.ssdlite320_mobilenet_v3_large diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index 36da06a8bc5..bccfc126460 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -545,6 +545,9 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ + if "size" in kwargs: + warnings.warn("The size of the model is already fixed; ignoring the argument.") + trainable_backbone_layers = _validate_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5) @@ -556,8 +559,13 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]], scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], steps=[8, 16, 32, 64, 100, 300]) - model = SSD(backbone, anchor_generator, (300, 300), num_classes, - image_mean=[0.48235, 0.45882, 0.40784], image_std=[1., 1., 1.], **kwargs) + + defaults = { + "image_mean": [0.48235, 0.45882, 0.40784], + "image_std": [1., 1., 1.], + } + kwargs = {**defaults, **kwargs} + model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) if pretrained: weights_name = 'ssd300_vgg16_coco' if model_urls.get(weights_name, None) is None: @@ -678,6 +686,9 @@ def ssd512_resnet50(pretrained: bool = False, progress: bool = True, num_classes trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ + if "size" in kwargs: + warnings.warn("The size of the model is already fixed; ignoring the argument.") + trainable_backbone_layers = _validate_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 5, 5) diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 412434dabd7..0a7a48e7cb2 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -1,4 +1,5 @@ import torch +import warnings from collections import OrderedDict from functools import partial @@ -186,6 +187,9 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. norm_layer (callable, optional): Module specifying the normalization layer to use. """ + if "size" in kwargs: + warnings.warn("The size of the model is already fixed; ignoring the argument.") + trainable_backbone_layers = _validate_trainable_layers( pretrained or pretrained_backbone, trainable_backbone_layers, 6, 6) From 87d0153947ed4eb4adc1782066666ea978df21bc Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 11 May 2021 16:12:50 +0100 Subject: [PATCH 11/14] Updating docs. --- torchvision/models/detection/ssdlite.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 0a7a48e7cb2..79d39f897bd 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -167,7 +167,7 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru norm_layer: Optional[Callable[..., nn.Module]] = None, **kwargs: Any): """ - Constructs an SSDlite model with a MobileNetV3 Large backbone. See `SSD` for more details. + Constructs an SSDlite model with input size 320x320 and a MobileNetV3 Large backbone. See `SSD` for more details. Example: From 8b2715dae6da9786ba397ed9b0c85a34d2f3ec7d Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 11 May 2021 16:36:03 +0100 Subject: [PATCH 12/14] Change the way we rescale to [-1, 1] --- torchvision/models/detection/ssdlite.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index 79d39f897bd..db1fd496cc3 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -95,8 +95,7 @@ def __init__(self, in_channels: List[int], num_anchors: List[int], norm_layer: C class SSDLiteFeatureExtractorMobileNet(nn.Module): - def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], rescaling: bool, - **kwargs: Any): + def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., nn.Module], **kwargs: Any): super().__init__() # non-public config parameters min_depth = kwargs.pop('_min_depth', 16) @@ -118,13 +117,8 @@ def __init__(self, backbone: nn.Module, c4_pos: int, norm_layer: Callable[..., n _normal_init(extra) self.extra = extra - self.rescaling = rescaling def forward(self, x: Tensor) -> Dict[str, Tensor]: - # Rescale from [0, 1] to [-1, -1] - if self.rescaling: - x = 2.0 * x - 1.0 - # Get feature maps from backbone and extra. Can't be refactored due to JIT limitations. output = [] for block in self.features: @@ -139,7 +133,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, trainable_layers: int, - norm_layer: Callable[..., nn.Module], rescaling: bool, **kwargs: Any): + norm_layer: Callable[..., nn.Module], **kwargs: Any): backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, progress=progress, norm_layer=norm_layer, **kwargs).features if not pretrained: @@ -159,7 +153,7 @@ def _mobilenet_extractor(backbone_name: str, progress: bool, pretrained: bool, t for parameter in b.parameters(): parameter.requires_grad_(False) - return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, rescaling, **kwargs) + return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs) def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = True, num_classes: int = 91, @@ -196,14 +190,14 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru if pretrained: pretrained_backbone = False - # Enable [-1, 1] rescaling and reduced tail if no pretrained backbone is selected - rescaling = reduce_tail = not pretrained_backbone + # Enable reduced tail if no pretrained backbone is selected + reduce_tail = not pretrained_backbone if norm_layer is None: norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03) backbone = _mobilenet_extractor("mobilenet_v3_large", progress, pretrained_backbone, trainable_backbone_layers, - norm_layer, rescaling, _reduced_tail=reduce_tail, _width_mult=1.0) + norm_layer, _reduced_tail=reduce_tail, _width_mult=1.0) size = (320, 320) anchor_generator = DefaultBoxGenerator([[2, 3] for _ in range(6)], min_ratio=0.2, max_ratio=0.95) @@ -216,8 +210,9 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru "nms_thresh": 0.55, "detections_per_img": 300, "topk_candidates": 300, - "image_mean": [0., 0., 0.], - "image_std": [1., 1., 1.], + # The following mean/std rescale the data from [0, 1] to [-1, -1] + "image_mean": [0.5, 0.5, 0.5], + "image_std": [0.5, 0.5, 0.5], } kwargs = {**defaults, **kwargs} model = SSD(backbone, anchor_generator, size, num_classes, From d08fc10e7301d0c75120f111a8f26967466bb6e8 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 11 May 2021 16:46:01 +0100 Subject: [PATCH 13/14] Change the way we rescale input on SSD300+VGG16 --- torchvision/models/detection/ssd.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/torchvision/models/detection/ssd.py b/torchvision/models/detection/ssd.py index bccfc126460..0538be1672b 100644 --- a/torchvision/models/detection/ssd.py +++ b/torchvision/models/detection/ssd.py @@ -411,7 +411,7 @@ def postprocess_detections(self, head_outputs: Dict[str, Tensor], image_anchors: class SSDFeatureExtractorVGG(nn.Module): - def __init__(self, backbone: nn.Module, highres: bool, rescaling: bool): + def __init__(self, backbone: nn.Module, highres: bool): super().__init__() _, _, maxpool3_pos, maxpool4_pos, _ = (i for i, layer in enumerate(backbone) if isinstance(layer, nn.MaxPool2d)) @@ -477,13 +477,8 @@ def __init__(self, backbone: nn.Module, highres: bool, rescaling: bool): fc, )) self.extra = extra - self.rescaling = rescaling def forward(self, x: Tensor) -> Dict[str, Tensor]: - # Undo the 0-1 scaling of toTensor. Necessary for some backbones. - if self.rescaling: - x *= 255 - # L2 regularization + Rescaling of 1st block's feature map x = self.features(x) rescaled = self.scale_weight.view(1, -1, 1, 1) * F.normalize(x) @@ -497,8 +492,7 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]: return OrderedDict([(str(i), v) for i, v in enumerate(output)]) -def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int, - rescaling: bool): +def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int): if backbone_name in backbone_urls: # Use custom backbones more appropriate for SSD arch = backbone_name.split('_')[0] @@ -522,7 +516,7 @@ def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained for parameter in b.parameters(): parameter.requires_grad_(False) - return SSDFeatureExtractorVGG(backbone, highres, rescaling) + return SSDFeatureExtractorVGG(backbone, highres) def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: int = 91, @@ -555,14 +549,15 @@ def ssd300_vgg16(pretrained: bool = False, progress: bool = True, num_classes: i # no need to download the backbone if pretrained is set pretrained_backbone = False - backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers, True) + backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers) anchor_generator = DefaultBoxGenerator([[2], [2, 3], [2, 3], [2, 3], [2], [2]], scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05], steps=[8, 16, 32, 64, 100, 300]) defaults = { + # Rescale the input in a way compatible to the backbone "image_mean": [0.48235, 0.45882, 0.40784], - "image_std": [1., 1., 1.], + "image_std": [1.0 / 255.0, 1.0 / 255.0, 1.0 / 255.0], # undo the 0-1 scaling of toTensor } kwargs = {**defaults, **kwargs} model = SSD(backbone, anchor_generator, (300, 300), num_classes, **kwargs) From 61ae2927c949e360847dfd048bb4b7a3df071953 Mon Sep 17 00:00:00 2001 From: Vasilis Vryniotis Date: Tue, 11 May 2021 16:59:37 +0100 Subject: [PATCH 14/14] Add comment. --- torchvision/models/detection/ssdlite.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torchvision/models/detection/ssdlite.py b/torchvision/models/detection/ssdlite.py index db1fd496cc3..8498a78d6dd 100644 --- a/torchvision/models/detection/ssdlite.py +++ b/torchvision/models/detection/ssdlite.py @@ -210,6 +210,7 @@ def ssdlite320_mobilenet_v3_large(pretrained: bool = False, progress: bool = Tru "nms_thresh": 0.55, "detections_per_img": 300, "topk_candidates": 300, + # Rescale the input in a way compatible to the backbone: # The following mean/std rescale the data from [0, 1] to [-1, -1] "image_mean": [0.5, 0.5, 0.5], "image_std": [0.5, 0.5, 0.5],