From 233a388d39afbc4edd4ad89326bd2f12aec10fcf Mon Sep 17 00:00:00 2001 From: Vasileios Vryniotis Date: Fri, 30 Oct 2020 16:59:06 +0000 Subject: [PATCH 1/2] Overwriting FrozenBN eps=0.0 if pretrained=True for detection models. --- test/test_models.py | 6 ++---- torchvision/models/_utils.py | 8 ++++++++ torchvision/models/detection/faster_rcnn.py | 2 ++ torchvision/models/detection/keypoint_rcnn.py | 2 ++ torchvision/models/detection/mask_rcnn.py | 2 ++ torchvision/models/detection/retinanet.py | 2 ++ 6 files changed, 18 insertions(+), 4 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index b37fb176a2b..f7c684b4272 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -8,7 +8,7 @@ import unittest import random -from torchvision.ops.misc import FrozenBatchNorm2d +from torchvision.models._utils import overwrite_eps def set_rng_seed(seed): @@ -151,9 +151,7 @@ def _test_detection_model(self, name, dev): kwargs["score_thresh"] = 0.013 model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs) if "keypointrcnn" in name or "retinanet" in name: - for module in model.modules(): - if isinstance(module, FrozenBatchNorm2d): - module.eps = 0 + overwrite_eps(model, 0.0) model.eval().to(device=dev) input_shape = (3, 300, 300) # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index 928a6f1c559..71788b7df3f 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -4,6 +4,8 @@ from torch import nn from typing import Dict +from ..ops.misc import FrozenBatchNorm2d + class IntermediateLayerGetter(nn.ModuleDict): """ @@ -65,3 +67,9 @@ def forward(self, x): out_name = self.return_layers[name] out[out_name] = x return out + + +def overwrite_eps(model, eps): + for module in model.modules(): + if isinstance(module, FrozenBatchNorm2d): + module.eps = eps diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 117d985eca6..5ef0e910b7f 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -7,6 +7,7 @@ from torchvision.ops import misc as misc_nn_ops from torchvision.ops import MultiScaleRoIAlign +from .._utils import overwrite_eps from ..utils import load_state_dict_from_url from .anchor_utils import AnchorGenerator @@ -361,4 +362,5 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'], progress=progress) model.load_state_dict(state_dict) + overwrite_eps(model, 0.0) return model diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index f1f4ad26680..d91ac5ba8e6 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -3,6 +3,7 @@ from torchvision.ops import MultiScaleRoIAlign +from .._utils import overwrite_eps from ..utils import load_state_dict_from_url from .faster_rcnn import FasterRCNN @@ -332,4 +333,5 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, state_dict = load_state_dict_from_url(model_urls[key], progress=progress) model.load_state_dict(state_dict) + overwrite_eps(model, 0.0) return model diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 668d8ab8122..8a0acf88e06 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -7,6 +7,7 @@ from torchvision.ops import misc as misc_nn_ops from torchvision.ops import MultiScaleRoIAlign +from .._utils import overwrite_eps from ..utils import load_state_dict_from_url from .faster_rcnn import FasterRCNN @@ -328,4 +329,5 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True, state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'], progress=progress) model.load_state_dict(state_dict) + overwrite_eps(model, 0.0) return model diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index d128ecb5699..784f8c390bf 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -7,6 +7,7 @@ from torch import Tensor from torch.jit.annotations import Dict, List, Tuple +from .._utils import overwrite_eps from ..utils import load_state_dict_from_url from . import _utils as det_utils @@ -628,4 +629,5 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True, state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'], progress=progress) model.load_state_dict(state_dict) + overwrite_eps(model, 0.0) return model From 07bc7a086dbc89278e9c1c838ed7ef75f32bd712 Mon Sep 17 00:00:00 2001 From: Vasileios Vryniotis Date: Tue, 3 Nov 2020 13:09:38 +0000 Subject: [PATCH 2/2] Moving the method to detection utils and adding comments. --- test/test_models.py | 2 +- torchvision/models/_utils.py | 8 ------- torchvision/models/detection/_utils.py | 21 ++++++++++++++++++- torchvision/models/detection/faster_rcnn.py | 2 +- torchvision/models/detection/keypoint_rcnn.py | 2 +- torchvision/models/detection/mask_rcnn.py | 2 +- torchvision/models/detection/retinanet.py | 2 +- 7 files changed, 25 insertions(+), 14 deletions(-) diff --git a/test/test_models.py b/test/test_models.py index f7c684b4272..e27021c4337 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -8,7 +8,7 @@ import unittest import random -from torchvision.models._utils import overwrite_eps +from torchvision.models.detection._utils import overwrite_eps def set_rng_seed(seed): diff --git a/torchvision/models/_utils.py b/torchvision/models/_utils.py index 71788b7df3f..928a6f1c559 100644 --- a/torchvision/models/_utils.py +++ b/torchvision/models/_utils.py @@ -4,8 +4,6 @@ from torch import nn from typing import Dict -from ..ops.misc import FrozenBatchNorm2d - class IntermediateLayerGetter(nn.ModuleDict): """ @@ -67,9 +65,3 @@ def forward(self, x): out_name = self.return_layers[name] out[out_name] = x return out - - -def overwrite_eps(model, eps): - for module in model.modules(): - if isinstance(module, FrozenBatchNorm2d): - module.eps = eps diff --git a/torchvision/models/detection/_utils.py b/torchvision/models/detection/_utils.py index cf4567daa8d..9834fb6b92a 100644 --- a/torchvision/models/detection/_utils.py +++ b/torchvision/models/detection/_utils.py @@ -3,7 +3,8 @@ import torch from torch.jit.annotations import List, Tuple from torch import Tensor -import torchvision + +from torchvision.ops.misc import FrozenBatchNorm2d class BalancedPositiveNegativeSampler(object): @@ -349,3 +350,21 @@ def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = Tru if size_average: return loss.mean() return loss.sum() + + +def overwrite_eps(model, eps): + """ + This method overwrites the default eps values of all the + FrozenBatchNorm2d layers of the model with the provided value. + This is necessary to address the BC-breaking change introduced + by the bug-fix at pytorch/vision#2933. The overwrite is applied + only when the pretrained weights are loaded to maintain compatibility + with previous versions. + + Arguments: + model (nn.Module): The model on which we perform the overwrite. + eps (float): The new value of eps. + """ + for module in model.modules(): + if isinstance(module, FrozenBatchNorm2d): + module.eps = eps diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 5ef0e910b7f..256b11e40bc 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -7,7 +7,7 @@ from torchvision.ops import misc as misc_nn_ops from torchvision.ops import MultiScaleRoIAlign -from .._utils import overwrite_eps +from ._utils import overwrite_eps from ..utils import load_state_dict_from_url from .anchor_utils import AnchorGenerator diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index d91ac5ba8e6..fb480119144 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -3,7 +3,7 @@ from torchvision.ops import MultiScaleRoIAlign -from .._utils import overwrite_eps +from ._utils import overwrite_eps from ..utils import load_state_dict_from_url from .faster_rcnn import FasterRCNN diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 8a0acf88e06..f933180e4fa 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -7,7 +7,7 @@ from torchvision.ops import misc as misc_nn_ops from torchvision.ops import MultiScaleRoIAlign -from .._utils import overwrite_eps +from ._utils import overwrite_eps from ..utils import load_state_dict_from_url from .faster_rcnn import FasterRCNN diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 784f8c390bf..fc05106a807 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -7,7 +7,7 @@ from torch import Tensor from torch.jit.annotations import Dict, List, Tuple -from .._utils import overwrite_eps +from ._utils import overwrite_eps from ..utils import load_state_dict_from_url from . import _utils as det_utils