From b2d2b3a2d31fe0a2e051881d06239c427f14640d Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 4 Sep 2025 06:40:10 -0700 Subject: [PATCH 1/4] Implement of dpnp.linalg.lu_solve for 2D inputs --- dpnp/linalg/dpnp_iface_linalg.py | 71 ++++++++++++++++++++ dpnp/linalg/dpnp_utils_linalg.py | 112 +++++++++++++++++++++++++++++++ 2 files changed, 183 insertions(+) diff --git a/dpnp/linalg/dpnp_iface_linalg.py b/dpnp/linalg/dpnp_iface_linalg.py index 1a46205ea82..d6ee9aef136 100644 --- a/dpnp/linalg/dpnp_iface_linalg.py +++ b/dpnp/linalg/dpnp_iface_linalg.py @@ -57,6 +57,7 @@ dpnp_inv, dpnp_lstsq, dpnp_lu_factor, + dpnp_lu_solve, dpnp_matrix_power, dpnp_matrix_rank, dpnp_multi_dot, @@ -81,6 +82,7 @@ "inv", "lstsq", "lu_factor", + "lu_solve", "matmul", "matrix_norm", "matrix_power", @@ -966,6 +968,75 @@ def lu_factor(a, overwrite_a=False, check_finite=True): return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite) +def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): + """ + Solve an equation system, a x = b, given the LU factorization of `a` + + For full documentation refer to :obj:`scipy.linalg.lu_solve`. + + Parameters + ---------- + (lu, piv) : {tuple of dpnp.ndarrays or usm_ndarrays} + LU factorization of matrix `a` ((M, N)) together with pivot indices. + b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray} + Right-hand side + trans : {0, 1, 2} , optional + Type of system to solve: + + ===== ========= + trans system + ===== ========= + 0 a x = b + 1 a^T x = b + 2 a^H x = b + ===== ========= + overwrite_b : {None, bool}, optional + Whether to overwrite data in `b` (may increase performance). + + Default: ``False``. + check_finite : {None, bool}, optional + Whether to check that the input matrix contains only finite numbers. + Disabling may give a performance gain, but may result in problems + (crashes, non-termination) if the inputs do contain infinities or NaNs. + + Default: ``True``. + + Returns + ------- + x : {(M,), (M, K)} dpnp.ndarray + Solution to the system + + Warning + ------- + This function synchronizes in order to validate array elements + when ``check_finite=True``. + + Examples + -------- + >>> import dpnp as np + >>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) + >>> b = np.array([1, 1, 1, 1]) + >>> lu, piv = np.linalg.lu_factor(A) + >>> x = np.linalg.lu_solve((lu, piv), b) + >>> np.allclose(A @ x - b, np.zeros((4,))) + array(True) + + """ + + (lu, piv) = lu_and_piv + dpnp.check_supported_arrays_type(lu, piv, b) + assert_stacked_2d(lu) + + return dpnp_lu_solve( + lu, + piv, + b, + trans=trans, + overwrite_b=overwrite_b, + check_finite=check_finite, + ) + + def matmul(x1, x2, /): """ Computes the matrix product. diff --git a/dpnp/linalg/dpnp_utils_linalg.py b/dpnp/linalg/dpnp_utils_linalg.py index fdf46174bfc..04339a5f587 100644 --- a/dpnp/linalg/dpnp_utils_linalg.py +++ b/dpnp/linalg/dpnp_utils_linalg.py @@ -2514,6 +2514,118 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): return (a_h, ipiv_h) +def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): + """ + dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True) + + Solve an equation system (SciPy-compatible behavior). + + This function mimics the behavior of `scipy.linalg.lu_solve` including + support for `trans`, `overwrite_b`, `check_finite`, + and 0-based pivot indexing. + + """ + + res_usm_type, exec_q = get_usm_allocations([lu, piv, b]) + + res_type = _common_type(lu, b) + + # TODO: add broadcasting + if lu.shape[0] != b.shape[0]: + raise ValueError( + f"Shapes of lu {lu.shape} and b {b.shape} are incompatible" + ) + + if b.size == 0: + return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type) + + if lu.ndim > 2: + raise NotImplementedError("Batched matrices are not supported") + + if check_finite: + if not dpnp.isfinite(lu).all(): + raise ValueError( + "array must not contain infs or NaNs.\n" + "Note that when a singular matrix is given, unlike " + "dpnp.linalg.lu_factor returns an array containing NaN." + ) + if not dpnp.isfinite(b).all(): + raise ValueError("array must not contain infs or NaNs") + + lu_usm_arr = dpnp.get_usm_ndarray(lu) + piv_usm_arr = dpnp.get_usm_ndarray(piv) + b_usm_arr = dpnp.get_usm_ndarray(b) + + _manager = dpu.SequentialOrderManager[exec_q] + dep_evs = _manager.submitted_events + + # oneMKL LAPACK getrf overwrites `a`. + lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type) + + # use DPCTL tensor function to fill the сopy of the input array + # from the input array + ht_ev, lu_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=lu_usm_arr, + dst=lu_h.get_array(), + sycl_queue=lu.sycl_queue, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, lu_copy_ev) + + # SciPy-compatible behavior + # Copy is required if: + # - overwrite_a is False (always copy), + # - dtype mismatch, + # - not F-contiguous,s + # - not writeable + if not overwrite_b or _is_copy_required(b, res_type): + b_h = dpnp.empty_like( + b, order="F", dtype=res_type, usm_type=res_usm_type + ) + ht_ev, dep_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=b_usm_arr, + dst=b_h.get_array(), + sycl_queue=b.sycl_queue, + depends=_manager.submitted_events, + ) + _manager.add_event_pair(ht_ev, dep_ev) + dep_ev = [dep_ev] + else: + # input is suitable for in-place modification + b_h = b + dep_ev = _manager.submitted_events + + # oneMKL LAPACK getrf overwrites `a`. + piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type) + + # use DPCTL tensor function to fill the сopy of the pivot array + # from the pivot array + ht_ev, piv_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( + src=piv_usm_arr, + dst=piv_h.get_array(), + sycl_queue=piv.sycl_queue, + depends=dep_evs, + ) + _manager.add_event_pair(ht_ev, piv_copy_ev) + # MKL lapack uses 1-origin while SciPy uses 0-origin + piv_h += 1 + + # Call the LAPACK extension function _getrs + # to solve the system of linear equations with an LU-factored + # coefficient square matrix, with multiple right-hand sides. + ht_ev, getrs_ev = li._getrs( + exec_q, + lu_h.get_array(), + piv_h.get_array(), + b_h.get_array(), + trans, + depends=dep_ev, + ) + _manager.add_event_pair(ht_ev, getrs_ev) + + return b_h + + def dpnp_matrix_power(a, n): """ dpnp_matrix_power(a, n) From a979b1a7848a800b3c2fb0bb9c1bce2e13cdd1a1 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 4 Sep 2025 06:41:18 -0700 Subject: [PATCH 2/4] Add dpnp.linalg.lu_solve to generated docs --- doc/reference/linalg.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/reference/linalg.rst b/doc/reference/linalg.rst index 107b5a86a5b..142c6052db8 100644 --- a/doc/reference/linalg.rst +++ b/doc/reference/linalg.rst @@ -86,6 +86,7 @@ Solving linear equations dpnp.linalg.solve dpnp.linalg.tensorsolve dpnp.linalg.lstsq + dpnp.linalg.lu_solve dpnp.linalg.inv dpnp.linalg.pinv dpnp.linalg.tensorinv From 6541bf09830e576b8b2a8a3be07ee11038509263 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 4 Sep 2025 06:42:46 -0700 Subject: [PATCH 3/4] Add TestLuSolve to test_linalg.py --- dpnp/tests/test_linalg.py | 209 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 209 insertions(+) diff --git a/dpnp/tests/test_linalg.py b/dpnp/tests/test_linalg.py index bef8159e6e9..259b35174a2 100644 --- a/dpnp/tests/test_linalg.py +++ b/dpnp/tests/test_linalg.py @@ -2135,6 +2135,215 @@ def test_check_finite_raises(self): assert_raises(ValueError, dpnp.linalg.lu_factor, a, check_finite=True) +class TestLuSolve: + @staticmethod + def _make_nonsingular_np(shape, dtype, order): + A = generate_random_numpy_array(shape, dtype, order) + m, n = shape + k = min(m, n) + for i in range(k): + off = numpy.sum(numpy.abs(A[i, :n])) - numpy.abs(A[i, i]) + A[i, i] = A.dtype.type(off + 1.0) + return A + + @pytest.mark.parametrize("shape", [(1, 1), (2, 2), (3, 3), (5, 5)]) + @pytest.mark.parametrize("rhs_cols", [None, 1, 3]) + @pytest.mark.parametrize("order", ["C", "F"]) + @pytest.mark.parametrize( + "dtype", get_all_dtypes(no_bool=True, no_none=True) + ) + def test_lu_solve(self, shape, rhs_cols, order, dtype): + a_np = self._make_nonsingular_np(shape, dtype, order) + a_dp = dpnp.array(a_np, order=order) + + n = shape[0] + if rhs_cols is None: + b_np = generate_random_numpy_array((n,), dtype, order) + else: + b_np = generate_random_numpy_array((n, rhs_cols), dtype, order) + b_dp = dpnp.array(b_np, order=order) + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, trans=0, overwrite_b=False, check_finite=False + ) + + # check A @ x = b + Ax = a_dp @ x + assert dpnp.allclose(Ax, b_dp, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize("trans", [0, 1, 2]) + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_trans(self, trans, dtype): + n = 4 + a_np = self._make_nonsingular_np((n, n), dtype, order="F") + a_dp = dpnp.array(a_np, order="F") + b_dp = dpnp.array(generate_random_numpy_array((n, 2), dtype, "F")) + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, trans=trans, overwrite_b=False, check_finite=False + ) + + if trans == 0: + lhs = a_dp @ x + elif trans == 1: + lhs = a_dp.T @ x + else: # trans == 2 + lhs = a_dp.conj().T @ x + + assert dpnp.allclose(lhs, b_dp, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_overwrite_inplace(self, dtype): + a_dp = dpnp.array([[4, 3], [6, 3]], dtype=dtype, order="F") + b_dp = dpnp.array([1, 0], dtype=dtype, order="F") + b_orig = b_dp.copy() + + lu, piv = dpnp.linalg.lu_factor( + a_dp, overwrite_a=False, check_finite=False + ) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, trans=0, overwrite_b=True, check_finite=False + ) + + assert x is b_dp + assert dpnp.allclose(a_dp @ x, b_orig, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize("dtype", get_float_complex_dtypes()) + def test_overwrite_copy_special(self, dtype): + a_dp = dpnp.array([[4, 3], [6, 3]], dtype=dtype, order="F") + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + + # F-contig but dtype != res_type + b1 = dpnp.array([1, 0], dtype=dpnp.int32, order="F") + x1 = dpnp.linalg.lu_solve( + (lu, piv), b1, overwrite_b=True, check_finite=False + ) + assert x1 is not b1 + + # F-contig, match dtype but read-only input + b2 = dpnp.array([1, 0], dtype=dtype, order="F") + b2.flags["WRITABLE"] = False + x2 = dpnp.linalg.lu_solve( + (lu, piv), b2, overwrite_b=True, check_finite=False + ) + assert x2 is not b2 + + for x in (x1, x2): + assert dpnp.allclose( + a_dp @ x, + dpnp.array([1, 0], dtype=x.dtype), + rtol=1e-6, + atol=1e-6, + ) + + @pytest.mark.parametrize( + "dtype_a", get_all_dtypes(no_bool=True, no_none=True) + ) + @pytest.mark.parametrize( + "dtype_b", get_all_dtypes(no_bool=True, no_none=True) + ) + def test_diff_type(self, dtype_a, dtype_b): + a_np = self._make_nonsingular_np((3, 3), dtype_a, order="F") + a_dp = dpnp.array(a_np, order="F") + + b_np = generate_random_numpy_array((3,), dtype_b, order="F") + b_dp = dpnp.array(b_np, order="F") + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve((lu, piv), b_dp, check_finite=False) + assert dpnp.allclose( + a_dp @ x, b_dp.astype(x.dtype, copy=False), rtol=1e-6, atol=1e-6 + ) + + def test_strided_rhs(self): + n = 7 + a_np = self._make_nonsingular_np( + (n, n), dpnp.default_float_type(), order="F" + ) + a_dp = dpnp.array(a_np, order="F") + + rhs_full = ( + dpnp.arange(n * n, dtype=dpnp.default_float_type()).reshape( + n, n, order="F" + ) + + 1.0 + ) + b_dp = rhs_full[:, ::2][:, :3] + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, overwrite_b=False, check_finite=False + ) + + assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-6, atol=1e-6) + + @pytest.mark.skip("Not implemented yet") + @pytest.mark.parametrize( + "b_shape", + [ + (4,), + (4, 1), + (4, 3), + # (1, 4, 3), + # (2, 4, 3), + # (1, 1, 4, 3) + ], + ) + def test_broadcast_rhs(self, b_shape): + dtype = dpnp.default_float_type() + + a_np = self._make_nonsingular_np((4, 4), dtype, order="F") + a_dp = dpnp.array(a_np, order="F") + + b_np = generate_random_numpy_array(b_shape, dtype, order="F") + b_dp = dpnp.array(b_np, order="F") + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve( + (lu, piv), b_dp, overwrite_b=True, check_finite=False + ) + + assert x.shape == b_dp.shape + + assert dpnp.allclose(a_dp @ x, b_dp, rtol=1e-6, atol=1e-6) + + @pytest.mark.parametrize("shape", [(0, 0), (0, 5), (5, 5)]) + @pytest.mark.parametrize("rhs_cols", [None, 0, 3]) + def test_empty_shapes(self, shape, rhs_cols): + a_dp = dpnp.empty(shape, dtype=dpnp.default_float_type(), order="F") + if min(shape) > 0: + for i in range(min(shape)): + a_dp[i, i] = a_dp.dtype.type(1.0) + + n = shape[0] + if rhs_cols is None: + b_shape = (n,) + else: + b_shape = (n, rhs_cols) + b_dp = dpnp.empty(b_shape, dtype=dpnp.default_float_type(), order="F") + + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + x = dpnp.linalg.lu_solve((lu, piv), b_dp, check_finite=False) + + assert x.shape == b_shape + + @pytest.mark.parametrize("bad", [numpy.inf, -numpy.inf, numpy.nan]) + def test_check_finite_raises(self, bad): + a_dp = dpnp.array([[1.0, 0.0], [0.0, 1.0]], order="F") + lu, piv = dpnp.linalg.lu_factor(a_dp, check_finite=False) + + b_bad = dpnp.array([1.0, bad], order="F") + assert_raises( + ValueError, + dpnp.linalg.lu_solve, + (lu, piv), + b_bad, + check_finite=True, + ) + + class TestMatrixPower: @pytest.mark.parametrize("dtype", get_all_dtypes()) @pytest.mark.parametrize( From 6a3bae191504de9f893bc7892a363b096e983948 Mon Sep 17 00:00:00 2001 From: Vladislav Perevezentsev Date: Thu, 4 Sep 2025 07:06:47 -0700 Subject: [PATCH 4/4] Add sycl_queue and usm_type tests --- dpnp/tests/test_sycl_queue.py | 15 +++++++++++++++ dpnp/tests/test_usm_type.py | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/dpnp/tests/test_sycl_queue.py b/dpnp/tests/test_sycl_queue.py index d3d8e19439e..d947e5fe51e 100644 --- a/dpnp/tests/test_sycl_queue.py +++ b/dpnp/tests/test_sycl_queue.py @@ -9,6 +9,7 @@ from numpy.testing import assert_array_equal, assert_raises import dpnp +import dpnp.linalg from dpnp.dpnp_array import dpnp_array from dpnp.dpnp_utils import get_usm_allocations @@ -1582,6 +1583,20 @@ def test_lu_factor(self, data, device): param_queue = param.sycl_queue assert_sycl_queue_equal(param_queue, a.sycl_queue) + @pytest.mark.parametrize( + "data", + [[1.0, 2.0], numpy.empty((2, 0))], + ) + def test_lu_solve(self, data, device): + a = dpnp.array([[1.0, 2.0], [3.0, 5.0]], device=device) + lu, piv = dpnp.linalg.lu_factor(a) + b = dpnp.array(data, device=device) + + result = dpnp.linalg.lu_solve((lu, piv), b) + + assert_sycl_queue_equal(result.sycl_queue, a.sycl_queue) + assert_sycl_queue_equal(result.sycl_queue, b.sycl_queue) + @pytest.mark.parametrize("n", [-1, 0, 1, 2, 3]) def test_matrix_power(self, n, device): x = dpnp.array([[1.0, 2.0], [3.0, 5.0]], device=device) diff --git a/dpnp/tests/test_usm_type.py b/dpnp/tests/test_usm_type.py index 34fd9bbc003..2edffc2175c 100644 --- a/dpnp/tests/test_usm_type.py +++ b/dpnp/tests/test_usm_type.py @@ -1461,6 +1461,24 @@ def test_lu_factor(self, data, usm_type): for param in result: assert param.usm_type == a.usm_type + @pytest.mark.parametrize("usm_type_rhs", list_of_usm_types) + @pytest.mark.parametrize( + "data", + [[1.0, 2.0], numpy.empty((2, 0))], + ) + def test_lu_solve(self, data, usm_type, usm_type_rhs): + a = dpnp.array(data, usm_type=usm_type) + lu, piv = dpnp.linalg.lu_factor(a) + b = dpnp.array(data, usm_type=usm_type_rhs) + + result = dpnp.linalg.lu_solve((lu, piv), b) + + assert lu.usm_type == usm_type + assert b.usm_type == usm_type_rhs + assert result.usm_type == du.get_coerced_usm_type( + [usm_type, usm_type_rhs] + ) + @pytest.mark.parametrize("n", [-1, 0, 1, 2, 3]) def test_matrix_power(self, n, usm_type): a = dpnp.array([[1, 2], [3, 5]], usm_type=usm_type)