Skip to content

🐛 [Bug] failed correctness check when using F.interpolate(align_corners=False) #1558

Closed
@1559588143

Description

@1559588143

Bug Description

INFO:torch_tensorrt.fx.passes.pass_utils:== Log pass <function fuse_permute_matmul at 0x000001F8996789D0> before/after graph to C:\Users\Holy\AppData\Local\Temp\tmpsek2ezsi, before/after are the same = True
INFO:torch_tensorrt.fx.passes.pass_utils:== Log pass <function fuse_permute_linear at 0x000001F899678790> before/after graph to C:\Users\Holy\AppData\Local\Temp\tmpp6a4seyh, before/after are the same = True

Supported node types in the model:
acc_ops.interpolate: ((), {'input': torch.float32})

Unsupported node types in the model:

Got 1 acc subgraphs and 0 non-acc subgraphs
INFO:torch_tensorrt.fx.passes.lower_pass_manager_builder:Now lowering submodule _run_on_acc_0
INFO:torch_tensorrt.fx.lower:split_name=_run_on_acc_0, input_specs=[InputTensorSpec(shape=torch.Size([1, 3, 256, 256]), dtype=torch.float32, device=device(type='cuda', index=0), shape_ranges=[], has_batch_dim=True)]
INFO:torch_tensorrt.fx.lower:Timing cache is used!
INFO:torch_tensorrt.fx.fx2trt:TRT INetwork construction elapsed time: 0:00:00
[12/17/2022-14:17:40] [TRT] [W] TensorRT was linked against cuDNN 8.6.0 but loaded cuDNN 8.5.0
INFO:torch_tensorrt.fx.fx2trt:Build TRT engine elapsed time: 0:00:00.809599
INFO:torch_tensorrt.fx.passes.lower_pass_manager_builder:Lowering submodule _run_on_acc_0 elapsed time 0:00:04.167018
Traceback (most recent call last):
  File "C:\Users\Holy\Downloads\test.py", line 22, in <module>
    trt_mod = compile(
  File "C:\Python310\lib\site-packages\torch_tensorrt\fx\lower.py", line 88, in compile
    return lowerer(module, input)
  File "C:\Python310\lib\site-packages\torch_tensorrt\fx\lower.py", line 323, in __call__
    return do_lower(module, inputs)
  File "C:\Python310\lib\site-packages\torch_tensorrt\fx\passes\pass_utils.py", line 155, in pass_with_validation
    raise e
  File "C:\Python310\lib\site-packages\torch_tensorrt\fx\passes\pass_utils.py", line 141, in pass_with_validation
    torch.testing.assert_close(x, y, **kwargs2)
  File "C:\Python310\lib\site-packages\torch\testing\_comparison.py", line 1342, in assert_close
    assert_equal(
  File "C:\Python310\lib\site-packages\torch\testing\_comparison.py", line 1093, in assert_equal
    raise error_metas[0].to_error(msg)
AssertionError: Pass <function Lowerer.__call__.<locals>.do_lower at 0x000001F8996D3E20> failed correctness check due at output 0:
Tensor-likes are not close!

Mismatched elements: 353039 / 3145728 (11.2%)
Greatest absolute difference: 0.4991211108863354 at index (0, 2, 24, 4) (up to 0.1 allowed)
Greatest relative difference: 56.67618494203449 at index (0, 1, 4, 4) (up to 0.1 allowed)

To Reproduce

import torch
from torch import nn
from torch.nn import functional as F
from torch_tensorrt.fx import compile
from torch_tensorrt.fx.utils import LowerPrecision


class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()

    def forward(self, x):
        return F.interpolate(x, scale_factor=4, mode='bilinear', align_corners=False)


if __name__ == '__main__':
    device = torch.device('cuda')
    x = torch.rand(1, 3, 256, 256, dtype=torch.float32, device=device)

    with torch.inference_mode():
        mod = MyModule().eval().to(device)
        trt_mod = compile(
            mod,
            [x],
            min_acc_module_size=1,
            explicit_batch_dimension=True,
            lower_precision=LowerPrecision.FP32,
            dynamic_batch=False,
        )

Environment

  • Torch-TensorRT Version (e.g. 1.0.0): 1.3.0
  • PyTorch Version (e.g. 1.0): 1.13.1+cu117
  • CPU Architecture: x64
  • OS (e.g., Linux): Windows 10
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.10.9
  • CUDA version: 11.7
  • GPU models and configuration: RTX 3050
  • Any other relevant information:

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions