diff --git a/lax/src/layout.rs b/lax/src/layout.rs index 13caca2c..aa9fe110 100644 --- a/lax/src/layout.rs +++ b/lax/src/layout.rs @@ -1,44 +1,76 @@ //! Memory layout of matrices +//! +//! Different from ndarray format which consists of shape and strides, +//! matrix format in LAPACK consists of row or column size and leading dimension. +//! +//! ndarray format and stride +//! -------------------------- +//! +//! Let us consider 3-dimensional array for explaining ndarray structure. +//! The address of `(x,y,z)`-element in ndarray satisfies following relation: +//! +//! ```text +//! shape = [Nx, Ny, Nz] +//! where Nx > 0, Ny > 0, Nz > 0 +//! stride = [Sx, Sy, Sz] +//! +//! &data[(x, y, z)] = &data[(0, 0, 0)] + Sx*x + Sy*y + Sz*z +//! for x < Nx, y < Ny, z < Nz +//! ``` +//! +//! The array is called +//! +//! - C-continuous if `[Sx, Sy, Sz] = [Nz*Ny, Nz, 1]` +//! - F(Fortran)-continuous if `[Sx, Sy, Sz] = [1, Nx, Nx*Ny]` +//! +//! Strides of ndarray `[Sx, Sy, Sz]` take arbitrary value, +//! e.g. it can be non-ordered `Sy > Sx > Sz`, or can be negative `Sx < 0`. +//! If the minimum of `[Sx, Sy, Sz]` equals to `1`, +//! the value of elements fills `data` memory region and called "continuous". +//! Non-continuous ndarray is useful to get sub-array without copying data. +//! +//! Matrix layout for LAPACK +//! ------------------------- +//! +//! LAPACK interface focuses on the linear algebra operations for F-continuous 2-dimensional array. +//! Under this restriction, stride becomes far simpler; we only have to consider the case `[1, S]` +//! This `S` for a matrix `A` is called "leading dimension of the array A" in LAPACK document, and denoted by `lda`. +//! -pub type LDA = i32; -pub type LEN = i32; -pub type Col = i32; -pub type Row = i32; - -#[derive(Debug, Clone, Copy, PartialEq)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum MatrixLayout { - C((Row, LDA)), - F((Col, LDA)), + C { row: i32, lda: i32 }, + F { col: i32, lda: i32 }, } impl MatrixLayout { - pub fn size(&self) -> (Row, Col) { + pub fn size(&self) -> (i32, i32) { match *self { - MatrixLayout::C((row, lda)) => (row, lda), - MatrixLayout::F((col, lda)) => (lda, col), + MatrixLayout::C { row, lda } => (row, lda), + MatrixLayout::F { col, lda } => (lda, col), } } - pub fn resized(&self, row: Row, col: Col) -> MatrixLayout { + pub fn resized(&self, row: i32, col: i32) -> MatrixLayout { match *self { - MatrixLayout::C(_) => MatrixLayout::C((row, col)), - MatrixLayout::F(_) => MatrixLayout::F((col, row)), + MatrixLayout::C { .. } => MatrixLayout::C { row, lda: col }, + MatrixLayout::F { .. } => MatrixLayout::F { col, lda: row }, } } - pub fn lda(&self) -> LDA { + pub fn lda(&self) -> i32 { std::cmp::max( 1, match *self { - MatrixLayout::C((_, lda)) | MatrixLayout::F((_, lda)) => lda, + MatrixLayout::C { lda, .. } | MatrixLayout::F { lda, .. } => lda, }, ) } - pub fn len(&self) -> LEN { + pub fn len(&self) -> i32 { match *self { - MatrixLayout::C((row, _)) => row, - MatrixLayout::F((col, _)) => col, + MatrixLayout::C { row, .. } => row, + MatrixLayout::F { col, .. } => col, } } @@ -48,8 +80,8 @@ impl MatrixLayout { pub fn lapacke_layout(&self) -> lapacke::Layout { match *self { - MatrixLayout::C(_) => lapacke::Layout::RowMajor, - MatrixLayout::F(_) => lapacke::Layout::ColumnMajor, + MatrixLayout::C { .. } => lapacke::Layout::RowMajor, + MatrixLayout::F { .. } => lapacke::Layout::ColumnMajor, } } @@ -59,8 +91,8 @@ impl MatrixLayout { pub fn toggle_order(&self) -> Self { match *self { - MatrixLayout::C((row, col)) => MatrixLayout::F((col, row)), - MatrixLayout::F((col, row)) => MatrixLayout::C((row, col)), + MatrixLayout::C { row, lda } => MatrixLayout::F { lda: row, col: lda }, + MatrixLayout::F { col, lda } => MatrixLayout::C { row: lda, lda: col }, } } } diff --git a/lax/src/opnorm.rs b/lax/src/opnorm.rs index 305629c1..4786fd6e 100644 --- a/lax/src/opnorm.rs +++ b/lax/src/opnorm.rs @@ -15,8 +15,8 @@ macro_rules! impl_opnorm { impl OperatorNorm_ for $scalar { unsafe fn opnorm(t: NormType, l: MatrixLayout, a: &[Self]) -> Self::Real { match l { - MatrixLayout::F((col, lda)) => $lange(cm, t as u8, lda, col, a, lda), - MatrixLayout::C((row, lda)) => { + MatrixLayout::F { col, lda } => $lange(cm, t as u8, lda, col, a, lda), + MatrixLayout::C { row, lda } => { $lange(cm, t.transpose() as u8, lda, row, a, lda) } } diff --git a/lax/src/solveh.rs b/lax/src/solveh.rs index 2f851d42..01e90f13 100644 --- a/lax/src/solveh.rs +++ b/lax/src/solveh.rs @@ -58,8 +58,8 @@ macro_rules! impl_solveh { let (n, _) = l.size(); let nrhs = 1; let ldb = match l { - MatrixLayout::C(_) => 1, - MatrixLayout::F(_) => n, + MatrixLayout::C { .. } => 1, + MatrixLayout::F { .. } => n, }; $trs( l.lapacke_layout(), diff --git a/ndarray-linalg/src/convert.rs b/ndarray-linalg/src/convert.rs index a4f37d59..307eb477 100644 --- a/ndarray-linalg/src/convert.rs +++ b/ndarray-linalg/src/convert.rs @@ -36,11 +36,11 @@ where S: DataOwned, { match l { - MatrixLayout::C((row, col)) => { - Ok(ArrayBase::from_shape_vec((row as usize, col as usize), a)?) + MatrixLayout::C { row, lda } => { + Ok(ArrayBase::from_shape_vec((row as usize, lda as usize), a)?) } - MatrixLayout::F((col, row)) => Ok(ArrayBase::from_shape_vec( - (row as usize, col as usize).f(), + MatrixLayout::F { col, lda } => Ok(ArrayBase::from_shape_vec( + (lda as usize, col as usize).f(), a, )?), } @@ -52,11 +52,11 @@ where S: DataOwned, { match l { - MatrixLayout::C((row, col)) => unsafe { - ArrayBase::uninitialized((row as usize, col as usize)) + MatrixLayout::C { row, lda } => unsafe { + ArrayBase::uninitialized((row as usize, lda as usize)) }, - MatrixLayout::F((col, row)) => unsafe { - ArrayBase::uninitialized((row as usize, col as usize).f()) + MatrixLayout::F { col, lda } => unsafe { + ArrayBase::uninitialized((lda as usize, col as usize).f()) }, } } diff --git a/ndarray-linalg/src/eigh.rs b/ndarray-linalg/src/eigh.rs index 1fc7a031..b0438a99 100644 --- a/ndarray-linalg/src/eigh.rs +++ b/ndarray-linalg/src/eigh.rs @@ -96,8 +96,8 @@ where let layout = self.square_layout()?; // XXX Force layout to be Fortran (see #146) match layout { - MatrixLayout::C(_) => self.swap_axes(0, 1), - MatrixLayout::F(_) => {} + MatrixLayout::C { .. } => self.swap_axes(0, 1), + MatrixLayout::F { .. } => {} } let s = unsafe { A::eigh(true, self.square_layout()?, uplo, self.as_allocated_mut()?)? }; Ok((ArrayBase::from(s), self)) @@ -116,14 +116,14 @@ where let layout = self.0.square_layout()?; // XXX Force layout to be Fortran (see #146) match layout { - MatrixLayout::C(_) => self.0.swap_axes(0, 1), - MatrixLayout::F(_) => {} + MatrixLayout::C { .. } => self.0.swap_axes(0, 1), + MatrixLayout::F { .. } => {} } let layout = self.1.square_layout()?; match layout { - MatrixLayout::C(_) => self.1.swap_axes(0, 1), - MatrixLayout::F(_) => {} + MatrixLayout::C { .. } => self.1.swap_axes(0, 1), + MatrixLayout::F { .. } => {} } let s = unsafe { diff --git a/ndarray-linalg/src/layout.rs b/ndarray-linalg/src/layout.rs index 9e43369b..cadbcf2d 100644 --- a/ndarray-linalg/src/layout.rs +++ b/ndarray-linalg/src/layout.rs @@ -1,4 +1,4 @@ -//! Memory layout of matrices +//! Convert ndarray into LAPACK-compatible matrix format use super::error::*; use ndarray::*; @@ -28,10 +28,16 @@ where let shape = self.shape(); let strides = self.strides(); if shape[0] == strides[1] as usize { - return Ok(MatrixLayout::F((self.ncols() as i32, self.nrows() as i32))); + return Ok(MatrixLayout::F { + col: self.ncols() as i32, + lda: self.nrows() as i32, + }); } if shape[1] == strides[0] as usize { - return Ok(MatrixLayout::C((self.nrows() as i32, self.ncols() as i32))); + return Ok(MatrixLayout::C { + row: self.nrows() as i32, + lda: self.ncols() as i32, + }); } Err(LinalgError::InvalidStride { s0: strides[0], diff --git a/ndarray-linalg/src/least_squares.rs b/ndarray-linalg/src/least_squares.rs index 9fab0062..4e629a4e 100644 --- a/ndarray-linalg/src/least_squares.rs +++ b/ndarray-linalg/src/least_squares.rs @@ -735,7 +735,7 @@ mod tests { fn test_incompatible_shape_error_on_mismatching_layout() { let a: Array2 = array![[1., 2.], [4., 5.], [3., 4.]]; let b = array![[1.], [2.]].t().to_owned(); - assert_eq!(b.layout().unwrap(), MatrixLayout::F((2, 1))); + assert_eq!(b.layout().unwrap(), MatrixLayout::F { col: 2, lda: 1 }); let res = a.least_squares(&b); match res { diff --git a/ndarray-linalg/tests/layout.rs b/ndarray-linalg/tests/layout.rs index 89938330..2c253f5e 100644 --- a/ndarray-linalg/tests/layout.rs +++ b/ndarray-linalg/tests/layout.rs @@ -6,26 +6,26 @@ use ndarray_linalg::*; fn layout_c_3x1() { let a: Array2 = Array::zeros((3, 1)); println!("a = {:?}", &a); - assert_eq!(a.layout().unwrap(), MatrixLayout::C((3, 1))); + assert_eq!(a.layout().unwrap(), MatrixLayout::C { row: 3, lda: 1 }); } #[test] fn layout_f_3x1() { let a: Array2 = Array::zeros((3, 1).f()); println!("a = {:?}", &a); - assert_eq!(a.layout().unwrap(), MatrixLayout::F((1, 3))); + assert_eq!(a.layout().unwrap(), MatrixLayout::F { col: 1, lda: 3 }); } #[test] fn layout_c_3x2() { let a: Array2 = Array::zeros((3, 2)); println!("a = {:?}", &a); - assert_eq!(a.layout().unwrap(), MatrixLayout::C((3, 2))); + assert_eq!(a.layout().unwrap(), MatrixLayout::C { row: 3, lda: 2 }); } #[test] fn layout_f_3x2() { let a: Array2 = Array::zeros((3, 2).f()); println!("a = {:?}", &a); - assert_eq!(a.layout().unwrap(), MatrixLayout::F((2, 3))); + assert_eq!(a.layout().unwrap(), MatrixLayout::F { col: 2, lda: 3 }); }