9
9
# define ffloat float
10
10
#endif
11
11
12
- TH_EXTERNC void dswap_ (int * n , double * x , int * incx , double * y , int * incy );
13
- TH_EXTERNC void sswap_ (int * n , float * x , int * incx , float * y , int * incy );
14
- TH_EXTERNC void dscal_ (int * n , double * a , double * x , int * incx );
15
- TH_EXTERNC void sscal_ (int * n , float * a , float * x , int * incx );
16
- TH_EXTERNC void dcopy_ (int * n , double * x , int * incx , double * y , int * incy );
17
- TH_EXTERNC void scopy_ (int * n , float * x , int * incx , float * y , int * incy );
18
- TH_EXTERNC void daxpy_ (int * n , double * a , double * x , int * incx , double * y , int * incy );
19
- TH_EXTERNC void saxpy_ (int * n , float * a , float * x , int * incx , float * y , int * incy );
20
- TH_EXTERNC double ddot_ (int * n , double * x , int * incx , double * y , int * incy );
21
- TH_EXTERNC ffloat sdot_ (int * n , float * x , int * incx , float * y , int * incy );
22
- TH_EXTERNC void dgemv_ (char * trans , int * m , int * n , double * alpha , double * a , int * lda , double * x , int * incx , double * beta , double * y , int * incy );
23
- TH_EXTERNC void sgemv_ (char * trans , int * m , int * n , float * alpha , float * a , int * lda , float * x , int * incx , float * beta , float * y , int * incy );
24
- TH_EXTERNC void dger_ (int * m , int * n , double * alpha , double * x , int * incx , double * y , int * incy , double * a , int * lda );
25
- TH_EXTERNC void sger_ (int * m , int * n , float * alpha , float * x , int * incx , float * y , int * incy , float * a , int * lda );
26
- TH_EXTERNC void dgemm_ (char * transa , char * transb , int * m , int * n , int * k , double * alpha , double * a , int * lda , double * b , int * ldb , double * beta , double * c , int * ldc );
27
- TH_EXTERNC void sgemm_ (char * transa , char * transb , int * m , int * n , int * k , float * alpha , float * a , int * lda , float * b , int * ldb , float * beta , float * c , int * ldc );
12
+ // define MKL_LP64 to get 32bit ints on 64bit platforms
13
+ #ifndef MKL_LP64
14
+ // 64bit ints
15
+ #ifdef WIN32
16
+ #define BLAS_INT __int64
17
+ #else
18
+ #define BLAS_INT long
19
+ #endif
20
+ #else
21
+ // 32bit ints
22
+ #define BLAS_INT int
23
+ #endif
28
24
29
25
26
+ TH_EXTERNC void dswap_ (BLAS_INT * n , double * x , BLAS_INT * incx , double * y , BLAS_INT * incy );
27
+ TH_EXTERNC void sswap_ (BLAS_INT * n , float * x , BLAS_INT * incx , float * y , BLAS_INT * incy );
28
+ TH_EXTERNC void dscal_ (BLAS_INT * n , double * a , double * x , BLAS_INT * incx );
29
+ TH_EXTERNC void sscal_ (BLAS_INT * n , float * a , float * x , BLAS_INT * incx );
30
+ TH_EXTERNC void dcopy_ (BLAS_INT * n , double * x , BLAS_INT * incx , double * y , BLAS_INT * incy );
31
+ TH_EXTERNC void scopy_ (BLAS_INT * n , float * x , BLAS_INT * incx , float * y , BLAS_INT * incy );
32
+ TH_EXTERNC void daxpy_ (BLAS_INT * n , double * a , double * x , BLAS_INT * incx , double * y , BLAS_INT * incy );
33
+ TH_EXTERNC void saxpy_ (BLAS_INT * n , float * a , float * x , BLAS_INT * incx , float * y , BLAS_INT * incy );
34
+ TH_EXTERNC double ddot_ (BLAS_INT * n , double * x , BLAS_INT * incx , double * y , BLAS_INT * incy );
35
+ TH_EXTERNC ffloat sdot_ (BLAS_INT * n , float * x , BLAS_INT * incx , float * y , BLAS_INT * incy );
36
+ TH_EXTERNC void dgemv_ (char * trans , BLAS_INT * m , BLAS_INT * n , double * alpha , double * a , BLAS_INT * lda , double * x , BLAS_INT * incx , double * beta , double * y , BLAS_INT * incy );
37
+ TH_EXTERNC void sgemv_ (char * trans , BLAS_INT * m , BLAS_INT * n , float * alpha , float * a , BLAS_INT * lda , float * x , BLAS_INT * incx , float * beta , float * y , BLAS_INT * incy );
38
+ TH_EXTERNC void dger_ (BLAS_INT * m , BLAS_INT * n , double * alpha , double * x , BLAS_INT * incx , double * y , BLAS_INT * incy , double * a , BLAS_INT * lda );
39
+ TH_EXTERNC void sger_ (BLAS_INT * m , BLAS_INT * n , float * alpha , float * x , BLAS_INT * incx , float * y , BLAS_INT * incy , float * a , BLAS_INT * lda );
40
+ TH_EXTERNC void dgemm_ (char * transa , char * transb , BLAS_INT * m , BLAS_INT * n , BLAS_INT * k , double * alpha , double * a , BLAS_INT * lda , double * b , BLAS_INT * ldb , double * beta , double * c , BLAS_INT * ldc );
41
+ TH_EXTERNC void sgemm_ (char * transa , char * transb , BLAS_INT * m , BLAS_INT * n , BLAS_INT * k , float * alpha , float * a , BLAS_INT * lda , float * b , BLAS_INT * ldb , float * beta , float * c , BLAS_INT * ldc );
42
+
30
43
31
44
void THBlas_ (swap )(long n , real * x , long incx , real * y , long incy )
32
45
{
@@ -39,9 +52,9 @@ void THBlas_(swap)(long n, real *x, long incx, real *y, long incy)
39
52
#if defined(USE_BLAS ) && (defined(TH_REAL_IS_DOUBLE ) || defined(TH_REAL_IS_FLOAT ))
40
53
if ( (n <= INT_MAX ) && (incx <= INT_MAX ) && (incy <= INT_MAX ) )
41
54
{
42
- int i_n = (int )n ;
43
- int i_incx = (int )incx ;
44
- int i_incy = (int )incy ;
55
+ BLAS_INT i_n = (BLAS_INT )n ;
56
+ BLAS_INT i_incx = (BLAS_INT )incx ;
57
+ BLAS_INT i_incy = (BLAS_INT )incy ;
45
58
46
59
#if defined(TH_REAL_IS_DOUBLE )
47
60
dswap_ (& i_n , x , & i_incx , y , & i_incy );
@@ -70,8 +83,8 @@ void THBlas_(scal)(long n, real a, real *x, long incx)
70
83
#if defined(USE_BLAS ) && (defined(TH_REAL_IS_DOUBLE ) || defined(TH_REAL_IS_FLOAT ))
71
84
if ( (n <= INT_MAX ) && (incx <= INT_MAX ) )
72
85
{
73
- int i_n = (int )n ;
74
- int i_incx = (int )incx ;
86
+ BLAS_INT i_n = (BLAS_INT )n ;
87
+ BLAS_INT i_incx = (BLAS_INT )incx ;
75
88
76
89
#if defined(TH_REAL_IS_DOUBLE )
77
90
dscal_ (& i_n , & a , x , & i_incx );
@@ -99,9 +112,9 @@ void THBlas_(copy)(long n, real *x, long incx, real *y, long incy)
99
112
#if defined(USE_BLAS ) && (defined(TH_REAL_IS_DOUBLE ) || defined(TH_REAL_IS_FLOAT ))
100
113
if ( (n <= INT_MAX ) && (incx <= INT_MAX ) && (incy <= INT_MAX ) )
101
114
{
102
- int i_n = (int )n ;
103
- int i_incx = (int )incx ;
104
- int i_incy = (int )incy ;
115
+ BLAS_INT i_n = (BLAS_INT )n ;
116
+ BLAS_INT i_incx = (BLAS_INT )incx ;
117
+ BLAS_INT i_incy = (BLAS_INT )incy ;
105
118
106
119
#if defined(TH_REAL_IS_DOUBLE )
107
120
dcopy_ (& i_n , x , & i_incx , y , & i_incy );
@@ -129,9 +142,9 @@ void THBlas_(axpy)(long n, real a, real *x, long incx, real *y, long incy)
129
142
#if defined(USE_BLAS ) && (defined(TH_REAL_IS_DOUBLE ) || defined(TH_REAL_IS_FLOAT ))
130
143
if ( (n <= INT_MAX ) && (incx <= INT_MAX ) && (incy <= INT_MAX ) )
131
144
{
132
- int i_n = (int )n ;
133
- int i_incx = (int )incx ;
134
- int i_incy = (int )incy ;
145
+ BLAS_INT i_n = (BLAS_INT )n ;
146
+ BLAS_INT i_incx = (BLAS_INT )incx ;
147
+ BLAS_INT i_incy = (BLAS_INT )incy ;
135
148
136
149
#if defined(TH_REAL_IS_DOUBLE )
137
150
daxpy_ (& i_n , & a , x , & i_incx , y , & i_incy );
@@ -159,9 +172,9 @@ real THBlas_(dot)(long n, real *x, long incx, real *y, long incy)
159
172
#if defined(USE_BLAS ) && (defined(TH_REAL_IS_DOUBLE ) || defined(TH_REAL_IS_FLOAT ))
160
173
if ( (n <= INT_MAX ) && (incx <= INT_MAX ) && (incy <= INT_MAX ) )
161
174
{
162
- int i_n = (int )n ;
163
- int i_incx = (int )incx ;
164
- int i_incy = (int )incy ;
175
+ BLAS_INT i_n = (BLAS_INT )n ;
176
+ BLAS_INT i_incx = (BLAS_INT )incx ;
177
+ BLAS_INT i_incy = (BLAS_INT )incy ;
165
178
166
179
#if defined(TH_REAL_IS_DOUBLE )
167
180
return (real ) ddot_ (& i_n , x , & i_incx , y , & i_incy );
@@ -190,11 +203,11 @@ void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, re
190
203
(incx > 0 ) && (incx <= INT_MAX ) &&
191
204
(incy > 0 ) && (incy <= INT_MAX ) )
192
205
{
193
- int i_m = (int )m ;
194
- int i_n = (int )n ;
195
- int i_lda = (int )lda ;
196
- int i_incx = (int )incx ;
197
- int i_incy = (int )incy ;
206
+ BLAS_INT i_m = (BLAS_INT )m ;
207
+ BLAS_INT i_n = (BLAS_INT )n ;
208
+ BLAS_INT i_lda = (BLAS_INT )lda ;
209
+ BLAS_INT i_incx = (BLAS_INT )incx ;
210
+ BLAS_INT i_incy = (BLAS_INT )incy ;
198
211
199
212
#if defined(TH_REAL_IS_DOUBLE )
200
213
dgemv_ (& trans , & i_m , & i_n , & alpha , a , & i_lda , x , & i_incx , & beta , y , & i_incy );
@@ -245,11 +258,11 @@ void THBlas_(ger)(long m, long n, real alpha, real *x, long incx, real *y, long
245
258
#if defined(USE_BLAS ) && (defined(TH_REAL_IS_DOUBLE ) || defined(TH_REAL_IS_FLOAT ))
246
259
if ( (m <= INT_MAX ) && (n <= INT_MAX ) && (lda <= INT_MAX ) && (incx <= INT_MAX ) && (incy <= INT_MAX ) )
247
260
{
248
- int i_m = (int )m ;
249
- int i_n = (int )n ;
250
- int i_lda = (int )lda ;
251
- int i_incx = (int )incx ;
252
- int i_incy = (int )incy ;
261
+ BLAS_INT i_m = (BLAS_INT )m ;
262
+ BLAS_INT i_n = (BLAS_INT )n ;
263
+ BLAS_INT i_lda = (BLAS_INT )lda ;
264
+ BLAS_INT i_incx = (BLAS_INT )incx ;
265
+ BLAS_INT i_incy = (BLAS_INT )incy ;
253
266
254
267
#if defined(TH_REAL_IS_DOUBLE )
255
268
dger_ (& i_m , & i_n , & alpha , x , & i_incx , y , & i_incy , a , & i_lda );
@@ -304,12 +317,12 @@ void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha,
304
317
#if defined(USE_BLAS ) && (defined(TH_REAL_IS_DOUBLE ) || defined(TH_REAL_IS_FLOAT ))
305
318
if ( (m <= INT_MAX ) && (n <= INT_MAX ) && (k <= INT_MAX ) && (lda <= INT_MAX ) && (ldb <= INT_MAX ) && (ldc <= INT_MAX ) )
306
319
{
307
- int i_m = (int )m ;
308
- int i_n = (int )n ;
309
- int i_k = (int )k ;
310
- int i_lda = (int )lda ;
311
- int i_ldb = (int )ldb ;
312
- int i_ldc = (int )ldc ;
320
+ BLAS_INT i_m = (BLAS_INT )m ;
321
+ BLAS_INT i_n = (BLAS_INT )n ;
322
+ BLAS_INT i_k = (BLAS_INT )k ;
323
+ BLAS_INT i_lda = (BLAS_INT )lda ;
324
+ BLAS_INT i_ldb = (BLAS_INT )ldb ;
325
+ BLAS_INT i_ldc = (BLAS_INT )ldc ;
313
326
314
327
#if defined(TH_REAL_IS_DOUBLE )
315
328
dgemm_ (& transa , & transb , & i_m , & i_n , & i_k , & alpha , a , & i_lda , b , & i_ldb , & beta , c , & i_ldc );
0 commit comments