-
Notifications
You must be signed in to change notification settings - Fork 23
Implement dpnp.linalg.lu_solve()
2D inputs
#2575
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: pass_trans_to_getrs
Are you sure you want to change the base?
Conversation
dpnp.linalg.solve()
2D inputs dpnp.linalg.lu_solve()
2D inputs
Array API standard conformance tests for dpnp=0.19.0dev3=py313h509198e_27 ran successfully. |
View rendered docs @ https://intelpython.github.io/dpnp/pull/2575/index.html |
@@ -966,6 +968,75 @@ def lu_factor(a, overwrite_a=False, check_finite=True): | |||
return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite) | |||
|
|||
|
|||
def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True): | |||
""" | |||
Solve an equation system, a x = b, given the LU factorization of `a` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solve an equation system, a x = b, given the LU factorization of `a` | |
Solve an equation system, a x = b, given the LU factorization of `a`. |
|
||
Parameters | ||
---------- | ||
(lu, piv) : {tuple of dpnp.ndarrays or usm_ndarrays} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now it is rendered in a wrong way in the documentation
(lu, piv) : {tuple of dpnp.ndarrays or usm_ndarrays} | |
lu, piv : {tuple of dpnp.ndarrays or usm_ndarrays} |
Parameters | ||
---------- | ||
(lu, piv) : {tuple of dpnp.ndarrays or usm_ndarrays} | ||
LU factorization of matrix `a` ((M, N)) together with pivot indices. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need double brackets here?
LU factorization of matrix `a` ((M, N)) together with pivot indices. | |
LU factorization of matrix `a` (M, N) together with pivot indices. |
A[i, i] = A.dtype.type(off + 1.0) | ||
return A | ||
|
||
@pytest.mark.parametrize("shape", [(1, 1), (2, 2), (3, 3), (5, 5)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I failed to find a test with (M, N) shapes, only with (M, M)
_manager = dpu.SequentialOrderManager[exec_q] | ||
dep_evs = _manager.submitted_events | ||
|
||
# oneMKL LAPACK getrf overwrites `a`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# oneMKL LAPACK getrf overwrites `a`. | |
# oneMKL LAPACK getrf overwrites `lu`. |
ht_ev, dep_ev = ti._copy_usm_ndarray_into_usm_ndarray( | ||
src=b_usm_arr, | ||
dst=b_h.get_array(), | ||
sycl_queue=b.sycl_queue, | ||
depends=_manager.submitted_events, | ||
) | ||
_manager.add_event_pair(ht_ev, dep_ev) | ||
dep_ev = [dep_ev] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't we need to override dep_evs
with list of copy events?
ht_ev, dep_ev = ti._copy_usm_ndarray_into_usm_ndarray( | |
src=b_usm_arr, | |
dst=b_h.get_array(), | |
sycl_queue=b.sycl_queue, | |
depends=_manager.submitted_events, | |
) | |
_manager.add_event_pair(ht_ev, dep_ev) | |
dep_ev = [dep_ev] | |
ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( | |
src=b_usm_arr, | |
dst=b_h.get_array(), | |
sycl_queue=b.sycl_queue, | |
depends=dep_evs, | |
) | |
_manager.add_event_pair(ht_ev, b_copy_ev) | |
dep_evs = [lu_copy_ev, b_copy_ev] |
else: | ||
# input is suitable for in-place modification | ||
b_h = b | ||
dep_ev = _manager.submitted_events |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and then here:
dep_ev = _manager.submitted_events | |
dep_evs = [lu_copy_ev] |
b_h = b | ||
dep_ev = _manager.submitted_events | ||
|
||
# oneMKL LAPACK getrf overwrites `a`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# oneMKL LAPACK getrf overwrites `a`. | |
# oneMKL LAPACK getrf overwrites `piv`. |
# oneMKL LAPACK getrf overwrites `a`. | ||
piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type) | ||
|
||
# use DPCTL tensor function to fill the сopy of the pivot array | ||
# from the pivot array | ||
ht_ev, piv_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray( | ||
src=piv_usm_arr, | ||
dst=piv_h.get_array(), | ||
sycl_queue=piv.sycl_queue, | ||
depends=dep_evs, | ||
) | ||
_manager.add_event_pair(ht_ev, piv_copy_ev) | ||
# MKL lapack uses 1-origin while SciPy uses 0-origin | ||
piv_h += 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it'd make sense to do before line#2581, to build a list of copy events dep_evs = [lu_copy_ev, piv_copy_ev, b_copy_ev]
depends=dep_evs, | ||
) | ||
_manager.add_event_pair(ht_ev, piv_copy_ev) | ||
# MKL lapack uses 1-origin while SciPy uses 0-origin |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess SciPy also uses MKL to call getrs
, so it seems unclear for me.
This PR suggests adding
dpnp.linalg.lu_solve()
for 2D arrays similar to scipy.linalg.lu_solveSupport for ND inputs will be added in the next phase.