Skip to content

🐛 [Bug] RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:92] Expected to find type float for value scale_factor.1 but get nothing. #1378

Closed
@zhaozhiming37

Description

@zhaozhiming37

Bug Description

To Reproduce

Steps to reproduce the behavior:

import cv2
import numpy as np
import torch
import torch_tensorrt
from torch import nn, Tensor
from typing import Tuple
import torch.nn.functional as F


class Resize(nn.Module):
    def __init__(
            self,
            size: Tuple[int, int] = (640, 640)
    ):
        super(Resize, self).__init__()
        self.size = torch.as_tensor(size, dtype=torch.int32)

    def forward(self, images: Tensor) -> Tensor:
        images_shape = images.shape[2:4]
        scale_factor = min(float(self.size[0]) / images_shape[0],
                           float(self.size[1]) / images_shape[1])
        images = F.interpolate(
            images,
            size=None,
            scale_factor=scale_factor,
            mode='bilinear',
            recompute_scale_factor=True,
            align_corners=False
        )
        return images


if __name__ == "__main__":
    img_path = "images/0017.jpg"
    images = torch.as_tensor(cv2.imread(img_path)[np.newaxis], dtype=torch.float)

    model = Resize().eval()
    inputs = [
        torch_tensorrt.Input(
            # min_shape=[1, 720, 1280, 3],
            # opt_shape=[2, 720, 1280, 3],
            # max_shape=[4, 720, 1280, 3],
            shape=(1, 720, 1280, 3),
            dtype=torch.float,
        )
    ]
    enabled_precisions = {torch.float}  # Run with fp16

    trt_ts_module = torch_tensorrt.compile(
        model, inputs=inputs, enabled_precisions=enabled_precisions
    )
    input_data = images.to("cuda")
    result = trt_ts_module(input_data)
    torch.jit.save(trt_ts_module, "./resize.ts")

    trt_ts_module = torch.jit.load("./resize.ts")
    input_data = input_data.to("cuda")
    result = trt_ts_module(input_data)

bug info:

Traceback (most recent call last):
  File "/opt/project/trt.py", line 84, in <module>
    trt_ts_module = torch_tensorrt.compile(
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 111, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 116, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: [Error thrown at core/partitioning/shape_analysis.cpp:92] Expected to find type float for value scale_factor.1 but get nothing. 

Expected behavior

I want dynamic parameter scale_factor supported.

Environment

nvcr.io/nvidia/pytorch:22.08-py3

Additional context

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions