Skip to content

🐛 [Bug] Failure to compile swin/BERT with dynamic batch #1271

Closed
@Njuapp

Description

@Njuapp

Bug Description

When compiling swin transformer/BERT with dynamic batch, torch-trt will report errors.

Swin-Transformer reports error:

Traceback (most recent call last):
  File "main.py", line 581, in <module>
    main(config)
  File "main.py", line 197, in main
    trt_ts_module = torch_tensorrt.compile(torch_script_module.float(), **compile_settings)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 115, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 113, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: Trying to create tensor with negative dimension -1: [-1]

BERT reports error:

Traceback (most recent call last):
  File "test_bert.py", line 47, in <module>
    trt_ts_module = torch_tensorrt.compile(torch_script_module.float(), **compile_settings)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/_compile.py", line 115, in compile
    return torch_tensorrt.ts.compile(ts_mod, inputs=inputs, enabled_precisions=enabled_precisions, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch_tensorrt/ts/_compiler.py", line 113, in compile
    compiled_cpp_mod = _C.compile_graph(module._c, _parse_compile_spec(spec))
RuntimeError: upper bound and larger bound inconsistent with step sign

To Reproduce

  • Steps to reproduce the swin transformer compilation error:
    1. Download swin_tiny.ts
      NOTE: this torchscript is saved by torch 1.12, may not be acessible from older torch version
    2. run the following code snippet:
import torch
import time
import numpy as np
import torch_tensorrt
torch_script_module = torch.jit.load('swin_tiny.ts')

print("COMPILE:FP32")
compile_settings = {
        "inputs": [torch_tensorrt.Input(
                min_shape=[1, 3, 224, 224],
                opt_shape=[32, 3, 224, 224],
                max_shape=[32, 3, 224, 224],
                dtype=torch.float,
            )
            ],
    "require_full_compilation": False,
    "enabled_precisions": {torch.float},  # Run with FP32
    "truncate_long_and_double": True,
}

trt_ts_module = torch_tensorrt.compile(torch_script_module.float(), **compile_settings)
torch.jit.save(trt_ts_module, 'swin_fp32.trt.ts')
trt_ts_module = torch.jit.load('swin_fp32.trt.ts')
print("RUN: FP32")
warmup_time = 10
test_time = 100
batch_size = 8
x = np.random.randn(batch_size, 3, 224, 224).astype(np.float)
x = torch.from_numpy(x).cuda().float()
for i in range(warmup_time):
    result_trt_ts = trt_ts_module(x)

torch.cuda.synchronize()
t1 = time.time()
for i in range(test_time):
    result_trt_ts = trt_ts_module(x)
torch.cuda.synchronize()
t2 = time.time()
print("Result shape: ", result_trt_ts.shape)
print("Cost: ", (t2-t1)/test_time*1000.0, "ms")

result_ts = torch_script_module(x)
diff = abs(result_ts - result_trt_ts)
print("diff after torch-trt: mean diff {0}, max diff {1}".format(diff.mean(), diff.max()))
  • Steps to reproduce bert compilation error:
from transformers import BertModel, BertTokenizer, BertConfig
import numpy as np
import torch
import torch_tensorrt
import time

# Creating a dummy input
test_batchsz = 16
tokens_tensor = torch.ones((test_batchsz,192)).to(torch.int32).cuda()
segments_tensors = torch.ones((test_batchsz,192)).to(torch.int32).cuda()

# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-chinese", torchscript=True)
model = model.eval().cuda()

# torch_tensorrt.logging.set_reportable_log_level(torch_tensorrt.logging.Level.Graph)

print("VERSION:", torch_tensorrt.__version__)
print("TYPE:", type(model))

torch_script_module = torch.jit.trace(model, (tokens_tensor, segments_tensors))

'''
Compile FP32###############################################################
'''
print("COMPILE:FP32")
compile_settings = {
    "inputs": [
    torch_tensorrt.Input(
        min_shape=[1, 192],
        opt_shape=[32, 192],
        max_shape=[32, 192],
        dtype=torch.int32,
    ),
    torch_tensorrt.Input(
        min_shape=[1, 192],
        opt_shape=[32, 192],
        max_shape=[32, 192],
        dtype=torch.int32,
        ),
    ],
    "require_full_compilation": True,
    "enabled_precisions": {torch.float},  # Run with FP21
    "truncate_long_and_double": True,
}

trt_ts_module = torch_tensorrt.compile(torch_script_module.float(), **compile_settings)
torch.jit.save(trt_ts_module, 'bert_fp32.trt.ts')
trt_ts_module = torch.jit.load('bert_fp32.trt.ts')

print("RUN:FP32")
ts_result = torch_script_module(tokens_tensor, segments_tensors)
warmup_time = 10
test_time = 100
for i in range(warmup_time):
    trt_ts_result = trt_ts_module(tokens_tensor, segments_tensors)

torch.cuda.synchronize()
t1 = time.time()
for i in range(test_time):
    trt_ts_result = trt_ts_module(tokens_tensor, segments_tensors)
torch.cuda.synchronize()
t2 = time.time()
print("Cost: ", round(t2-t1, 4) / test_time * 1000.0)
diff = abs(trt_ts_result[0] - ts_result[0])
print('output shape ', diff.shape, ', diff is {}'.format(diff.mean()))

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0):
  • PyTorch Version (e.g. 1.0):
  • CPU Architecture:
  • OS (e.g., Linux):
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version:
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions