Skip to content

Commit a9d2b01

Browse files
committed
Introduced callback to Pthread, Win32 and OpenMP backend
1 parent 87f83eb commit a9d2b01

8 files changed

+320
-186
lines changed

cblas.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ char* openblas_get_config(void);
2626
/*Get the CPU corename on runtime.*/
2727
char* openblas_get_corename(void);
2828

29+
/*Set the threading backend to a custom callback.*/
30+
typedef void (*openblas_dojob_callback)(int thread_num, void *jobdata, int dojob_data);
31+
typedef void (*openblas_threads_callback)(int sync, openblas_dojob_callback dojob, int numjobs, size_t jobdata_elsize, void *jobdata, int dojob_data);
32+
void openblas_set_threads_callback_function(openblas_threads_callback callback);
33+
2934
#ifdef OPENBLAS_OS_LINUX
3035
/* Sets thread affinity for OpenBLAS threads. `thread_idx` is in [0, openblas_get_num_threads()-1]. */
3136
int openblas_setaffinity(int thread_idx, size_t cpusetsize, cpu_set_t* cpu_set);

common_interface.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ int BLASFUNC(xerbla)(char *, blasint *info, blasint);
4747

4848
void openblas_set_num_threads_(int *);
4949

50+
/*Set the threading backend to a custom callback.*/
51+
typedef void (*openblas_dojob_callback)(int thread_num, void *jobdata, int dojob_data);
52+
typedef void (*openblas_threads_callback)(int sync, openblas_dojob_callback dojob, int numjobs, size_t jobdata_elsize, void *jobdata, int dojob_data);
53+
extern openblas_threads_callback openblas_threads_callback_;
54+
5055
FLOATRET BLASFUNC(sdot) (blasint *, float *, blasint *, float *, blasint *);
5156
FLOATRET BLASFUNC(sdsdot)(blasint *, float *, float *, blasint *, float *, blasint *);
5257

driver/others/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ if (USE_THREAD)
2525
${BLAS_SERVER}
2626
divtable.c # TODO: Makefile has -UDOUBLE
2727
blas_l1_thread.c
28+
blas_server_callback.c
2829
)
2930

3031
if (NOT NO_AFFINITY)

driver/others/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ COMMONOBJS = memory.$(SUFFIX) xerbla.$(SUFFIX) c_abs.$(SUFFIX) z_abs.$(SUFFIX)
66
#COMMONOBJS += slamch.$(SUFFIX) slamc3.$(SUFFIX) dlamch.$(SUFFIX) dlamc3.$(SUFFIX)
77

88
ifdef SMP
9-
COMMONOBJS += blas_server.$(SUFFIX) divtable.$(SUFFIX) blasL1thread.$(SUFFIX)
9+
COMMONOBJS += blas_server.$(SUFFIX) divtable.$(SUFFIX) blasL1thread.$(SUFFIX) blas_server_callback.$(SUFFIX)
1010
ifneq ($(NO_AFFINITY), 1)
1111
COMMONOBJS += init.$(SUFFIX)
1212
endif

driver/others/blas_server.c

Lines changed: 144 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ int blas_server_avail __attribute__((aligned(ATTRIBUTE_SIZE))) = 0;
115115

116116
int blas_omp_threads_local = 1;
117117

118+
static void * blas_thread_buffer[MAX_CPU_NUMBER];
119+
118120
/* Local Variables */
119121
#if defined(USE_PTHREAD_LOCK)
120122
static pthread_mutex_t server_lock = PTHREAD_MUTEX_INITIALIZER;
@@ -190,6 +192,10 @@ static int main_status[MAX_CPU_NUMBER];
190192
BLASLONG exit_time[MAX_CPU_NUMBER];
191193
#endif
192194

195+
//Prototypes
196+
static void exec_threads(int , blas_queue_t *, int);
197+
static void adjust_thread_buffers();
198+
193199
static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
194200

195201
if (!(mode & BLAS_COMPLEX)){
@@ -375,7 +381,6 @@ static void* blas_thread_server(void *arg){
375381
/* Thread identifier */
376382
BLASLONG cpu = (BLASLONG)arg;
377383
unsigned int last_tick;
378-
void *buffer, *sa, *sb;
379384
blas_queue_t *queue;
380385

381386
blas_queue_t *tscq;
@@ -395,8 +400,6 @@ blas_queue_t *tscq;
395400
main_status[cpu] = MAIN_ENTER;
396401
#endif
397402

398-
buffer = blas_memory_alloc(2);
399-
400403
#ifdef SMP_DEBUG
401404
fprintf(STDERR, "Server[%2ld] Thread has just been spawned!\n", cpu);
402405
#endif
@@ -457,92 +460,8 @@ blas_queue_t *tscq;
457460
#endif
458461

459462
if (queue) {
460-
int (*routine)(blas_arg_t *, void *, void *, void *, void *, BLASLONG) = (int (*)(blas_arg_t *, void *, void *, void *, void *, BLASLONG))queue -> routine;
461463

462-
atomic_store_queue(&thread_status[cpu].queue, (blas_queue_t *)1);
463-
464-
sa = queue -> sa;
465-
sb = queue -> sb;
466-
467-
#ifdef SMP_DEBUG
468-
if (queue -> args) {
469-
fprintf(STDERR, "Server[%2ld] Calculation started. Mode = 0x%03x M = %3ld N=%3ld K=%3ld\n",
470-
cpu, queue->mode, queue-> args ->m, queue->args->n, queue->args->k);
471-
}
472-
#endif
473-
474-
#ifdef CONSISTENT_FPCSR
475-
#ifdef __aarch64__
476-
__asm__ __volatile__ ("msr fpcr, %0" : : "r" (queue -> sse_mode));
477-
#else
478-
__asm__ __volatile__ ("ldmxcsr %0" : : "m" (queue -> sse_mode));
479-
__asm__ __volatile__ ("fldcw %0" : : "m" (queue -> x87_mode));
480-
#endif
481-
#endif
482-
483-
#ifdef MONITOR
484-
main_status[cpu] = MAIN_RUNNING1;
485-
#endif
486-
487-
if (sa == NULL) sa = (void *)((BLASLONG)buffer + GEMM_OFFSET_A);
488-
489-
if (sb == NULL) {
490-
if (!(queue -> mode & BLAS_COMPLEX)){
491-
#ifdef EXPRECISION
492-
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
493-
sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble)
494-
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
495-
} else
496-
#endif
497-
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE) {
498-
#ifdef BUILD_DOUBLE
499-
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
500-
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
501-
#endif
502-
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
503-
#ifdef BUILD_SINGLE
504-
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
505-
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
506-
#endif
507-
} else {
508-
/* Other types in future */
509-
}
510-
} else {
511-
#ifdef EXPRECISION
512-
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
513-
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
514-
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
515-
} else
516-
#endif
517-
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
518-
#ifdef BUILD_COMPLEX16
519-
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
520-
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
521-
#endif
522-
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
523-
#ifdef BUILD_COMPLEX
524-
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
525-
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
526-
#endif
527-
} else {
528-
/* Other types in future */
529-
}
530-
}
531-
queue->sb=sb;
532-
}
533-
534-
#ifdef MONITOR
535-
main_status[cpu] = MAIN_RUNNING2;
536-
#endif
537-
538-
if (queue -> mode & BLAS_LEGACY) {
539-
legacy_exec(routine, queue -> mode, queue -> args, sb);
540-
} else
541-
if (queue -> mode & BLAS_PTHREAD) {
542-
void (*pthreadcompat)(void *) = (void(*)(void*))queue -> routine;
543-
(pthreadcompat)(queue -> args);
544-
} else
545-
(routine)(queue -> args, queue -> range_m, queue -> range_n, sa, sb, queue -> position);
464+
exec_threads(cpu, queue, 0);
546465

547466
#ifdef SMP_DEBUG
548467
fprintf(STDERR, "Server[%2ld] Calculation finished!\n", cpu);
@@ -557,7 +476,7 @@ blas_queue_t *tscq;
557476
MB;
558477
atomic_store_queue(&thread_status[cpu].queue, (blas_queue_t *)0);
559478

560-
479+
561480
}
562481

563482
#ifdef MONITOR
@@ -580,8 +499,6 @@ blas_queue_t *tscq;
580499
fprintf(STDERR, "Server[%2ld] Shutdown!\n", cpu);
581500
#endif
582501

583-
blas_memory_free(buffer);
584-
585502
//pthread_exit(NULL);
586503

587504
return NULL;
@@ -663,6 +580,9 @@ int blas_thread_init(void){
663580

664581
LOCK_COMMAND(&server_lock);
665582

583+
// Adjust thread buffers
584+
adjust_thread_buffers();
585+
666586
if (!blas_server_avail){
667587

668588
thread_timeout_env=openblas_thread_timeout();
@@ -893,6 +813,18 @@ int exec_blas(BLASLONG num, blas_queue_t *queue){
893813
fprintf(STDERR, "Exec_blas is called. Number of executing threads : %ld\n", num);
894814
#endif
895815

816+
//Redirect to caller's callback routine
817+
if (openblas_threads_callback_) {
818+
int buf_index = 0;
819+
#ifndef USE_SIMPLE_THREADED_LEVEL3
820+
for (int i = 0; i < num; i ++)
821+
queue[i].position = i;
822+
#endif
823+
openblas_threads_callback_(1, (openblas_dojob_callback) exec_threads, num, sizeof(blas_queue_t), (void*) queue, buf_index);
824+
return 0;
825+
}
826+
827+
896828
#ifdef __ELF__
897829
if (omp_in_parallel && (num > 1)) {
898830
if (omp_in_parallel() > 0) {
@@ -1066,6 +998,14 @@ int BLASFUNC(blas_thread_shutdown)(void){
1066998

1067999
LOCK_COMMAND(&server_lock);
10681000

1001+
//Free buffers allocated for threads
1002+
for(i=0; i<MAX_CPU_NUMBER; i++){
1003+
if(blas_thread_buffer[i]!=NULL){
1004+
blas_memory_free(blas_thread_buffer[i]);
1005+
blas_thread_buffer[i]=NULL;
1006+
}
1007+
}
1008+
10691009
if (blas_server_avail) {
10701010

10711011
for (i = 0; i < blas_num_threads - 1; i++) {
@@ -1102,5 +1042,118 @@ int BLASFUNC(blas_thread_shutdown)(void){
11021042
return 0;
11031043
}
11041044

1045+
static void adjust_thread_buffers() {
1046+
1047+
int i=0;
1048+
1049+
//adjust buffer for each thread
1050+
for(i=0; i < blas_cpu_number; i++){
1051+
if(blas_thread_buffer[i] == NULL){
1052+
blas_thread_buffer[i] = blas_memory_alloc(2);
1053+
}
1054+
}
1055+
for(; i < MAX_CPU_NUMBER; i++){
1056+
if(blas_thread_buffer[i] != NULL){
1057+
blas_memory_free(blas_thread_buffer[i]);
1058+
blas_thread_buffer[i] = NULL;
1059+
}
1060+
}
1061+
}
1062+
1063+
static void exec_threads(int cpu, blas_queue_t *queue, int buf_index)
1064+
{
1065+
1066+
void *buffer, *sa, *sb;
1067+
1068+
buffer = blas_thread_buffer[cpu];
1069+
1070+
int (*routine)(blas_arg_t *, void *, void *, void *, void *, BLASLONG) = (int (*)(blas_arg_t *, void *, void *, void *, void *, BLASLONG))queue -> routine;
1071+
1072+
atomic_store_queue(&thread_status[cpu].queue, (blas_queue_t *)1);
1073+
1074+
sa = queue -> sa;
1075+
sb = queue -> sb;
1076+
1077+
#ifdef SMP_DEBUG
1078+
if (queue -> args) {
1079+
fprintf(STDERR, "Server[%2ld] Calculation started. Mode = 0x%03x M = %3ld N=%3ld K=%3ld\n",
1080+
cpu, queue->mode, queue-> args ->m, queue->args->n, queue->args->k);
1081+
}
1082+
#endif
1083+
1084+
#ifdef CONSISTENT_FPCSR
1085+
#ifdef __aarch64__
1086+
__asm__ __volatile__ ("msr fpcr, %0" : : "r" (queue -> sse_mode));
1087+
#else
1088+
__asm__ __volatile__ ("ldmxcsr %0" : : "m" (queue -> sse_mode));
1089+
__asm__ __volatile__ ("fldcw %0" : : "m" (queue -> x87_mode));
1090+
#endif
1091+
#endif
1092+
1093+
#ifdef MONITOR
1094+
main_status[cpu] = MAIN_RUNNING1;
1095+
#endif
1096+
1097+
if (sa == NULL) sa = (void *)((BLASLONG)buffer + GEMM_OFFSET_A);
1098+
1099+
if (sb == NULL) {
1100+
if (!(queue -> mode & BLAS_COMPLEX)){
1101+
#ifdef EXPRECISION
1102+
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
1103+
sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble)
1104+
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
1105+
} else
1106+
#endif
1107+
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE) {
1108+
#ifdef BUILD_DOUBLE
1109+
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
1110+
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
1111+
#endif
1112+
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
1113+
#ifdef BUILD_SINGLE
1114+
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
1115+
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
1116+
#endif
1117+
} else {
1118+
/* Other types in future */
1119+
}
1120+
} else {
1121+
#ifdef EXPRECISION
1122+
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
1123+
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
1124+
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
1125+
} else
1126+
#endif
1127+
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
1128+
#ifdef BUILD_COMPLEX16
1129+
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
1130+
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
11051131
#endif
1132+
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
1133+
#ifdef BUILD_COMPLEX
1134+
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
1135+
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
1136+
#endif
1137+
} else {
1138+
/* Other types in future */
1139+
}
1140+
}
1141+
queue->sb=sb;
1142+
}
1143+
1144+
#ifdef MONITOR
1145+
main_status[cpu] = MAIN_RUNNING2;
1146+
#endif
1147+
1148+
if (queue -> mode & BLAS_LEGACY) {
1149+
legacy_exec(routine, queue -> mode, queue -> args, sb);
1150+
} else
1151+
if (queue -> mode & BLAS_PTHREAD) {
1152+
void (*pthreadcompat)(void *) = (void(*)(void*))queue -> routine;
1153+
(pthreadcompat)(queue -> args);
1154+
} else
1155+
(routine)(queue -> args, queue -> range_m, queue -> range_n, sa, sb, queue -> position);
1156+
1157+
}
11061158

1159+
#endif

driver/others/blas_server_callback.c

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
#include "common.h"
2+
3+
/* global variable to change threading backend from openblas-managed to caller-managed */
4+
openblas_threads_callback openblas_threads_callback_ = 0;
5+
6+
/* non-threadsafe function should be called before any other
7+
openblas function to change how threads are managed */
8+
9+
void openblas_set_threads_callback_function(openblas_threads_callback callback)
10+
{
11+
openblas_threads_callback_ = callback;
12+
}

0 commit comments

Comments
 (0)