Skip to content

Commit 77235f4

Browse files
committed
InvWork
1 parent 07ab31d commit 77235f4

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

lax/src/solve.rs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,69 @@ pub trait Solve_: Scalar + Sized {
6464
fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
6565
}
6666

67+
pub struct InvWork<T: Scalar> {
68+
pub layout: MatrixLayout,
69+
pub work: Vec<MaybeUninit<T>>,
70+
}
71+
72+
pub trait InvWorkImpl: Sized {
73+
type Elem: Scalar;
74+
fn new(layout: MatrixLayout) -> Result<Self>;
75+
fn calc(&mut self, a: &mut [Self::Elem], p: &Pivot) -> Result<()>;
76+
}
77+
78+
macro_rules! impl_inv_work {
79+
($s:ty, $tri:path) => {
80+
impl InvWorkImpl for InvWork<$s> {
81+
type Elem = $s;
82+
83+
fn new(layout: MatrixLayout) -> Result<Self> {
84+
let (n, _) = layout.size();
85+
let mut info = 0;
86+
let mut work_size = [Self::Elem::zero()];
87+
unsafe {
88+
$tri(
89+
&n,
90+
std::ptr::null_mut(),
91+
&layout.lda(),
92+
std::ptr::null(),
93+
AsPtr::as_mut_ptr(&mut work_size),
94+
&(-1),
95+
&mut info,
96+
)
97+
};
98+
info.as_lapack_result()?;
99+
let lwork = work_size[0].to_usize().unwrap();
100+
let work = vec_uninit(lwork);
101+
Ok(InvWork { layout, work })
102+
}
103+
104+
fn calc(&mut self, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> {
105+
let lwork = self.work.len().to_i32().unwrap();
106+
let mut info = 0;
107+
unsafe {
108+
$tri(
109+
&self.layout.len(),
110+
AsPtr::as_mut_ptr(a),
111+
&self.layout.lda(),
112+
ipiv.as_ptr(),
113+
AsPtr::as_mut_ptr(&mut self.work),
114+
&lwork,
115+
&mut info,
116+
)
117+
};
118+
info.as_lapack_result()?;
119+
Ok(())
120+
}
121+
}
122+
};
123+
}
124+
125+
impl_inv_work!(c64, lapack_sys::zgetri_);
126+
impl_inv_work!(c32, lapack_sys::cgetri_);
127+
impl_inv_work!(f64, lapack_sys::dgetri_);
128+
impl_inv_work!(f32, lapack_sys::sgetri_);
129+
67130
macro_rules! impl_solve {
68131
($scalar:ty, $getrf:path, $getri:path, $getrs:path) => {
69132
impl Solve_ for $scalar {

0 commit comments

Comments
 (0)