Skip to content

[BC-breaking] Fix initialisation bug on FeaturePyramidNetwork #2954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Nov 9, 2020

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Nov 3, 2020

Fixes #2326

The FeaturePyramidNetwork intended to initialize all the Conv2d layer weights using the following distribution:

for m in self.children():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight, a=1)
nn.init.constant_(m.bias, 0)

Unfortunately due to the use of children() instead of modules() we only initialize the direct blocks of the FPN and not the ones of its nested modules. This means that those nested modules use the default distribution from PyTorch:

init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
    fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
    bound = 1 / math.sqrt(fan_in)
    init.uniform_(self.bias, -bound, bound)

This PR changes the use of children() to modules() to ensure that the weight initialization is consistent across all nested blocks of FeaturePyramidNetwork.

The above change, though trivial in terms of coding it has side-effects:

  1. It increases the instability of unit-tests.
  2. It affects the validation metrics of the Detection pre-trained models.

The first issue is addressed on two separate PRs. See #2965 and #2966 from more information. To address the second, I retrained the Detection models on master and on this branch and compared their stats. Overall we see that fixing the bug and applying initialization correctly in all blocks leads to small improvements for all models except keypointrcnn.

Master

fasterrcnn_resnet50_fpn: boxAP=37.6
0: IoU metric: bbox
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.376
0:  Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.583
0:  Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.405
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.212
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.413
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.489
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.313
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.495
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.518
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.317
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.556
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.669
retinanet_resnet50_fpn: boxAP=36.3
0: IoU metric: bbox
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.363
0:  Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.555
0:  Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.384
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.191
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.400
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.492
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.313
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.501
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.539
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.329
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.581
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.692
maskrcnn_resnet50_fpn: boxAP=38.5, maskAP=34.5
0: IoU metric: bbox
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.385
0:  Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.592
0:  Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.415
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.219
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.420
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.508
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.319
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.506
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.529
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.335
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.569
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.677
0: IoU metric: segm
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.345
0:  Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.559
0:  Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.366
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.159
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.372
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.517
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.295
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.457
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.476
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.280
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.514
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.637
keypointrcnn_resnet50_fpn: boxAP=55.7, kpAP=65.1
0: IoU metric: bbox
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.557
0:  Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.831
0:  Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.606
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.382
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.637
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.721
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.190
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.565
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.646
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.497
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.710
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.795
0: IoU metric: keypoints
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets= 20 ] = 0.651
0:  Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets= 20 ] = 0.862
0:  Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets= 20 ] = 0.711
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = 0.602
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = 0.732
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 20 ] = 0.718
0:  Average Recall     (AR) @[ IoU=0.50      | area=   all | maxDets= 20 ] = 0.907
0:  Average Recall     (AR) @[ IoU=0.75      | area=   all | maxDets= 20 ] = 0.771
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = 0.669
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = 0.788

Branch

fasterrcnn_resnet50_fpn: boxAP=38.0
0: IoU metric: bbox
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.380
0:  Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.587
0:  Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.411
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.214
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.413
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.500
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.316
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.496
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.521
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.327
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.557
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.664
retinanet_resnet50_fpn: boxAP=36.4
0: IoU metric: bbox
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.364
0:  Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.558
0:  Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.381
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.193
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.400
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.489
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.314
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.502
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.541
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.334
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.584
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.692
maskrcnn_resnet50_fpn: boxAP=38.7, maskAP=34.7
0: IoU metric: bbox
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.387
0:  Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.592
0:  Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.420
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.216
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.421
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.512
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.321
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.506
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.531
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.333
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.569
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.683
0: IoU metric: segm
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.347
0:  Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.559
0:  Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.371
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.152
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.373
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.519
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.297
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.458
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.478
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.278
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.514
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.642
keypointrcnn_resnet50_fpn: boxAP=55.5, kpAP=65.1
0: IoU metric: bbox
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.555
0:  Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.833
0:  Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.603
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.381
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.634
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.719
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.190
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.564
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.647
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.499
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.709
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.796
0: IoU metric: keypoints
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets= 20 ] = 0.651
0:  Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets= 20 ] = 0.865
0:  Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets= 20 ] = 0.706
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = 0.605
0:  Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = 0.730
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 20 ] = 0.719
0:  Average Recall     (AR) @[ IoU=0.50      | area=   all | maxDets= 20 ] = 0.911
0:  Average Recall     (AR) @[ IoU=0.75      | area=   all | maxDets= 20 ] = 0.768
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] = 0.670
0:  Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] = 0.787

The deterioration in accuracy might be because of:

  1. Random effects: Let's assume that sampling from distribution X is a better initialization scheme than sampling from distribution Y. On average we expect that applying scheme X yields better results than Y. This does not mean that every instantiation of the sampling from X will yield better results than instantiations from Y. Using different seeds can yield different results.
  2. Changing the initialization is not beneficial: Initializing from the same distribution X in all blocks is not beneficial for all models. Thus for some models, it might be beneficial to keep the old scheme Y.

To fully investigate the situation, we would need to do multiple runs using different seeds for each initialization scheme and then conduct a statistical test to assess if our change has statistically significant effects on the accuracy. If the change turns out to have statistically significant negative results, we might want to support different init schemes per model.

I believe that the above is not worth the effort because:

  1. The suggested change does not seem to hurt overall the performance in most cases.
  2. Training multiple models is costly and time consuming and the use of similar init schemes is unlikely in this case to lead to significant improvements.
  3. Our goal is to provide a clean and consistent implementation rather than apply hacks to achieve SOTA performance.

Based on the above, I recommend merging the change into master.

@datumbox datumbox changed the title [WIP] Change children() to modules() to ensure init happens in all blocks. [WIP] Fix initialisation bug on FeaturePyramidNetwork Nov 3, 2020
@codecov
Copy link

codecov bot commented Nov 6, 2020

Codecov Report

Merging #2954 into master will increase coverage by 0.03%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #2954      +/-   ##
==========================================
+ Coverage   73.41%   73.44%   +0.03%     
==========================================
  Files          99       99              
  Lines        8812     8812              
  Branches     1389     1389              
==========================================
+ Hits         6469     6472       +3     
+ Misses       1917     1915       -2     
+ Partials      426      425       -1     
Impacted Files Coverage Δ
torchvision/ops/feature_pyramid_network.py 94.50% <100.00%> (+3.29%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 3852b41...f344366. Read the comment docs.

@datumbox datumbox requested a review from fmassa November 8, 2020 14:04
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR and the investigation Vasilis!

I agree with all your points, I think it's not necessary for now to run more tests for keypoint rcnn.

We should at some point though consider uploading newer weights, but we will need to handle versioning before.

@fmassa fmassa merged commit 8fe0b15 into pytorch:master Nov 9, 2020
@fmassa fmassa changed the title [WIP] Fix initialisation bug on FeaturePyramidNetwork [BC-breaking] Fix initialisation bug on FeaturePyramidNetwork Nov 9, 2020
@datumbox datumbox deleted the bugfix/pyramidnet_init branch November 9, 2020 11:07
bryant1410 pushed a commit to bryant1410/vision-1 that referenced this pull request Nov 22, 2020
* Change children() to modules() to ensure init happens in all blocks.

* Update expected values of all detection models.

* Revert "Update expected values of all detection models."

This reverts commit 050b64a

* Update expecting values.
vfdev-5 pushed a commit to Quansight/vision that referenced this pull request Dec 4, 2020
* Change children() to modules() to ensure init happens in all blocks.

* Update expected values of all detection models.

* Revert "Update expected values of all detection models."

This reverts commit 050b64a

* Update expecting values.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Feature Pyramid Network code bug
3 participants