@@ -64,6 +64,69 @@ pub trait Solve_: Scalar + Sized {
64
64
fn solve ( l : MatrixLayout , t : Transpose , a : & [ Self ] , p : & Pivot , b : & mut [ Self ] ) -> Result < ( ) > ;
65
65
}
66
66
67
+ pub struct InvWork < T : Scalar > {
68
+ pub layout : MatrixLayout ,
69
+ pub work : Vec < MaybeUninit < T > > ,
70
+ }
71
+
72
+ pub trait InvWorkImpl : Sized {
73
+ type Elem : Scalar ;
74
+ fn new ( layout : MatrixLayout ) -> Result < Self > ;
75
+ fn calc ( & mut self , a : & mut [ Self :: Elem ] , p : & Pivot ) -> Result < ( ) > ;
76
+ }
77
+
78
+ macro_rules! impl_inv_work {
79
+ ( $s: ty, $tri: path) => {
80
+ impl InvWorkImpl for InvWork <$s> {
81
+ type Elem = $s;
82
+
83
+ fn new( layout: MatrixLayout ) -> Result <Self > {
84
+ let ( n, _) = layout. size( ) ;
85
+ let mut info = 0 ;
86
+ let mut work_size = [ Self :: Elem :: zero( ) ] ;
87
+ unsafe {
88
+ $tri(
89
+ & n,
90
+ std:: ptr:: null_mut( ) ,
91
+ & layout. lda( ) ,
92
+ std:: ptr:: null( ) ,
93
+ AsPtr :: as_mut_ptr( & mut work_size) ,
94
+ & ( -1 ) ,
95
+ & mut info,
96
+ )
97
+ } ;
98
+ info. as_lapack_result( ) ?;
99
+ let lwork = work_size[ 0 ] . to_usize( ) . unwrap( ) ;
100
+ let work = vec_uninit( lwork) ;
101
+ Ok ( InvWork { layout, work } )
102
+ }
103
+
104
+ fn calc( & mut self , a: & mut [ Self :: Elem ] , ipiv: & Pivot ) -> Result <( ) > {
105
+ let lwork = self . work. len( ) . to_i32( ) . unwrap( ) ;
106
+ let mut info = 0 ;
107
+ unsafe {
108
+ $tri(
109
+ & self . layout. len( ) ,
110
+ AsPtr :: as_mut_ptr( a) ,
111
+ & self . layout. lda( ) ,
112
+ ipiv. as_ptr( ) ,
113
+ AsPtr :: as_mut_ptr( & mut self . work) ,
114
+ & lwork,
115
+ & mut info,
116
+ )
117
+ } ;
118
+ info. as_lapack_result( ) ?;
119
+ Ok ( ( ) )
120
+ }
121
+ }
122
+ } ;
123
+ }
124
+
125
+ impl_inv_work ! ( c64, lapack_sys:: zgetri_) ;
126
+ impl_inv_work ! ( c32, lapack_sys:: cgetri_) ;
127
+ impl_inv_work ! ( f64 , lapack_sys:: dgetri_) ;
128
+ impl_inv_work ! ( f32 , lapack_sys:: sgetri_) ;
129
+
67
130
macro_rules! impl_solve {
68
131
( $scalar: ty, $getrf: path, $getri: path, $getrs: path) => {
69
132
impl Solve_ for $scalar {
0 commit comments