Skip to content

Commit 1b056c5

Browse files
committed
Refs #130 Prevent reading ipiv array beyond the bound in ?laswp. Use laswp instead of laswp_oncopy in getrf.
1 parent e8306f6 commit 1b056c5

13 files changed

+2693
-674
lines changed

lapack/getrf/getrf_parallel.c

+3-2
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ static void inner_basic_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *ra
118118
min_jj = js + min_j - jjs;
119119
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
120120

121-
if (GEMM_UNROLL_N <= 8) {
121+
if (0 && GEMM_UNROLL_N <= 8) {
122122

123123
LASWP_NCOPY(min_jj, off + 1, off + k,
124124
c + (- off + jjs * lda) * COMPSIZE, lda,
@@ -245,7 +245,8 @@ static int inner_advanced_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *
245245
min_jj = MIN(n_to, xxx + div_n) - jjs;
246246
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
247247

248-
if (GEMM_UNROLL_N <= 8) {
248+
if (0 && GEMM_UNROLL_N <= 8) {
249+
printf("helllo\n");
249250

250251
LASWP_NCOPY(min_jj, off + 1, off + k,
251252
b + (- off + jjs * lda) * COMPSIZE, lda,

lapack/getrf/getrf_parallel_omp.c

+11
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,21 @@ static void inner_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n,
7777
min_jj = js + min_j - jjs;
7878
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
7979

80+
#if 0
8081
LASWP_NCOPY(min_jj, off + 1, off + k,
8182
c + (- off + jjs * lda) * COMPSIZE, lda,
8283
ipiv, sb + k * (jjs - js) * COMPSIZE);
8384

85+
#else
86+
LASWP_PLUS(min_jj, off + 1, off + k, ZERO,
87+
#ifdef COMPLEX
88+
ZERO,
89+
#endif
90+
c + (- off + jjs * lda) * COMPSIZE, lda, NULL, 0, ipiv, 1);
91+
92+
GEMM_ONCOPY (k, min_jj, c + jjs * lda * COMPSIZE, lda, sb + (jjs - js) * k * COMPSIZE);
93+
#endif
94+
8495
for (is = 0; is < k; is += GEMM_P) {
8596
min_i = k - is;
8697
if (min_i > GEMM_P) min_i = GEMM_P;

lapack/getrf/getrf_single.c

+1-1
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
113113
min_jj = js + jmin - jjs;
114114
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
115115

116-
#if 0
116+
#if 1
117117
LASWP_PLUS(min_jj, j + offset + 1, j + jb + offset, ZERO,
118118
#ifdef COMPLEX
119119
ZERO,

lapack/laswp/generic/laswp_k_1.c

+88-9
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@
4848
int CNAME(BLASLONG n, BLASLONG k1, BLASLONG k2, FLOAT dummy1, FLOAT *a, BLASLONG lda,
4949
FLOAT *dummy2, BLASLONG dumy3, blasint *ipiv, BLASLONG incx){
5050

51-
BLASLONG i, j, ip1, ip2;
51+
BLASLONG i, j, ip1, ip2, rows;
5252
blasint *piv;
5353
FLOAT *a1;
5454
FLOAT *b1, *b2;
@@ -58,13 +58,34 @@ int CNAME(BLASLONG n, BLASLONG k1, BLASLONG k2, FLOAT dummy1, FLOAT *a, BLASLONG
5858
k1 --;
5959

6060
#ifndef MINUS
61-
ipiv += k1
62-
;
61+
ipiv += k1;
6362
#else
6463
ipiv -= (k2 - 1) * incx;
6564
#endif
6665

6766
if (n <= 0) return 0;
67+
68+
rows = k2-k1;
69+
if (rows <=0) return 0;
70+
if (rows == 1) {
71+
//Only have 1 row
72+
ip1 = *ipiv;
73+
a1 = a + k1 + 1;
74+
b1 = a + ip1;
75+
76+
if(a1 == b1) return 0;
77+
78+
for(j=0; j<n; j++){
79+
A1 = *a1;
80+
B1 = *b1;
81+
*a1 = B1;
82+
*b1 = A1;
83+
84+
a1 += lda;
85+
b1 += lda;
86+
}
87+
return 0;
88+
}
6889

6990
j = n;
7091
if (j > 0) {
@@ -85,10 +106,11 @@ int CNAME(BLASLONG n, BLASLONG k1, BLASLONG k2, FLOAT dummy1, FLOAT *a, BLASLONG
85106
b1 = a + ip1;
86107
b2 = a + ip2;
87108

88-
i = ((k2 - k1) >> 1);
89-
90-
if (i > 0) {
91-
do {
109+
i = (rows >> 1);
110+
111+
i--;
112+
//Main Loop
113+
while (i > 0) {
92114
#ifdef OPTERON
93115
#ifndef MINUS
94116
asm volatile("prefetchw 2 * 128(%0)\n" : : "r"(a1));
@@ -172,12 +194,69 @@ int CNAME(BLASLONG n, BLASLONG k1, BLASLONG k2, FLOAT dummy1, FLOAT *a, BLASLONG
172194
a1 -= 2;
173195
#endif
174196
i --;
175-
} while (i > 0);
176197
}
198+
199+
//Loop Ending
200+
A1 = *a1;
201+
A2 = *a2;
202+
B1 = *b1;
203+
B2 = *b2;
204+
if (b1 == a1) {
205+
if (b2 == a1) {
206+
*a1 = A2;
207+
*a2 = A1;
208+
} else
209+
if (b2 != a2) {
210+
*a2 = B2;
211+
*b2 = A2;
212+
}
213+
} else
214+
if (b1 == a2) {
215+
if (b2 != a1) {
216+
if (b2 == a2) {
217+
*a1 = A2;
218+
*a2 = A1;
219+
} else {
220+
*a1 = A2;
221+
*a2 = B2;
222+
*b2 = A1;
223+
}
224+
}
225+
} else {
226+
if (b2 == a1) {
227+
*a1 = A2;
228+
*a2 = B1;
229+
*b1 = A1;
230+
} else
231+
if (b2 == a2) {
232+
*a1 = B1;
233+
*b1 = A1;
234+
} else
235+
if (b2 == b1) {
236+
*a1 = B1;
237+
*a2 = A1;
238+
*b1 = A2;
239+
} else {
240+
*a1 = B1;
241+
*a2 = B2;
242+
*b1 = A1;
243+
*b2 = A2;
244+
}
245+
}
246+
247+
#ifndef MINUS
248+
a1 += 2;
249+
#else
250+
a1 -= 2;
251+
#endif
177252

178-
i = ((k2 - k1) & 1);
253+
//Remain
254+
i = (rows & 1);
179255

180256
if (i > 0) {
257+
ip1 = *piv;
258+
b1 = a + ip1;
259+
181260
A1 = *a1;
182261
B1 = *b1;
183262
*a1 = B1;

0 commit comments

Comments
 (0)