diff --git a/dpnp/backend/extensions/lapack/getrf.hpp b/dpnp/backend/extensions/lapack/getrf.hpp index 5fd9ecdcc49..952b244ef13 100644 --- a/dpnp/backend/extensions/lapack/getrf.hpp +++ b/dpnp/backend/extensions/lapack/getrf.hpp @@ -44,6 +44,7 @@ extern std::pair const dpctl::tensor::usm_ndarray &a_array, const dpctl::tensor::usm_ndarray &ipiv_array, py::list dev_info, + std::int64_t m, std::int64_t n, std::int64_t stride_a, std::int64_t stride_ipiv, diff --git a/dpnp/backend/extensions/lapack/getrf_batch.cpp b/dpnp/backend/extensions/lapack/getrf_batch.cpp index ec87c8b1f2a..446f565d6e4 100644 --- a/dpnp/backend/extensions/lapack/getrf_batch.cpp +++ b/dpnp/backend/extensions/lapack/getrf_batch.cpp @@ -46,6 +46,7 @@ namespace type_utils = dpctl::tensor::type_utils; typedef sycl::event (*getrf_batch_impl_fn_ptr_t)( sycl::queue &, std::int64_t, + std::int64_t, char *, std::int64_t, std::int64_t, @@ -61,6 +62,7 @@ static getrf_batch_impl_fn_ptr_t template static sycl::event getrf_batch_impl(sycl::queue &exec_q, + std::int64_t m, std::int64_t n, char *in_a, std::int64_t lda, @@ -77,7 +79,7 @@ static sycl::event getrf_batch_impl(sycl::queue &exec_q, T *a = reinterpret_cast(in_a); const std::int64_t scratchpad_size = - mkl_lapack::getrf_batch_scratchpad_size(exec_q, n, n, lda, stride_a, + mkl_lapack::getrf_batch_scratchpad_size(exec_q, m, n, lda, stride_a, stride_ipiv, batch_size); T *scratchpad = nullptr; @@ -91,11 +93,11 @@ static sycl::event getrf_batch_impl(sycl::queue &exec_q, getrf_batch_event = mkl_lapack::getrf_batch( exec_q, - n, // The order of each square matrix in the batch; (0 ≤ n). + m, // The number of rows in each matrix in the batch; (0 ≤ m). // It must be a non-negative integer. n, // The number of columns in each matrix in the batch; (0 ≤ n). // It must be a non-negative integer. - a, // Pointer to the batch of square matrices, each of size (n x n). + a, // Pointer to the batch of input matrices, each of size (m x n). lda, // The leading dimension of each matrix in the batch. stride_a, // Stride between consecutive matrices in the batch. ipiv, // Pointer to the array of pivot indices for each matrix in @@ -179,6 +181,7 @@ std::pair const dpctl::tensor::usm_ndarray &a_array, const dpctl::tensor::usm_ndarray &ipiv_array, py::list dev_info, + std::int64_t m, std::int64_t n, std::int64_t stride_a, std::int64_t stride_ipiv, @@ -191,13 +194,13 @@ std::pair if (a_array_nd < 3) { throw py::value_error( "The input array has ndim=" + std::to_string(a_array_nd) + - ", but an array with ndim >= 3 is expected."); + ", but an array with ndim >= 3 is expected"); } if (ipiv_array_nd != 2) { throw py::value_error("The array of pivot indices has ndim=" + std::to_string(ipiv_array_nd) + - ", but a 2-dimensional array is expected."); + ", but a 2-dimensional array is expected"); } const int dev_info_size = py::len(dev_info); @@ -205,7 +208,7 @@ std::pair throw py::value_error("The size of 'dev_info' (" + std::to_string(dev_info_size) + ") does not match the expected batch size (" + - std::to_string(batch_size) + ")."); + std::to_string(batch_size) + ")"); } // check compatibility of execution queue and allocation queue @@ -241,7 +244,7 @@ std::pair if (getrf_batch_fn == nullptr) { throw py::value_error( "No getrf_batch implementation defined for the provided type " - "of the input matrix."); + "of the input matrix"); } auto ipiv_types = dpctl_td_ns::usm_ndarray_types(); @@ -249,19 +252,26 @@ std::pair ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum()); if (ipiv_array_type_id != static_cast(dpctl_td_ns::typenum_t::INT64)) { - throw py::value_error("The type of 'ipiv_array' must be int64."); + throw py::value_error("The type of 'ipiv_array' must be int64"); + } + + const py::ssize_t *ipiv_array_shape = ipiv_array.get_shape_raw(); + if (ipiv_array_shape[0] != batch_size || + ipiv_array_shape[1] != std::min(m, n)) { + throw py::value_error( + "The shape of 'ipiv_array' must be (batch_size, min(m, n))"); } char *a_array_data = a_array.get_data(); - const std::int64_t lda = std::max(1UL, n); + const std::int64_t lda = std::max(1UL, m); char *ipiv_array_data = ipiv_array.get_data(); std::int64_t *d_ipiv = reinterpret_cast(ipiv_array_data); std::vector host_task_events; sycl::event getrf_batch_ev = getrf_batch_fn( - exec_q, n, a_array_data, lda, stride_a, d_ipiv, stride_ipiv, batch_size, - dev_info, host_task_events, depends); + exec_q, m, n, a_array_data, lda, stride_a, d_ipiv, stride_ipiv, + batch_size, dev_info, host_task_events, depends); sycl::event args_ev = dpctl::utils::keep_args_alive( exec_q, {a_array, ipiv_array}, host_task_events); diff --git a/dpnp/backend/extensions/lapack/lapack_py.cpp b/dpnp/backend/extensions/lapack/lapack_py.cpp index fb4dce4643b..83a0555f808 100644 --- a/dpnp/backend/extensions/lapack/lapack_py.cpp +++ b/dpnp/backend/extensions/lapack/lapack_py.cpp @@ -141,10 +141,10 @@ PYBIND11_MODULE(_lapack_impl, m) m.def("_getrf_batch", &lapack_ext::getrf_batch, "Call `getrf_batch` from OneMKL LAPACK library to return " - "the LU factorization of a batch of general n x n matrices", + "the LU factorization of a batch of general m x n matrices", py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"), - py::arg("dev_info_array"), py::arg("n"), py::arg("stride_a"), - py::arg("stride_ipiv"), py::arg("batch_size"), + py::arg("dev_info_array"), py::arg("m"), py::arg("n"), + py::arg("stride_a"), py::arg("stride_ipiv"), py::arg("batch_size"), py::arg("depends") = py::list()); m.def("_getri_batch", &lapack_ext::getri_batch, diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index 838daac6630..803f8e7326c 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -246,6 +246,7 @@ def _batched_inv(a, res_type): ipiv_h.get_array(), dev_info, n, + n, a_stride, ipiv_stride, batch_size, @@ -327,6 +328,7 @@ def _batched_lu_factor(a, res_type): ipiv_h.get_array(), dev_info_h, n, + n, a_stride, ipiv_stride, batch_size, @@ -396,6 +398,131 @@ def _batched_lu_factor(a, res_type): return (out_a, out_ipiv, out_dev_info) +def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals + """SciPy-compatible LU factorization for batched inputs.""" + + # TODO: Find out at which array sizes the best performance is obtained + # getrf_batch can be slow on large GPU arrays. + # Use getrf_batch only on CPU. + # On GPU fall back to calling getrf per 2D slice. + use_batch = a.sycl_device.has_aspect_cpu + + a_sycl_queue = a.sycl_queue + a_usm_type = a.usm_type + _manager = dpu.SequentialOrderManager[a_sycl_queue] + + m, n = a.shape[-2:] + k = min(m, n) + orig_shape = a.shape + batch_shape = orig_shape[:-2] + + # handle empty input + if a.size == 0: + lu = dpnp.empty_like(a) + piv = dpnp.empty( + (*batch_shape, k), + dtype=dpnp.int64, + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + return lu, piv + + # get 3d input arrays by reshape + a = dpnp.reshape(a, (-1, m, n)) + batch_size = a.shape[0] + + # Move batch axis to the end (m, n, batch) in Fortran order: + # required by getrf_batch + # and ensures each a[..., i] is F-contiguous for getrf + a = dpnp.moveaxis(a, 0, -1) + + a_usm_arr = dpnp.get_usm_ndarray(a) + + # `a` must be copied because getrf/getrf_batch destroys the input matrix + a_h = dpnp.empty_like(a, order="F", dtype=res_type) + ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=a_usm_arr, + dst=a_h.get_array(), + sycl_queue=a_sycl_queue, + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht_ev, copy_ev) + + ipiv_h = dpnp.empty( + (batch_size, k), + dtype=dpnp.int64, + order="C", + usm_type=a_usm_type, + sycl_queue=a_sycl_queue, + ) + + if use_batch: + dev_info_h = [0] * batch_size + + ipiv_stride = k + a_stride = a_h.strides[-1] + + # Call the LAPACK extension function _getrf_batch + # to perform LU decomposition of a batch of general matrices + ht_ev, getrf_ev = li._getrf_batch( + a_sycl_queue, + a_h.get_array(), + ipiv_h.get_array(), + dev_info_h, + m, + n, + a_stride, + ipiv_stride, + batch_size, + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, getrf_ev) + + if any(dev_info_h): + diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0) + warn( + f"Diagonal number {diag_nums} are exactly zero. " + "Singular matrix.", + RuntimeWarning, + stacklevel=2, + ) + else: + dev_info_vecs = [[0] for _ in range(batch_size)] + + # Sequential LU factorization using getrf per slice + for i in range(batch_size): + ht_ev, getrf_ev = li._getrf( + a_sycl_queue, + a_h[..., i].get_array(), + ipiv_h[i].get_array(), + dev_info_vecs[i], + depends=[copy_ev], + ) + _manager.add_event_pair(ht_ev, getrf_ev) + + diag_nums = ", ".join( + str(v) for info in dev_info_vecs for v in info if v > 0 + ) + if diag_nums: + warn( + f"Diagonal number {diag_nums} are exactly zero. " + "Singular matrix.", + RuntimeWarning, + stacklevel=2, + ) + + # Restore original shape: move batch axis back and reshape + a_h = dpnp.moveaxis(a_h, -1, 0).reshape(orig_shape) + ipiv_h = ipiv_h.reshape((*batch_shape, k)) + + # oneMKL LAPACK uses 1-origin while SciPy uses 0-origin + ipiv_h -= 1 + + # Return a tuple containing the factorized matrix 'a_h', + # pivot indices 'ipiv_h' + return (a_h, ipiv_h) + + def _batched_solve(a, b, exec_q, res_usm_type, res_type): """ _batched_solve(a, b, exec_q, res_usm_type, res_type) @@ -2308,6 +2435,15 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): a_sycl_queue = a.sycl_queue a_usm_type = a.usm_type + if check_finite: + if not dpnp.isfinite(a).all(): + raise ValueError("array must not contain infs or NaNs") + + if a.ndim > 2: + # SciPy always copies each 2D slice, + # so `overwrite_a` is ignored here + return _batched_lu_factor_scipy(a, res_type) + # accommodate empty arrays if a.size == 0: lu = dpnp.empty_like(a) @@ -2316,13 +2452,6 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): ) return lu, piv - if check_finite: - if not dpnp.isfinite(a).all(): - raise ValueError("array must not contain infs or NaNs") - - if a.ndim > 2: - raise NotImplementedError("Batched matrices are not supported") - _manager = dpu.SequentialOrderManager[a_sycl_queue] a_usm_arr = dpnp.get_usm_ndarray(a) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index dd5daa99b74..2fd001d4d04 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2017,9 +2017,122 @@ def test_check_finite_raises(self, bad): ValueError, dpnp.linalg.lu_factor, a_dp, check_finite=True ) - def test_batched_not_supported(self): - a_dp = dpnp.ones((2, 2, 2)) - assert_raises(NotImplementedError, dpnp.linalg.lu_factor, a_dp) + +class TestLuFactorBatched: + @staticmethod + def _apply_pivots_rows(A_dp, piv_dp): + m = A_dp.shape[0] + + if m == 0 or piv_dp.size == 0: + return A_dp + + rows = list(range(m)) + piv_np = dpnp.asnumpy(piv_dp) + for i, r in enumerate(piv_np): + if i != r: + rows[i], rows[r] = rows[r], rows[i] + + rows = dpnp.asarray(rows) + return A_dp[rows] + + @staticmethod + def _split_lu(lu, m, n): + L = dpnp.tril(lu, k=-1) + dpnp.fill_diagonal(L, 1) + L = L[:, : min(m, n)] + U = dpnp.triu(lu)[: min(m, n), :] + return L, U + + @pytest.mark.parametrize( + "shape", + [(2, 2, 2), (3, 4, 4), (2, 3, 5, 2), (4, 1, 3)], + ids=["(2,2,2)", "(3,4,4)", "(2,3,5,2)", "(4,1,3)"], + ) + @pytest.mark.parametrize("order", ["C", "F"]) + @pytest.mark.parametrize("dtype", get_all_dtypes(no_bool=True)) + def test_lu_factor_batched(self, shape, order, dtype): + a_np = generate_random_numpy_array(shape, dtype, order) + a_dp = dpnp.array(a_np, order=order) + + lu, piv = dpnp.linalg.lu_factor( + a_dp, check_finite=False, overwrite_a=False + ) + + assert lu.shape == a_dp.shape + m, n = shape[-2], shape[-1] + assert piv.shape == (*shape[:-2], min(m, n)) + assert piv.dtype == dpnp.int64 + + a_3d = a_dp.reshape((-1, m, n)) + lu_3d = lu.reshape((-1, m, n)) + piv_2d = piv.reshape((-1, min(m, n))) + for i in range(a_3d.shape[0]): + L, U = self._split_lu(lu_3d[i], m, n) + A_cast = a_3d[i].astype(L.dtype, copy=False) + PA = self._apply_pivots_rows(A_cast, piv_2d[i]) + assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + @pytest.mark.parametrize("order", ["C", "F"]) + def test_overwrite(self, dtype, order): + a_dp = dpnp.arange(2 * 2 * 3, dtype=dtype).reshape(3, 2, 2, order=order) + a_dp_orig = a_dp.copy() + lu, piv = dpnp.linalg.lu_factor( + a_dp, overwrite_a=True, check_finite=False + ) + + assert lu is not a_dp + assert_allclose(a_dp, a_dp_orig) + + m = n = 2 + lu_3d = lu.reshape((-1, m, n)) + a_3d = a_dp.reshape((-1, m, n)) + piv_2d = piv.reshape((-1, min(m, n))) + for i in range(a_3d.shape[0]): + L, U = self._split_lu(lu_3d[i], m, n) + A_cast = a_3d[i].astype(L.dtype, copy=False) + PA = self._apply_pivots_rows(A_cast, piv_2d[i]) + assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize( + "shape", [(0, 2, 2), (2, 0, 2), (2, 2, 0), (0, 0, 0)] + ) + def test_empty_inputs(self, shape): + a = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F") + + lu, piv = dpnp.linalg.lu_factor(a, check_finite=False) + assert lu.shape == shape + m, n = shape[-2:] + assert piv.shape == (*shape[:-2], min(m, n)) + + def test_strided(self): + a = ( + dpnp.arange(5 * 3 * 3, dtype=dpnp.default_float_type()).reshape( + 5, 3, 3, order="F" + ) + + 0.1 + ) + a_stride = a[::2] + lu, piv = dpnp.linalg.lu_factor(a_stride, check_finite=False) + for i in range(a_stride.shape[0]): + L, U = self._split_lu(lu[i], 3, 3) + PA = self._apply_pivots_rows( + a_stride[i].astype(L.dtype, copy=False), piv[i] + ) + assert_allclose(L @ U, PA, rtol=1e-6, atol=1e-6) + + def test_singular_matrix(self): + a = dpnp.zeros((3, 2, 2), dtype=dpnp.default_float_type()) + a[0] = dpnp.array([[1.0, 2.0], [2.0, 4.0]]) + a[1] = dpnp.eye(2) + a[2] = dpnp.array([[1.0, 1.0], [1.0, 1.0]]) + with pytest.warns(RuntimeWarning, match="Singular matrix"): + dpnp.linalg.lu_factor(a, check_finite=False) + + def test_check_finite_raises(self): + a = dpnp.ones((2, 3, 3), dtype=dpnp.default_float_type(), order="F") + a[1, 0, 0] = dpnp.nan + assert_raises(ValueError, dpnp.linalg.lu_factor, a, check_finite=True) class TestMatrixPower: diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index 7c08263e672..d3d8e19439e 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -1572,7 +1572,7 @@ def test_lstsq(self, m, n, nrhs, device): @pytest.mark.parametrize( "data", - [[[1.0, 2.0], [3.0, 5.0]], [[]]], + [[[1.0, 2.0], [3.0, 5.0]], [[]], [[[1.0, 2.0], [3.0, 5.0]]], [[[]]]], ) def test_lu_factor(self, data, device): a = dpnp.array(data, device=device) diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index 6f886f8ec3c..34fd9bbc003 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -1451,7 +1451,7 @@ def test_lstsq(self, m, n, nrhs, usm_type, usm_type_other): @pytest.mark.parametrize( "data", - [[[1.0, 2.0], [3.0, 5.0]], [[]]], + [[[1.0, 2.0], [3.0, 5.0]], [[]], [[[1.0, 2.0], [3.0, 5.0]]], [[[]]]], ) def test_lu_factor(self, data, usm_type): a = dpnp.array(data, usm_type=usm_type)