Skip to content

[export] _fft_r2c does not support dynamic shapes #135087

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
justinchuby opened this issue Sep 4, 2024 · 10 comments
Closed

[export] _fft_r2c does not support dynamic shapes #135087

justinchuby opened this issue Sep 4, 2024 · 10 comments
Assignees
Labels
good first issue module: meta tensors oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@justinchuby
Copy link
Collaborator

justinchuby commented Sep 4, 2024

fft_r2c does not support dynamic shapes:

import torch


class STFTModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self._window = torch.hann_window(window_length=320)

    def forward(self, signals: torch.Tensor) -> torch.Tensor:
        x = signals.stft(
            n_fft=512,
            hop_length=160,
            win_length=320,
            return_complex=True,
            window=self._window,
            pad_mode="constant",
        )
        return x


m = STFTModel()

# Shape [B, T] audio signals
input_signals = torch.randn([2, 16000])

args = (input_signals,)
ep = torch.export.export(
    m,
    args,
)

# Successful
print(ep)

# Dynamic axis
# Fails
ep2 = torch.export.export(
    m,
    args,
    dynamic_shapes=[
        {
            0: torch.export.Dim("dim1"),
            1: torch.export.Dim("dim2"),
        },
    ],
)

print(ep2)
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, c__window: "f32[320]", signals: "f32[2, 16000]"):
             # File: /Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py:10 in forward, code: x = signals.stft(
            view: "f32[1, 2, 16000]" = torch.ops.aten.view.default(signals, [1, 2, 16000]);  signals = None
            pad: "f32[1, 2, 16512]" = torch.ops.aten.pad.default(view, [256, 256]);  view = None
            view_1: "f32[2, 16512]" = torch.ops.aten.view.default(pad, [2, 16512]);  pad = None
            stft: "c64[2, 257, 101]" = torch.ops.aten.stft.default(view_1, 512, 160, 320, c__window, False, None, True);  view_1 = c__window = None
            return (stft,)
            
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.CONSTANT_TENSOR: 4>, arg=TensorArgument(name='c__window'), target='_window', persistent=True), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='signals'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='stft'), target=None)])
Range constraints: {}

Traceback (most recent call last):
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2103, in run_node
    return getattr(args[0], node.target)(*args[1:], **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_tensor.py", line 839, in stft
    return torch.stft(
           ^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/functional.py", line 704, in stft
    return _VF.stft(  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_refs/__init__.py", line 3396, in stft
    out = torch.fft.rfft(input, dim=-1, norm=norm)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1251, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1705, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1361, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1993, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_impls.py", line 244, in wordaround_stride_incorrect_op
    raise UnsupportedOperatorException(func)
torch._subclasses.fake_tensor.UnsupportedOperatorException: aten._fft_r2c.default

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1983, in get_fake_value
    ret_val = wrap_fake_exception(
              ^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1468, in wrap_fake_exception
    return fn()
           ^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1984, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2119, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2103, in run_node
    return getattr(args[0], node.target)(*args[1:], **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_tensor.py", line 839, in stft
    return torch.stft(
           ^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/functional.py", line 704, in stft
    return _VF.stft(  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_refs/__init__.py", line 3396, in stft
    out = torch.fft.rfft(input, dim=-1, norm=norm)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1251, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1705, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1361, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1993, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_impls.py", line 244, in wordaround_stride_incorrect_op
    raise UnsupportedOperatorException(func)
RuntimeError: Failed running call_method stft(*(FakeTensor(..., size=(s0, s1)),), **{'n_fft': 512, 'hop_length': 160, 'win_length': 320, 'return_complex': True, 'window': FakeTensor(..., size=(320,)), 'pad_mode': 'constant'}):
aten._fft_r2c.default

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py", line 35, in <module>
    ep2 = torch.export.export(
          ^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/__init__.py", line 173, in export
    return _export(
           ^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1069, in wrapper
    raise e
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1042, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/exported_program.py", line 96, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 2035, in _export
    export_artifact = export_func(  # type: ignore[operator]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1276, in _strict_export
    return _strict_export_lower_to_aten_ir(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1304, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 552, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
                        ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1437, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 469, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1238, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 514, in __call__
    return _compile(
           ^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 902, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 653, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_utils_internal.py", line 85, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 686, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 208, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 622, in transform
    tracer.run()
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2731, in run
    super().run()
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 958, in run
    while self.step():
          ^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 870, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 558, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2242, in CALL
    self._call(inst)
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2236, in _call
    self.call_function(fn, args, kwargs)
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 974, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 535, in call_method
    return wrap_fx_proxy(
           ^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1903, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1990, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2030, in get_fake_value
    unimplemented(
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 289, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: unsupported operator: aten._fft_r2c.default (see https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0 for how to fix)

from user code:
   File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py", line 10, in forward
    x = signals.stft(

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

cc @ezyang @eellison @bdhirsh @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

@justinchuby
Copy link
Collaborator Author

justinchuby commented Sep 4, 2024

Logs

-1 0
V0903 20:27:01.657000 67762 torch/_dynamo/convert_frame.py:1203] skipping: _wrapped_call_impl (reason: in skipfiles, file: /Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py)
-1 0
V0903 20:27:01.657000 67762 torch/_dynamo/convert_frame.py:1203] skipping: _call_impl (reason: in skipfiles, file: /Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py)
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0] torchdynamo start compiling forward /Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py:9, stack (elided 4 frames):
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]   File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py", line 35, in <module>
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]     ep2 = torch.export.export(
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]   File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/__init__.py", line 173, in export
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]     return _export(
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]   File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1042, in wrapper
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]     ep = fn(*args, **kwargs)
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]   File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/exported_program.py", line 96, in wrapper
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]     return fn(*args, **kwargs)
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]   File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 2035, in _export
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]     export_artifact = export_func(  # type: ignore[operator]
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]   File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1276, in _strict_export
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]     return _strict_export_lower_to_aten_ir(
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]   File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1304, in _strict_export_lower_to_aten_ir
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]     gm_torch_level = _export_to_torch_ir(
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]   File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 552, in _export_to_torch_ir
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]     gm_torch_level, _ = torch._dynamo.export(
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]   File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1437, in inner
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]     result_traced = opt_f(*args, **kwargs)
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]   File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]     return self._call_impl(*args, **kwargs)
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]   File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]     return forward_call(*args, **kwargs)
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]   File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 469, in _fn
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]     return fn(*args, **kwargs)
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]   File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0]     return self._call_impl(*args, **kwargs)
V0903 20:27:01.663000 67762 torch/_dynamo/convert_frame.py:845] [0/0] 
I0903 20:27:01.664000 67762 torch/_dynamo/logging.py:57] [0/0] Step 1: torchdynamo start tracing forward /Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py:9
V0903 20:27:01.665000 67762 torch/fx/experimental/symbolic_shapes.py:2506] [0/0] create_env
V0903 20:27:01.682000 67762 torch/_dynamo/output_graph.py:2068] [0/0] create_graph_input L_signals_ L['signals']
V0903 20:27:01.686000 67762 torch/_dynamo/variables/builder.py:2573] [0/0] wrap_to_fake L['signals'] (2, 16000) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.DYNAMIC: 0>, <DimDynamic.DYNAMIC: 0>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>, <DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo]), StrictMinMaxConstraint(warn_only=False, vr=VR[0, int_oo])], constraint_strides=[None, None], view_base_context=None, tensor_source=LocalSource(local_name='signals', cell_or_freevar=False), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
I0903 20:27:01.716000 67762 torch/fx/experimental/symbolic_shapes.py:3567] [0/0] create_symbol s0 = 2 for L['signals'].size()[0] [2, int_oo] (_dynamo/variables/builder.py:2581 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s0"
I0903 20:27:01.718000 67762 torch/fx/experimental/symbolic_shapes.py:3567] [0/0] create_symbol s1 = 16000 for L['signals'].size()[1] [2, int_oo] (_dynamo/variables/builder.py:2581 in <lambda>), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1"
V0903 20:27:01.719000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.719000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.719000 67762 torch/fx/experimental/symbolic_shapes.py:5209] [0/0] eval False == False [statically known]
V0903 20:27:01.721000 67762 torch/fx/experimental/symbolic_shapes.py:5209] [0/0] eval True == True [statically known]
V0903 20:27:01.721000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source] TRACE starts_line /Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py:9 in forward (STFTModel.forward)
V0903 20:27:01.721000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source]         def forward(self, signals: torch.Tensor) -> torch.Tensor:
V0903 20:27:01.722000 67762 torch/_dynamo/symbolic_convert.py:863] [0/0] [__trace_bytecode] TRACE RESUME 0 []
V0903 20:27:01.722000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source] TRACE starts_line /Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py:10 in forward (STFTModel.forward)
V0903 20:27:01.722000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source]             x = signals.stft(
V0903 20:27:01.722000 67762 torch/_dynamo/symbolic_convert.py:863] [0/0] [__trace_bytecode] TRACE LOAD_FAST signals []
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:863] [0/0] [__trace_bytecode] TRACE LOAD_METHOD stft [TensorVariable()]
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source] TRACE starts_line /Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py:11 in forward (STFTModel.forward)
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source]                 n_fft=512,
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:863] [0/0] [__trace_bytecode] TRACE LOAD_CONST 512 [NullVariable(), GetAttrVariable()]
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source] TRACE starts_line /Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py:12 in forward (STFTModel.forward)
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source]                 hop_length=160,
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:863] [0/0] [__trace_bytecode] TRACE LOAD_CONST 160 [NullVariable(), GetAttrVariable(), ConstantVariable()]
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source] TRACE starts_line /Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py:13 in forward (STFTModel.forward)
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source]                 win_length=320,
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:863] [0/0] [__trace_bytecode] TRACE LOAD_CONST 320 [NullVariable(), GetAttrVariable(), ConstantVariable(), ConstantVariable()]
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source] TRACE starts_line /Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py:14 in forward (STFTModel.forward)
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source]                 return_complex=True,  # doesn't affect errors
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:863] [0/0] [__trace_bytecode] TRACE LOAD_CONST True [NullVariable(), GetAttrVariable(), ConstantVariable(), ConstantVariable(), ConstantVariable()]
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source] TRACE starts_line /Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py:15 in forward (STFTModel.forward)
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source]                 window=self._window,
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:863] [0/0] [__trace_bytecode] TRACE LOAD_FAST self [NullVariable(), GetAttrVariable(), ConstantVariable(), ConstantVariable(), ConstantVariable(), ConstantVariable()]
V0903 20:27:01.723000 67762 torch/_dynamo/symbolic_convert.py:863] [0/0] [__trace_bytecode] TRACE LOAD_ATTR _window [NullVariable(), GetAttrVariable(), ConstantVariable(), ConstantVariable(), ConstantVariable(), ConstantVariable(), NNModuleVariable()]
V0903 20:27:01.723000 67762 torch/_dynamo/variables/builder.py:2573] [0/0] wrap_to_fake L['self']._window (320,) StatefulSymbolicContext(dynamic_sizes=[<DimDynamic.STATIC: 2>], dynamic_strides=[<DimDynamic.INFER_STRIDE: 4>], constraint_sizes=[None], constraint_strides=[None], view_base_context=None, tensor_source=NNModuleSource(base=AttrSource(base=LocalSource(local_name='self', cell_or_freevar=False), member='_window')), shape_env_to_source_to_symbol_cache={}) <class 'torch.Tensor'>
V0903 20:27:01.724000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source] TRACE starts_line /Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py:16 in forward (STFTModel.forward)
V0903 20:27:01.724000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source]                 pad_mode="constant",  # aten.reflection_pad1d unsupported op
V0903 20:27:01.724000 67762 torch/_dynamo/symbolic_convert.py:863] [0/0] [__trace_bytecode] TRACE LOAD_CONST constant [NullVariable(), GetAttrVariable(), ConstantVariable(), ConstantVariable(), ConstantVariable(), ConstantVariable(), TensorVariable()]
V0903 20:27:01.724000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source] TRACE starts_line /Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py:10 in forward (STFTModel.forward)
V0903 20:27:01.724000 67762 torch/_dynamo/symbolic_convert.py:840] [0/0] [__trace_source]             x = signals.stft(
V0903 20:27:01.724000 67762 torch/_dynamo/symbolic_convert.py:863] [0/0] [__trace_bytecode] TRACE KW_NAMES ('n_fft', 'hop_length', 'win_length', 'return_complex', 'window', 'pad_mode') [NullVariable(), GetAttrVariable(), ConstantVariable(), ConstantVariable(), ConstantVariable(), ConstantVariable(), TensorVariable(), ConstantVariable()]
V0903 20:27:01.724000 67762 torch/_dynamo/symbolic_convert.py:863] [0/0] [__trace_bytecode] TRACE PRECALL 6 [NullVariable(), GetAttrVariable(), ConstantVariable(), ConstantVariable(), ConstantVariable(), ConstantVariable(), TensorVariable(), ConstantVariable()]
V0903 20:27:01.724000 67762 torch/_dynamo/symbolic_convert.py:863] [0/0] [__trace_bytecode] TRACE CALL 6 [NullVariable(), GetAttrVariable(), ConstantVariable(), ConstantVariable(), ConstantVariable(), ConstantVariable(), TensorVariable(), ConstantVariable()]
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call] TRACE FX call stft from /Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py:10 in forward (STFTModel.forward)
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]         x = signals.stft(
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]             ~~~~~~~~~~~~^
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]             n_fft=512,
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]             ^^^^^^^^^^
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]             hop_length=160,
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]             ^^^^^^^^^^^^^^^
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]             win_length=320,
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]             ^^^^^^^^^^^^^^^
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]             return_complex=True,  # doesn't affect errors
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]             window=self._window,
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]             ^^^^^^^^^^^^^^^^^^^^
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]             pad_mode="constant",  # aten.reflection_pad1d unsupported op
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]         )
V0903 20:27:01.724000 67762 torch/_dynamo/output_graph.py:1942] [0/0] [__trace_call]         ^
V0903 20:27:01.726000 67762 torch/fx/experimental/symbolic_shapes.py:5209] [0/0] eval False == False [statically known]
V0903 20:27:01.727000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.729000 67762 torch/fx/experimental/symbolic_shapes.py:5209] [0/0] eval Ne(s1, 1) == True [statically known]
V0903 20:27:01.730000 67762 torch/fx/experimental/symbolic_shapes.py:5209] [0/0] eval Ne(s0, 1) == True [statically known]
V0903 20:27:01.730000 67762 torch/fx/experimental/symbolic_shapes.py:5209] [0/0] eval True == True [statically known]
V0903 20:27:01.731000 67762 torch/fx/experimental/symbolic_shapes.py:5209] [0/0] eval Eq(s0, 1) == False [statically known]
V0903 20:27:01.734000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.735000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.735000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.736000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.736000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.740000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.740000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.741000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.742000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.748000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.748000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.754000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.756000 67762 torch/fx/experimental/symbolic_shapes.py:5209] [0/0] eval Ne(Mod(1, s0), 0) == True [statically known]
V0903 20:27:01.758000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.759000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
I0903 20:27:01.765000 67762 torch/fx/experimental/symbolic_shapes.py:5113] [0/0] eval Ne((s1//160) + 1, 1) [guard added] at src/torch_onnx/test2.py:10 in forward (_subclasses/fake_impls.py:1084 in infer_size), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Ne((s1//160) + 1, 1)"
V0903 20:27:01.765000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.765000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.776000 67762 torch/fx/experimental/symbolic_shapes.py:5209] [0/0] eval Eq(512*s0*((s1//160)) + 512*s0, 0) == False [statically known]
I0903 20:27:01.777000 67762 torch/fx/experimental/symbolic_shapes.py:5113] [0/0] eval Ne((s1//160) + 1, 1) [guard added] at src/torch_onnx/test2.py:10 in forward (_subclasses/fake_impls.py:1194 in fast_binary_impl), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Ne((s1//160) + 1, 1)"
V0903 20:27:01.778000 67762 torch/fx/experimental/symbolic_shapes.py:5209] [0/0] eval (s1//160) + 1 < 0 == False [statically known]
V0903 20:27:01.779000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.780000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert (s1//160) + 1 >= 0 == True [statically known]
V0903 20:27:01.781000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.781000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert (s1//160) + 1 >= 0 == True [statically known]
V0903 20:27:01.782000 67762 torch/fx/experimental/symbolic_shapes.py:5209] [0/0] eval Eq((s1//160) + 1, 0) == False [statically known]
V0903 20:27:01.784000 67762 torch/fx/experimental/symbolic_shapes.py:5209] [0/0] eval Ne(512*s0*((s1//160)) + 512*s0, 0) == True [statically known]
V0903 20:27:01.784000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.785000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert (s1//160) + 1 >= 0 == True [statically known]
V0903 20:27:01.786000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.786000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert (s1//160) + 1 >= 0 == True [statically known]
V0903 20:27:01.789000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.789000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert (s1//160) + 1 >= 0 == True [statically known]
V0903 20:27:01.790000 67762 torch/fx/experimental/symbolic_shapes.py:5209] [0/0] eval (s1//160) + 1 < 1 == False [statically known]
V0903 20:27:01.790000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert True == True [statically known]
V0903 20:27:01.790000 67762 torch/fx/experimental/symbolic_shapes.py:5367] [0/0] runtime_assert (s1//160) + 1 >= 0 == True [statically known]
V0903 20:27:01.798000 67762 torch/_dynamo/symbolic_convert.py:879] [0/0] empty checkpoint
Traceback (most recent call last):
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2103, in run_node
    return getattr(args[0], node.target)(*args[1:], **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_tensor.py", line 839, in stft
    return torch.stft(
           ^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/functional.py", line 704, in stft
    return _VF.stft(  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_refs/__init__.py", line 3396, in stft
    out = torch.fft.rfft(input, dim=-1, norm=norm)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1251, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1705, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1361, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1993, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_impls.py", line 244, in wordaround_stride_incorrect_op
    raise UnsupportedOperatorException(func)
torch._subclasses.fake_tensor.UnsupportedOperatorException: aten._fft_r2c.default

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1983, in get_fake_value
    ret_val = wrap_fake_exception(
              ^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1468, in wrap_fake_exception
    return fn()
           ^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 1984, in <lambda>
    lambda: run_node(tx.output, node, args, kwargs, nnmodule)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2119, in run_node
    raise RuntimeError(make_error_message(e)).with_traceback(
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2103, in run_node
    return getattr(args[0], node.target)(*args[1:], **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_tensor.py", line 839, in stft
    return torch.stft(
           ^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/functional.py", line 704, in stft
    return _VF.stft(  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_refs/__init__.py", line 3396, in stft
    out = torch.fft.rfft(input, dim=-1, norm=norm)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/utils/_stats.py", line 21, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1251, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1705, in dispatch
    return self._cached_dispatch_impl(func, types, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1361, in _cached_dispatch_impl
    output = self._dispatch_impl(func, types, args, kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 1993, in _dispatch_impl
    op_impl_out = op_impl(self, func, *args, **kwargs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_subclasses/fake_impls.py", line 244, in wordaround_stride_incorrect_op
    raise UnsupportedOperatorException(func)
RuntimeError: Failed running call_method stft(*(FakeTensor(..., size=(s0, s1)),), **{'n_fft': 512, 'hop_length': 160, 'win_length': 320, 'return_complex': True, 'window': FakeTensor(..., size=(320,)), 'pad_mode': 'constant'}):
aten._fft_r2c.default

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py", line 35, in <module>
    ep2 = torch.export.export(
          ^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/__init__.py", line 173, in export
    return _export(
           ^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1069, in wrapper
    raise e
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1042, in wrapper
    ep = fn(*args, **kwargs)
         ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/exported_program.py", line 96, in wrapper
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 2035, in _export
    export_artifact = export_func(  # type: ignore[operator]
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1276, in _strict_export
    return _strict_export_lower_to_aten_ir(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 1304, in _strict_export_lower_to_aten_ir
    gm_torch_level = _export_to_torch_ir(
                     ^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/export/_trace.py", line 552, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
                        ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 1437, in inner
    result_traced = opt_f(*args, **kwargs)
                    ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/eval_frame.py", line 469, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1747, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 1238, in __call__
    return self._torchdynamo_orig_callable(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 514, in __call__
    return _compile(
           ^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 902, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 653, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_utils_internal.py", line 85, in wrapper_function
    return StrobelightCompileTimeProfiler.profile_compile_time(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 686, in _compile_inner
    out_code = transform_code_object(code, transform)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 208, in _fn
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 622, in transform
    tracer.run()
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2731, in run
    super().run()
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 958, in run
    while self.step():
          ^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 870, in step
    self.dispatch_table[inst.opcode](self, inst)
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 558, in wrapper
    return inner_fn(self, inst)
           ^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2242, in CALL
    self._call(inst)
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2236, in _call
    self.call_function(fn, args, kwargs)
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in call_function
    self.push(fn.call_function(self, args, kwargs))  # type: ignore[arg-type]
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/variables/misc.py", line 974, in call_function
    return self.obj.call_method(tx, self.name, args, kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/variables/tensor.py", line 535, in call_method
    return wrap_fx_proxy(
           ^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1903, in wrap_fx_proxy
    return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/variables/builder.py", line 1990, in wrap_fx_proxy_cls
    example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 2030, in get_fake_value
    unimplemented(
  File "/Users/justinc/Documents/GitHub/torch-onnx/venv/lib/python3.11/site-packages/torch/_dynamo/exc.py", line 289, in unimplemented
    raise Unsupported(msg, case_name=case_name)
torch._dynamo.exc.Unsupported: unsupported operator: aten._fft_r2c.default (see https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0 for how to fix)

from user code:
   File "/Users/justinc/Documents/GitHub/torch-onnx/src/torch_onnx/test2.py", line 10, in forward
    x = signals.stft(

I0903 20:27:01.819000 67762 torch/_dynamo/utils.py:371] TorchDynamo compilation metrics:
I0903 20:27:01.819000 67762 torch/_dynamo/utils.py:371] Function, Runtimes (s)
I0903 20:27:01.819000 67762 torch/_dynamo/utils.py:371] _compile.compile_inner, 0.0000
V0903 20:27:01.819000 67762 torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats constrain_symbol_range: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0903 20:27:01.819000 67762 torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats evaluate_expr: CacheInfo(hits=98, misses=15, maxsize=256, currsize=15)
V0903 20:27:01.819000 67762 torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _simplify_floor_div: CacheInfo(hits=0, misses=0, maxsize=None, currsize=0)
V0903 20:27:01.819000 67762 torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _maybe_guard_rel: CacheInfo(hits=1, misses=1, maxsize=256, currsize=1)
V0903 20:27:01.819000 67762 torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _find: CacheInfo(hits=53, misses=2, maxsize=None, currsize=2)
V0903 20:27:01.819000 67762 torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats has_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0903 20:27:01.819000 67762 torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats size_hint: CacheInfo(hits=0, misses=0, maxsize=256, currsize=0)
V0903 20:27:01.819000 67762 torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats simplify: CacheInfo(hits=3, misses=16, maxsize=None, currsize=16)
V0903 20:27:01.820000 67762 torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _update_divisible: CacheInfo(hits=7, misses=1, maxsize=None, currsize=1)
V0903 20:27:01.820000 67762 torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats replace: CacheInfo(hits=1628, misses=53, maxsize=None, currsize=53)
V0903 20:27:01.820000 67762 torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats _maybe_evaluate_static: CacheInfo(hits=29, misses=19, maxsize=None, currsize=19)
V0903 20:27:01.820000 67762 torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats get_implications: CacheInfo(hits=1, misses=1, maxsize=None, currsize=1)
V0903 20:27:01.820000 67762 torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats get_axioms: CacheInfo(hits=15, misses=4, maxsize=None, currsize=4)
V0903 20:27:01.820000 67762 torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats safe_expand: CacheInfo(hits=291, misses=71, maxsize=256, currsize=71)
V0903 20:27:01.820000 67762 torch/fx/experimental/symbolic_shapes.py:122] lru_cache_stats uninteresting_files: CacheInfo(hits=38, misses=1, maxsize=None, currsize=1)

@ezyang
Copy link
Contributor

ezyang commented Sep 5, 2024

Need a meta for it. https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0is the right reference for how to do it.

@angelayi angelayi added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Sep 5, 2024
@Rajveer100
Copy link

@ezyang
__fft_r2c already has a registered meta:

def meta_fft_r2c(self, dim, normalization, onesided):

@ezyang
Copy link
Contributor

ezyang commented Sep 14, 2024

@justinchuby is your pytorch version old

@justinchuby
Copy link
Collaborator Author

I verified with the latest nightly and got the same error:

Collecting environment information...
PyTorch version: 2.6.0.dev20240914
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.6.1 (arm64)
GCC version: Could not collect
Clang version: 17.0.6
CMake version: version 3.28.3
Libc version: N/A

Python version: 3.11.8 (main, Feb  6 2024, 21:21:21) [Clang 15.0.0 (clang-1500.1.0.2.5)] (64-bit runtime)
Python platform: macOS-14.6.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Max

Versions of relevant libraries:
[pip3] model-explorer-onnx==0.2.4
[pip3] mypy==1.11.1
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.26.4
[pip3] onnx==1.16.1
[pip3] onnxruntime==1.18.1
[pip3] onnxscript==0.1.0.dev20240904
[pip3] torch==2.6.0.dev20240914
[pip3] torch-onnx==0.1.19
[pip3] torchaudio==2.5.0.dev20240914
[pip3] torchvision==0.20.0.dev20240914
[conda] Could not collect

@ezyang
Copy link
Contributor

ezyang commented Sep 21, 2024

Ah, it's not that the meta is wrong, it is that the meta is stride incorrect and so it's been suppressed:

def stride_incorrect_op(op):
    if op.namespace not in ("aten", "prims"):
        return False
    if op is aten._fft_c2c.default:
        return False

    op_name = op.name()
    if "fft" in op_name:
        return True
    return False

So the job is to figure out how to setup strides correctly.

More back story in #106319 see also #106623 #106622

@xenova
Copy link

xenova commented Nov 26, 2024

Ran into this issue today trying to export https://github.com/jishengpeng/WavTokenizer. Hopefully this gets fixed soon 👍

@juanceresa
Copy link

@JulianMu16 and I want to know if this issue is still available? If so we want to look into it.

@kabyanil
Copy link

I can confirm that as of today, the issue still exists. Hoping for a fix soon!

@ezyang ezyang self-assigned this Jan 24, 2025
nWEIdia pushed a commit to nWEIdia/pytorch that referenced this issue Jan 27, 2025
I gotta say, the FFT implementation is completely insane, there's gotta be a better way to do this than repeatedly inplace restriding the output tensor. Anyway, this is a faithful translation of both the MKL and cuFFT paths to Python.

Fixes pytorch#135087

Signed-off-by: Edward Z. Yang <[email protected]>
Pull Request resolved: pytorch#145080
Approved by: https://github.com/Skylion007, https://github.com/albanD
ghstack dependencies: pytorch#145530
@Froskekongen
Copy link

Froskekongen commented Mar 14, 2025

I guess this didn't make it into pytorch 2.6?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
good first issue module: meta tensors oncall: export oncall: pt2 triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants