Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/reference/linalg.rst
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ Solving linear equations
dpnp.linalg.solve
dpnp.linalg.tensorsolve
dpnp.linalg.lstsq
dpnp.linalg.lu_solve
dpnp.linalg.inv
dpnp.linalg.pinv
dpnp.linalg.tensorinv
Expand Down
71 changes: 71 additions & 0 deletions dpnp/linalg/dpnp_iface_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
dpnp_inv,
dpnp_lstsq,
dpnp_lu_factor,
dpnp_lu_solve,
dpnp_matrix_power,
dpnp_matrix_rank,
dpnp_multi_dot,
Expand All @@ -81,6 +82,7 @@
"inv",
"lstsq",
"lu_factor",
"lu_solve",
"matmul",
"matrix_norm",
"matrix_power",
Expand Down Expand Up @@ -966,6 +968,75 @@ def lu_factor(a, overwrite_a=False, check_finite=True):
return dpnp_lu_factor(a, overwrite_a=overwrite_a, check_finite=check_finite)


def lu_solve(lu_and_piv, b, trans=0, overwrite_b=False, check_finite=True):
"""
Solve an equation system, a x = b, given the LU factorization of `a`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Solve an equation system, a x = b, given the LU factorization of `a`
Solve an equation system, a x = b, given the LU factorization of `a`.


For full documentation refer to :obj:`scipy.linalg.lu_solve`.

Parameters
----------
(lu, piv) : {tuple of dpnp.ndarrays or usm_ndarrays}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now it is rendered in a wrong way in the documentation

Suggested change
(lu, piv) : {tuple of dpnp.ndarrays or usm_ndarrays}
lu, piv : {tuple of dpnp.ndarrays or usm_ndarrays}

LU factorization of matrix `a` ((M, N)) together with pivot indices.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need double brackets here?

Suggested change
LU factorization of matrix `a` ((M, N)) together with pivot indices.
LU factorization of matrix `a` (M, N) together with pivot indices.

b : {(M,), (..., M, K)} {dpnp.ndarray, usm_ndarray}
Right-hand side
trans : {0, 1, 2} , optional
Type of system to solve:

===== =========
trans system
===== =========
0 a x = b
1 a^T x = b
2 a^H x = b
===== =========
overwrite_b : {None, bool}, optional
Whether to overwrite data in `b` (may increase performance).

Default: ``False``.
check_finite : {None, bool}, optional
Whether to check that the input matrix contains only finite numbers.
Disabling may give a performance gain, but may result in problems
(crashes, non-termination) if the inputs do contain infinities or NaNs.

Default: ``True``.

Returns
-------
x : {(M,), (M, K)} dpnp.ndarray
Solution to the system

Warning
-------
This function synchronizes in order to validate array elements
when ``check_finite=True``.

Examples
--------
>>> import dpnp as np
>>> A = np.array([[2, 5, 8, 7], [5, 2, 2, 8], [7, 5, 6, 6], [5, 4, 4, 8]])
>>> b = np.array([1, 1, 1, 1])
>>> lu, piv = np.linalg.lu_factor(A)
>>> x = np.linalg.lu_solve((lu, piv), b)
>>> np.allclose(A @ x - b, np.zeros((4,)))
array(True)

"""

(lu, piv) = lu_and_piv
dpnp.check_supported_arrays_type(lu, piv, b)
assert_stacked_2d(lu)

return dpnp_lu_solve(
lu,
piv,
b,
trans=trans,
overwrite_b=overwrite_b,
check_finite=check_finite,
)


def matmul(x1, x2, /):
"""
Computes the matrix product.
Expand Down
112 changes: 112 additions & 0 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2514,6 +2514,118 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
return (a_h, ipiv_h)


def dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True):
"""
dpnp_lu_solve(lu, piv, b, trans=0, overwrite_b=False, check_finite=True)

Solve an equation system (SciPy-compatible behavior).

This function mimics the behavior of `scipy.linalg.lu_solve` including
support for `trans`, `overwrite_b`, `check_finite`,
and 0-based pivot indexing.

"""

res_usm_type, exec_q = get_usm_allocations([lu, piv, b])

res_type = _common_type(lu, b)

# TODO: add broadcasting
if lu.shape[0] != b.shape[0]:
raise ValueError(
f"Shapes of lu {lu.shape} and b {b.shape} are incompatible"
)

if b.size == 0:
return dpnp.empty_like(b, dtype=res_type, usm_type=res_usm_type)

if lu.ndim > 2:
raise NotImplementedError("Batched matrices are not supported")

if check_finite:
if not dpnp.isfinite(lu).all():
raise ValueError(
"array must not contain infs or NaNs.\n"
"Note that when a singular matrix is given, unlike "
"dpnp.linalg.lu_factor returns an array containing NaN."
)
if not dpnp.isfinite(b).all():
raise ValueError("array must not contain infs or NaNs")

lu_usm_arr = dpnp.get_usm_ndarray(lu)
piv_usm_arr = dpnp.get_usm_ndarray(piv)
b_usm_arr = dpnp.get_usm_ndarray(b)

_manager = dpu.SequentialOrderManager[exec_q]
dep_evs = _manager.submitted_events

# oneMKL LAPACK getrf overwrites `a`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# oneMKL LAPACK getrf overwrites `a`.
# oneMKL LAPACK getrf overwrites `lu`.

lu_h = dpnp.empty_like(lu, order="F", dtype=res_type, usm_type=res_usm_type)

# use DPCTL tensor function to fill the сopy of the input array
# from the input array
ht_ev, lu_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=lu_usm_arr,
dst=lu_h.get_array(),
sycl_queue=lu.sycl_queue,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, lu_copy_ev)

# SciPy-compatible behavior
# Copy is required if:
# - overwrite_a is False (always copy),
# - dtype mismatch,
# - not F-contiguous,s
# - not writeable
if not overwrite_b or _is_copy_required(b, res_type):
b_h = dpnp.empty_like(
b, order="F", dtype=res_type, usm_type=res_usm_type
)
ht_ev, dep_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=b_usm_arr,
dst=b_h.get_array(),
sycl_queue=b.sycl_queue,
depends=_manager.submitted_events,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need to depend on lu_copy_ev here?

Suggested change
depends=_manager.submitted_events,
depends=dep_evs,

)
_manager.add_event_pair(ht_ev, dep_ev)
dep_ev = [dep_ev]
Comment on lines +2585 to +2592
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't we need to override dep_evs with list of copy events?

Suggested change
ht_ev, dep_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=b_usm_arr,
dst=b_h.get_array(),
sycl_queue=b.sycl_queue,
depends=_manager.submitted_events,
)
_manager.add_event_pair(ht_ev, dep_ev)
dep_ev = [dep_ev]
ht_ev, b_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=b_usm_arr,
dst=b_h.get_array(),
sycl_queue=b.sycl_queue,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, b_copy_ev)
dep_evs = [lu_copy_ev, b_copy_ev]

else:
# input is suitable for in-place modification
b_h = b
dep_ev = _manager.submitted_events
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and then here:

Suggested change
dep_ev = _manager.submitted_events
dep_evs = [lu_copy_ev]


# oneMKL LAPACK getrf overwrites `a`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# oneMKL LAPACK getrf overwrites `a`.
# oneMKL LAPACK getrf overwrites `piv`.

piv_h = dpnp.empty_like(piv, order="F", usm_type=res_usm_type)

# use DPCTL tensor function to fill the сopy of the pivot array
# from the pivot array
ht_ev, piv_copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=piv_usm_arr,
dst=piv_h.get_array(),
sycl_queue=piv.sycl_queue,
depends=dep_evs,
)
_manager.add_event_pair(ht_ev, piv_copy_ev)
# MKL lapack uses 1-origin while SciPy uses 0-origin
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess SciPy also uses MKL to call getrs, so it seems unclear for me.

piv_h += 1
Comment on lines +2598 to +2611
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it'd make sense to do before line#2581, to build a list of copy events dep_evs = [lu_copy_ev, piv_copy_ev, b_copy_ev]


# Call the LAPACK extension function _getrs
# to solve the system of linear equations with an LU-factored
# coefficient square matrix, with multiple right-hand sides.
ht_ev, getrs_ev = li._getrs(
exec_q,
lu_h.get_array(),
piv_h.get_array(),
b_h.get_array(),
trans,
depends=dep_ev,
)
_manager.add_event_pair(ht_ev, getrs_ev)

return b_h


def dpnp_matrix_power(a, n):
"""
dpnp_matrix_power(a, n)
Expand Down
Loading
Loading