1
1
//! Implement linear solver and inverse matrix
2
2
3
3
use super :: * ;
4
- use crate :: { error:: * , layout:: MatrixLayout } ;
4
+ use crate :: { error:: * , layout:: * } ;
5
5
use cauchy:: * ;
6
+ use num_traits:: Zero ;
6
7
7
8
#[ derive( Debug , Clone , Copy ) ]
8
9
#[ repr( u8 ) ]
@@ -12,9 +13,8 @@ pub enum Diag {
12
13
}
13
14
14
15
/// Wraps `*trtri` and `*trtrs`
15
- pub trait Triangular_ : Sized {
16
- unsafe fn inv_triangular ( l : MatrixLayout , uplo : UPLO , d : Diag , a : & mut [ Self ] ) -> Result < ( ) > ;
17
- unsafe fn solve_triangular (
16
+ pub trait Triangular_ : Scalar {
17
+ fn solve_triangular (
18
18
al : MatrixLayout ,
19
19
bl : MatrixLayout ,
20
20
uplo : UPLO ,
@@ -27,50 +27,66 @@ pub trait Triangular_: Sized {
27
27
macro_rules! impl_triangular {
28
28
( $scalar: ty, $trtri: path, $trtrs: path) => {
29
29
impl Triangular_ for $scalar {
30
- unsafe fn inv_triangular(
31
- l: MatrixLayout ,
32
- uplo: UPLO ,
33
- diag: Diag ,
34
- a: & mut [ Self ] ,
35
- ) -> Result <( ) > {
36
- let ( n, _) = l. size( ) ;
37
- let lda = l. lda( ) ;
38
- $trtri( l. lapacke_layout( ) , uplo as u8 , diag as u8 , n, a, lda) . as_lapack_result( ) ?;
39
- Ok ( ( ) )
40
- }
41
-
42
- unsafe fn solve_triangular(
43
- al: MatrixLayout ,
44
- bl: MatrixLayout ,
30
+ fn solve_triangular(
31
+ a_layout: MatrixLayout ,
32
+ b_layout: MatrixLayout ,
45
33
uplo: UPLO ,
46
34
diag: Diag ,
47
35
a: & [ Self ] ,
48
- mut b: & mut [ Self ] ,
36
+ b: & mut [ Self ] ,
49
37
) -> Result <( ) > {
50
- let ( n, _) = al. size( ) ;
51
- let lda = al. lda( ) ;
52
- let ( _, nrhs) = bl. size( ) ;
53
- let ldb = bl. lda( ) ;
54
- $trtrs(
55
- al. lapacke_layout( ) ,
56
- uplo as u8 ,
57
- Transpose :: No as u8 ,
58
- diag as u8 ,
59
- n,
60
- nrhs,
61
- a,
62
- lda,
63
- & mut b,
64
- ldb,
65
- )
66
- . as_lapack_result( ) ?;
38
+ // Transpose if a is C-continuous
39
+ let mut a_t = None ;
40
+ let a_layout = match a_layout {
41
+ MatrixLayout :: C { .. } => {
42
+ a_t = Some ( vec![ Self :: zero( ) ; a. len( ) ] ) ;
43
+ transpose( a_layout, a, a_t. as_mut( ) . unwrap( ) )
44
+ }
45
+ MatrixLayout :: F { .. } => a_layout,
46
+ } ;
47
+
48
+ // Transpose if b is C-continuous
49
+ let mut b_t = None ;
50
+ let b_layout = match b_layout {
51
+ MatrixLayout :: C { .. } => {
52
+ b_t = Some ( vec![ Self :: zero( ) ; b. len( ) ] ) ;
53
+ transpose( b_layout, b, b_t. as_mut( ) . unwrap( ) )
54
+ }
55
+ MatrixLayout :: F { .. } => b_layout,
56
+ } ;
57
+
58
+ let ( m, n) = a_layout. size( ) ;
59
+ let ( n_, nrhs) = b_layout. size( ) ;
60
+ assert_eq!( n, n_) ;
61
+
62
+ let mut info = 0 ;
63
+ unsafe {
64
+ $trtrs(
65
+ uplo as u8 ,
66
+ Transpose :: No as u8 ,
67
+ diag as u8 ,
68
+ m,
69
+ nrhs,
70
+ a_t. as_ref( ) . map( |v| v. as_slice( ) ) . unwrap_or( a) ,
71
+ a_layout. lda( ) ,
72
+ b_t. as_mut( ) . map( |v| v. as_mut_slice( ) ) . unwrap_or( b) ,
73
+ b_layout. lda( ) ,
74
+ & mut info,
75
+ ) ;
76
+ }
77
+ info. as_lapack_result( ) ?;
78
+
79
+ // Re-transpose b
80
+ if let Some ( b_t) = b_t {
81
+ transpose( b_layout, & b_t, b) ;
82
+ }
67
83
Ok ( ( ) )
68
84
}
69
85
}
70
86
} ;
71
87
} // impl_triangular!
72
88
73
- impl_triangular ! ( f64 , lapacke :: dtrtri, lapacke :: dtrtrs) ;
74
- impl_triangular ! ( f32 , lapacke :: strtri, lapacke :: strtrs) ;
75
- impl_triangular ! ( c64, lapacke :: ztrtri, lapacke :: ztrtrs) ;
76
- impl_triangular ! ( c32, lapacke :: ctrtri, lapacke :: ctrtrs) ;
89
+ impl_triangular ! ( f64 , lapack :: dtrtri, lapack :: dtrtrs) ;
90
+ impl_triangular ! ( f32 , lapack :: strtri, lapack :: strtrs) ;
91
+ impl_triangular ! ( c64, lapack :: ztrtri, lapack :: ztrtrs) ;
92
+ impl_triangular ! ( c32, lapack :: ctrtri, lapack :: ctrtrs) ;
0 commit comments