Skip to content

Commit 4624ff2

Browse files
committed
fix based on code review
1 parent 0ff92e4 commit 4624ff2

File tree

7 files changed

+127
-132
lines changed

7 files changed

+127
-132
lines changed

examples/tridiagonal.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,8 @@ fn factorize() -> Result<(), error::LinalgError> {
2727
fn main() {
2828
solve().unwrap();
2929
factorize().unwrap();
30+
match arr2(&[[0.0]]).extract_tridiagonal() {
31+
Ok(_) => {}
32+
Err(err) => println!("{}", err),
33+
}
3034
}

src/error.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@ pub enum LinalgError {
1717
InvalidStride { s0: Ixs, s1: Ixs },
1818
/// Memory is not aligned continously
1919
MemoryNotCont,
20+
/// Obj cannot be made from a (rows, cols) matrix
21+
NotStandardShape {
22+
obj: &'static str,
23+
rows: i32,
24+
cols: i32,
25+
},
2026
/// Strides of the array is not supported
2127
Shape(ShapeError),
2228
}
@@ -34,6 +40,11 @@ impl fmt::Display for LinalgError {
3440
write!(f, "invalid stride: s0={}, s1={}", s0, s1)
3541
}
3642
LinalgError::MemoryNotCont => write!(f, "Memory is not contiguous"),
43+
LinalgError::NotStandardShape { obj, rows, cols } => write!(
44+
f,
45+
"{} cannot be made from a ({}, {}) matrix",
46+
obj, rows, cols
47+
),
3748
LinalgError::Shape(err) => write!(f, "Shape Error: {}", err),
3849
}
3950
}

src/lapack/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ pub trait Lapack:
4141
+ Eig_
4242
+ Eigh_
4343
+ Triangular_
44-
+ TriDiagonal_
44+
+ Tridiagonal_
4545
{
4646
}
4747

src/lapack/tridiagonal.rs

Lines changed: 20 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
//! for tridiagonal matrix
33
44
use lapacke;
5-
use ndarray::*;
65
use num_traits::Zero;
76

87
use super::NormType;
@@ -11,20 +10,18 @@ use super::{into_result, Pivot, Transpose};
1110
use crate::error::*;
1211
use crate::layout::MatrixLayout;
1312
use crate::opnorm::*;
14-
use crate::tridiagonal::{LUFactorizedTriDiagonal, TriDiagonal};
13+
use crate::tridiagonal::{LUFactorizedTridiagonal, Tridiagonal};
1514
use crate::types::*;
1615

1716
/// Wraps `*gttrf`, `*gtcon` and `*gttrs`
18-
pub trait TriDiagonal_: Scalar + Sized {
17+
pub trait Tridiagonal_: Scalar + Sized {
1918
/// Computes the LU factorization of a tridiagonal `m x n` matrix `a` using
2019
/// partial pivoting with row interchanges.
21-
unsafe fn lu_tridiagonal(
22-
a: &mut TriDiagonal<Self>,
23-
) -> Result<(Array1<Self>, Self::Real, Pivot)>;
20+
unsafe fn lu_tridiagonal(a: &mut Tridiagonal<Self>) -> Result<(Vec<Self>, Self::Real, Pivot)>;
2421
/// Estimates the the reciprocal of the condition number of the tridiagonal matrix in 1-norm.
25-
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTriDiagonal<Self>) -> Result<Self::Real>;
22+
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real>;
2623
unsafe fn solve_tridiagonal(
27-
lu: &LUFactorizedTriDiagonal<Self>,
24+
lu: &LUFactorizedTridiagonal<Self>,
2825
bl: MatrixLayout,
2926
t: Transpose,
3027
b: &mut [Self],
@@ -33,37 +30,30 @@ pub trait TriDiagonal_: Scalar + Sized {
3330

3431
macro_rules! impl_tridiagonal {
3532
($scalar:ty, $gttrf:path, $gtcon:path, $gttrs:path) => {
36-
impl TriDiagonal_ for $scalar {
33+
impl Tridiagonal_ for $scalar {
3734
unsafe fn lu_tridiagonal(
38-
a: &mut TriDiagonal<Self>,
39-
) -> Result<(Array1<Self>, Self::Real, Pivot)> {
35+
a: &mut Tridiagonal<Self>,
36+
) -> Result<(Vec<Self>, Self::Real, Pivot)> {
4037
let (n, _) = a.l.size();
4138
let anom = a.opnorm_one()?;
42-
let dl = a.dl.as_slice_mut().unwrap();
43-
let d = a.d.as_slice_mut().unwrap();
44-
let du = a.du.as_slice_mut().unwrap();
4539
let mut du2 = vec![Zero::zero(); (n - 2) as usize];
4640
let mut ipiv = vec![0; n as usize];
47-
let info = $gttrf(n, dl, d, du, &mut du2, &mut ipiv);
48-
into_result(info, (arr1(&du2), anom, ipiv))
41+
let info = $gttrf(n, &mut a.dl, &mut a.d, &mut a.du, &mut du2, &mut ipiv);
42+
into_result(info, (du2, anom, ipiv))
4943
}
5044

51-
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTriDiagonal<Self>) -> Result<Self::Real> {
45+
unsafe fn rcond_tridiagonal(lu: &LUFactorizedTridiagonal<Self>) -> Result<Self::Real> {
5246
let (n, _) = lu.a.l.size();
53-
let dl = lu.a.dl.as_slice().unwrap();
54-
let d = lu.a.d.as_slice().unwrap();
55-
let du = lu.a.du.as_slice().unwrap();
56-
let du2 = lu.du2.as_slice().unwrap();
5747
let ipiv = &lu.ipiv;
5848
let anorm = lu.anom;
5949
let mut rcond = Self::Real::zero();
6050
let info = $gtcon(
6151
NormType::One as u8,
6252
n,
63-
dl,
64-
d,
65-
du,
66-
du2,
53+
&lu.a.dl,
54+
&lu.a.d,
55+
&lu.a.du,
56+
&lu.du2,
6757
ipiv,
6858
anorm,
6959
&mut rcond,
@@ -72,28 +62,24 @@ macro_rules! impl_tridiagonal {
7262
}
7363

7464
unsafe fn solve_tridiagonal(
75-
lu: &LUFactorizedTriDiagonal<Self>,
65+
lu: &LUFactorizedTridiagonal<Self>,
7666
bl: MatrixLayout,
7767
t: Transpose,
7868
b: &mut [Self],
7969
) -> Result<()> {
8070
let (n, _) = lu.a.l.size();
8171
let (_, nrhs) = bl.size();
82-
let dl = lu.a.dl.as_slice().unwrap();
83-
let d = lu.a.d.as_slice().unwrap();
84-
let du = lu.a.du.as_slice().unwrap();
85-
let du2 = lu.du2.as_slice().unwrap();
8672
let ipiv = &lu.ipiv;
8773
let ldb = bl.lda();
8874
let info = $gttrs(
8975
lu.a.l.lapacke_layout(),
9076
t as u8,
9177
n,
9278
nrhs,
93-
dl,
94-
d,
95-
du,
96-
du2,
79+
&lu.a.dl,
80+
&lu.a.d,
81+
&lu.a.du,
82+
&lu.du2,
9783
ipiv,
9884
b,
9985
ldb,

src/opnorm.rs

Lines changed: 7 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use ndarray::*;
55
use crate::convert::*;
66
use crate::error::*;
77
use crate::layout::*;
8-
use crate::tridiagonal::TriDiagonal;
8+
use crate::tridiagonal::Tridiagonal;
99
use crate::types::*;
1010

1111
pub use crate::lapack::NormType;
@@ -49,7 +49,7 @@ where
4949
}
5050
}
5151

52-
impl<A> OperationNorm for TriDiagonal<A>
52+
impl<A> OperationNorm for Tridiagonal<A>
5353
where
5454
A: Scalar + Lapack,
5555
{
@@ -73,12 +73,7 @@ where
7373
let zu: Array1<A> = Array::zeros(1);
7474
let dl = stack![Axis(0), self.dl.to_owned(), zl];
7575
let du = stack![Axis(0), zu, self.du.to_owned()];
76-
let arr = stack![
77-
Axis(0),
78-
into_row(du),
79-
into_row(self.d.to_owned()),
80-
into_row(dl)
81-
];
76+
let arr = stack![Axis(0), into_row(du), into_row(arr1(&self.d)), into_row(dl)];
8277
arr
8378
}
8479
// opnorm_inf() calculates muximum row sum.
@@ -93,12 +88,7 @@ where
9388
let zu: Array1<A> = Array::zeros(1);
9489
let dl = stack![Axis(0), zl, self.dl.to_owned()];
9590
let du = stack![Axis(0), self.du.to_owned(), zu];
96-
let arr = stack![
97-
Axis(1),
98-
into_col(dl),
99-
into_col(self.d.to_owned()),
100-
into_col(du)
101-
];
91+
let arr = stack![Axis(1), into_col(dl), into_col(arr1(&self.d)), into_col(du)];
10292
arr
10393
}
10494
// opnorm_fro() calculates square root of sum of squares.
@@ -108,9 +98,9 @@ where
10898
NormType::Frobenius => {
10999
let arr = stack![
110100
Axis(1),
111-
into_row(self.dl.to_owned()),
112-
into_row(self.d.to_owned()),
113-
into_row(self.du.to_owned())
101+
into_row(arr1(&self.dl)),
102+
into_row(arr1(&self.d)),
103+
into_row(arr1(&self.du))
114104
];
115105
arr
116106
}

0 commit comments

Comments
 (0)