Skip to content

Commit 3670da0

Browse files
author
Zhiyong Dang
committed
Fix race condition in blas_server_omp.c
Change-Id: Ic896276cd073d6b41930c7c5a29d66348cd1725d
1 parent 8a3b6fa commit 3670da0

File tree

4 files changed

+79
-26
lines changed

4 files changed

+79
-26
lines changed

Makefile.system

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ endif
184184

185185
endif
186186

187+
ifndef NUM_PARALLEL
188+
NUM_PARALLEL = 1
189+
endif
190+
187191
ifndef NUM_THREADS
188192
NUM_THREADS = $(NUM_CORES)
189193
endif
@@ -961,6 +965,8 @@ endif
961965

962966
CCOMMON_OPT += -DMAX_CPU_NUMBER=$(NUM_THREADS)
963967

968+
CCOMMON_OPT += -DMAX_PARALLEL_NUMBER=$(NUM_PARALLEL)
969+
964970
ifdef USE_SIMPLE_THREADED_LEVEL3
965971
CCOMMON_OPT += -DUSE_SIMPLE_THREADED_LEVEL3
966972
endif

cmake/system.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ if (NOT CMAKE_CROSSCOMPILING)
9696

9797
endif()
9898

99+
if (NOT DEFINED NUM_PARALLEL)
100+
set(NUM_PARALLEL 1)
101+
endif()
102+
99103
if (NOT DEFINED NUM_THREADS)
100104
if (DEFINED NUM_CORES AND NOT NUM_CORES EQUAL 0)
101105
# HT?
@@ -224,6 +228,8 @@ endif ()
224228

225229
set(CCOMMON_OPT "${CCOMMON_OPT} -DMAX_CPU_NUMBER=${NUM_THREADS}")
226230

231+
set(CCOMMON_OPT "${CCOMMON_OPT} -DMAX_PARALLEL_NUMBER=${NUM_PARALLEL}")
232+
227233
if (USE_SIMPLE_THREADED_LEVEL3)
228234
set(CCOMMON_OPT "${CCOMMON_OPT} -DUSE_SIMPLE_THREADED_LEVEL3")
229235
endif ()

common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ extern "C" {
179179

180180
#define ALLOCA_ALIGN 63UL
181181

182-
#define NUM_BUFFERS (MAX_CPU_NUMBER * 2)
182+
#define NUM_BUFFERS (MAX_CPU_NUMBER * 2 * MAX_PARALLEL_NUMBER)
183183

184184
#ifdef NEEDBUNDERSCORE
185185
#define BLASFUNC(FUNC) FUNC##_

driver/others/blas_server_omp.c

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,13 @@
3636
/* or implied, of The University of Texas at Austin. */
3737
/*********************************************************************/
3838

39+
#if _STDC_VERSION__ >= 201112L
40+
#ifndef _Atomic
41+
#define _Atomic volatile
42+
#endif
43+
#include <stdatomic.h>
44+
#endif
45+
#include <stdbool.h>
3946
#include <stdio.h>
4047
#include <stdlib.h>
4148
//#include <sys/mman.h>
@@ -49,11 +56,16 @@
4956

5057
int blas_server_avail = 0;
5158

52-
static void * blas_thread_buffer[MAX_CPU_NUMBER];
59+
static void * blas_thread_buffer[MAX_PARALLEL_NUMBER][MAX_CPU_NUMBER];
60+
#if _STDC_VERSION__ >= 201112L
61+
static atomic_bool blas_buffer_inuse[MAX_PARALLEL_NUMBER];
62+
#else
63+
static _Bool blas_buffer_inuse[MAX_PARALLEL_NUMBER];
64+
#endif
5365

5466
void goto_set_num_threads(int num_threads) {
5567

56-
int i=0;
68+
int i=0, j=0;
5769

5870
if (num_threads < 1) num_threads = blas_num_threads;
5971

@@ -68,15 +80,17 @@ void goto_set_num_threads(int num_threads) {
6880
omp_set_num_threads(blas_cpu_number);
6981

7082
//adjust buffer for each thread
71-
for(i=0; i<blas_cpu_number; i++){
72-
if(blas_thread_buffer[i]==NULL){
73-
blas_thread_buffer[i]=blas_memory_alloc(2);
83+
for(i=0; i<MAX_PARALLEL_NUMBER; i++) {
84+
for(j=0; j<blas_cpu_number; j++){
85+
if(blas_thread_buffer[i][j]==NULL){
86+
blas_thread_buffer[i][j]=blas_memory_alloc(2);
87+
}
7488
}
75-
}
76-
for(; i<MAX_CPU_NUMBER; i++){
77-
if(blas_thread_buffer[i]!=NULL){
78-
blas_memory_free(blas_thread_buffer[i]);
79-
blas_thread_buffer[i]=NULL;
89+
for(; j<MAX_CPU_NUMBER; j++){
90+
if(blas_thread_buffer[i][j]!=NULL){
91+
blas_memory_free(blas_thread_buffer[i][j]);
92+
blas_thread_buffer[i][j]=NULL;
93+
}
8094
}
8195
}
8296
#if defined(ARCH_MIPS64)
@@ -92,30 +106,34 @@ void openblas_set_num_threads(int num_threads) {
92106

93107
int blas_thread_init(void){
94108

95-
int i=0;
109+
int i=0, j=0;
96110

97111
blas_get_cpu_number();
98112

99113
blas_server_avail = 1;
100114

101-
for(i=0; i<blas_num_threads; i++){
102-
blas_thread_buffer[i]=blas_memory_alloc(2);
103-
}
104-
for(; i<MAX_CPU_NUMBER; i++){
105-
blas_thread_buffer[i]=NULL;
115+
for(i=0; i<MAX_PARALLEL_NUMBER; i++) {
116+
for(j=0; j<blas_num_threads; j++){
117+
blas_thread_buffer[i][j]=blas_memory_alloc(2);
118+
}
119+
for(; j<MAX_CPU_NUMBER; j++){
120+
blas_thread_buffer[i][j]=NULL;
121+
}
106122
}
107123

108124
return 0;
109125
}
110126

111127
int BLASFUNC(blas_thread_shutdown)(void){
112-
int i=0;
128+
int i=0, j=0;
113129
blas_server_avail = 0;
114130

115-
for(i=0; i<MAX_CPU_NUMBER; i++){
116-
if(blas_thread_buffer[i]!=NULL){
117-
blas_memory_free(blas_thread_buffer[i]);
118-
blas_thread_buffer[i]=NULL;
131+
for(i=0; i<MAX_PARALLEL_NUMBER; i++) {
132+
for(j=0; j<MAX_CPU_NUMBER; j++){
133+
if(blas_thread_buffer[i][j]!=NULL){
134+
blas_memory_free(blas_thread_buffer[i][j]);
135+
blas_thread_buffer[i][j]=NULL;
136+
}
119137
}
120138
}
121139

@@ -206,7 +224,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
206224
}
207225
}
208226

209-
static void exec_threads(blas_queue_t *queue){
227+
static void exec_threads(blas_queue_t *queue, int buf_index){
210228

211229
void *buffer, *sa, *sb;
212230
int pos=0, release_flag=0;
@@ -223,7 +241,7 @@ static void exec_threads(blas_queue_t *queue){
223241
if ((sa == NULL) && (sb == NULL) && ((queue -> mode & BLAS_PTHREAD) == 0)) {
224242

225243
pos = omp_get_thread_num();
226-
buffer = blas_thread_buffer[pos];
244+
buffer = blas_thread_buffer[buf_index][pos];
227245

228246
//fallback
229247
if(buffer==NULL) {
@@ -291,7 +309,7 @@ static void exec_threads(blas_queue_t *queue){
291309

292310
int exec_blas(BLASLONG num, blas_queue_t *queue){
293311

294-
BLASLONG i;
312+
BLASLONG i, buf_index;
295313

296314
if ((num <= 0) || (queue == NULL)) return 0;
297315

@@ -302,16 +320,39 @@ int exec_blas(BLASLONG num, blas_queue_t *queue){
302320
}
303321
#endif
304322

323+
while(true) {
324+
for(i=0; i < MAX_PARALLEL_NUMBER; i++) {
325+
#if _STDC_VERSION__ >= 201112L
326+
_Bool inuse = false;
327+
if(atomic_compare_exchange_weak(&blas_buffer_inuse[i], &inuse, true)) {
328+
#else
329+
if(blas_buffer_inuse[i] == false) {
330+
blas_buffer_inuse[i] = true;
331+
#endif
332+
buf_index = i;
333+
break;
334+
}
335+
}
336+
if(i != MAX_PARALLEL_NUMBER)
337+
break;
338+
}
339+
305340
#pragma omp parallel for schedule(static)
306341
for (i = 0; i < num; i ++) {
307342

308343
#ifndef USE_SIMPLE_THREADED_LEVEL3
309344
queue[i].position = i;
310345
#endif
311346

312-
exec_threads(&queue[i]);
347+
exec_threads(&queue[i], buf_index);
313348
}
314349

350+
#if _STDC_VERSION__ >= 201112L
351+
atomic_store(&blas_buffer_inuse[buf_index], false);
352+
#else
353+
blas_buffer_inuse[buf_index] = false;
354+
#endif
355+
315356
return 0;
316357
}
317358

0 commit comments

Comments
 (0)