Skip to content

Commit de8a10d

Browse files
committed
BUG: core: use blas_ilp64 also for *_matmul, *_dot, and *_vdot
Changing these to support ILP64 blas was missed in numpygh-15012
1 parent b7f42ea commit de8a10d

File tree

4 files changed

+53
-38
lines changed

4 files changed

+53
-38
lines changed

numpy/core/src/multiarray/arraytypes.c.src

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3535,17 +3535,17 @@ NPY_NO_EXPORT void
35353535
npy_intp n, void *NPY_UNUSED(ignore))
35363536
{
35373537
#if defined(HAVE_CBLAS)
3538-
int is1b = blas_stride(is1, sizeof(@type@));
3539-
int is2b = blas_stride(is2, sizeof(@type@));
3538+
CBLAS_INT is1b = blas_stride(is1, sizeof(@type@));
3539+
CBLAS_INT is2b = blas_stride(is2, sizeof(@type@));
35403540

35413541
if (is1b && is2b)
35423542
{
35433543
double sum = 0.; /* double for stability */
35443544

35453545
while (n > 0) {
3546-
int chunk = n < NPY_CBLAS_CHUNK ? n : NPY_CBLAS_CHUNK;
3546+
CBLAS_INT chunk = n < NPY_CBLAS_CHUNK ? n : NPY_CBLAS_CHUNK;
35473547

3548-
sum += cblas_@prefix@dot(chunk,
3548+
sum += CBLAS_FUNC(cblas_@prefix@dot)(chunk,
35493549
(@type@ *) ip1, is1b,
35503550
(@type@ *) ip2, is2b);
35513551
/* use char strides here */
@@ -3584,17 +3584,17 @@ NPY_NO_EXPORT void
35843584
char *op, npy_intp n, void *NPY_UNUSED(ignore))
35853585
{
35863586
#if defined(HAVE_CBLAS)
3587-
int is1b = blas_stride(is1, sizeof(@ctype@));
3588-
int is2b = blas_stride(is2, sizeof(@ctype@));
3587+
CBLAS_INT is1b = blas_stride(is1, sizeof(@ctype@));
3588+
CBLAS_INT is2b = blas_stride(is2, sizeof(@ctype@));
35893589

35903590
if (is1b && is2b) {
35913591
double sum[2] = {0., 0.}; /* double for stability */
35923592

35933593
while (n > 0) {
3594-
int chunk = n < NPY_CBLAS_CHUNK ? n : NPY_CBLAS_CHUNK;
3594+
CBLAS_INT chunk = n < NPY_CBLAS_CHUNK ? n : NPY_CBLAS_CHUNK;
35953595
@type@ tmp[2];
35963596

3597-
cblas_@prefix@dotu_sub((int)n, ip1, is1b, ip2, is2b, tmp);
3597+
CBLAS_FUNC(cblas_@prefix@dotu_sub)((CBLAS_INT)n, ip1, is1b, ip2, is2b, tmp);
35983598
sum[0] += (double)tmp[0];
35993599
sum[1] += (double)tmp[1];
36003600
/* use char strides here */

numpy/core/src/multiarray/common.h

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,11 @@ blas_stride(npy_intp stride, unsigned itemsize)
303303
*/
304304
if (stride > 0 && npy_is_aligned((void *)stride, itemsize)) {
305305
stride /= itemsize;
306+
#ifndef HAVE_BLAS_ILP64
306307
if (stride <= INT_MAX) {
308+
#else
309+
if (stride <= NPY_MAX_INT64) {
310+
#endif
307311
return stride;
308312
}
309313
}
@@ -314,7 +318,11 @@ blas_stride(npy_intp stride, unsigned itemsize)
314318
* Define a chunksize for CBLAS. CBLAS counts in integers.
315319
*/
316320
#if NPY_MAX_INTP > INT_MAX
317-
# define NPY_CBLAS_CHUNK (INT_MAX / 2 + 1)
321+
# ifndef HAVE_BLAS_ILP64
322+
# define NPY_CBLAS_CHUNK (INT_MAX / 2 + 1)
323+
# else
324+
# define NPY_CBLAS_CHUNK (NPY_MAX_INT64 / 2 + 1)
325+
# endif
318326
#else
319327
# define NPY_CBLAS_CHUNK NPY_MAX_INTP
320328
#endif

numpy/core/src/multiarray/vdot.c

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@ CFLOAT_vdot(char *ip1, npy_intp is1, char *ip2, npy_intp is2,
1515
char *op, npy_intp n, void *NPY_UNUSED(ignore))
1616
{
1717
#if defined(HAVE_CBLAS)
18-
int is1b = blas_stride(is1, sizeof(npy_cfloat));
19-
int is2b = blas_stride(is2, sizeof(npy_cfloat));
18+
CBLAS_INT is1b = blas_stride(is1, sizeof(npy_cfloat));
19+
CBLAS_INT is2b = blas_stride(is2, sizeof(npy_cfloat));
2020

2121
if (is1b && is2b) {
2222
double sum[2] = {0., 0.}; /* double for stability */
2323

2424
while (n > 0) {
25-
int chunk = n < NPY_CBLAS_CHUNK ? n : NPY_CBLAS_CHUNK;
25+
CBLAS_INT chunk = n < NPY_CBLAS_CHUNK ? n : NPY_CBLAS_CHUNK;
2626
float tmp[2];
2727

28-
cblas_cdotc_sub((int)n, ip1, is1b, ip2, is2b, tmp);
28+
CBLAS_FUNC(cblas_cdotc_sub)((CBLAS_INT)n, ip1, is1b, ip2, is2b, tmp);
2929
sum[0] += (double)tmp[0];
3030
sum[1] += (double)tmp[1];
3131
/* use char strides here */
@@ -66,17 +66,17 @@ CDOUBLE_vdot(char *ip1, npy_intp is1, char *ip2, npy_intp is2,
6666
char *op, npy_intp n, void *NPY_UNUSED(ignore))
6767
{
6868
#if defined(HAVE_CBLAS)
69-
int is1b = blas_stride(is1, sizeof(npy_cdouble));
70-
int is2b = blas_stride(is2, sizeof(npy_cdouble));
69+
CBLAS_INT is1b = blas_stride(is1, sizeof(npy_cdouble));
70+
CBLAS_INT is2b = blas_stride(is2, sizeof(npy_cdouble));
7171

7272
if (is1b && is2b) {
7373
double sum[2] = {0., 0.}; /* double for stability */
7474

7575
while (n > 0) {
76-
int chunk = n < NPY_CBLAS_CHUNK ? n : NPY_CBLAS_CHUNK;
76+
CBLAS_INT chunk = n < NPY_CBLAS_CHUNK ? n : NPY_CBLAS_CHUNK;
7777
double tmp[2];
7878

79-
cblas_zdotc_sub((int)n, ip1, is1b, ip2, is2b, tmp);
79+
CBLAS_FUNC(cblas_zdotc_sub)((CBLAS_INT)n, ip1, is1b, ip2, is2b, tmp);
8080
sum[0] += (double)tmp[0];
8181
sum[1] += (double)tmp[1];
8282
/* use char strides here */

numpy/core/src/umath/matmul.c.src

Lines changed: 28 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,11 @@
3131
* -1 to be conservative, in case blas internally uses a for loop with an
3232
* inclusive upper bound
3333
*/
34+
#ifndef HAVE_BLAS_ILP64
3435
#define BLAS_MAXSIZE (NPY_MAX_INT - 1)
36+
#else
37+
#define BLAS_MAXSIZE (NPY_MAX_INT64 - 1)
38+
#endif
3539

3640
/*
3741
* Determine if a 2d matrix can be used by BLAS
@@ -84,25 +88,25 @@ NPY_NO_EXPORT void
8488
* op: data in c order, m shape
8589
*/
8690
enum CBLAS_ORDER order;
87-
int M, N, lda;
91+
CBLAS_INT M, N, lda;
8892

8993
assert(m <= BLAS_MAXSIZE && n <= BLAS_MAXSIZE);
9094
assert (is_blasable2d(is2_n, sizeof(@typ@), n, 1, sizeof(@typ@)));
91-
M = (int)m;
92-
N = (int)n;
95+
M = (CBLAS_INT)m;
96+
N = (CBLAS_INT)n;
9397

9498
if (is_blasable2d(is1_m, is1_n, m, n, sizeof(@typ@))) {
9599
order = CblasColMajor;
96-
lda = (int)(is1_m / sizeof(@typ@));
100+
lda = (CBLAS_INT)(is1_m / sizeof(@typ@));
97101
}
98102
else {
99103
/* If not ColMajor, caller should have ensured we are RowMajor */
100104
/* will not assert in release mode */
101105
order = CblasRowMajor;
102106
assert(is_blasable2d(is1_n, is1_m, n, m, sizeof(@typ@)));
103-
lda = (int)(is1_n / sizeof(@typ@));
107+
lda = (CBLAS_INT)(is1_n / sizeof(@typ@));
104108
}
105-
cblas_@prefix@gemv(order, CblasTrans, N, M, @step1@, ip1, lda, ip2,
109+
CBLAS_FUNC(cblas_@prefix@gemv)(order, CblasTrans, N, M, @step1@, ip1, lda, ip2,
106110
is2_n / sizeof(@typ@), @step0@, op, op_m / sizeof(@typ@));
107111
}
108112

@@ -117,37 +121,37 @@ NPY_NO_EXPORT void
117121
*/
118122
enum CBLAS_ORDER order = CblasRowMajor;
119123
enum CBLAS_TRANSPOSE trans1, trans2;
120-
int M, N, P, lda, ldb, ldc;
124+
CBLAS_INT M, N, P, lda, ldb, ldc;
121125
assert(m <= BLAS_MAXSIZE && n <= BLAS_MAXSIZE && p <= BLAS_MAXSIZE);
122-
M = (int)m;
123-
N = (int)n;
124-
P = (int)p;
126+
M = (CBLAS_INT)m;
127+
N = (CBLAS_INT)n;
128+
P = (CBLAS_INT)p;
125129

126130
assert(is_blasable2d(os_m, os_p, m, p, sizeof(@typ@)));
127-
ldc = (int)(os_m / sizeof(@typ@));
131+
ldc = (CBLAS_INT)(os_m / sizeof(@typ@));
128132

129133
if (is_blasable2d(is1_m, is1_n, m, n, sizeof(@typ@))) {
130134
trans1 = CblasNoTrans;
131-
lda = (int)(is1_m / sizeof(@typ@));
135+
lda = (CBLAS_INT)(is1_m / sizeof(@typ@));
132136
}
133137
else {
134138
/* If not ColMajor, caller should have ensured we are RowMajor */
135139
/* will not assert in release mode */
136140
assert(is_blasable2d(is1_n, is1_m, n, m, sizeof(@typ@)));
137141
trans1 = CblasTrans;
138-
lda = (int)(is1_n / sizeof(@typ@));
142+
lda = (CBLAS_INT)(is1_n / sizeof(@typ@));
139143
}
140144

141145
if (is_blasable2d(is2_n, is2_p, n, p, sizeof(@typ@))) {
142146
trans2 = CblasNoTrans;
143-
ldb = (int)(is2_n / sizeof(@typ@));
147+
ldb = (CBLAS_INT)(is2_n / sizeof(@typ@));
144148
}
145149
else {
146150
/* If not ColMajor, caller should have ensured we are RowMajor */
147151
/* will not assert in release mode */
148152
assert(is_blasable2d(is2_p, is2_n, p, n, sizeof(@typ@)));
149153
trans2 = CblasTrans;
150-
ldb = (int)(is2_p / sizeof(@typ@));
154+
ldb = (CBLAS_INT)(is2_p / sizeof(@typ@));
151155
}
152156
/*
153157
* Use syrk if we have a case of a matrix times its transpose.
@@ -162,12 +166,14 @@ NPY_NO_EXPORT void
162166
) {
163167
npy_intp i,j;
164168
if (trans1 == CblasNoTrans) {
165-
cblas_@prefix@syrk(order, CblasUpper, trans1, P, N, @step1@,
166-
ip1, lda, @step0@, op, ldc);
169+
CBLAS_FUNC(cblas_@prefix@syrk)(
170+
order, CblasUpper, trans1, P, N, @step1@,
171+
ip1, lda, @step0@, op, ldc);
167172
}
168173
else {
169-
cblas_@prefix@syrk(order, CblasUpper, trans1, P, N, @step1@,
170-
ip1, ldb, @step0@, op, ldc);
174+
CBLAS_FUNC(cblas_@prefix@syrk)(
175+
order, CblasUpper, trans1, P, N, @step1@,
176+
ip1, ldb, @step0@, op, ldc);
171177
}
172178
/* Copy the triangle */
173179
for (i = 0; i < P; i++) {
@@ -178,8 +184,9 @@ NPY_NO_EXPORT void
178184

179185
}
180186
else {
181-
cblas_@prefix@gemm(order, trans1, trans2, M, P, N, @step1@, ip1, lda,
182-
ip2, ldb, @step0@, op, ldc);
187+
CBLAS_FUNC(cblas_@prefix@gemm)(
188+
order, trans1, trans2, M, P, N, @step1@, ip1, lda,
189+
ip2, ldb, @step0@, op, ldc);
183190
}
184191
}
185192

0 commit comments

Comments
 (0)