From 6bfb1c736f05c411ee02add9aeadebf6afbf4add Mon Sep 17 00:00:00 2001 From: Vahid Tavanashad Date: Fri, 16 May 2025 09:02:00 -0700 Subject: [PATCH] avoid memory-overlap between input and output arrays --- CHANGELOG.md | 1 + mkl_fft/_fft_utils.py | 11 +++-- mkl_fft/_pydfti.pyx | 112 +++++++++++++++++++++++++----------------- 3 files changed, 76 insertions(+), 48 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fe3221a..b54186d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed * Fixed a bug for N-D FFTs when both `s` and `out` are given [gh-185](https://github.com/IntelPython/mkl_fft/pull/185) +* Fixed a bug when there is overlapping memory of input and output arrays [gh-216](https://github.com/IntelPython/mkl_fft/pull/216) ## [2.0.0] - 2025-06-03 diff --git a/mkl_fft/_fft_utils.py b/mkl_fft/_fft_utils.py index a012e31..e772f55 100644 --- a/mkl_fft/_fft_utils.py +++ b/mkl_fft/_fft_utils.py @@ -384,6 +384,10 @@ def _c2c_fftnd_impl( raise ValueError("Direction of FFT should +1 or -1") valid_dtypes = [np.complex64, np.complex128, np.float32, np.float64] + inplace_FFT = 0 + if x.dtype not in valid_dtypes: + x = x.astype(np.complex128, copy=True) + inplace_FFT = 1 # _direct_fftnd requires complex type, and full-dimensional transform if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1: _direct = s is None and axes is None @@ -393,7 +397,7 @@ def _c2c_fftnd_impl( xs, xa = _cook_nd_args(x, s, axes) if _check_shapes_for_direct(xs, x.shape, xa): _direct = True - _direct = _direct and x.dtype in valid_dtypes + _direct = _direct else: _direct = False @@ -402,10 +406,11 @@ def _c2c_fftnd_impl( x, direction=direction, fsc=fsc, + in_place=inplace_FFT, out=out, ) else: - if s is None and x.dtype in valid_dtypes: + if s is None: x = np.asarray(x) if out is None: res = np.empty_like(x, dtype=_output_dtype(x.dtype)) @@ -417,7 +422,7 @@ def _c2c_fftnd_impl( x, axes, _direct_fftnd, - {"direction": direction, "fsc": fsc}, + {"direction": direction, "fsc": fsc, "in_place": inplace_FFT}, res, ) else: diff --git a/mkl_fft/_pydfti.pyx b/mkl_fft/_pydfti.pyx index d9158da..57970ce 100644 --- a/mkl_fft/_pydfti.pyx +++ b/mkl_fft/_pydfti.pyx @@ -29,7 +29,7 @@ import sys import numpy as np -if np.lib.NumpyVersion(np.__version__) >= "2.0.0a0": +if np.lib.NumpyVersion(np.__version__) >= "2.0.0": from numpy._core._multiarray_tests import internal_overlap else: from numpy.core._multiarray_tests import internal_overlap @@ -389,9 +389,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None): x_arr = _process_arguments(x, n, axis, &axis_, &n_, &in_place, &xnd, 0) x_type = cnp.PyArray_TYPE(x_arr) - if out is not None: - in_place = 0 - elif x_type is cnp.NPY_CFLOAT or x_type is cnp.NPY_CDOUBLE: + if x_type is cnp.NPY_CFLOAT or x_type is cnp.NPY_CDOUBLE: # we can operate in place if requested. if in_place: if not cnp.PyArray_ISONESEGMENT(x_arr): @@ -416,6 +414,29 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None): x_type = cnp.PyArray_TYPE(x_arr) in_place = 1 + f_arr = None + if x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_CFLOAT: + f_type = cnp.NPY_CFLOAT + else: + f_type = cnp.NPY_CDOUBLE + + if out is not None: + out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type)) + _validate_out_array(out, x, out_dtype, axis=axis_, n=n_) + if x is out: + in_place = 1 + elif ( + _get_element_strides(x) == _get_element_strides(out) + and not np.shares_memory(x, out) + ): + # out array that is used in OneMKL c2c FFT must have the same stride + # as input array and must have no common elements with input array. + # If these conditions are not met, we need to allocate a new array, + # which is done later. + # TODO: check to see if the same stride condition can be relaxed + f_arr = out + in_place = 0 + if in_place: _cache_capsule = _tls_dfti_cache_capsule() _cache = cpython.pycapsule.PyCapsule_GetPointer( @@ -453,25 +474,14 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None): ind[axis_] = slice(0, n_, None) x_arr = x_arr[tuple(ind)] - return x_arr - else: - if x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_CFLOAT: - f_type = cnp.NPY_CFLOAT + if out is not None: + out[...] = x_arr + return out else: - f_type = cnp.NPY_CDOUBLE - - if out is None: + return x_arr + else: + if f_arr is None: f_arr = _allocate_result(x_arr, n_, axis_, f_type) - else: - out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type)) - _validate_out_array(out, x, out_dtype, axis=axis_, n=n_) - # out array that is used in OneMKL c2c FFT must have the exact same - # stride as input array. If not, we need to allocate a new array. - # TODO: check to see if this condition can be relaxed - if _get_element_strides(x) == _get_element_strides(out): - f_arr = out - else: - f_arr = _allocate_result(x_arr, n_, axis_, f_type) # call out-of-place FFT _cache_capsule = _tls_dfti_cache_capsule() @@ -612,9 +622,10 @@ def _r2c_fft1d_impl( # be compared directly. # TODO: currently instead of this condition, we check both input # and output to be c_contig or f_contig, relax this condition + # In addition, input and output data sets must have no common elements c_contig = x.flags.c_contiguous and out.flags.c_contiguous f_contig = x.flags.f_contiguous and out.flags.f_contiguous - if c_contig or f_contig: + if c_contig or f_contig and not np.shares_memory(x, out): f_arr = out else: f_arr = _allocate_result(x_arr, f_shape, axis_, f_type) @@ -715,9 +726,10 @@ def _c2r_fft1d_impl( # strides cannot be compared directly. # TODO: currently instead of this condition, we check both input # and output to be c_contig or f_contig, relax this condition + # Also input and output data sets must have no common elements c_contig = x.flags.c_contiguous and out.flags.c_contiguous f_contig = x.flags.f_contiguous and out.flags.f_contiguous - if c_contig or f_contig: + if c_contig or f_contig and not np.shares_memory(x, out): f_arr = out else: f_arr = _allocate_result(x_arr, n_, axis_, f_type) @@ -755,13 +767,13 @@ def _c2r_fft1d_impl( def _direct_fftnd( - x, direction=+1, double fsc=1.0, out=None + x, direction=+1, double fsc=1.0, in_place=0, out=None ): """Perform n-dimensional FFT over all axes""" cdef int err cdef cnp.ndarray x_arr "xxnd_arrayObject" cdef cnp.ndarray f_arr "ffnd_arrayObject" - cdef int in_place, x_type, f_type + cdef int x_type, f_type if direction not in [-1, +1]: raise ValueError("Direction of FFT should +1 or -1") @@ -779,7 +791,7 @@ def _direct_fftnd( raise ValueError("An input argument x is not an array-like object") # a copy was made, so we can work in place. - in_place = 1 if _datacopied(x_arr, x) else 0 + in_place = 1 if _datacopied(x_arr, x) else in_place x_type = cnp.PyArray_TYPE(x_arr) if ( @@ -798,15 +810,35 @@ def _direct_fftnd( assert x_type == cnp.NPY_CDOUBLE in_place = 1 - if out is not None: - in_place = 0 - if in_place: if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_CFLOAT: in_place = 1 else: in_place = 0 + f_arr = None + if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_DOUBLE: + f_type = cnp.NPY_CDOUBLE + else: + f_type = cnp.NPY_CFLOAT + + if out is not None: + out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type)) + _validate_out_array(out, x, out_dtype) + if x is out: + in_place = 1 + elif ( + _get_element_strides(x) == _get_element_strides(out) + and not np.shares_memory(x, out) + ): + # out array that is used in OneMKL c2c FFT must have the same stride + # as input array and must have no common elements with input array. + # If these conditions are not met, we need to allocate a new array, + # which is done later. + # TODO: check to see if the same stride condition can be relaxed + f_arr = out + in_place = 0 + if in_place: if x_type == cnp.NPY_CDOUBLE: if direction == 1: @@ -821,24 +853,14 @@ def _direct_fftnd( else: raise ValueError("An input argument x is not complex type array") - return x_arr - else: - if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_DOUBLE: - f_type = cnp.NPY_CDOUBLE + if out is not None: + out[...] = x_arr + return out else: - f_type = cnp.NPY_CFLOAT - if out is None: + return x_arr + else: + if f_arr is None: f_arr = _allocate_result(x_arr, -1, 0, f_type) - else: - out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type)) - _validate_out_array(out, x, out_dtype) - # out array that is used in OneMKL c2c FFT must have the exact same - # stride as input array. If not, we need to allocate a new array. - # TODO: check to see if this condition can be relaxed - if _get_element_strides(x) == _get_element_strides(out): - f_arr = out - else: - f_arr = _allocate_result(x_arr, -1, 0, f_type) if x_type == cnp.NPY_CDOUBLE: if direction == 1: