Skip to content

Commit c5b6736

Browse files
datumboxvfdev-5
authored andcommitted
Overwriting FrozenBN eps=0.0 if pretrained=True for detection models. (pytorch#2940)
* Overwriting FrozenBN eps=0.0 if pretrained=True for detection models. * Moving the method to detection utils and adding comments.
1 parent 25689e8 commit c5b6736

File tree

6 files changed

+30
-5
lines changed

6 files changed

+30
-5
lines changed

test/test_models.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import unittest
99
import random
1010

11-
from torchvision.ops.misc import FrozenBatchNorm2d
11+
from torchvision.models.detection._utils import overwrite_eps
1212

1313

1414
def set_rng_seed(seed):
@@ -151,9 +151,7 @@ def _test_detection_model(self, name, dev):
151151
kwargs["score_thresh"] = 0.013
152152
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
153153
if "keypointrcnn" in name or "retinanet" in name:
154-
for module in model.modules():
155-
if isinstance(module, FrozenBatchNorm2d):
156-
module.eps = 0
154+
overwrite_eps(model, 0.0)
157155
model.eval().to(device=dev)
158156
input_shape = (3, 300, 300)
159157
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests

torchvision/models/detection/_utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import torch
44
from torch.jit.annotations import List, Tuple
55
from torch import Tensor
6-
import torchvision
6+
7+
from torchvision.ops.misc import FrozenBatchNorm2d
78

89

910
class BalancedPositiveNegativeSampler(object):
@@ -349,3 +350,21 @@ def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = Tru
349350
if size_average:
350351
return loss.mean()
351352
return loss.sum()
353+
354+
355+
def overwrite_eps(model, eps):
356+
"""
357+
This method overwrites the default eps values of all the
358+
FrozenBatchNorm2d layers of the model with the provided value.
359+
This is necessary to address the BC-breaking change introduced
360+
by the bug-fix at pytorch/vision#2933. The overwrite is applied
361+
only when the pretrained weights are loaded to maintain compatibility
362+
with previous versions.
363+
364+
Arguments:
365+
model (nn.Module): The model on which we perform the overwrite.
366+
eps (float): The new value of eps.
367+
"""
368+
for module in model.modules():
369+
if isinstance(module, FrozenBatchNorm2d):
370+
module.eps = eps

torchvision/models/detection/faster_rcnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torchvision.ops import misc as misc_nn_ops
88
from torchvision.ops import MultiScaleRoIAlign
99

10+
from ._utils import overwrite_eps
1011
from ..utils import load_state_dict_from_url
1112

1213
from .anchor_utils import AnchorGenerator
@@ -361,4 +362,5 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
361362
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_resnet50_fpn_coco'],
362363
progress=progress)
363364
model.load_state_dict(state_dict)
365+
overwrite_eps(model, 0.0)
364366
return model

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from torchvision.ops import MultiScaleRoIAlign
55

6+
from ._utils import overwrite_eps
67
from ..utils import load_state_dict_from_url
78

89
from .faster_rcnn import FasterRCNN
@@ -332,4 +333,5 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
332333
state_dict = load_state_dict_from_url(model_urls[key],
333334
progress=progress)
334335
model.load_state_dict(state_dict)
336+
overwrite_eps(model, 0.0)
335337
return model

torchvision/models/detection/mask_rcnn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torchvision.ops import misc as misc_nn_ops
88
from torchvision.ops import MultiScaleRoIAlign
99

10+
from ._utils import overwrite_eps
1011
from ..utils import load_state_dict_from_url
1112

1213
from .faster_rcnn import FasterRCNN
@@ -328,4 +329,5 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
328329
state_dict = load_state_dict_from_url(model_urls['maskrcnn_resnet50_fpn_coco'],
329330
progress=progress)
330331
model.load_state_dict(state_dict)
332+
overwrite_eps(model, 0.0)
331333
return model

torchvision/models/detection/retinanet.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch import Tensor
88
from torch.jit.annotations import Dict, List, Tuple
99

10+
from ._utils import overwrite_eps
1011
from ..utils import load_state_dict_from_url
1112

1213
from . import _utils as det_utils
@@ -628,4 +629,5 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
628629
state_dict = load_state_dict_from_url(model_urls['retinanet_resnet50_fpn_coco'],
629630
progress=progress)
630631
model.load_state_dict(state_dict)
632+
overwrite_eps(model, 0.0)
631633
return model

0 commit comments

Comments
 (0)