-
Notifications
You must be signed in to change notification settings - Fork 24.1k
[export] _fft_r2c does not support dynamic shapes #135087
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Comments
Logs
|
Need a meta for it. https://docs.google.com/document/d/1GgvOe7C8_NVOMLOCwDaYV1mXXyHMXY7ExoewHqooxrs/edit#heading=h.64r4npvq0w0is the right reference for how to do it. |
@ezyang pytorch/torch/_meta_registrations.py Line 293 in 2f53d57
|
@justinchuby is your pytorch version old |
I verified with the latest nightly and got the same error:
|
Ah, it's not that the meta is wrong, it is that the meta is stride incorrect and so it's been suppressed:
So the job is to figure out how to setup strides correctly. |
Ran into this issue today trying to export https://github.com/jishengpeng/WavTokenizer. Hopefully this gets fixed soon 👍 |
@JulianMu16 and I want to know if this issue is still available? If so we want to look into it. |
I can confirm that as of today, the issue still exists. Hoping for a fix soon! |
I gotta say, the FFT implementation is completely insane, there's gotta be a better way to do this than repeatedly inplace restriding the output tensor. Anyway, this is a faithful translation of both the MKL and cuFFT paths to Python. Fixes pytorch#135087 Signed-off-by: Edward Z. Yang <[email protected]> Pull Request resolved: pytorch#145080 Approved by: https://github.com/Skylion007, https://github.com/albanD ghstack dependencies: pytorch#145530
I guess this didn't make it into pytorch 2.6? |
fft_r2c does not support dynamic shapes:
cc @ezyang @eellison @bdhirsh @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4
The text was updated successfully, but these errors were encountered: