Skip to content

[export] _fft_r2c does not support dynamic shapes #135087

Closed
@justinchuby

Description

@justinchuby

fft_r2c does not support dynamic shapes:

import torch


class STFTModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._window = torch.hann_window(window_length=320)

    def forward(self, signals: torch.Tensor) -> torch.Tensor:
        x = signals.stft(
            n_fft=512,
            hop_length=160,
            win_length=320,
            return_complex=True,
            window=self._window,
            pad_mode="constant",
        )
        return x


m = STFTModel()

# Shape [B, T] audio signals
input_signals = torch.randn([2, 16000])

args = (input_signals,)
ep = torch.export.export(
    m,
    args,
)

# Successful
print(ep)

# Dynamic axis
# Fails
ep2 = torch.export.export(
    m,
    args,
    dynamic_shapes=[
        {
            0: torch.export.Dim("dim1"),
            1: torch.export.Dim("dim2"),
        },
    ],
)

print(ep2)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, c__window: "f32[320]", signals: "f32[2, 16000]"):
             # File: /Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py:10 in forward, code: x = signals.stft(
            view: "f32[1, 2, 16000]" = torch.ops.aten.view.default(signals, [1, 2, 16000]);  signals = None
            pad: "f32[1, 2, 16512]" = torch.ops.aten.pad.default(view, [256, 256]);  view = None
            view_1: "f32[2, 16512]" = torch.ops.aten.view.default(pad, [2, 16512]);  pad = None
            stft: "c64[2, 257, 101]" = torch.ops.aten.stft.default(view_1, 512, 160, 320, c__window, False, None, True);  view_1 = c__window = None
            return (stft,)
            
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='c__window'), target='_window', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='signals'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='stft'), target=None)])
Range constraints: {}

Traceback (most recent call last):
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2103, in run_node
    return getattr(args[0], node.target)(*args[1:], **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_tensor.py", line 839, in stft
    return torch.stft(
           ^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/functional.py", line 704, in stft
    return _VF.stft(  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_refs/__init__.py", line 3396, in stft
    out = torch.fft.rfft(input, dim=-1, norm=norm)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1251, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1705, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1361, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1993, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_impls.py", line 244, in wordaround_stride_incorrect_op
    raise UnsupportedOperatorException(func)
torch._subclasses.fake_tensor.UnsupportedOperatorException: aten._fft_r2c.default

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1983, in get_fake_value
    ret_val = wrap_fake_exception(
              ^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1468, in wrap_fake_exception
    return fn()
           ^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1984, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2119, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2103, in run_node
    return getattr(args[0], node.target)(*args[1:], **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_tensor.py", line 839, in stft
    return torch.stft(
           ^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/functional.py", line 704, in stft
    return _VF.stft(  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_refs/__init__.py", line 3396, in stft
    out = torch.fft.rfft(input, dim=-1, norm=norm)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1251, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1705, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1361, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1993, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_impls.py", line 244, in wordaround_stride_incorrect_op
    raise UnsupportedOperatorException(func)
RuntimeError: Failed running call_method stft(*(FakeTensor(..., size=(s0, s1)),), **{'n_fft': 512, 'hop_length': 160, 'win_length': 320, 'return_complex': True, 'window': FakeTensor(..., size=(320,)), 'pad_mode': 'constant'}):
aten._fft_r2c.default

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py", line 35, in <module>
    ep2 = torch.export.export(
          ^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/__init__.py", line 173, in export
    return _export(
           ^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1069, in wrapper
    raise e
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1042, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/exported_program.py", line 96, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 2035, in _export
    export_artifact = export_func(  # type: ignore[operator]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1276, in _strict_export
    return _strict_export_lower_to_aten_ir(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1304, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 552, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
                        ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1437, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 469, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1238, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 514, in __call__
    return _compile(
           ^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 902, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 653, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_utils_internal.py", line 85, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 686, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 208, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 622, in transform
    tracer.run()
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2731, in run
    super().run()
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 958, in run
    while self.step():
          ^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 870, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 558, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2242, in CALL
    self._call(inst)
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2236, in _call
    self.call_function(fn, args, kwargs)
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 974, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 535, in call_method
    return wrap_fx_proxy(
           ^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1903, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1990, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2030, in get_fake_value
    unimplemented(
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 289, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: unsupported operator: aten._fft_r2c.default (see https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0 for how to fix)

from user code:
   File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py", line 10, in forward
    x = signals.stft(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

cc @ezyang @eellison @bdhirsh @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions