Skip to content

Commit 8c2eac0

Browse files
committed
blas: Refactor and simplify gemm call further
Further clarify transpose logic by putting it into BlasOrder methods.
1 parent 7226d39 commit 8c2eac0

File tree

1 file changed

+57
-68
lines changed

1 file changed

+57
-68
lines changed

src/linalg/impl_linalg.rs

Lines changed: 57 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ use libc::c_int;
2828
#[cfg(feature = "blas")]
2929
use cblas_sys as blas_sys;
3030
#[cfg(feature = "blas")]
31-
use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT};
31+
use cblas_sys::{CblasNoTrans, CblasTrans, CBLAS_LAYOUT, CBLAS_TRANSPOSE};
3232

3333
/// len of vector before we use blas
3434
#[cfg(feature = "blas")]
@@ -400,40 +400,33 @@ fn mat_mul_impl<A>(
400400
// Compute A B -> C
401401
// We require for BLAS compatibility that:
402402
// A, B, C are contiguous (stride=1) in their fastest dimension,
403-
// but it can be either first or second axis (either rowmajor/"c" or colmajor/"f").
403+
// but they can be either row major/"c" or col major/"f".
404404
//
405405
// The "normal case" is CblasRowMajor for cblas.
406-
// Select CblasRowMajor, CblasColMajor to fit C's memory order.
406+
// Select CblasRowMajor / CblasColMajor to fit C's memory order.
407407
//
408-
// Apply transpose to A, B as needed if they differ from the normal case.
408+
// Apply transpose to A, B as needed if they differ from the row major case.
409409
// If C is CblasColMajor then transpose both A, B (again!)
410410

411-
let (a_layout, a_axis, b_layout, b_axis, c_layout) =
412-
match (get_blas_compatible_layout(a),
413-
get_blas_compatible_layout(b),
414-
get_blas_compatible_layout(c))
411+
let (a_layout, b_layout, c_layout) =
412+
if let (Some(a_layout), Some(b_layout), Some(c_layout)) =
413+
(get_blas_compatible_layout(a),
414+
get_blas_compatible_layout(b),
415+
get_blas_compatible_layout(c))
415416
{
416-
(Some(a_layout), Some(b_layout), Some(c_layout @ MemoryOrder::C)) => {
417-
(a_layout, a_layout.lead_axis(),
418-
b_layout, b_layout.lead_axis(), c_layout)
419-
},
420-
(Some(a_layout), Some(b_layout), Some(c_layout @ MemoryOrder::F)) => {
421-
// CblasColMajor is the "other case"
422-
// Mark a, b as having layouts opposite of what they were detected as, which
423-
// ends up with the correct transpose setting w.r.t col major
424-
(a_layout.opposite(), a_layout.lead_axis(),
425-
b_layout.opposite(), b_layout.lead_axis(), c_layout)
426-
},
427-
_ => break 'blas_block,
417+
(a_layout, b_layout, c_layout)
418+
} else {
419+
break 'blas_block;
428420
};
429421

430-
let a_trans = a_layout.to_cblas_transpose();
431-
let lda = blas_stride(&a, a_axis);
422+
let cblas_layout = c_layout.to_cblas_layout();
423+
let a_trans = a_layout.to_cblas_transpose_for(cblas_layout);
424+
let lda = blas_stride(&a, a_layout);
432425

433-
let b_trans = b_layout.to_cblas_transpose();
434-
let ldb = blas_stride(&b, b_axis);
426+
let b_trans = b_layout.to_cblas_transpose_for(cblas_layout);
427+
let ldb = blas_stride(&b, b_layout);
435428

436-
let ldc = blas_stride(&c, c_layout.lead_axis());
429+
let ldc = blas_stride(&c, c_layout);
437430

438431
macro_rules! gemm_scalar_cast {
439432
(f32, $var:ident) => {
@@ -457,7 +450,7 @@ fn mat_mul_impl<A>(
457450
// Where Op is notrans/trans/conjtrans
458451
unsafe {
459452
blas_sys::$gemm(
460-
c_layout.to_cblas_layout(),
453+
cblas_layout,
461454
a_trans,
462455
b_trans,
463456
m as blas_index, // m, rows of Op(a)
@@ -696,16 +689,8 @@ unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
696689
// may be arbitrary.
697690
let a_trans = CblasNoTrans;
698691

699-
let (a_stride, cblas_layout) = match layout {
700-
MemoryOrder::C => {
701-
(a.strides()[0].max(k as isize) as blas_index,
702-
CBLAS_LAYOUT::CblasRowMajor)
703-
}
704-
MemoryOrder::F => {
705-
(a.strides()[1].max(m as isize) as blas_index,
706-
CBLAS_LAYOUT::CblasColMajor)
707-
}
708-
};
692+
let a_stride = blas_stride(&a, layout);
693+
let cblas_layout = layout.to_cblas_layout();
709694

710695
// Low addr in memory pointers required for x, y
711696
let x_offset = offset_from_low_addr_ptr_to_logical_ptr(&x.dim, &x.strides);
@@ -832,64 +817,68 @@ where
832817
true
833818
}
834819

835-
#[cfg(feature = "blas")]
836820
#[derive(Copy, Clone)]
837821
#[cfg_attr(test, derive(PartialEq, Eq, Debug))]
838-
enum MemoryOrder
822+
enum BlasOrder
839823
{
840824
C,
841825
F,
842826
}
843827

844828
#[cfg(feature = "blas")]
845-
impl MemoryOrder
829+
impl BlasOrder
846830
{
847-
#[inline]
848-
/// Axis of leading stride (opposite of contiguous axis)
849-
fn lead_axis(self) -> usize
831+
fn transpose(self) -> Self
850832
{
851833
match self {
852-
MemoryOrder::C => 0,
853-
MemoryOrder::F => 1,
834+
Self::C => Self::F,
835+
Self::F => Self::C,
854836
}
855837
}
856838

857-
/// Get opposite memory order
858839
#[inline]
859-
fn opposite(self) -> Self
840+
/// Axis of leading stride (opposite of contiguous axis)
841+
fn get_blas_lead_axis(self) -> usize
860842
{
861843
match self {
862-
MemoryOrder::C => MemoryOrder::F,
863-
MemoryOrder::F => MemoryOrder::C,
844+
Self::C => 0,
845+
Self::F => 1,
864846
}
865847
}
866848

867-
fn to_cblas_transpose(self) -> cblas_sys::CBLAS_TRANSPOSE
849+
fn to_cblas_layout(self) -> CBLAS_LAYOUT
868850
{
869851
match self {
870-
MemoryOrder::C => CblasNoTrans,
871-
MemoryOrder::F => CblasTrans,
852+
Self::C => CBLAS_LAYOUT::CblasRowMajor,
853+
Self::F => CBLAS_LAYOUT::CblasColMajor,
872854
}
873855
}
874856

875-
fn to_cblas_layout(self) -> CBLAS_LAYOUT
857+
/// When using cblas_sgemm (etc) with C matrix using `for_layout`,
858+
/// how should this `self` matrix be transposed
859+
fn to_cblas_transpose_for(self, for_layout: CBLAS_LAYOUT) -> CBLAS_TRANSPOSE
876860
{
877-
match self {
878-
MemoryOrder::C => CBLAS_LAYOUT::CblasRowMajor,
879-
MemoryOrder::F => CBLAS_LAYOUT::CblasColMajor,
861+
let effective_order = match for_layout {
862+
CBLAS_LAYOUT::CblasRowMajor => self,
863+
CBLAS_LAYOUT::CblasColMajor => self.transpose(),
864+
};
865+
866+
match effective_order {
867+
Self::C => CblasNoTrans,
868+
Self::F => CblasTrans,
880869
}
881870
}
882871
}
883872

884873
#[cfg(feature = "blas")]
885-
fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
874+
fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: BlasOrder) -> bool
886875
{
887876
let (m, n) = dim.into_pattern();
888877
let s0 = stride[0] as isize;
889878
let s1 = stride[1] as isize;
890879
let (inner_stride, outer_stride, inner_dim, outer_dim) = match order {
891-
MemoryOrder::C => (s1, s0, m, n),
892-
MemoryOrder::F => (s0, s1, n, m),
880+
BlasOrder::C => (s1, s0, m, n),
881+
BlasOrder::F => (s0, s1, n, m),
893882
};
894883

895884
if !(inner_stride == 1 || outer_dim == 1) {
@@ -920,13 +909,13 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
920909

921910
/// Get BLAS compatible layout if any (C or F, preferring the former)
922911
#[cfg(feature = "blas")]
923-
fn get_blas_compatible_layout<S>(a: &ArrayBase<S, Ix2>) -> Option<MemoryOrder>
912+
fn get_blas_compatible_layout<S>(a: &ArrayBase<S, Ix2>) -> Option<BlasOrder>
924913
where S: Data
925914
{
926-
if is_blas_2d(&a.dim, &a.strides, MemoryOrder::C) {
927-
Some(MemoryOrder::C)
928-
} else if is_blas_2d(&a.dim, &a.strides, MemoryOrder::F) {
929-
Some(MemoryOrder::F)
915+
if is_blas_2d(&a.dim, &a.strides, BlasOrder::C) {
916+
Some(BlasOrder::C)
917+
} else if is_blas_2d(&a.dim, &a.strides, BlasOrder::F) {
918+
Some(BlasOrder::F)
930919
} else {
931920
None
932921
}
@@ -937,10 +926,10 @@ where S: Data
937926
///
938927
/// Return leading stride (lda, ldb, ldc) of array
939928
#[cfg(feature = "blas")]
940-
fn blas_stride<S>(a: &ArrayBase<S, Ix2>, axis: usize) -> blas_index
929+
fn blas_stride<S>(a: &ArrayBase<S, Ix2>, order: BlasOrder) -> blas_index
941930
where S: Data
942931
{
943-
debug_assert!(axis <= 1);
932+
let axis = order.get_blas_lead_axis();
944933
let other_axis = 1 - axis;
945934
let len_this = a.shape()[axis];
946935
let len_other = a.shape()[other_axis];
@@ -968,7 +957,7 @@ where
968957
if !same_type::<A, S::Elem>() {
969958
return false;
970959
}
971-
is_blas_2d(&a.dim, &a.strides, MemoryOrder::C)
960+
is_blas_2d(&a.dim, &a.strides, BlasOrder::C)
972961
}
973962

974963
#[cfg(test)]
@@ -982,7 +971,7 @@ where
982971
if !same_type::<A, S::Elem>() {
983972
return false;
984973
}
985-
is_blas_2d(&a.dim, &a.strides, MemoryOrder::F)
974+
is_blas_2d(&a.dim, &a.strides, BlasOrder::F)
986975
}
987976

988977
#[cfg(test)]
@@ -1096,7 +1085,7 @@ mod blas_tests
10961085
if stride < N {
10971086
assert_eq!(get_blas_compatible_layout(&m), None);
10981087
} else {
1099-
assert_eq!(get_blas_compatible_layout(&m), Some(MemoryOrder::C));
1088+
assert_eq!(get_blas_compatible_layout(&m), Some(BlasOrder::C));
11001089
}
11011090
}
11021091
}

0 commit comments

Comments
 (0)