Skip to content

avoid memory-overlap between input and output arrays #216

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: revisit_overwrite_x
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed
* Fixed a bug for N-D FFTs when both `s` and `out` are given [gh-185](https://github.com/IntelPython/mkl_fft/pull/185)
* Fixed a bug when there is overlapping memory of input and output arrays [gh-216](https://github.com/IntelPython/mkl_fft/pull/216)

## [2.0.0] - 2025-06-03

Expand Down
11 changes: 8 additions & 3 deletions mkl_fft/_fft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,10 @@ def _c2c_fftnd_impl(
raise ValueError("Direction of FFT should +1 or -1")

valid_dtypes = [np.complex64, np.complex128, np.float32, np.float64]
inplace_FFT = 0
if x.dtype not in valid_dtypes:
x = x.astype(np.complex128, copy=True)
inplace_FFT = 1
# _direct_fftnd requires complex type, and full-dimensional transform
if isinstance(x, np.ndarray) and x.size != 0 and x.ndim > 1:
_direct = s is None and axes is None
Expand All @@ -393,7 +397,7 @@ def _c2c_fftnd_impl(
xs, xa = _cook_nd_args(x, s, axes)
if _check_shapes_for_direct(xs, x.shape, xa):
_direct = True
_direct = _direct and x.dtype in valid_dtypes
_direct = _direct
else:
_direct = False

Expand All @@ -402,10 +406,11 @@ def _c2c_fftnd_impl(
x,
direction=direction,
fsc=fsc,
in_place=inplace_FFT,
out=out,
)
else:
if s is None and x.dtype in valid_dtypes:
if s is None:
x = np.asarray(x)
if out is None:
res = np.empty_like(x, dtype=_output_dtype(x.dtype))
Expand All @@ -417,7 +422,7 @@ def _c2c_fftnd_impl(
x,
axes,
_direct_fftnd,
{"direction": direction, "fsc": fsc},
{"direction": direction, "fsc": fsc, "in_place": inplace_FFT},
res,
)
else:
Expand Down
112 changes: 67 additions & 45 deletions mkl_fft/_pydfti.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import sys

import numpy as np

if np.lib.NumpyVersion(np.__version__) >= "2.0.0a0":
if np.lib.NumpyVersion(np.__version__) >= "2.0.0":
from numpy._core._multiarray_tests import internal_overlap
else:
from numpy.core._multiarray_tests import internal_overlap
Expand Down Expand Up @@ -389,9 +389,7 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
x_arr = _process_arguments(x, n, axis, &axis_, &n_, &in_place, &xnd, 0)
x_type = cnp.PyArray_TYPE(x_arr)

if out is not None:
in_place = 0
elif x_type is cnp.NPY_CFLOAT or x_type is cnp.NPY_CDOUBLE:
if x_type is cnp.NPY_CFLOAT or x_type is cnp.NPY_CDOUBLE:
# we can operate in place if requested.
if in_place:
if not cnp.PyArray_ISONESEGMENT(x_arr):
Expand All @@ -416,6 +414,29 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
x_type = cnp.PyArray_TYPE(x_arr)
in_place = 1

f_arr = None
if x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_CFLOAT:
f_type = cnp.NPY_CFLOAT
else:
f_type = cnp.NPY_CDOUBLE

if out is not None:
out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
_validate_out_array(out, x, out_dtype, axis=axis_, n=n_)
if x is out:
in_place = 1
elif (
_get_element_strides(x) == _get_element_strides(out)
and not np.shares_memory(x, out)
):
# out array that is used in OneMKL c2c FFT must have the same stride
# as input array and must have no common elements with input array.
# If these conditions are not met, we need to allocate a new array,
# which is done later.
# TODO: check to see if the same stride condition can be relaxed
f_arr = <cnp.ndarray> out
in_place = 0

if in_place:
_cache_capsule = _tls_dfti_cache_capsule()
_cache = <DftiCache *>cpython.pycapsule.PyCapsule_GetPointer(
Expand Down Expand Up @@ -453,25 +474,14 @@ def _c2c_fft1d_impl(x, n=None, axis=-1, direction=+1, double fsc=1.0, out=None):
ind[axis_] = slice(0, n_, None)
x_arr = x_arr[tuple(ind)]

return x_arr
else:
if x_type is cnp.NPY_FLOAT or x_type is cnp.NPY_CFLOAT:
f_type = cnp.NPY_CFLOAT
if out is not None:
out[...] = x_arr
return out
else:
f_type = cnp.NPY_CDOUBLE

if out is None:
return x_arr
else:
if f_arr is None:
f_arr = _allocate_result(x_arr, n_, axis_, f_type)
else:
out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
_validate_out_array(out, x, out_dtype, axis=axis_, n=n_)
# out array that is used in OneMKL c2c FFT must have the exact same
# stride as input array. If not, we need to allocate a new array.
# TODO: check to see if this condition can be relaxed
if _get_element_strides(x) == _get_element_strides(out):
f_arr = <cnp.ndarray> out
else:
f_arr = _allocate_result(x_arr, n_, axis_, f_type)

# call out-of-place FFT
_cache_capsule = _tls_dfti_cache_capsule()
Expand Down Expand Up @@ -612,9 +622,10 @@ def _r2c_fft1d_impl(
# be compared directly.
# TODO: currently instead of this condition, we check both input
# and output to be c_contig or f_contig, relax this condition
# In addition, input and output data sets must have no common elements
c_contig = x.flags.c_contiguous and out.flags.c_contiguous
f_contig = x.flags.f_contiguous and out.flags.f_contiguous
if c_contig or f_contig:
if c_contig or f_contig and not np.shares_memory(x, out):
f_arr = <cnp.ndarray> out
else:
f_arr = _allocate_result(x_arr, f_shape, axis_, f_type)
Expand Down Expand Up @@ -715,9 +726,10 @@ def _c2r_fft1d_impl(
# strides cannot be compared directly.
# TODO: currently instead of this condition, we check both input
# and output to be c_contig or f_contig, relax this condition
# Also input and output data sets must have no common elements
c_contig = x.flags.c_contiguous and out.flags.c_contiguous
f_contig = x.flags.f_contiguous and out.flags.f_contiguous
if c_contig or f_contig:
if c_contig or f_contig and not np.shares_memory(x, out):
f_arr = <cnp.ndarray> out
else:
f_arr = _allocate_result(x_arr, n_, axis_, f_type)
Expand Down Expand Up @@ -755,13 +767,13 @@ def _c2r_fft1d_impl(


def _direct_fftnd(
x, direction=+1, double fsc=1.0, out=None
x, direction=+1, double fsc=1.0, in_place=0, out=None
):
"""Perform n-dimensional FFT over all axes"""
cdef int err
cdef cnp.ndarray x_arr "xxnd_arrayObject"
cdef cnp.ndarray f_arr "ffnd_arrayObject"
cdef int in_place, x_type, f_type
cdef int x_type, f_type

if direction not in [-1, +1]:
raise ValueError("Direction of FFT should +1 or -1")
Expand All @@ -779,7 +791,7 @@ def _direct_fftnd(
raise ValueError("An input argument x is not an array-like object")

# a copy was made, so we can work in place.
in_place = 1 if _datacopied(x_arr, x) else 0
in_place = 1 if _datacopied(x_arr, x) else in_place

x_type = cnp.PyArray_TYPE(x_arr)
if (
Expand All @@ -798,15 +810,35 @@ def _direct_fftnd(
assert x_type == cnp.NPY_CDOUBLE
in_place = 1

if out is not None:
in_place = 0

if in_place:
if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_CFLOAT:
in_place = 1
else:
in_place = 0

f_arr = None
if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_DOUBLE:
f_type = cnp.NPY_CDOUBLE
else:
f_type = cnp.NPY_CFLOAT

if out is not None:
out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
_validate_out_array(out, x, out_dtype)
if x is out:
in_place = 1
elif (
_get_element_strides(x) == _get_element_strides(out)
and not np.shares_memory(x, out)
):
# out array that is used in OneMKL c2c FFT must have the same stride
# as input array and must have no common elements with input array.
# If these conditions are not met, we need to allocate a new array,
# which is done later.
# TODO: check to see if the same stride condition can be relaxed
f_arr = <cnp.ndarray> out
in_place = 0

if in_place:
if x_type == cnp.NPY_CDOUBLE:
if direction == 1:
Expand All @@ -821,24 +853,14 @@ def _direct_fftnd(
else:
raise ValueError("An input argument x is not complex type array")

return x_arr
else:
if x_type == cnp.NPY_CDOUBLE or x_type == cnp.NPY_DOUBLE:
f_type = cnp.NPY_CDOUBLE
if out is not None:
out[...] = x_arr
return out
else:
f_type = cnp.NPY_CFLOAT
if out is None:
return x_arr
else:
if f_arr is None:
f_arr = _allocate_result(x_arr, -1, 0, f_type)
else:
out_dtype = np.dtype(cnp.PyArray_DescrFromType(f_type))
_validate_out_array(out, x, out_dtype)
# out array that is used in OneMKL c2c FFT must have the exact same
# stride as input array. If not, we need to allocate a new array.
# TODO: check to see if this condition can be relaxed
if _get_element_strides(x) == _get_element_strides(out):
f_arr = <cnp.ndarray> out
else:
f_arr = _allocate_result(x_arr, -1, 0, f_type)

if x_type == cnp.NPY_CDOUBLE:
if direction == 1:
Expand Down
Loading