Skip to content

Commit ee5e11c

Browse files
author
Zhiyong Dang
committed
Fix race condition in blas_server_omp.c
Change-Id: Ic896276cd073d6b41930c7c5a29d66348cd1725d
1 parent 8a3b6fa commit ee5e11c

File tree

4 files changed

+64
-26
lines changed

4 files changed

+64
-26
lines changed

Makefile.system

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,10 @@ endif
184184

185185
endif
186186

187+
ifndef NUM_PARALLEL
188+
NUM_PARALLEL = 1
189+
endif
190+
187191
ifndef NUM_THREADS
188192
NUM_THREADS = $(NUM_CORES)
189193
endif
@@ -961,6 +965,8 @@ endif
961965

962966
CCOMMON_OPT += -DMAX_CPU_NUMBER=$(NUM_THREADS)
963967

968+
CCOMMON_OPT += -DMAX_PARALLEL_NUMBER=$(NUM_PARALLEL)
969+
964970
ifdef USE_SIMPLE_THREADED_LEVEL3
965971
CCOMMON_OPT += -DUSE_SIMPLE_THREADED_LEVEL3
966972
endif

cmake/system.cmake

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ if (NOT CMAKE_CROSSCOMPILING)
9696

9797
endif()
9898

99+
if (NOT DEFINED NUM_PARALLEL)
100+
set(NUM_PARALLEL 1)
101+
endif()
102+
99103
if (NOT DEFINED NUM_THREADS)
100104
if (DEFINED NUM_CORES AND NOT NUM_CORES EQUAL 0)
101105
# HT?
@@ -224,6 +228,8 @@ endif ()
224228

225229
set(CCOMMON_OPT "${CCOMMON_OPT} -DMAX_CPU_NUMBER=${NUM_THREADS}")
226230

231+
set(CCOMMON_OPT "${CCOMMON_OPT} -DMAX_PARALLEL_NUMBER=${NUM_PARALLEL}")
232+
227233
if (USE_SIMPLE_THREADED_LEVEL3)
228234
set(CCOMMON_OPT "${CCOMMON_OPT} -DUSE_SIMPLE_THREADED_LEVEL3")
229235
endif ()

common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ extern "C" {
179179

180180
#define ALLOCA_ALIGN 63UL
181181

182-
#define NUM_BUFFERS (MAX_CPU_NUMBER * 2)
182+
#define NUM_BUFFERS (MAX_CPU_NUMBER * 2 * MAX_PARALLEL_NUMBER)
183183

184184
#ifdef NEEDBUNDERSCORE
185185
#define BLASFUNC(FUNC) FUNC##_

driver/others/blas_server_omp.c

Lines changed: 51 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,11 @@
3636
/* or implied, of The University of Texas at Austin. */
3737
/*********************************************************************/
3838

39+
#ifndef _Atomic
40+
#define _Atomic volatile
41+
#endif
42+
#include <stdatomic.h>
43+
#include <stdbool.h>
3944
#include <stdio.h>
4045
#include <stdlib.h>
4146
//#include <sys/mman.h>
@@ -49,11 +54,12 @@
4954

5055
int blas_server_avail = 0;
5156

52-
static void * blas_thread_buffer[MAX_CPU_NUMBER];
57+
static void * blas_thread_buffer[MAX_PARALLEL_NUMBER][MAX_CPU_NUMBER];
58+
static atomic_bool blas_buffer_inuse[MAX_PARALLEL_NUMBER];
5359

5460
void goto_set_num_threads(int num_threads) {
5561

56-
int i=0;
62+
int i=0, j=0;
5763

5864
if (num_threads < 1) num_threads = blas_num_threads;
5965

@@ -68,15 +74,17 @@ void goto_set_num_threads(int num_threads) {
6874
omp_set_num_threads(blas_cpu_number);
6975

7076
//adjust buffer for each thread
71-
for(i=0; i<blas_cpu_number; i++){
72-
if(blas_thread_buffer[i]==NULL){
73-
blas_thread_buffer[i]=blas_memory_alloc(2);
77+
for(i=0; i<MAX_PARALLEL_NUMBER; i++) {
78+
for(j=0; j<blas_cpu_number; j++){
79+
if(blas_thread_buffer[i][j]==NULL){
80+
blas_thread_buffer[i][j]=blas_memory_alloc(2);
81+
}
7482
}
75-
}
76-
for(; i<MAX_CPU_NUMBER; i++){
77-
if(blas_thread_buffer[i]!=NULL){
78-
blas_memory_free(blas_thread_buffer[i]);
79-
blas_thread_buffer[i]=NULL;
83+
for(; j<MAX_CPU_NUMBER; j++){
84+
if(blas_thread_buffer[i][j]!=NULL){
85+
blas_memory_free(blas_thread_buffer[i][j]);
86+
blas_thread_buffer[i][j]=NULL;
87+
}
8088
}
8189
}
8290
#if defined(ARCH_MIPS64)
@@ -92,30 +100,34 @@ void openblas_set_num_threads(int num_threads) {
92100

93101
int blas_thread_init(void){
94102

95-
int i=0;
103+
int i=0, j=0;
96104

97105
blas_get_cpu_number();
98106

99107
blas_server_avail = 1;
100108

101-
for(i=0; i<blas_num_threads; i++){
102-
blas_thread_buffer[i]=blas_memory_alloc(2);
103-
}
104-
for(; i<MAX_CPU_NUMBER; i++){
105-
blas_thread_buffer[i]=NULL;
109+
for(i=0; i<MAX_PARALLEL_NUMBER; i++) {
110+
for(j=0; j<blas_num_threads; j++){
111+
blas_thread_buffer[i][j]=blas_memory_alloc(2);
112+
}
113+
for(; j<MAX_CPU_NUMBER; j++){
114+
blas_thread_buffer[i][j]=NULL;
115+
}
106116
}
107117

108118
return 0;
109119
}
110120

111121
int BLASFUNC(blas_thread_shutdown)(void){
112-
int i=0;
122+
int i=0, j=0;
113123
blas_server_avail = 0;
114124

115-
for(i=0; i<MAX_CPU_NUMBER; i++){
116-
if(blas_thread_buffer[i]!=NULL){
117-
blas_memory_free(blas_thread_buffer[i]);
118-
blas_thread_buffer[i]=NULL;
125+
for(i=0; i<MAX_PARALLEL_NUMBER; i++) {
126+
for(j=0; j<MAX_CPU_NUMBER; j++){
127+
if(blas_thread_buffer[i][j]!=NULL){
128+
blas_memory_free(blas_thread_buffer[i][j]);
129+
blas_thread_buffer[i][j]=NULL;
130+
}
119131
}
120132
}
121133

@@ -206,7 +218,7 @@ static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
206218
}
207219
}
208220

209-
static void exec_threads(blas_queue_t *queue){
221+
static void exec_threads(blas_queue_t *queue, int buf_index){
210222

211223
void *buffer, *sa, *sb;
212224
int pos=0, release_flag=0;
@@ -223,7 +235,7 @@ static void exec_threads(blas_queue_t *queue){
223235
if ((sa == NULL) && (sb == NULL) && ((queue -> mode & BLAS_PTHREAD) == 0)) {
224236

225237
pos = omp_get_thread_num();
226-
buffer = blas_thread_buffer[pos];
238+
buffer = blas_thread_buffer[buf_index][pos];
227239

228240
//fallback
229241
if(buffer==NULL) {
@@ -291,7 +303,7 @@ static void exec_threads(blas_queue_t *queue){
291303

292304
int exec_blas(BLASLONG num, blas_queue_t *queue){
293305

294-
BLASLONG i;
306+
BLASLONG i, buf_index;
295307

296308
if ((num <= 0) || (queue == NULL)) return 0;
297309

@@ -302,16 +314,30 @@ int exec_blas(BLASLONG num, blas_queue_t *queue){
302314
}
303315
#endif
304316

317+
while(true) {
318+
for(i=0; i < MAX_PARALLEL_NUMBER; i++) {
319+
_Bool inuse = false;
320+
if(atomic_compare_exchange_weak(&blas_buffer_inuse[i], &inuse, true)) {
321+
buf_index = i;
322+
break;
323+
}
324+
}
325+
if(i != MAX_PARALLEL_NUMBER)
326+
break;
327+
}
328+
305329
#pragma omp parallel for schedule(static)
306330
for (i = 0; i < num; i ++) {
307331

308332
#ifndef USE_SIMPLE_THREADED_LEVEL3
309333
queue[i].position = i;
310334
#endif
311335

312-
exec_threads(&queue[i]);
336+
exec_threads(&queue[i], buf_index);
313337
}
314338

339+
atomic_store(&blas_buffer_inuse[buf_index], false);
340+
315341
return 0;
316342
}
317343

0 commit comments

Comments
 (0)