-
Notifications
You must be signed in to change notification settings - Fork 23
Implement dpnp.linalg.lu_solve()
2D inputs
#2575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pass_trans_to_getrs
Are you sure you want to change the base?
Changes from all commits
b2d2b3a
a979b1a
6541bf0
6a3bae1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -57,6 +57,7 @@ | |||||
dpnp_inv, | ||||||
dpnp_lstsq, | ||||||
dpnp_lu_factor, | ||||||
dpnp_lu_solve, | ||||||
dpnp_matrix_power, | ||||||
dpnp_matrix_rank, | ||||||
dpnp_multi_dot, | ||||||
|
@@ -81,6 +82,7 @@ | |||||
"inv", | ||||||
"lstsq", | ||||||
"lu_factor", | ||||||
"lu_solve", | ||||||
"matmul", | ||||||
"matrix_norm", | ||||||
"matrix_power", | ||||||
|
@@ -966,6 +968,75 @@ def lu_factor(a, overwrite_a=False, check_finite=True): | |||||
return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite) | ||||||
|
||||||
|
||||||
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): | ||||||
""" | ||||||
Solve an equation system, a x = b, given the LU factorization of `a` | ||||||
|
||||||
For full documentation refer to :obj:`scipy.linalg.lu_solve`. | ||||||
|
||||||
Parameters | ||||||
---------- | ||||||
(lu, piv) : {tuple of dpnp.ndarrays or usm_ndarrays} | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now it is rendered in a wrong way in the documentation
Suggested change
|
||||||
LU factorization of matrix `a` ((M, N)) together with pivot indices. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need double brackets here?
Suggested change
|
||||||
b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray} | ||||||
Right-hand side | ||||||
trans : {0, 1, 2} , optional | ||||||
Type of system to solve: | ||||||
|
||||||
===== ========= | ||||||
trans system | ||||||
===== ========= | ||||||
0 a x = b | ||||||
1 a^T x = b | ||||||
2 a^H x = b | ||||||
===== ========= | ||||||
overwrite_b : {None, bool}, optional | ||||||
Whether to overwrite data in `b` (may increase performance). | ||||||
|
||||||
Default: ``False``. | ||||||
check_finite : {None, bool}, optional | ||||||
Whether to check that the input matrix contains only finite numbers. | ||||||
Disabling may give a performance gain, but may result in problems | ||||||
(crashes, non-termination) if the inputs do contain infinities or NaNs. | ||||||
|
||||||
Default: ``True``. | ||||||
|
||||||
Returns | ||||||
------- | ||||||
x : {(M,), (M, K)} dpnp.ndarray | ||||||
Solution to the system | ||||||
|
||||||
Warning | ||||||
------- | ||||||
This function synchronizes in order to validate array elements | ||||||
when ``check_finite=True``. | ||||||
|
||||||
Examples | ||||||
-------- | ||||||
>>> import dpnp as np | ||||||
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]]) | ||||||
>>> b = np.array([1, 1, 1, 1]) | ||||||
>>> lu, piv = np.linalg.lu_factor(A) | ||||||
>>> x = np.linalg.lu_solve((lu, piv), b) | ||||||
>>> np.allclose(A @ x - b, np.zeros((4,))) | ||||||
array(True) | ||||||
|
||||||
""" | ||||||
|
||||||
(lu, piv) = lu_and_piv | ||||||
dpnp.check_supported_arrays_type(lu, piv, b) | ||||||
assert_stacked_2d(lu) | ||||||
|
||||||
return dpnp_lu_solve( | ||||||
lu, | ||||||
piv, | ||||||
b, | ||||||
trans=trans, | ||||||
overwrite_b=overwrite_b, | ||||||
check_finite=check_finite, | ||||||
) | ||||||
|
||||||
|
||||||
def matmul(x1, x2, /): | ||||||
""" | ||||||
Computes the matrix product. | ||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -2514,6 +2514,118 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True): | |||||||||||||||||||||||||||||||||
return (a_h, ipiv_h) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True): | ||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
Solve an equation system (SciPy-compatible behavior). | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
This function mimics the behavior of `scipy.linalg.lu_solve` including | ||||||||||||||||||||||||||||||||||
support for `trans`, `overwrite_b`, `check_finite`, | ||||||||||||||||||||||||||||||||||
and 0-based pivot indexing. | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
res_usm_type, exec_q = get_usm_allocations([lu, piv, b]) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
res_type = _common_type(lu, b) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# TODO: add broadcasting | ||||||||||||||||||||||||||||||||||
if lu.shape[0] != b.shape[0]: | ||||||||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||||||||
f"Shapes of lu {lu.shape} and b {b.shape} are incompatible" | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
if b.size == 0: | ||||||||||||||||||||||||||||||||||
return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
if lu.ndim > 2: | ||||||||||||||||||||||||||||||||||
raise NotImplementedError("Batched matrices are not supported") | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
if check_finite: | ||||||||||||||||||||||||||||||||||
if not dpnp.isfinite(lu).all(): | ||||||||||||||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||||||||||||||
"array must not contain infs or NaNs.\n" | ||||||||||||||||||||||||||||||||||
"Note that when a singular matrix is given, unlike " | ||||||||||||||||||||||||||||||||||
"dpnp.linalg.lu_factor returns an array containing NaN." | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
if not dpnp.isfinite(b).all(): | ||||||||||||||||||||||||||||||||||
raise ValueError("array must not contain infs or NaNs") | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
lu_usm_arr = dpnp.get_usm_ndarray(lu) | ||||||||||||||||||||||||||||||||||
piv_usm_arr = dpnp.get_usm_ndarray(piv) | ||||||||||||||||||||||||||||||||||
b_usm_arr = dpnp.get_usm_ndarray(b) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
_manager = dpu.SequentialOrderManager[exec_q] | ||||||||||||||||||||||||||||||||||
dep_evs = _manager.submitted_events | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# oneMKL LAPACK getrf overwrites `a`. | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# use DPCTL tensor function to fill the сopy of the input array | ||||||||||||||||||||||||||||||||||
# from the input array | ||||||||||||||||||||||||||||||||||
ht_ev, lu_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( | ||||||||||||||||||||||||||||||||||
src=lu_usm_arr, | ||||||||||||||||||||||||||||||||||
dst=lu_h.get_array(), | ||||||||||||||||||||||||||||||||||
sycl_queue=lu.sycl_queue, | ||||||||||||||||||||||||||||||||||
depends=dep_evs, | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
_manager.add_event_pair(ht_ev, lu_copy_ev) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# SciPy-compatible behavior | ||||||||||||||||||||||||||||||||||
# Copy is required if: | ||||||||||||||||||||||||||||||||||
# - overwrite_a is False (always copy), | ||||||||||||||||||||||||||||||||||
# - dtype mismatch, | ||||||||||||||||||||||||||||||||||
# - not F-contiguous,s | ||||||||||||||||||||||||||||||||||
# - not writeable | ||||||||||||||||||||||||||||||||||
if not overwrite_b or _is_copy_required(b, res_type): | ||||||||||||||||||||||||||||||||||
b_h = dpnp.empty_like( | ||||||||||||||||||||||||||||||||||
b, order="F", dtype=res_type, usm_type=res_usm_type | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
ht_ev, dep_ev = ti._copy_usm_ndarray_into_usm_ndarray( | ||||||||||||||||||||||||||||||||||
src=b_usm_arr, | ||||||||||||||||||||||||||||||||||
dst=b_h.get_array(), | ||||||||||||||||||||||||||||||||||
sycl_queue=b.sycl_queue, | ||||||||||||||||||||||||||||||||||
depends=_manager.submitted_events, | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we need to depend on
Suggested change
|
||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
_manager.add_event_pair(ht_ev, dep_ev) | ||||||||||||||||||||||||||||||||||
dep_ev = [dep_ev] | ||||||||||||||||||||||||||||||||||
Comment on lines
+2585
to
+2592
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't we need to override
Suggested change
|
||||||||||||||||||||||||||||||||||
else: | ||||||||||||||||||||||||||||||||||
# input is suitable for in-place modification | ||||||||||||||||||||||||||||||||||
b_h = b | ||||||||||||||||||||||||||||||||||
dep_ev = _manager.submitted_events | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and then here:
Suggested change
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# oneMKL LAPACK getrf overwrites `a`. | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# use DPCTL tensor function to fill the сopy of the pivot array | ||||||||||||||||||||||||||||||||||
# from the pivot array | ||||||||||||||||||||||||||||||||||
ht_ev, piv_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( | ||||||||||||||||||||||||||||||||||
src=piv_usm_arr, | ||||||||||||||||||||||||||||||||||
dst=piv_h.get_array(), | ||||||||||||||||||||||||||||||||||
sycl_queue=piv.sycl_queue, | ||||||||||||||||||||||||||||||||||
depends=dep_evs, | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
_manager.add_event_pair(ht_ev, piv_copy_ev) | ||||||||||||||||||||||||||||||||||
# MKL lapack uses 1-origin while SciPy uses 0-origin | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess SciPy also uses MKL to call |
||||||||||||||||||||||||||||||||||
piv_h += 1 | ||||||||||||||||||||||||||||||||||
Comment on lines
+2598
to
+2611
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it'd make sense to do before line#2581, to build a list of copy events |
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
# Call the LAPACK extension function _getrs | ||||||||||||||||||||||||||||||||||
# to solve the system of linear equations with an LU-factored | ||||||||||||||||||||||||||||||||||
# coefficient square matrix, with multiple right-hand sides. | ||||||||||||||||||||||||||||||||||
ht_ev, getrs_ev = li._getrs( | ||||||||||||||||||||||||||||||||||
exec_q, | ||||||||||||||||||||||||||||||||||
lu_h.get_array(), | ||||||||||||||||||||||||||||||||||
piv_h.get_array(), | ||||||||||||||||||||||||||||||||||
b_h.get_array(), | ||||||||||||||||||||||||||||||||||
trans, | ||||||||||||||||||||||||||||||||||
depends=dep_ev, | ||||||||||||||||||||||||||||||||||
) | ||||||||||||||||||||||||||||||||||
_manager.add_event_pair(ht_ev, getrs_ev) | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
return b_h | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
def dpnp_matrix_power(a, n): | ||||||||||||||||||||||||||||||||||
""" | ||||||||||||||||||||||||||||||||||
dpnp_matrix_power(a, n) | ||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.