Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dpnp/backend/extensions/lapack/getrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ extern std::pair<sycl::event, sycl::event>
const dpctl::tensor::usm_ndarray &a_array,
const dpctl::tensor::usm_ndarray &ipiv_array,
py::list dev_info,
std::int64_t m,
std::int64_t n,
std::int64_t stride_a,
std::int64_t stride_ipiv,
Expand Down
32 changes: 21 additions & 11 deletions dpnp/backend/extensions/lapack/getrf_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ namespace type_utils = dpctl::tensor::type_utils;
typedef sycl::event (*getrf_batch_impl_fn_ptr_t)(
sycl::queue &,
std::int64_t,
std::int64_t,
char *,
std::int64_t,
std::int64_t,
Expand All @@ -61,6 +62,7 @@ static getrf_batch_impl_fn_ptr_t

template <typename T>
static sycl::event getrf_batch_impl(sycl::queue &exec_q,
std::int64_t m,
std::int64_t n,
char *in_a,
std::int64_t lda,
Expand All @@ -77,7 +79,7 @@ static sycl::event getrf_batch_impl(sycl::queue &exec_q,
T *a = reinterpret_cast<T *>(in_a);

const std::int64_t scratchpad_size =
mkl_lapack::getrf_batch_scratchpad_size<T>(exec_q, n, n, lda, stride_a,
mkl_lapack::getrf_batch_scratchpad_size<T>(exec_q, m, n, lda, stride_a,
stride_ipiv, batch_size);
T *scratchpad = nullptr;

Expand All @@ -91,11 +93,11 @@ static sycl::event getrf_batch_impl(sycl::queue &exec_q,

getrf_batch_event = mkl_lapack::getrf_batch(
exec_q,
n, // The order of each square matrix in the batch; (0 ≤ n).
m, // The number of rows in each matrix in the batch; (0 ≤ m).
// It must be a non-negative integer.
n, // The number of columns in each matrix in the batch; (0 ≤ n).
// It must be a non-negative integer.
a, // Pointer to the batch of square matrices, each of size (n x n).
a, // Pointer to the batch of input matrices, each of size (m x n).
lda, // The leading dimension of each matrix in the batch.
stride_a, // Stride between consecutive matrices in the batch.
ipiv, // Pointer to the array of pivot indices for each matrix in
Expand Down Expand Up @@ -179,6 +181,7 @@ std::pair<sycl::event, sycl::event>
const dpctl::tensor::usm_ndarray &a_array,
const dpctl::tensor::usm_ndarray &ipiv_array,
py::list dev_info,
std::int64_t m,
std::int64_t n,
std::int64_t stride_a,
std::int64_t stride_ipiv,
Expand All @@ -191,21 +194,21 @@ std::pair<sycl::event, sycl::event>
if (a_array_nd < 3) {
throw py::value_error(
"The input array has ndim=" + std::to_string(a_array_nd) +
", but an array with ndim >= 3 is expected.");
", but an array with ndim >= 3 is expected");
}

if (ipiv_array_nd != 2) {
throw py::value_error("The array of pivot indices has ndim=" +
std::to_string(ipiv_array_nd) +
", but a 2-dimensional array is expected.");
", but a 2-dimensional array is expected");
}

const int dev_info_size = py::len(dev_info);
if (dev_info_size != batch_size) {
throw py::value_error("The size of 'dev_info' (" +
std::to_string(dev_info_size) +
") does not match the expected batch size (" +
std::to_string(batch_size) + ").");
std::to_string(batch_size) + ")");
}

// check compatibility of execution queue and allocation queue
Expand Down Expand Up @@ -241,27 +244,34 @@ std::pair<sycl::event, sycl::event>
if (getrf_batch_fn == nullptr) {
throw py::value_error(
"No getrf_batch implementation defined for the provided type "
"of the input matrix.");
"of the input matrix");
}

auto ipiv_types = dpctl_td_ns::usm_ndarray_types();
int ipiv_array_type_id =
ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum());

if (ipiv_array_type_id != static_cast<int>(dpctl_td_ns::typenum_t::INT64)) {
throw py::value_error("The type of 'ipiv_array' must be int64.");
throw py::value_error("The type of 'ipiv_array' must be int64");
}

const py::ssize_t *ipiv_array_shape = ipiv_array.get_shape_raw();
if (ipiv_array_shape[0] != batch_size ||
ipiv_array_shape[1] != std::min(m, n)) {
throw py::value_error(
"The shape of 'ipiv_array' must be (batch_size, min(m, n))");
}

char *a_array_data = a_array.get_data();
const std::int64_t lda = std::max<size_t>(1UL, n);
const std::int64_t lda = std::max<size_t>(1UL, m);

char *ipiv_array_data = ipiv_array.get_data();
std::int64_t *d_ipiv = reinterpret_cast<std::int64_t *>(ipiv_array_data);

std::vector<sycl::event> host_task_events;
sycl::event getrf_batch_ev = getrf_batch_fn(
exec_q, n, a_array_data, lda, stride_a, d_ipiv, stride_ipiv, batch_size,
dev_info, host_task_events, depends);
exec_q, m, n, a_array_data, lda, stride_a, d_ipiv, stride_ipiv,
batch_size, dev_info, host_task_events, depends);

sycl::event args_ev = dpctl::utils::keep_args_alive(
exec_q, {a_array, ipiv_array}, host_task_events);
Expand Down
6 changes: 3 additions & 3 deletions dpnp/backend/extensions/lapack/lapack_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,10 @@ PYBIND11_MODULE(_lapack_impl, m)

m.def("_getrf_batch", &lapack_ext::getrf_batch,
"Call `getrf_batch` from OneMKL LAPACK library to return "
"the LU factorization of a batch of general n x n matrices",
"the LU factorization of a batch of general m x n matrices",
py::arg("sycl_queue"), py::arg("a_array"), py::arg("ipiv_array"),
py::arg("dev_info_array"), py::arg("n"), py::arg("stride_a"),
py::arg("stride_ipiv"), py::arg("batch_size"),
py::arg("dev_info_array"), py::arg("m"), py::arg("n"),
py::arg("stride_a"), py::arg("stride_ipiv"), py::arg("batch_size"),
py::arg("depends") = py::list());

m.def("_getri_batch", &lapack_ext::getri_batch,
Expand Down
143 changes: 136 additions & 7 deletions dpnp/linalg/dpnp_utils_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ def _batched_inv(a, res_type):
ipiv_h.get_array(),
dev_info,
n,
n,
a_stride,
ipiv_stride,
batch_size,
Expand Down Expand Up @@ -327,6 +328,7 @@ def _batched_lu_factor(a, res_type):
ipiv_h.get_array(),
dev_info_h,
n,
n,
a_stride,
ipiv_stride,
batch_size,
Expand Down Expand Up @@ -396,6 +398,131 @@ def _batched_lu_factor(a, res_type):
return (out_a, out_ipiv, out_dev_info)


def _batched_lu_factor_scipy(a, res_type): # pylint: disable=too-many-locals
"""SciPy-compatible LU factorization for batched inputs."""

# TODO: Find out at which array sizes the best performance is obtained
# getrf_batch can be slow on large GPU arrays.
# Use getrf_batch only on CPU.
# On GPU fall back to calling getrf per 2D slice.
use_batch = a.sycl_device.has_aspect_cpu

a_sycl_queue = a.sycl_queue
a_usm_type = a.usm_type
_manager = dpu.SequentialOrderManager[a_sycl_queue]

m, n = a.shape[-2:]
k = min(m, n)
orig_shape = a.shape
batch_shape = orig_shape[:-2]

# handle empty input
if a.size == 0:
lu = dpnp.empty_like(a)
piv = dpnp.empty(
(*batch_shape, k),
dtype=dpnp.int64,
usm_type=a_usm_type,
sycl_queue=a_sycl_queue,
)
return lu, piv

# get 3d input arrays by reshape
a = dpnp.reshape(a, (-1, m, n))
batch_size = a.shape[0]

# Move batch axis to the end (m, n, batch) in Fortran order:
# required by getrf_batch
# and ensures each a[..., i] is F-contiguous for getrf
a = dpnp.moveaxis(a, 0, -1)

a_usm_arr = dpnp.get_usm_ndarray(a)

# `a` must be copied because getrf/getrf_batch destroys the input matrix
a_h = dpnp.empty_like(a, order="F", dtype=res_type)
ht_ev, copy_ev = ti._copy_usm_ndarray_into_usm_ndarray(
src=a_usm_arr,
dst=a_h.get_array(),
sycl_queue=a_sycl_queue,
depends=_manager.submitted_events,
)
_manager.add_event_pair(ht_ev, copy_ev)

ipiv_h = dpnp.empty(
(batch_size, k),
dtype=dpnp.int64,
order="C",
usm_type=a_usm_type,
sycl_queue=a_sycl_queue,
)

if use_batch:
dev_info_h = [0] * batch_size

ipiv_stride = k
a_stride = a_h.strides[-1]

# Call the LAPACK extension function _getrf_batch
# to perform LU decomposition of a batch of general matrices
ht_ev, getrf_ev = li._getrf_batch(
a_sycl_queue,
a_h.get_array(),
ipiv_h.get_array(),
dev_info_h,
m,
n,
a_stride,
ipiv_stride,
batch_size,
depends=[copy_ev],
)
_manager.add_event_pair(ht_ev, getrf_ev)

if any(dev_info_h):
diag_nums = ", ".join(str(v) for v in dev_info_h if v > 0)
warn(
f"Diagonal number {diag_nums} are exactly zero. "
"Singular matrix.",
RuntimeWarning,
stacklevel=2,
)
else:
dev_info_vecs = [[0] for _ in range(batch_size)]

# Sequential LU factorization using getrf per slice
for i in range(batch_size):
ht_ev, getrf_ev = li._getrf(
a_sycl_queue,
a_h[..., i].get_array(),
ipiv_h[i].get_array(),
dev_info_vecs[i],
depends=[copy_ev],
)
_manager.add_event_pair(ht_ev, getrf_ev)

diag_nums = ", ".join(
str(v) for info in dev_info_vecs for v in info if v > 0
)
if diag_nums:
warn(
f"Diagonal number {diag_nums} are exactly zero. "
"Singular matrix.",
RuntimeWarning,
stacklevel=2,
)

# Restore original shape: move batch axis back and reshape
a_h = dpnp.moveaxis(a_h, -1, 0).reshape(orig_shape)
ipiv_h = ipiv_h.reshape((*batch_shape, k))

# oneMKL LAPACK uses 1-origin while SciPy uses 0-origin
ipiv_h -= 1

# Return a tuple containing the factorized matrix 'a_h',
# pivot indices 'ipiv_h'
return (a_h, ipiv_h)


def _batched_solve(a, b, exec_q, res_usm_type, res_type):
"""
_batched_solve(a, b, exec_q, res_usm_type, res_type)
Expand Down Expand Up @@ -2308,6 +2435,15 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
a_sycl_queue = a.sycl_queue
a_usm_type = a.usm_type

if check_finite:
if not dpnp.isfinite(a).all():
raise ValueError("array must not contain infs or NaNs")

if a.ndim > 2:
# SciPy always copies each 2D slice,
# so `overwrite_a` is ignored here
return _batched_lu_factor_scipy(a, res_type)

# accommodate empty arrays
if a.size == 0:
lu = dpnp.empty_like(a)
Expand All @@ -2316,13 +2452,6 @@ def dpnp_lu_factor(a, overwrite_a=False, check_finite=True):
)
return lu, piv

if check_finite:
if not dpnp.isfinite(a).all():
raise ValueError("array must not contain infs or NaNs")

if a.ndim > 2:
raise NotImplementedError("Batched matrices are not supported")

_manager = dpu.SequentialOrderManager[a_sycl_queue]
a_usm_arr = dpnp.get_usm_ndarray(a)

Expand Down
Loading
Loading