Skip to content

Commit ca19337

Browse files
elikosansoumith
authored andcommitted
make it compile on Windows + use ilp64 MKL (#981)
1 parent 0166825 commit ca19337

File tree

5 files changed

+82
-57
lines changed

5 files changed

+82
-57
lines changed

doc/tensor.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@
44
The `Tensor` class is probably the most important class in
55
`Torch`. Almost every package depends on this class. It is *__the__*
66
class for handling numeric data. As with pretty much anything in
7-
[Torch7](./../index.md), tensors are
7+
[Torch7](./index.md), tensors are
88
[serializable](file.md#torch.File.serialization).
99

1010
__Multi-dimensional matrix__
1111

12-
A `Tensor` is a potentially multi-dimensional matrix. The number of
13-
dimensions is unlimited that can be created using
14-
[LongStorage](storage.md) with more dimensions.
12+
A `Tensor` is a multi-dimensional matrix. The number of
13+
dimensions is unlimited (up to what can be created using
14+
[LongStorage](storage.md)).
1515

1616
Example:
1717
```lua

lib/TH/cmake/FindBLAS.cmake

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,32 +242,44 @@ endif()
242242
# Determine if blas was compiled with the f2c conventions
243243
IF (BLAS_LIBRARIES)
244244
SET(CMAKE_REQUIRED_LIBRARIES ${BLAS_LIBRARIES})
245+
245246
CHECK_C_SOURCE_RUNS("
246247
#include <stdlib.h>
247248
#include <stdio.h>
248249
float x[4] = { 1, 2, 3, 4 };
249250
float y[4] = { .1, .01, .001, .0001 };
250-
int four = 4;
251-
int one = 1;
251+
#ifdef WIN32
252+
typedef __int64 BLINT;
253+
#else
254+
typedef long BLINT;
255+
#endif
256+
BLINT four = 4;
257+
BLINT one = 1;
252258
extern double sdot_();
253259
int main() {
254-
int i;
255260
double r = sdot_(&four, x, &one, y, &one);
256261
exit((float)r != (float).1234);
257262
}" BLAS_F2C_DOUBLE_WORKS )
263+
258264
CHECK_C_SOURCE_RUNS("
259265
#include <stdlib.h>
260266
#include <stdio.h>
261267
float x[4] = { 1, 2, 3, 4 };
262268
float y[4] = { .1, .01, .001, .0001 };
263-
int four = 4;
264-
int one = 1;
269+
#ifdef WIN32
270+
typedef __int64 BLINT;
271+
#else
272+
typedef long BLINT;
273+
#endif
274+
BLINT four = 4;
275+
BLINT one = 1;
265276
extern float sdot_();
266277
int main() {
267278
int i;
268279
double r = sdot_(&four, x, &one, y, &one);
269280
exit((float)r != (float).1234);
270281
}" BLAS_F2C_FLOAT_WORKS )
282+
271283
IF (BLAS_F2C_DOUBLE_WORKS AND NOT BLAS_F2C_FLOAT_WORKS)
272284
MESSAGE(STATUS "This BLAS uses the F2C return conventions")
273285
SET(BLAS_F2C TRUE)

lib/TH/cmake/FindMKL.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ CHECK_TYPE_SIZE("void*" SIZE_OF_VOIDP)
4141
IF ("${SIZE_OF_VOIDP}" EQUAL 8)
4242
SET(mklvers "em64t")
4343
SET(iccvers "intel64")
44-
SET(mkl64s "_lp64")
44+
SET(mkl64s "_ilp64")
4545
ELSE ("${SIZE_OF_VOIDP}" EQUAL 8)
4646
SET(mklvers "32")
4747
SET(iccvers "ia32")

lib/TH/cmake/FindSSE.cmake

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ SET(AVX2_CODE "
7373
7474
int main()
7575
{
76-
__m256i a;
76+
__m256i a = {0};
7777
a = _mm256_abs_epi16(a);
7878
return 0;
7979
}

lib/TH/generic/THBlas.c

Lines changed: 59 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -9,24 +9,37 @@
99
# define ffloat float
1010
#endif
1111

12-
TH_EXTERNC void dswap_(int *n, double *x, int *incx, double *y, int *incy);
13-
TH_EXTERNC void sswap_(int *n, float *x, int *incx, float *y, int *incy);
14-
TH_EXTERNC void dscal_(int *n, double *a, double *x, int *incx);
15-
TH_EXTERNC void sscal_(int *n, float *a, float *x, int *incx);
16-
TH_EXTERNC void dcopy_(int *n, double *x, int *incx, double *y, int *incy);
17-
TH_EXTERNC void scopy_(int *n, float *x, int *incx, float *y, int *incy);
18-
TH_EXTERNC void daxpy_(int *n, double *a, double *x, int *incx, double *y, int *incy);
19-
TH_EXTERNC void saxpy_(int *n, float *a, float *x, int *incx, float *y, int *incy);
20-
TH_EXTERNC double ddot_(int *n, double *x, int *incx, double *y, int *incy);
21-
TH_EXTERNC ffloat sdot_(int *n, float *x, int *incx, float *y, int *incy);
22-
TH_EXTERNC void dgemv_(char *trans, int *m, int *n, double *alpha, double *a, int *lda, double *x, int *incx, double *beta, double *y, int *incy);
23-
TH_EXTERNC void sgemv_(char *trans, int *m, int *n, float *alpha, float *a, int *lda, float *x, int *incx, float *beta, float *y, int *incy);
24-
TH_EXTERNC void dger_(int *m, int *n, double *alpha, double *x, int *incx, double *y, int *incy, double *a, int *lda);
25-
TH_EXTERNC void sger_(int *m, int *n, float *alpha, float *x, int *incx, float *y, int *incy, float *a, int *lda);
26-
TH_EXTERNC void dgemm_(char *transa, char *transb, int *m, int *n, int *k, double *alpha, double *a, int *lda, double *b, int *ldb, double *beta, double *c, int *ldc);
27-
TH_EXTERNC void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float *alpha, float *a, int *lda, float *b, int *ldb, float *beta, float *c, int *ldc);
12+
// define MKL_LP64 to get 32bit ints on 64bit platforms
13+
#ifndef MKL_LP64
14+
// 64bit ints
15+
#ifdef WIN32
16+
#define BLAS_INT __int64
17+
#else
18+
#define BLAS_INT long
19+
#endif
20+
#else
21+
// 32bit ints
22+
#define BLAS_INT int
23+
#endif
2824

2925

26+
TH_EXTERNC void dswap_(BLAS_INT *n, double *x, BLAS_INT *incx, double *y, BLAS_INT *incy);
27+
TH_EXTERNC void sswap_(BLAS_INT *n, float *x, BLAS_INT *incx, float *y, BLAS_INT *incy);
28+
TH_EXTERNC void dscal_(BLAS_INT *n, double *a, double *x, BLAS_INT *incx);
29+
TH_EXTERNC void sscal_(BLAS_INT *n, float *a, float *x, BLAS_INT *incx);
30+
TH_EXTERNC void dcopy_(BLAS_INT *n, double *x, BLAS_INT *incx, double *y, BLAS_INT *incy);
31+
TH_EXTERNC void scopy_(BLAS_INT *n, float *x, BLAS_INT *incx, float *y, BLAS_INT *incy);
32+
TH_EXTERNC void daxpy_(BLAS_INT *n, double *a, double *x, BLAS_INT *incx, double *y, BLAS_INT *incy);
33+
TH_EXTERNC void saxpy_(BLAS_INT *n, float *a, float *x, BLAS_INT *incx, float *y, BLAS_INT *incy);
34+
TH_EXTERNC double ddot_(BLAS_INT *n, double *x, BLAS_INT *incx, double *y, BLAS_INT *incy);
35+
TH_EXTERNC ffloat sdot_(BLAS_INT *n, float *x, BLAS_INT *incx, float *y, BLAS_INT *incy);
36+
TH_EXTERNC void dgemv_(char *trans, BLAS_INT *m, BLAS_INT *n, double *alpha, double *a, BLAS_INT *lda, double *x, BLAS_INT *incx, double *beta, double *y, BLAS_INT *incy);
37+
TH_EXTERNC void sgemv_(char *trans, BLAS_INT *m, BLAS_INT *n, float *alpha, float *a, BLAS_INT *lda, float *x, BLAS_INT *incx, float *beta, float *y, BLAS_INT *incy);
38+
TH_EXTERNC void dger_(BLAS_INT *m, BLAS_INT *n, double *alpha, double *x, BLAS_INT *incx, double *y, BLAS_INT *incy, double *a, BLAS_INT *lda);
39+
TH_EXTERNC void sger_(BLAS_INT *m, BLAS_INT *n, float *alpha, float *x, BLAS_INT *incx, float *y, BLAS_INT *incy, float *a, BLAS_INT *lda);
40+
TH_EXTERNC void dgemm_(char *transa, char *transb, BLAS_INT *m, BLAS_INT *n, BLAS_INT *k, double *alpha, double *a, BLAS_INT *lda, double *b, BLAS_INT *ldb, double *beta, double *c, BLAS_INT *ldc);
41+
TH_EXTERNC void sgemm_(char *transa, char *transb, BLAS_INT *m, BLAS_INT *n, BLAS_INT *k, float *alpha, float *a, BLAS_INT *lda, float *b, BLAS_INT *ldb, float *beta, float *c, BLAS_INT *ldc);
42+
3043

3144
void THBlas_(swap)(long n, real *x, long incx, real *y, long incy)
3245
{
@@ -39,9 +52,9 @@ void THBlas_(swap)(long n, real *x, long incx, real *y, long incy)
3952
#if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
4053
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
4154
{
42-
int i_n = (int)n;
43-
int i_incx = (int)incx;
44-
int i_incy = (int)incy;
55+
BLAS_INT i_n = (BLAS_INT)n;
56+
BLAS_INT i_incx = (BLAS_INT)incx;
57+
BLAS_INT i_incy = (BLAS_INT)incy;
4558

4659
#if defined(TH_REAL_IS_DOUBLE)
4760
dswap_(&i_n, x, &i_incx, y, &i_incy);
@@ -70,8 +83,8 @@ void THBlas_(scal)(long n, real a, real *x, long incx)
7083
#if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
7184
if( (n <= INT_MAX) && (incx <= INT_MAX) )
7285
{
73-
int i_n = (int)n;
74-
int i_incx = (int)incx;
86+
BLAS_INT i_n = (BLAS_INT)n;
87+
BLAS_INT i_incx = (BLAS_INT)incx;
7588

7689
#if defined(TH_REAL_IS_DOUBLE)
7790
dscal_(&i_n, &a, x, &i_incx);
@@ -99,9 +112,9 @@ void THBlas_(copy)(long n, real *x, long incx, real *y, long incy)
99112
#if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
100113
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
101114
{
102-
int i_n = (int)n;
103-
int i_incx = (int)incx;
104-
int i_incy = (int)incy;
115+
BLAS_INT i_n = (BLAS_INT)n;
116+
BLAS_INT i_incx = (BLAS_INT)incx;
117+
BLAS_INT i_incy = (BLAS_INT)incy;
105118

106119
#if defined(TH_REAL_IS_DOUBLE)
107120
dcopy_(&i_n, x, &i_incx, y, &i_incy);
@@ -129,9 +142,9 @@ void THBlas_(axpy)(long n, real a, real *x, long incx, real *y, long incy)
129142
#if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
130143
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
131144
{
132-
int i_n = (int)n;
133-
int i_incx = (int)incx;
134-
int i_incy = (int)incy;
145+
BLAS_INT i_n = (BLAS_INT)n;
146+
BLAS_INT i_incx = (BLAS_INT)incx;
147+
BLAS_INT i_incy = (BLAS_INT)incy;
135148

136149
#if defined(TH_REAL_IS_DOUBLE)
137150
daxpy_(&i_n, &a, x, &i_incx, y, &i_incy);
@@ -159,9 +172,9 @@ real THBlas_(dot)(long n, real *x, long incx, real *y, long incy)
159172
#if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
160173
if( (n <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
161174
{
162-
int i_n = (int)n;
163-
int i_incx = (int)incx;
164-
int i_incy = (int)incy;
175+
BLAS_INT i_n = (BLAS_INT)n;
176+
BLAS_INT i_incx = (BLAS_INT)incx;
177+
BLAS_INT i_incy = (BLAS_INT)incy;
165178

166179
#if defined(TH_REAL_IS_DOUBLE)
167180
return (real) ddot_(&i_n, x, &i_incx, y, &i_incy);
@@ -190,11 +203,11 @@ void THBlas_(gemv)(char trans, long m, long n, real alpha, real *a, long lda, re
190203
(incx > 0) && (incx <= INT_MAX) &&
191204
(incy > 0) && (incy <= INT_MAX) )
192205
{
193-
int i_m = (int)m;
194-
int i_n = (int)n;
195-
int i_lda = (int)lda;
196-
int i_incx = (int)incx;
197-
int i_incy = (int)incy;
206+
BLAS_INT i_m = (BLAS_INT)m;
207+
BLAS_INT i_n = (BLAS_INT)n;
208+
BLAS_INT i_lda = (BLAS_INT)lda;
209+
BLAS_INT i_incx = (BLAS_INT)incx;
210+
BLAS_INT i_incy = (BLAS_INT)incy;
198211

199212
#if defined(TH_REAL_IS_DOUBLE)
200213
dgemv_(&trans, &i_m, &i_n, &alpha, a, &i_lda, x, &i_incx, &beta, y, &i_incy);
@@ -245,11 +258,11 @@ void THBlas_(ger)(long m, long n, real alpha, real *x, long incx, real *y, long
245258
#if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
246259
if( (m <= INT_MAX) && (n <= INT_MAX) && (lda <= INT_MAX) && (incx <= INT_MAX) && (incy <= INT_MAX) )
247260
{
248-
int i_m = (int)m;
249-
int i_n = (int)n;
250-
int i_lda = (int)lda;
251-
int i_incx = (int)incx;
252-
int i_incy = (int)incy;
261+
BLAS_INT i_m = (BLAS_INT)m;
262+
BLAS_INT i_n = (BLAS_INT)n;
263+
BLAS_INT i_lda = (BLAS_INT)lda;
264+
BLAS_INT i_incx = (BLAS_INT)incx;
265+
BLAS_INT i_incy = (BLAS_INT)incy;
253266

254267
#if defined(TH_REAL_IS_DOUBLE)
255268
dger_(&i_m, &i_n, &alpha, x, &i_incx, y, &i_incy, a, &i_lda);
@@ -304,12 +317,12 @@ void THBlas_(gemm)(char transa, char transb, long m, long n, long k, real alpha,
304317
#if defined(USE_BLAS) && (defined(TH_REAL_IS_DOUBLE) || defined(TH_REAL_IS_FLOAT))
305318
if( (m <= INT_MAX) && (n <= INT_MAX) && (k <= INT_MAX) && (lda <= INT_MAX) && (ldb <= INT_MAX) && (ldc <= INT_MAX) )
306319
{
307-
int i_m = (int)m;
308-
int i_n = (int)n;
309-
int i_k = (int)k;
310-
int i_lda = (int)lda;
311-
int i_ldb = (int)ldb;
312-
int i_ldc = (int)ldc;
320+
BLAS_INT i_m = (BLAS_INT)m;
321+
BLAS_INT i_n = (BLAS_INT)n;
322+
BLAS_INT i_k = (BLAS_INT)k;
323+
BLAS_INT i_lda = (BLAS_INT)lda;
324+
BLAS_INT i_ldb = (BLAS_INT)ldb;
325+
BLAS_INT i_ldc = (BLAS_INT)ldc;
313326

314327
#if defined(TH_REAL_IS_DOUBLE)
315328
dgemm_(&transa, &transb, &i_m, &i_n, &i_k, &alpha, a, &i_lda, b, &i_ldb, &beta, c, &i_ldc);

0 commit comments

Comments
 (0)