Skip to content

Commit 77b4dbd

Browse files
authored
Merge pull request #1506 from martin-frbg/issue1497
Fix thread races and infinite looping on systems with many cpus
2 parents bc4c3bc + 8ec28ff commit 77b4dbd

File tree

1 file changed

+86
-7
lines changed

1 file changed

+86
-7
lines changed

lapack/getrf/getrf_parallel.c

Lines changed: 86 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,26 @@ double sqrt(double);
6767
#undef GETRF_FACTOR
6868
#define GETRF_FACTOR 1.00
6969

70+
71+
#if defined(USE_PTHREAD_LOCK)
72+
static pthread_mutex_t getrf_lock = PTHREAD_MUTEX_INITIALIZER;
73+
#elif defined(USE_PTHREAD_SPINLOCK)
74+
static pthread_spinlock_t getrf_lock = 0;
75+
#else
76+
static BLASULONG getrf_lock = 0UL;
77+
#endif
78+
79+
#if defined(USE_PTHREAD_LOCK)
80+
static pthread_mutex_t getrf_flag_lock = PTHREAD_MUTEX_INITIALIZER;
81+
#elif defined(USE_PTHREAD_SPINLOCK)
82+
static pthread_spinlock_t getrf_flag_lock = 0;
83+
#else
84+
static BLASULONG getrf_flag_lock = 0UL;
85+
#endif
86+
87+
88+
89+
7090
static __inline BLASLONG FORMULA1(BLASLONG M, BLASLONG N, BLASLONG IS, BLASLONG BK, BLASLONG T) {
7191

7292
double m = (double)(M - IS - BK);
@@ -217,6 +237,8 @@ static int inner_advanced_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *
217237

218238
blasint *ipiv = (blasint *)args -> c;
219239

240+
BLASLONG jw;
241+
220242
volatile BLASLONG *flag = (volatile BLASLONG *)args -> d;
221243

222244
if (args -> a == NULL) {
@@ -245,8 +267,20 @@ static int inner_advanced_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *
245267
for (xxx = n_from, bufferside = 0; xxx < n_to; xxx += div_n, bufferside ++) {
246268

247269
for (i = 0; i < args -> nthreads; i++)
270+
#if 1
271+
{
272+
LOCK_COMMAND(&getrf_lock);
273+
jw = job[mypos].working[i][CACHE_LINE_SIZE * bufferside];
274+
UNLOCK_COMMAND(&getrf_lock);
275+
do {
276+
LOCK_COMMAND(&getrf_lock);
277+
jw = job[mypos].working[i][CACHE_LINE_SIZE * bufferside];
278+
UNLOCK_COMMAND(&getrf_lock);
279+
} while (jw);
280+
}
281+
#else
248282
while (job[mypos].working[i][CACHE_LINE_SIZE * bufferside]) {};
249-
283+
#endif
250284
for(jjs = xxx; jjs < MIN(n_to, xxx + div_n); jjs += min_jj){
251285
min_jj = MIN(n_to, xxx + div_n) - jjs;
252286
if (min_jj > GEMM_UNROLL_N) min_jj = GEMM_UNROLL_N;
@@ -283,18 +317,23 @@ static int inner_advanced_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *
283317
b + (is + jjs * lda) * COMPSIZE, lda, is);
284318
}
285319
}
286-
287320
MB;
288-
for (i = 0; i < args -> nthreads; i++)
321+
for (i = 0; i < args -> nthreads; i++) {
322+
LOCK_COMMAND(&getrf_lock);
289323
job[mypos].working[i][CACHE_LINE_SIZE * bufferside] = (BLASLONG)buffer[bufferside];
290-
324+
UNLOCK_COMMAND(&getrf_lock);
325+
}
291326
}
292327

328+
LOCK_COMMAND(&getrf_flag_lock);
293329
flag[mypos * CACHE_LINE_SIZE] = 0;
330+
UNLOCK_COMMAND(&getrf_flag_lock);
294331

295332
if (m == 0) {
296333
for (xxx = 0; xxx < DIVIDE_RATE; xxx++) {
334+
LOCK_COMMAND(&getrf_lock);
297335
job[mypos].working[mypos][CACHE_LINE_SIZE * xxx] = 0;
336+
UNLOCK_COMMAND(&getrf_lock);
298337
}
299338
}
300339

@@ -318,7 +357,18 @@ static int inner_advanced_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *
318357
for (xxx = range_n[current], bufferside = 0; xxx < range_n[current + 1]; xxx += div_n, bufferside ++) {
319358

320359
if ((current != mypos) && (!is)) {
360+
#if 1
361+
LOCK_COMMAND(&getrf_lock);
362+
jw = job[current].working[mypos][CACHE_LINE_SIZE * bufferside];
363+
UNLOCK_COMMAND(&getrf_lock);
364+
do {
365+
LOCK_COMMAND(&getrf_lock);
366+
jw = job[current].working[mypos][CACHE_LINE_SIZE * bufferside];
367+
UNLOCK_COMMAND(&getrf_lock);
368+
} while (jw == 0);
369+
#else
321370
while(job[current].working[mypos][CACHE_LINE_SIZE * bufferside] == 0) {};
371+
#endif
322372
}
323373

324374
KERNEL_OPERATION(min_i, MIN(range_n[current + 1] - xxx, div_n), k,
@@ -327,7 +377,9 @@ static int inner_advanced_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *
327377

328378
MB;
329379
if (is + min_i >= m) {
380+
LOCK_COMMAND(&getrf_lock);
330381
job[current].working[mypos][CACHE_LINE_SIZE * bufferside] = 0;
382+
UNLOCK_COMMAND(&getrf_lock);
331383
}
332384
}
333385

@@ -339,7 +391,18 @@ static int inner_advanced_thread(blas_arg_t *args, BLASLONG *range_m, BLASLONG *
339391

340392
for (i = 0; i < args -> nthreads; i++) {
341393
for (xxx = 0; xxx < DIVIDE_RATE; xxx++) {
394+
#if 1
395+
LOCK_COMMAND(&getrf_lock);
396+
jw = job[mypos].working[i][CACHE_LINE_SIZE *xxx];
397+
UNLOCK_COMMAND(&getrf_lock);
398+
do {
399+
LOCK_COMMAND(&getrf_lock);
400+
jw = job[mypos].working[i][CACHE_LINE_SIZE *xxx];
401+
UNLOCK_COMMAND(&getrf_lock);
402+
} while(jw != 0);
403+
#else
342404
while (job[mypos].working[i][CACHE_LINE_SIZE * xxx] ) {};
405+
#endif
343406
}
344407
}
345408

@@ -374,6 +437,7 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
374437
BLASLONG i, j, k, is, bk;
375438

376439
BLASLONG num_cpu;
440+
BLASLONG f;
377441

378442
#ifdef _MSC_VER
379443
BLASLONG flag[MAX_CPU_NUMBER * CACHE_LINE_SIZE];
@@ -501,11 +565,13 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
501565
if (mm >= nn) {
502566

503567
width = blas_quickdivide(nn + args -> nthreads - num_cpu, args -> nthreads - num_cpu - 1);
568+
if (width == 0) width = nn;
504569
if (nn < width) width = nn;
505570
nn -= width;
506571
range_N[num_cpu + 1] = range_N[num_cpu] + width;
507572

508573
width = blas_quickdivide(mm + args -> nthreads - num_cpu, args -> nthreads - num_cpu - 1);
574+
if (width == 0) width = mm;
509575
if (mm < width) width = mm;
510576
if (nn <= 0) width = mm;
511577
mm -= width;
@@ -514,11 +580,13 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
514580
} else {
515581

516582
width = blas_quickdivide(mm + args -> nthreads - num_cpu, args -> nthreads - num_cpu - 1);
583+
if (width == 0) width = mm;
517584
if (mm < width) width = mm;
518585
mm -= width;
519586
range_M[num_cpu + 1] = range_M[num_cpu] + width;
520587

521588
width = blas_quickdivide(nn + args -> nthreads - num_cpu, args -> nthreads - num_cpu - 1);
589+
if (width == 0) width = nn;
522590
if (nn < width) width = nn;
523591
if (mm <= 0) width = nn;
524592
nn -= width;
@@ -561,7 +629,6 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
561629
range_n_new[1] = offset + is + bk;
562630

563631
if (num_cpu > 0) {
564-
565632
queue[num_cpu - 1].next = NULL;
566633

567634
exec_blas_async(0, &queue[0]);
@@ -572,8 +639,20 @@ blasint CNAME(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *sa,
572639

573640
if (iinfo && !info) info = iinfo + is;
574641

575-
for (i = 0; i < num_cpu; i ++) while (flag[i * CACHE_LINE_SIZE]) {};
576-
642+
for (i = 0; i < num_cpu; i ++) {
643+
#if 1
644+
LOCK_COMMAND(&getrf_flag_lock);
645+
f=flag[i*CACHE_LINE_SIZE];
646+
UNLOCK_COMMAND(&getrf_flag_lock);
647+
while (f!=0) {
648+
LOCK_COMMAND(&getrf_flag_lock);
649+
f=flag[i*CACHE_LINE_SIZE];
650+
UNLOCK_COMMAND(&getrf_flag_lock);
651+
};
652+
#else
653+
while (flag[i*CACHE_LINE_SIZE]) {};
654+
#endif
655+
}
577656
TRSM_ILTCOPY(bk, bk, a + (is + is * lda) * COMPSIZE, lda, 0, sb);
578657

579658
} else {

0 commit comments

Comments
 (0)