Skip to content

Unify onnx and JIT resize implementations #3654

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Apr 9, 2021

Conversation

datumbox
Copy link
Contributor

@datumbox datumbox commented Apr 9, 2021

This PR unifies the two separate implementations of resize used by ONNX and JIT.

Unfortunately a workaround is necessary related to how we handle the scale_factor because:

  • The interpolate method needs it to be a float.
  • ONNX needs it to be a tensor so that the variable is in the computation graph.
  • JIT complains about its type if ONNX passes a tensor.

To resolve the problem we introduce a conditional fake cast to fool JIT and stop it from complaining about the ONNX branch.

@datumbox datumbox force-pushed the refactoring/unify_resize branch from 281725a to 599a1a5 Compare April 9, 2021 12:47
@datumbox datumbox changed the title [WIP] Unify onnx and JIT resize implementations Unify onnx and JIT resize implementations Apr 9, 2021
@datumbox datumbox requested a review from fmassa April 9, 2021 14:12
Copy link
Member

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, thanks!

@fmassa fmassa merged commit 1d0b43e into pytorch:master Apr 9, 2021
@datumbox datumbox deleted the refactoring/unify_resize branch April 9, 2021 14:36
facebook-github-bot pushed a commit that referenced this pull request Apr 13, 2021
Summary:
* Make two methods as similar as possible.

* Introducing conditional fake casting.

* Change the casting mechanism.

Reviewed By: NicolasHug

Differential Revision: D27706950

fbshipit-source-id: ef7503817cd64ffc8723fec89f1cd94647490eaf
@zhiqwang
Copy link
Contributor

Hi @datumbox

I found that this modification will cause an exception, the Pytorch "true" division is not numerically equal to the numpy, and I file an issue in pytorch/pytorch#58234 .

scale_factor is caught from master

scale = torch.min(self_min_size / min_size, self_max_size / max_size)

vs

scale_factor is caught from 0.9.0

scale_factor = self_min_size / min_size
if max_size * scale_factor > self_max_size:
scale_factor = self_max_size / max_size

This will cause numerically error between the exported ONNX model and the pytorch.

Besides, this phenomenon will weakly affect the mAP of the previous object detection model. I recompute the mAP of retinanet_resnet50_fpn as following (As a reference, the previous result is 0.364 in #2784)

IoU metric: bbox
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.363
 Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets=100 ] = 0.557
 Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets=100 ] = 0.382
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.193
 Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.400
 Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.490
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=  1 ] = 0.314
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 10 ] = 0.500
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets=100 ] = 0.540
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= small | maxDets=100 ] = 0.340
 Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=100 ] = 0.581
 Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets=100 ] = 0.696

Validated with:

train.py --dataset coco --model retinanet_resnet50_fpn --pretrained --test-only

@datumbox
Copy link
Contributor Author

@zhiqwang Thanks for the heads up! Those two should be the same, so it's great that you already filed an issue on PyTorch code.

Concerning the perceived decrease in accuracy, the specific metric is know to be a bit flimsy and fluctuate. Here is a newer benchmark on master from what you posted #2954 . Hopefully once the ticket is resolved, the mAP will return to its expected value.

Let me know if there is anything else you recommend doing in the meantime.

@zhiqwang
Copy link
Contributor

zhiqwang commented May 14, 2021

Hi @datumbox

As mentioned in pytorch/pytorch#58234 (comment) , it is because that the data type have been changed after this modification.

This is a little confusing and it's happening because PyTorch's default datatype is float32 while NumPy's is float64.

I've tested locally that it will perform previous behavior if we change the data type to torch.float64 in the im_shape. It’s a little strange that the exported ONNX model and NumPy behave the same behavior (Note that if the im_shape is set to torch.float64 in ONNX, the numerically error between PyTorch and the exported ONNX model will be greater). Do you think it is necessary to open a follow-up PR about this problem?

diff --git a/torchvision/models/detection/transform.py b/torchvision/models/detection/transform.py
index 55cfb483..a8525b41 100644
--- a/torchvision/models/detection/transform.py
+++ b/torchvision/models/detection/transform.py
@@ -27,9 +27,9 @@ def _resize_image_and_masks(image: Tensor, self_min_size: float, self_max_size:
                             target: Optional[Dict[str, Tensor]],
                             fixed_size: Optional[Tuple[int, int]]) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
     if torchvision._is_tracing():
-        im_shape = _get_shape_onnx(image)
+        im_shape = _get_shape_onnx(image).to(torch.float32)
     else:
-        im_shape = torch.tensor(image.shape[-2:])
+        im_shape = torch.tensor(image.shape[-2:], dtype=torch.float64)

     size: Optional[List[int]] = None
     scale_factor: Optional[float] = None
@@ -37,8 +37,8 @@ def _resize_image_and_masks(image: Tensor, self_min_size: float, self_max_size:
     if fixed_size is not None:
         size = [fixed_size[1], fixed_size[0]]
     else:
-        min_size = torch.min(im_shape).to(dtype=torch.float32)
-        max_size = torch.max(im_shape).to(dtype=torch.float32)
+        min_size = torch.min(im_shape)
+        max_size = torch.max(im_shape)
         scale = torch.min(self_min_size / min_size, self_max_size / max_size)

@datumbox
Copy link
Contributor Author

@zhiqwang Thanks for the top notch investigation as always.

Your proposed changes look great to me and since they fix the problem, it would be great if you could send a PR.

@datumbox
Copy link
Contributor Author

@zhiqwang Just following up on this. We would like to bring your proposed fix prior to the upcoming PyTorch release. I was wondering if you are still interested sending the PR since you pretty much have it ready to go as far as I understand. Else let me know if I can give you a hand.

@datumbox
Copy link
Contributor Author

@zhiqwang Sorry for the back-to-back pings. Let me know if you are interested in sending the PR. We would like to have this patched as soon as possible so if you are not available I'm happy to take your fix and merge it.

I would prefer if you do it instead because 1) you should get credit for finding and fixing the issue and 2) it would be great to have your confirmation that the problem is resolved after we merge.

Let me know, thanks! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants