From d921f557205b3555f76969819bb90d80738badf8 Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Wed, 21 May 2025 10:20:47 -0700 Subject: [PATCH] fix a bug for repeated axes in N-D c2c FFT --- CHANGELOG.md | 1 + mkl_fft/_fft_utils.py | 200 +++++++------------------------ mkl_fft/_mkl_fft.py | 36 ++++++ mkl_fft/_pydfti.pyx | 3 - mkl_fft/interfaces/_numpy_fft.py | 2 + mkl_fft/tests/helper.py | 5 + mkl_fft/tests/test_fftnd.py | 23 +++- 7 files changed, 105 insertions(+), 165 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fe3221a9..544df000 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed * Fixed a bug for N-D FFTs when both `s` and `out` are given [gh-185](https://github.com/IntelPython/mkl_fft/pull/185) +* Fixed a bug for a case when a repeated indices is passed for axes keyword in N-dimensional FFT [gh-215](https://github.com/IntelPython/mkl_fft/pull/215) ## [2.0.0] - 2025-06-03 diff --git a/mkl_fft/_fft_utils.py b/mkl_fft/_fft_utils.py index a012e310..93ba03c7 100644 --- a/mkl_fft/_fft_utils.py +++ b/mkl_fft/_fft_utils.py @@ -44,23 +44,19 @@ def _check_norm(norm): ) -def _check_shapes_for_direct(xs, shape, axes): +def _check_shapes_for_direct(s, shape, axes): if len(axes) > 7: # Intel MKL supports up to 7D return False - if not (len(xs) == len(shape)): - # full-dimensional transform + if len(s) != len(shape): + # not a full-dimensional transform return False - if not (len(set(axes)) == len(axes)): + if len(set(axes)) != len(axes): # repeated axes return False - for xsi, ai in zip(xs, axes): - try: - sh_ai = shape[ai] - except IndexError: - raise ValueError("Invalid axis (%d) specified" % ai) - - if not (xsi == sh_ai): - return False + new_shape = tuple(shape[ax] for ax in axes) + if tuple(s) != new_shape: + # trimming or padding is needed + return False return True @@ -78,30 +74,6 @@ def _compute_fwd_scale(norm, n, shape): return np.sqrt(fsc) -def _cook_nd_args(a, s=None, axes=None, invreal=False): - if s is None: - shapeless = True - if axes is None: - s = list(a.shape) - else: - try: - s = [a.shape[i] for i in axes] - except IndexError: - # fake s designed to trip the ValueError further down - s = range(len(axes) + 1) - pass - else: - shapeless = False - s = list(s) - if axes is None: - axes = list(range(-len(s), 0)) - if len(s) != len(axes): - raise ValueError("Shape and axes have different lengths.") - if invreal and shapeless: - s[-1] = (a.shape[axes[-1]] - 1) * 2 - return s, axes - - # copied from scipy.fft module # https://github.com/scipy/scipy/blob/main/scipy/fft/_pocketfft/helper.py def _datacopied(arr, original): @@ -129,89 +101,7 @@ def _flat_to_multi(ind, shape): return m_ind -# copied from scipy.fftpack.helper -def _init_nd_shape_and_axes(x, shape, axes): - """Handle shape and axes arguments for n-dimensional transforms. - Returns the shape and axes in a standard form, taking into account negative - values and checking for various potential errors. - Parameters - ---------- - x : array_like - The input array. - shape : int or array_like of ints or None - The shape of the result. If both `shape` and `axes` (see below) are - None, `shape` is ``x.shape``; if `shape` is None but `axes` is - not None, then `shape` is ``scipy.take(x.shape, axes, axis=0)``. - If `shape` is -1, the size of the corresponding dimension of `x` is - used. - axes : int or array_like of ints or None - Axes along which the calculation is computed. - The default is over all axes. - Negative indices are automatically converted to their positive - counterpart. - Returns - ------- - shape : array - The shape of the result. It is a 1D integer array. - axes : array - The shape of the result. It is a 1D integer array. - """ - x = np.asarray(x) - noshape = shape is None - noaxes = axes is None - - if noaxes: - axes = np.arange(x.ndim, dtype=np.intc) - else: - axes = np.atleast_1d(axes) - - if axes.size == 0: - axes = axes.astype(np.intc) - - if not axes.ndim == 1: - raise ValueError("when given, axes values must be a scalar or vector") - if not np.issubdtype(axes.dtype, np.integer): - raise ValueError("when given, axes values must be integers") - - axes = np.where(axes < 0, axes + x.ndim, axes) - - if axes.size != 0 and (axes.max() >= x.ndim or axes.min() < 0): - raise ValueError("axes exceeds dimensionality of input") - if axes.size != 0 and np.unique(axes).shape != axes.shape: - raise ValueError("all axes must be unique") - - if not noshape: - shape = np.atleast_1d(shape) - elif np.isscalar(x): - shape = np.array([], dtype=np.intc) - elif noaxes: - shape = np.array(x.shape, dtype=np.intc) - else: - shape = np.take(x.shape, axes) - - if shape.size == 0: - shape = shape.astype(np.intc) - - if shape.ndim != 1: - raise ValueError("when given, shape values must be a scalar or vector") - if not np.issubdtype(shape.dtype, np.integer): - raise ValueError("when given, shape values must be integers") - if axes.shape != shape.shape: - raise ValueError( - "when given, axes and shape arguments have to be of the same length" - ) - - shape = np.where(shape == -1, np.array(x.shape)[axes], shape) - if shape.size != 0 and (shape < 1).any(): - raise ValueError(f"invalid number of data points ({shape}) specified") - - return shape, axes - - def _iter_complementary(x, axes, func, kwargs, result): - if axes is None: - # s and axes are None, direct N-D FFT - return func(x, **kwargs, out=result) x_shape = x.shape nd = x.ndim r = list(range(nd)) @@ -260,9 +150,6 @@ def _iter_fftnd( direction=+1, scale_function=lambda ind: 1.0, ): - a = np.asarray(a) - s, axes = _init_nd_shape_and_axes(a, s, axes) - # Combine the two, but in reverse, to end with the first axis given. axes_and_s = list(zip(axes, s))[::-1] # We try to use in-place calculations where possible, which is @@ -309,13 +196,14 @@ def _output_dtype(dt): def _pad_array(arr, s, axes): """Pads array arr with zeros to attain shape s associated with axes""" arr_shape = arr.shape + new_shape = tuple(arr_shape[ax] for ax in axes) + if tuple(s) == new_shape: + return arr + no_padding = True pad_widths = [(0, 0)] * len(arr_shape) for si, ai in zip(s, axes): - try: - shp_i = arr_shape[ai] - except IndexError: - raise ValueError(f"Invalid axis {ai} specified") + shp_i = arr_shape[ai] if si > shp_i: no_padding = False pad_widths[ai] = (0, si - shp_i) @@ -345,14 +233,14 @@ def _trim_array(arr, s, axes): """ arr_shape = arr.shape + new_shape = tuple(arr_shape[ax] for ax in axes) + if tuple(s) == new_shape: + return arr + no_trim = True ind = [slice(None, None, None)] * len(arr_shape) for si, ai in zip(s, axes): - try: - shp_i = arr_shape[ai] - except IndexError: - raise ValueError(f"Invalid axis {ai} specified") - if si < shp_i: + if si < arr_shape[ai]: no_trim = False ind[ai] = slice(None, si, None) if no_trim: @@ -383,16 +271,11 @@ def _c2c_fftnd_impl( if direction not in [-1, +1]: raise ValueError("Direction of FFT should +1 or -1") + x = np.asarray(x) valid_dtypes = [np.complex64, np.complex128, np.float32, np.float64] # _direct_fftnd requires complex type, and full-dimensional transform - if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1: - _direct = s is None and axes is None - if _direct: - _direct = x.ndim <= 7 # Intel MKL only supports FFT up to 7D - if not _direct: - xs, xa = _cook_nd_args(x, s, axes) - if _check_shapes_for_direct(xs, x.shape, xa): - _direct = True + if x.size != 0 and x.ndim > 1: + _direct = _check_shapes_for_direct(s, x.shape, axes) _direct = _direct and x.dtype in valid_dtypes else: _direct = False @@ -405,14 +288,23 @@ def _c2c_fftnd_impl( out=out, ) else: - if s is None and x.dtype in valid_dtypes: - x = np.asarray(x) + new_shape = tuple(x.shape[ax] for ax in axes) + if ( + tuple(s) == new_shape + and x.dtype in valid_dtypes + and len(set(axes)) == len(axes) + ): if out is None: res = np.empty_like(x, dtype=_output_dtype(x.dtype)) else: _validate_out_array(out, x, _output_dtype(x.dtype)) res = out + # MKL is capable of doing batch N-D FFT, it is not required to + # manually loop over the batches as done in _iter_complementary and + # it is the reason for bad performance mentioned in the gh-issue-#67 + # TODO: implement a batch N-D FFT using MKL + # _iter_complementary performs batches of N-D FFT return _iter_complementary( x, axes, @@ -434,14 +326,9 @@ def _c2c_fftnd_impl( def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None): a = np.asarray(x) - no_trim = (s is None) and (axes is None) - s, axes = _cook_nd_args(a, s, axes) - axes = [ax + a.ndim if ax < 0 else ax for ax in axes] la = axes[-1] - # trim array, so that rfft avoids doing unnecessary computations - if not no_trim: - a = _trim_array(a, s, axes) + a = _trim_array(a, s, axes) # last axis is not included since we calculate r2c FFT separately # and not in the loop @@ -453,13 +340,11 @@ def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None): a = _r2c_fft1d_impl(a, n=s[-1], axis=la, fsc=fsc, out=res) res = a if len(s) > 1: - len_axes = len(axes) if len(set(axes)) == len_axes and len_axes == a.ndim and len_axes > 2: - if not no_trim: - ss = list(s) - ss[-1] = a.shape[la] - a = _pad_array(a, tuple(ss), axes) + ss = list(s) + ss[-1] = a.shape[la] + a = _pad_array(a, tuple(ss), axes) # a series of ND c2c FFTs along last axis ss, aa = _remove_axis(s, axes, -1) ind = [slice(None, None, 1)] * len(s) @@ -494,17 +379,12 @@ def _r2c_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None): def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None): a = np.asarray(x) - no_trim = (s is None) and (axes is None) - s, axes = _cook_nd_args(a, s, axes, invreal=True) - axes = [ax + a.ndim if ax < 0 else ax for ax in axes] la = axes[-1] - if not no_trim: - a = _trim_array(a, s, axes) if len(s) > 1: len_axes = len(axes) if len(set(axes)) == len_axes and len_axes == a.ndim and len_axes > 2: - if not no_trim: - a = _pad_array(a, s, axes) + a = _trim_array(a, s, axes) + a = _pad_array(a, s, axes) # a series of ND c2c FFTs along last axis # due to need to write into a, we must copy a = a if _datacopied(a, x) else a.copy() @@ -521,8 +401,8 @@ def _c2r_fftnd_impl(x, s=None, axes=None, fsc=1.0, out=None): tind = tuple(ind) a_inp = a[tind] # out has real dtype and cannot be used in intermediate steps - # ss and aa are reversed since np.irfftn uses forward order but - # np.ifftn uses reverse order see numpy-gh-28950 + # ss and aa are reversed since np.fft.irfftn uses forward order + # but np.fft.ifftn uses reverse order see numpy-gh-28950 _ = _c2c_fftnd_impl( a_inp, s=ss[::-1], axes=aa[::-1], out=a_inp, direction=-1 ) diff --git a/mkl_fft/_mkl_fft.py b/mkl_fft/_mkl_fft.py index 1cd49b97..d59be8e0 100644 --- a/mkl_fft/_mkl_fft.py +++ b/mkl_fft/_mkl_fft.py @@ -24,6 +24,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import numpy as np + from ._fft_utils import ( _c2c_fftnd_impl, _c2r_fftnd_impl, @@ -50,6 +52,36 @@ ] +# copied with modifications from: +# https://github.com/numpy/numpy/blob/main/numpy/fft/_pocketfft.py +def _cook_nd_args(a, s=None, axes=None, invreal=False): + if s is None: + shapeless = True + if axes is None: + s = list(a.shape) + else: + s = np.take(a.shape, axes) + else: + shapeless = False + s = list(s) + if axes is None: + if not shapeless: + raise ValueError("If s is not None, axes must not be None either.") + axes = list(range(-len(s), 0)) + if len(s) != len(axes): + raise ValueError("Shape and axes have different lengths.") + if invreal and shapeless: + s[-1] = (a.shape[axes[-1]] - 1) * 2 + if None in s: + raise ValueError("s must contain only int.") + # use the whole input array along axis `i` if `s[i] == -1` + s = [a.shape[_a] if _s == -1 else _s for _s, _a in zip(s, axes)] + + # make axes positive + axes = [ax + a.ndim if ax < 0 else ax for ax in axes] + return s, axes + + def fft(x, n=None, axis=-1, norm=None, out=None): fsc = _compute_fwd_scale(norm, n, x.shape[axis]) return _c2c_fft1d_impl(x, n=n, axis=axis, out=out, direction=+1, fsc=fsc) @@ -70,11 +102,13 @@ def ifft2(x, s=None, axes=(-2, -1), norm=None, out=None): def fftn(x, s=None, axes=None, norm=None, out=None): fsc = _compute_fwd_scale(norm, s, x.shape) + s, axes = _cook_nd_args(x, s, axes) return _c2c_fftnd_impl(x, s=s, axes=axes, out=out, direction=+1, fsc=fsc) def ifftn(x, s=None, axes=None, norm=None, out=None): fsc = _compute_fwd_scale(norm, s, x.shape) + s, axes = _cook_nd_args(x, s, axes) return _c2c_fftnd_impl(x, s=s, axes=axes, out=out, direction=-1, fsc=fsc) @@ -98,9 +132,11 @@ def irfft2(x, s=None, axes=(-2, -1), norm=None, out=None): def rfftn(x, s=None, axes=None, norm=None, out=None): fsc = _compute_fwd_scale(norm, s, x.shape) + s, axes = _cook_nd_args(x, s, axes) return _r2c_fftnd_impl(x, s=s, axes=axes, out=out, fsc=fsc) def irfftn(x, s=None, axes=None, norm=None, out=None): fsc = _compute_fwd_scale(norm, s, x.shape) + s, axes = _cook_nd_args(x, s, axes, invreal=True) return _c2r_fftnd_impl(x, s=s, axes=axes, out=out, fsc=fsc) diff --git a/mkl_fft/_pydfti.pyx b/mkl_fft/_pydfti.pyx index d9158da2..f6bd464d 100644 --- a/mkl_fft/_pydfti.pyx +++ b/mkl_fft/_pydfti.pyx @@ -763,9 +763,6 @@ def _direct_fftnd( cdef cnp.ndarray f_arr "ffnd_arrayObject" cdef int in_place, x_type, f_type - if direction not in [-1, +1]: - raise ValueError("Direction of FFT should +1 or -1") - # convert x to ndarray, ensure that strides are multiples of itemsize x_arr = PyArray_CheckFromAny( x, NULL, 0, 0, diff --git a/mkl_fft/interfaces/_numpy_fft.py b/mkl_fft/interfaces/_numpy_fft.py index d83fe0b8..268dbc34 100644 --- a/mkl_fft/interfaces/_numpy_fft.py +++ b/mkl_fft/interfaces/_numpy_fft.py @@ -98,6 +98,8 @@ def _cook_nd_args(a, s=None, axes=None, invreal=False): # use the whole input array along axis `i` if `s[i] == -1 or None` s = [a.shape[_a] if _s in [-1, None] else _s for _s, _a in zip(s, axes)] + # make axes positive + axes = [ax + a.ndim if ax < 0 else ax for ax in axes] return s, axes diff --git a/mkl_fft/tests/helper.py b/mkl_fft/tests/helper.py index 6a596644..c0903e45 100644 --- a/mkl_fft/tests/helper.py +++ b/mkl_fft/tests/helper.py @@ -32,3 +32,8 @@ np.lib.NumpyVersion(np.__version__) < "2.0.0", reason="Requires NumPy >= 2.0.0", ) + +requires_numpy_2_1 = pytest.mark.skipif( + np.lib.NumpyVersion(np.__version__) < "2.1.0", + reason="Requires NumPy >= 2.1.0", +) diff --git a/mkl_fft/tests/test_fftnd.py b/mkl_fft/tests/test_fftnd.py index 007f0bdf..27ec3228 100644 --- a/mkl_fft/tests/test_fftnd.py +++ b/mkl_fft/tests/test_fftnd.py @@ -31,7 +31,7 @@ import mkl_fft -from .helper import requires_numpy_2 +from .helper import requires_numpy_2, requires_numpy_2_1 reps_64 = (2**11) * np.finfo(np.float64).eps reps_32 = (2**11) * np.finfo(np.float32).eps @@ -267,7 +267,7 @@ def test_s_axes_out(dtype, s, axes, func): @pytest.mark.parametrize("dtype", [complex, float]) @pytest.mark.parametrize("axes", [(2, 0, 2, 0), (0, 1, 1), (2, 0, 1, 3, 2, 1)]) -@pytest.mark.parametrize("func", ["rfftn", "irfftn"]) +@pytest.mark.parametrize("func", ["fftn", "ifftn", "rfftn", "irfftn"]) def test_repeated_axes(dtype, axes, func): shape = (2, 3, 4, 5) if dtype is complex and func != "rfftn": @@ -282,6 +282,25 @@ def test_repeated_axes(dtype, axes, func): assert_allclose(r1, r2, rtol=rtol, atol=atol) +@requires_numpy_2_1 +@pytest.mark.parametrize("dtype", [complex, float]) +@pytest.mark.parametrize("axes", [(2, 3, 3, 2), (0, 0, 3, 3)]) +@pytest.mark.parametrize("s", [(5, 4, 3, 3), (7, 8, 10, 9)]) +@pytest.mark.parametrize("func", ["fftn", "ifftn", "rfftn", "irfftn"]) +def test_repeated_axes_with_s(dtype, axes, s, func): + shape = (2, 3, 4, 5) + if dtype is complex and func != "rfftn": + x = np.random.random(shape) + 1j * np.random.random(shape) + else: + x = np.random.random(shape) + + r1 = getattr(np.fft, func)(x, axes=axes, s=s) + r2 = getattr(mkl_fft, func)(x, axes=axes, s=s) + + rtol, atol = _get_rtol_atol(x) + assert_allclose(r1, r2, rtol=rtol, atol=atol) + + @requires_numpy_2 @pytest.mark.parametrize("axes", [None, (0, 1), (0, 2), (1, 2)]) @pytest.mark.parametrize("func", ["fftn", "ifftn"])