Skip to content

Conversation

zhiqwang
Copy link
Contributor

@zhiqwang zhiqwang commented May 19, 2021

Hi @datumbox This is a follow-up PR of #3654.

Sorry for long time no updating this stream, actually I found a bug in my proposal in #3654 (comment).

Note that if the im_shape is set to torch.float64 in ONNX, the numerically error between PyTorch and the exported ONNX model will be greater.

I wrote a script to test the performance between Pytorch and ONNXRuntime as below, it shows that only when both are set to the same dtype, the results of PyTorch and ONNXRuntime are the same. But the torch.nn.functional.interpolate in

image = torch.nn.functional.interpolate(image[None], size=size, scale_factor=scale_factor, mode='bilinear',
recompute_scale_factor=recompute_scale_factor, align_corners=False)[0]

will bring some numerical errors between PyTorch and ONNXRuntime, no matter how we set this data type. And its error seems to be random.

import torch
import torchvision
from torch import nn, Tensor

@torch.jit.unused
def _get_shape_onnx(image: Tensor) -> Tensor:
    from torch.onnx import operators
    return operators.shape_as_tensor(image)[-2:]

@torch.jit.unused
def _fake_cast_onnx(v: Tensor) -> float:
    # ONNX requires a tensor but here we fake its type for JIT.
    return v

class TestFloatPytorchVsONNX(nn.Module):
    def __init__(self, max_size):
        super().__init__()
        self.max_size = max_size

    def forward(self, image: Tensor) -> float:
        if torchvision._is_tracing():
            im_shape = _get_shape_onnx(image)
        else:
            im_shape = torch.tensor(image.shape[-2:])

        max_size = torch.max(im_shape).to(dtype=torch.float64)
        scale = self.max_size / max_size

        if torchvision._is_tracing():
            scale_factor = _fake_cast_onnx(scale)
        else:
            scale_factor = scale.item()

        return scale_factor

test_float_pytorch_vs_onnx = TestFloatPytorchVsONNX(416)
image = torch.rand(3, 1080, 810)
scale_factor = test_float_pytorch_vs_onnx(image)

print(scale_factor)

@datumbox datumbox added the bug label May 19, 2021
Copy link
Contributor

@datumbox datumbox left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zhiqwang Thanks for submitting the PR.

I noticed that your earlier patch modified the im_shape type. Is this not needed?

But the torch.nn.functional.interpolate in will bring some numerical errors between PyTorch and ONNXRuntime, no matter how we set this data type. And its error seems to be random.

Does this mean that this patch does not address fully the problem? If that's the case, should we wait until the issue you raised on PyTorch code is addressed?

cc @prabhat00155 and @fmassa who have more ONNX experience than me.

@datumbox datumbox requested review from fmassa and prabhat00155 May 19, 2021 18:17
@zhiqwang
Copy link
Contributor Author

zhiqwang commented May 19, 2021

Hi @datumbox

I noticed that your earlier patch modified the im_shape type. Is this not needed?

That's not needed, The reason why I modified the type of im_shape in the earlier patch is just to make it take effect before torch.div (/).

scale = torch.min(self_min_size / min_size, self_max_size / max_size)

Does this mean that this patch does not address fully the problem? If that's the case, should we wait until the issue you raised on PyTorch code is addressed?

I think this patch can't address fully the problem. In other words, the current patch can't ensure the accuracy consistency of the whole model between PyTorch and ONNXRuntime. The previous "fixing" may just be a coincidence (Because torch.nn.functional.interpolate and other ops will bring some numerically errors between PT and ORT, these maybe the key to influence the inference accuracy between PT and ORT).

This modification of the type here will not have a big impact on the PyTorch side. I tested the mAP of faster-rcnn and retinanet on the master branch and this patch. (I can't connect to my cloud now, so I post the summary as below)

faster-rcnn retinanet
master 36.9 36.3
after patch 37.0 36.3

Most importantly, I think this modification may affect downstream work. The torch.float64 here may not be friendly for other users.

For the currently time, I think this modification is not necessary, add you can fell free to close this PR.

@fmassa
Copy link
Member

fmassa commented May 20, 2021

If I understand this correctly, the difference is because we now perform the division in fp32 while before it was in fp64, so the scale might be slightly different.

IMO this his not a huge issue, and some numerical differences might be expected between different frameworks.
Before ONNX would use fp32 in its implementation, while PyTorch would use fp64 (as it was cast to a python scalar)

I'm not sure how well ONNX handles fp64, so I'd be tempted to keep things as is. The 0.1 drop seems to be bad luck due to the rounding of the division.

@datumbox
Copy link
Contributor

@zhiqwang Thanks a lot for the investigation, detailed information and PR.

Given that the patch does not reliably resolves the reported difference and it does introduce fp64 which might cause issues, I will close the PR.

Again thanks a lot for sending your contribution in such short notice and if you have any concerns please let me know.

@datumbox datumbox closed this May 20, 2021
@zhiqwang zhiqwang deleted the restore-previous-type branch May 20, 2021 14:09
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants