@@ -28,7 +28,7 @@ use libc::c_int;
28
28
#[ cfg( feature = "blas" ) ]
29
29
use cblas_sys as blas_sys;
30
30
#[ cfg( feature = "blas" ) ]
31
- use cblas_sys:: { CblasNoTrans , CblasTrans , CBLAS_LAYOUT } ;
31
+ use cblas_sys:: { CblasNoTrans , CblasTrans , CBLAS_LAYOUT , CBLAS_TRANSPOSE } ;
32
32
33
33
/// len of vector before we use blas
34
34
#[ cfg( feature = "blas" ) ]
@@ -400,40 +400,33 @@ fn mat_mul_impl<A>(
400
400
// Compute A B -> C
401
401
// We require for BLAS compatibility that:
402
402
// A, B, C are contiguous (stride=1) in their fastest dimension,
403
- // but it can be either first or second axis (either rowmajor /"c" or colmajor /"f") .
403
+ // but they can be either row major /"c" or col major /"f".
404
404
//
405
405
// The "normal case" is CblasRowMajor for cblas.
406
- // Select CblasRowMajor, CblasColMajor to fit C's memory order.
406
+ // Select CblasRowMajor / CblasColMajor to fit C's memory order.
407
407
//
408
- // Apply transpose to A, B as needed if they differ from the normal case.
408
+ // Apply transpose to A, B as needed if they differ from the row major case.
409
409
// If C is CblasColMajor then transpose both A, B (again!)
410
410
411
- let ( a_layout, a_axis, b_layout, b_axis, c_layout) =
412
- match ( get_blas_compatible_layout ( a) ,
413
- get_blas_compatible_layout ( b) ,
414
- get_blas_compatible_layout ( c) )
411
+ let ( a_layout, b_layout, c_layout) =
412
+ if let ( Some ( a_layout) , Some ( b_layout) , Some ( c_layout) ) =
413
+ ( get_blas_compatible_layout ( a) ,
414
+ get_blas_compatible_layout ( b) ,
415
+ get_blas_compatible_layout ( c) )
415
416
{
416
- ( Some ( a_layout) , Some ( b_layout) , Some ( c_layout @ MemoryOrder :: C ) ) => {
417
- ( a_layout, a_layout. lead_axis ( ) ,
418
- b_layout, b_layout. lead_axis ( ) , c_layout)
419
- } ,
420
- ( Some ( a_layout) , Some ( b_layout) , Some ( c_layout @ MemoryOrder :: F ) ) => {
421
- // CblasColMajor is the "other case"
422
- // Mark a, b as having layouts opposite of what they were detected as, which
423
- // ends up with the correct transpose setting w.r.t col major
424
- ( a_layout. opposite ( ) , a_layout. lead_axis ( ) ,
425
- b_layout. opposite ( ) , b_layout. lead_axis ( ) , c_layout)
426
- } ,
427
- _ => break ' blas_block,
417
+ ( a_layout, b_layout, c_layout)
418
+ } else {
419
+ break ' blas_block;
428
420
} ;
429
421
430
- let a_trans = a_layout. to_cblas_transpose ( ) ;
431
- let lda = blas_stride ( & a, a_axis) ;
422
+ let cblas_layout = c_layout. to_cblas_layout ( ) ;
423
+ let a_trans = a_layout. to_cblas_transpose_for ( cblas_layout) ;
424
+ let lda = blas_stride ( & a, a_layout) ;
432
425
433
- let b_trans = b_layout. to_cblas_transpose ( ) ;
434
- let ldb = blas_stride ( & b, b_axis ) ;
426
+ let b_trans = b_layout. to_cblas_transpose_for ( cblas_layout ) ;
427
+ let ldb = blas_stride ( & b, b_layout ) ;
435
428
436
- let ldc = blas_stride ( & c, c_layout. lead_axis ( ) ) ;
429
+ let ldc = blas_stride ( & c, c_layout) ;
437
430
438
431
macro_rules! gemm_scalar_cast {
439
432
( f32 , $var: ident) => {
@@ -457,7 +450,7 @@ fn mat_mul_impl<A>(
457
450
// Where Op is notrans/trans/conjtrans
458
451
unsafe {
459
452
blas_sys:: $gemm(
460
- c_layout . to_cblas_layout ( ) ,
453
+ cblas_layout ,
461
454
a_trans,
462
455
b_trans,
463
456
m as blas_index, // m, rows of Op(a)
@@ -696,16 +689,8 @@ unsafe fn general_mat_vec_mul_impl<A, S1, S2>(
696
689
// may be arbitrary.
697
690
let a_trans = CblasNoTrans ;
698
691
699
- let ( a_stride, cblas_layout) = match layout {
700
- MemoryOrder :: C => {
701
- ( a. strides( ) [ 0 ] . max( k as isize ) as blas_index,
702
- CBLAS_LAYOUT :: CblasRowMajor )
703
- }
704
- MemoryOrder :: F => {
705
- ( a. strides( ) [ 1 ] . max( m as isize ) as blas_index,
706
- CBLAS_LAYOUT :: CblasColMajor )
707
- }
708
- } ;
692
+ let a_stride = blas_stride( & a, layout) ;
693
+ let cblas_layout = layout. to_cblas_layout( ) ;
709
694
710
695
// Low addr in memory pointers required for x, y
711
696
let x_offset = offset_from_low_addr_ptr_to_logical_ptr( & x. dim, & x. strides) ;
@@ -832,64 +817,68 @@ where
832
817
true
833
818
}
834
819
835
- #[ cfg( feature = "blas" ) ]
836
820
#[ derive( Copy , Clone ) ]
837
821
#[ cfg_attr( test, derive( PartialEq , Eq , Debug ) ) ]
838
- enum MemoryOrder
822
+ enum BlasOrder
839
823
{
840
824
C ,
841
825
F ,
842
826
}
843
827
844
828
#[ cfg( feature = "blas" ) ]
845
- impl MemoryOrder
829
+ impl BlasOrder
846
830
{
847
- #[ inline]
848
- /// Axis of leading stride (opposite of contiguous axis)
849
- fn lead_axis ( self ) -> usize
831
+ fn transpose ( self ) -> Self
850
832
{
851
833
match self {
852
- MemoryOrder :: C => 0 ,
853
- MemoryOrder :: F => 1 ,
834
+ Self :: C => Self :: F ,
835
+ Self :: F => Self :: C ,
854
836
}
855
837
}
856
838
857
- /// Get opposite memory order
858
839
#[ inline]
859
- fn opposite ( self ) -> Self
840
+ /// Axis of leading stride (opposite of contiguous axis)
841
+ fn get_blas_lead_axis ( self ) -> usize
860
842
{
861
843
match self {
862
- MemoryOrder :: C => MemoryOrder :: F ,
863
- MemoryOrder :: F => MemoryOrder :: C ,
844
+ Self :: C => 0 ,
845
+ Self :: F => 1 ,
864
846
}
865
847
}
866
848
867
- fn to_cblas_transpose ( self ) -> cblas_sys :: CBLAS_TRANSPOSE
849
+ fn to_cblas_layout ( self ) -> CBLAS_LAYOUT
868
850
{
869
851
match self {
870
- MemoryOrder :: C => CblasNoTrans ,
871
- MemoryOrder :: F => CblasTrans ,
852
+ Self :: C => CBLAS_LAYOUT :: CblasRowMajor ,
853
+ Self :: F => CBLAS_LAYOUT :: CblasColMajor ,
872
854
}
873
855
}
874
856
875
- fn to_cblas_layout ( self ) -> CBLAS_LAYOUT
857
+ /// When using cblas_sgemm (etc) with C matrix using `for_layout`,
858
+ /// how should this `self` matrix be transposed
859
+ fn to_cblas_transpose_for ( self , for_layout : CBLAS_LAYOUT ) -> CBLAS_TRANSPOSE
876
860
{
877
- match self {
878
- MemoryOrder :: C => CBLAS_LAYOUT :: CblasRowMajor ,
879
- MemoryOrder :: F => CBLAS_LAYOUT :: CblasColMajor ,
861
+ let effective_order = match for_layout {
862
+ CBLAS_LAYOUT :: CblasRowMajor => self ,
863
+ CBLAS_LAYOUT :: CblasColMajor => self . transpose ( ) ,
864
+ } ;
865
+
866
+ match effective_order {
867
+ Self :: C => CblasNoTrans ,
868
+ Self :: F => CblasTrans ,
880
869
}
881
870
}
882
871
}
883
872
884
873
#[ cfg( feature = "blas" ) ]
885
- fn is_blas_2d ( dim : & Ix2 , stride : & Ix2 , order : MemoryOrder ) -> bool
874
+ fn is_blas_2d ( dim : & Ix2 , stride : & Ix2 , order : BlasOrder ) -> bool
886
875
{
887
876
let ( m, n) = dim. into_pattern ( ) ;
888
877
let s0 = stride[ 0 ] as isize ;
889
878
let s1 = stride[ 1 ] as isize ;
890
879
let ( inner_stride, outer_stride, inner_dim, outer_dim) = match order {
891
- MemoryOrder :: C => ( s1, s0, m, n) ,
892
- MemoryOrder :: F => ( s0, s1, n, m) ,
880
+ BlasOrder :: C => ( s1, s0, m, n) ,
881
+ BlasOrder :: F => ( s0, s1, n, m) ,
893
882
} ;
894
883
895
884
if !( inner_stride == 1 || outer_dim == 1 ) {
@@ -920,13 +909,13 @@ fn is_blas_2d(dim: &Ix2, stride: &Ix2, order: MemoryOrder) -> bool
920
909
921
910
/// Get BLAS compatible layout if any (C or F, preferring the former)
922
911
#[ cfg( feature = "blas" ) ]
923
- fn get_blas_compatible_layout < S > ( a : & ArrayBase < S , Ix2 > ) -> Option < MemoryOrder >
912
+ fn get_blas_compatible_layout < S > ( a : & ArrayBase < S , Ix2 > ) -> Option < BlasOrder >
924
913
where S : Data
925
914
{
926
- if is_blas_2d ( & a. dim , & a. strides , MemoryOrder :: C ) {
927
- Some ( MemoryOrder :: C )
928
- } else if is_blas_2d ( & a. dim , & a. strides , MemoryOrder :: F ) {
929
- Some ( MemoryOrder :: F )
915
+ if is_blas_2d ( & a. dim , & a. strides , BlasOrder :: C ) {
916
+ Some ( BlasOrder :: C )
917
+ } else if is_blas_2d ( & a. dim , & a. strides , BlasOrder :: F ) {
918
+ Some ( BlasOrder :: F )
930
919
} else {
931
920
None
932
921
}
@@ -937,10 +926,10 @@ where S: Data
937
926
///
938
927
/// Return leading stride (lda, ldb, ldc) of array
939
928
#[ cfg( feature = "blas" ) ]
940
- fn blas_stride < S > ( a : & ArrayBase < S , Ix2 > , axis : usize ) -> blas_index
929
+ fn blas_stride < S > ( a : & ArrayBase < S , Ix2 > , order : BlasOrder ) -> blas_index
941
930
where S : Data
942
931
{
943
- debug_assert ! ( axis <= 1 ) ;
932
+ let axis = order . get_blas_lead_axis ( ) ;
944
933
let other_axis = 1 - axis;
945
934
let len_this = a. shape ( ) [ axis] ;
946
935
let len_other = a. shape ( ) [ other_axis] ;
@@ -968,7 +957,7 @@ where
968
957
if !same_type :: < A , S :: Elem > ( ) {
969
958
return false ;
970
959
}
971
- is_blas_2d ( & a. dim , & a. strides , MemoryOrder :: C )
960
+ is_blas_2d ( & a. dim , & a. strides , BlasOrder :: C )
972
961
}
973
962
974
963
#[ cfg( test) ]
@@ -982,7 +971,7 @@ where
982
971
if !same_type :: < A , S :: Elem > ( ) {
983
972
return false ;
984
973
}
985
- is_blas_2d ( & a. dim , & a. strides , MemoryOrder :: F )
974
+ is_blas_2d ( & a. dim , & a. strides , BlasOrder :: F )
986
975
}
987
976
988
977
#[ cfg( test) ]
@@ -1096,7 +1085,7 @@ mod blas_tests
1096
1085
if stride < N {
1097
1086
assert_eq ! ( get_blas_compatible_layout( & m) , None ) ;
1098
1087
} else {
1099
- assert_eq ! ( get_blas_compatible_layout( & m) , Some ( MemoryOrder :: C ) ) ;
1088
+ assert_eq ! ( get_blas_compatible_layout( & m) , Some ( BlasOrder :: C ) ) ;
1100
1089
}
1101
1090
}
1102
1091
}
0 commit comments