Skip to content

🐛 [Bug] torch.arange causes SpecViolationError during torch_tensorrt.save #3189

Closed
@Qi-Zha0

Description

@Qi-Zha0

Bug Description

See example below

To Reproduce

Minimum example:

import torch
import torch_tensorrt


class Mod(torch.nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        x_embed = torch.arange(1, x.shape[-1] + 1, dtype=torch.float32, device=x.device)
        return x_embed


ep = torch_tensorrt.compile(Mod(), ir="dynamo", inputs=(torch.randn(1, 1, 128, 128)))
torch_tensorrt.save(ep, "test.ep", inputs=(torch.randn(1, 1, 128, 128)))

Error:

WARNING:torch_tensorrt.dynamo.conversion.aten_ops_converters:Unable to import quantization op. Please install modelopt library (https://github.com/NVIDIA/TensorRT-Model-Optimizer?tab=readme-ov-file#installation) to add support for compiling quantized models
INFO:torch_tensorrt.dynamo._compiler:Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, debug=False, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, refit=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/timing_cache.bin')

WARNING:torch_tensorrt.dynamo._compiler:0 supported operations detected in subgraph containing 0 computational nodes. Skipping this subgraph, since min_block_size was detected to be 5
Traceback (most recent call last):
  File "/home/user/project/project_subdirectory/scripts/debug.py", line 16, in <module>
    torch_tensorrt.save(ep, "test.ep", inputs=(torch.randn(1, 1, 128, 128)))
  File "/home/user/project/.venv/lib/python3.11/site-packages/torch_tensorrt/_compile.py", line 461, in save
    exp_program = export(module, inputs)
                  ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/project/.venv/lib/python3.11/site-packages/torch_tensorrt/dynamo/_exporter.py", line 33, in export
    exp_program = create_trt_exp_program(patched_module)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/user/project/.venv/lib/python3.11/site-packages/torch_tensorrt/dynamo/_exporter.py", line 328, in create_trt_exp_program
    trt_exp_program = ExportedProgram(
                      ^^^^^^^^^^^^^^^^
  File "/home/user/project/.venv/lib/python3.11/site-packages/torch/export/exported_program.py", line 246, in __init__
    self.verifier().check(self)
  File "/home/user/project/.venv/lib/python3.11/site-packages/torch/_export/verifier.py", line 155, in check
    _verify_exported_program_signature(ep)
  File "/home/user/project/.venv/lib/python3.11/site-packages/torch/_export/verifier.py", line 421, in _verify_exported_program_signature
    raise SpecViolationError(
torch._export.verifier.SpecViolationError: User output _frozen_param0_1 is not in the correct order or is not found in the exported program's user_output list: ('_frozen_param0',). 
WARNING:py.warnings:/usr/lib/python3.11/tempfile.py:1073: ResourceWarning: Implicitly cleaning up <TemporaryDirectory '/tmp/tmpn6njzoc7'>
  _warnings.warn(warn_message, ResourceWarning)

Expected behavior

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.4
  • PyTorch Version (e.g. 1.0): 2.4
  • CPU Architecture:
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source):
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.11
  • CUDA version:
  • GPU models and configuration:
  • Any other relevant information:

Additional context

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions