diff --git a/test/test_models.py b/test/test_models.py index acff816852b..b37fb176a2b 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -6,9 +6,10 @@ import numpy as np from torchvision import models import unittest -import traceback import random +from torchvision.ops.misc import FrozenBatchNorm2d + def set_rng_seed(seed): torch.manual_seed(seed) @@ -149,6 +150,10 @@ def _test_detection_model(self, name, dev): if "retinanet" in name: kwargs["score_thresh"] = 0.013 model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs) + if "keypointrcnn" in name or "retinanet" in name: + for module in model.modules(): + if isinstance(module, FrozenBatchNorm2d): + module.eps = 0 model.eval().to(device=dev) input_shape = (3, 300, 300) # RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests diff --git a/test/test_ops.py b/test/test_ops.py index 7c13de4dedc..79294ed173e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -623,10 +623,10 @@ def test_frozenbatchnorm2d_eps(self): running_var=torch.rand(sample_size[1]), num_batches_tracked=torch.tensor(100)) - # Check that default eps is zero for backward-compatibility + # Check that default eps is equal to the one of BN fbn = ops.misc.FrozenBatchNorm2d(sample_size[1]) fbn.load_state_dict(state_dict, strict=False) - bn = torch.nn.BatchNorm2d(sample_size[1], eps=0).eval() + bn = torch.nn.BatchNorm2d(sample_size[1]).eval() bn.load_state_dict(state_dict) # Difference is expected to fall in an acceptable range self.assertTrue(torch.allclose(fbn(x), bn(x), atol=1e-6)) diff --git a/torchvision/ops/misc.py b/torchvision/ops/misc.py index 3b52c0d8c4d..3e9f13c9daf 100644 --- a/torchvision/ops/misc.py +++ b/torchvision/ops/misc.py @@ -51,7 +51,7 @@ class FrozenBatchNorm2d(torch.nn.Module): def __init__( self, num_features: int, - eps: float = 0., + eps: float = 1e-5, n: Optional[int] = None, ): # n=None for backward-compatibility