Skip to content

Commit c1d8fbc

Browse files
committed
Introduced callback to Pthread, Win32 and OpenMP backend
1 parent 87f83eb commit c1d8fbc

8 files changed

+342
-209
lines changed

cblas.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ char* openblas_get_config(void);
2626
/*Get the CPU corename on runtime.*/
2727
char* openblas_get_corename(void);
2828

29+
/*Set the threading backend to a custom callback.*/
30+
typedef void (*openblas_dojob_callback)(int thread_num, void *jobdata, int dojob_data);
31+
typedef void (*openblas_threads_callback)(int sync, openblas_dojob_callback dojob, int numjobs, size_t jobdata_elsize, void *jobdata, int dojob_data);
32+
void openblas_set_threads_callback_function(openblas_threads_callback callback);
33+
2934
#ifdef OPENBLAS_OS_LINUX
3035
/* Sets thread affinity for OpenBLAS threads. `thread_idx` is in [0, openblas_get_num_threads()-1]. */
3136
int openblas_setaffinity(int thread_idx, size_t cpusetsize, cpu_set_t* cpu_set);

common_interface.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ int BLASFUNC(xerbla)(char *, blasint *info, blasint);
4747

4848
void openblas_set_num_threads_(int *);
4949

50+
/*Set the threading backend to a custom callback.*/
51+
typedef void (*openblas_dojob_callback)(int thread_num, void *jobdata, int dojob_data);
52+
typedef void (*openblas_threads_callback)(int sync, openblas_dojob_callback dojob, int numjobs, size_t jobdata_elsize, void *jobdata, int dojob_data);
53+
extern openblas_threads_callback openblas_threads_callback_;
54+
5055
FLOATRET BLASFUNC(sdot) (blasint *, float *, blasint *, float *, blasint *);
5156
FLOATRET BLASFUNC(sdsdot)(blasint *, float *, float *, blasint *, float *, blasint *);
5257

driver/others/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ if (USE_THREAD)
2525
${BLAS_SERVER}
2626
divtable.c # TODO: Makefile has -UDOUBLE
2727
blas_l1_thread.c
28+
blas_server_callback.c
2829
)
2930

3031
if (NOT NO_AFFINITY)

driver/others/Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ COMMONOBJS = memory.$(SUFFIX) xerbla.$(SUFFIX) c_abs.$(SUFFIX) z_abs.$(SUFFIX)
66
#COMMONOBJS += slamch.$(SUFFIX) slamc3.$(SUFFIX) dlamch.$(SUFFIX) dlamc3.$(SUFFIX)
77

88
ifdef SMP
9-
COMMONOBJS += blas_server.$(SUFFIX) divtable.$(SUFFIX) blasL1thread.$(SUFFIX)
9+
COMMONOBJS += blas_server.$(SUFFIX) divtable.$(SUFFIX) blasL1thread.$(SUFFIX) blas_server_callback.$(SUFFIX)
1010
ifneq ($(NO_AFFINITY), 1)
1111
COMMONOBJS += init.$(SUFFIX)
1212
endif
@@ -140,6 +140,9 @@ memory.$(SUFFIX) : $(MEMORY) ../../common.h ../../param.h
140140
blas_server.$(SUFFIX) : $(BLAS_SERVER) ../../common.h ../../common_thread.h ../../param.h
141141
$(CC) $(CFLAGS) -c $< -o $(@F)
142142

143+
blas_server_callback.$(SUFFIX) : blas_server_callback.c ../../common.h
144+
$(CC) $(CFLAGS) -c $< -o $(@F)
145+
143146
openblas_set_num_threads.$(SUFFIX) : openblas_set_num_threads.c
144147
$(CC) $(CFLAGS) -c $< -o $(@F)
145148

driver/others/blas_server.c

Lines changed: 158 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,8 @@ int blas_server_avail __attribute__((aligned(ATTRIBUTE_SIZE))) = 0;
115115

116116
int blas_omp_threads_local = 1;
117117

118+
static void * blas_thread_buffer[MAX_CPU_NUMBER];
119+
118120
/* Local Variables */
119121
#if defined(USE_PTHREAD_LOCK)
120122
static pthread_mutex_t server_lock = PTHREAD_MUTEX_INITIALIZER;
@@ -190,6 +192,10 @@ static int main_status[MAX_CPU_NUMBER];
190192
BLASLONG exit_time[MAX_CPU_NUMBER];
191193
#endif
192194

195+
//Prototypes
196+
static void exec_threads(int , blas_queue_t *, int);
197+
static void adjust_thread_buffers();
198+
193199
static void legacy_exec(void *func, int mode, blas_arg_t *args, void *sb){
194200

195201
if (!(mode & BLAS_COMPLEX)){
@@ -375,7 +381,6 @@ static void* blas_thread_server(void *arg){
375381
/* Thread identifier */
376382
BLASLONG cpu = (BLASLONG)arg;
377383
unsigned int last_tick;
378-
void *buffer, *sa, *sb;
379384
blas_queue_t *queue;
380385

381386
blas_queue_t *tscq;
@@ -395,8 +400,6 @@ blas_queue_t *tscq;
395400
main_status[cpu] = MAIN_ENTER;
396401
#endif
397402

398-
buffer = blas_memory_alloc(2);
399-
400403
#ifdef SMP_DEBUG
401404
fprintf(STDERR, "Server[%2ld] Thread has just been spawned!\n", cpu);
402405
#endif
@@ -456,109 +459,7 @@ blas_queue_t *tscq;
456459
start = rpcc();
457460
#endif
458461

459-
if (queue) {
460-
int (*routine)(blas_arg_t *, void *, void *, void *, void *, BLASLONG) = (int (*)(blas_arg_t *, void *, void *, void *, void *, BLASLONG))queue -> routine;
461-
462-
atomic_store_queue(&thread_status[cpu].queue, (blas_queue_t *)1);
463-
464-
sa = queue -> sa;
465-
sb = queue -> sb;
466-
467-
#ifdef SMP_DEBUG
468-
if (queue -> args) {
469-
fprintf(STDERR, "Server[%2ld] Calculation started. Mode = 0x%03x M = %3ld N=%3ld K=%3ld\n",
470-
cpu, queue->mode, queue-> args ->m, queue->args->n, queue->args->k);
471-
}
472-
#endif
473-
474-
#ifdef CONSISTENT_FPCSR
475-
#ifdef __aarch64__
476-
__asm__ __volatile__ ("msr fpcr, %0" : : "r" (queue -> sse_mode));
477-
#else
478-
__asm__ __volatile__ ("ldmxcsr %0" : : "m" (queue -> sse_mode));
479-
__asm__ __volatile__ ("fldcw %0" : : "m" (queue -> x87_mode));
480-
#endif
481-
#endif
482-
483-
#ifdef MONITOR
484-
main_status[cpu] = MAIN_RUNNING1;
485-
#endif
486-
487-
if (sa == NULL) sa = (void *)((BLASLONG)buffer + GEMM_OFFSET_A);
488-
489-
if (sb == NULL) {
490-
if (!(queue -> mode & BLAS_COMPLEX)){
491-
#ifdef EXPRECISION
492-
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
493-
sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble)
494-
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
495-
} else
496-
#endif
497-
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE) {
498-
#ifdef BUILD_DOUBLE
499-
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
500-
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
501-
#endif
502-
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
503-
#ifdef BUILD_SINGLE
504-
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
505-
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
506-
#endif
507-
} else {
508-
/* Other types in future */
509-
}
510-
} else {
511-
#ifdef EXPRECISION
512-
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
513-
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
514-
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
515-
} else
516-
#endif
517-
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
518-
#ifdef BUILD_COMPLEX16
519-
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
520-
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
521-
#endif
522-
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
523-
#ifdef BUILD_COMPLEX
524-
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
525-
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
526-
#endif
527-
} else {
528-
/* Other types in future */
529-
}
530-
}
531-
queue->sb=sb;
532-
}
533-
534-
#ifdef MONITOR
535-
main_status[cpu] = MAIN_RUNNING2;
536-
#endif
537-
538-
if (queue -> mode & BLAS_LEGACY) {
539-
legacy_exec(routine, queue -> mode, queue -> args, sb);
540-
} else
541-
if (queue -> mode & BLAS_PTHREAD) {
542-
void (*pthreadcompat)(void *) = (void(*)(void*))queue -> routine;
543-
(pthreadcompat)(queue -> args);
544-
} else
545-
(routine)(queue -> args, queue -> range_m, queue -> range_n, sa, sb, queue -> position);
546-
547-
#ifdef SMP_DEBUG
548-
fprintf(STDERR, "Server[%2ld] Calculation finished!\n", cpu);
549-
#endif
550-
551-
#ifdef MONITOR
552-
main_status[cpu] = MAIN_FINISH;
553-
#endif
554-
555-
// arm: make sure all results are written out _before_
556-
// thread is marked as done and other threads use them
557-
MB;
558-
atomic_store_queue(&thread_status[cpu].queue, (blas_queue_t *)0);
559-
560-
561-
}
462+
exec_threads(cpu, queue, 0);
562463

563464
#ifdef MONITOR
564465
main_status[cpu] = MAIN_DONE;
@@ -580,8 +481,6 @@ blas_queue_t *tscq;
580481
fprintf(STDERR, "Server[%2ld] Shutdown!\n", cpu);
581482
#endif
582483

583-
blas_memory_free(buffer);
584-
585484
//pthread_exit(NULL);
586485

587486
return NULL;
@@ -663,6 +562,9 @@ int blas_thread_init(void){
663562

664563
LOCK_COMMAND(&server_lock);
665564

565+
// Adjust thread buffers
566+
adjust_thread_buffers();
567+
666568
if (!blas_server_avail){
667569

668570
thread_timeout_env=openblas_thread_timeout();
@@ -893,6 +795,18 @@ int exec_blas(BLASLONG num, blas_queue_t *queue){
893795
fprintf(STDERR, "Exec_blas is called. Number of executing threads : %ld\n", num);
894796
#endif
895797

798+
//Redirect to caller's callback routine
799+
if (openblas_threads_callback_) {
800+
int buf_index = 0, i = 0;
801+
#ifndef USE_SIMPLE_THREADED_LEVEL3
802+
for (i = 0; i < num; i ++)
803+
queue[i].position = i;
804+
#endif
805+
openblas_threads_callback_(1, (openblas_dojob_callback) exec_threads, num, sizeof(blas_queue_t), (void*) queue, buf_index);
806+
return 0;
807+
}
808+
809+
896810
#ifdef __ELF__
897811
if (omp_in_parallel && (num > 1)) {
898812
if (omp_in_parallel() > 0) {
@@ -1066,6 +980,14 @@ int BLASFUNC(blas_thread_shutdown)(void){
1066980

1067981
LOCK_COMMAND(&server_lock);
1068982

983+
//Free buffers allocated for threads
984+
for(i=0; i<MAX_CPU_NUMBER; i++){
985+
if(blas_thread_buffer[i]!=NULL){
986+
blas_memory_free(blas_thread_buffer[i]);
987+
blas_thread_buffer[i]=NULL;
988+
}
989+
}
990+
1069991
if (blas_server_avail) {
1070992

1071993
for (i = 0; i < blas_num_threads - 1; i++) {
@@ -1102,5 +1024,132 @@ int BLASFUNC(blas_thread_shutdown)(void){
11021024
return 0;
11031025
}
11041026

1105-
#endif
1027+
static void adjust_thread_buffers() {
1028+
1029+
int i=0;
1030+
1031+
//adjust buffer for each thread
1032+
for(i=0; i < blas_cpu_number; i++){
1033+
if(blas_thread_buffer[i] == NULL){
1034+
blas_thread_buffer[i] = blas_memory_alloc(2);
1035+
}
1036+
}
1037+
for(; i < MAX_CPU_NUMBER; i++){
1038+
if(blas_thread_buffer[i] != NULL){
1039+
blas_memory_free(blas_thread_buffer[i]);
1040+
blas_thread_buffer[i] = NULL;
1041+
}
1042+
}
1043+
}
1044+
1045+
static void exec_threads(int cpu, blas_queue_t *queue, int buf_index)
1046+
{
1047+
1048+
if (queue) {
1049+
int (*routine)(blas_arg_t *, void *, void *, void *, void *, BLASLONG) = (int (*)(blas_arg_t *, void *, void *, void *, void *, BLASLONG))queue -> routine;
1050+
1051+
atomic_store_queue(&thread_status[cpu].queue, (blas_queue_t *)1);
1052+
1053+
void *buffer = blas_thread_buffer[cpu];
1054+
void *sa = queue -> sa;
1055+
void *sb = queue -> sb;
1056+
1057+
#ifdef SMP_DEBUG
1058+
if (queue -> args) {
1059+
fprintf(STDERR, "Server[%2ld] Calculation started. Mode = 0x%03x M = %3ld N=%3ld K=%3ld\n",
1060+
cpu, queue->mode, queue-> args ->m, queue->args->n, queue->args->k);
1061+
}
1062+
#endif
1063+
1064+
#ifdef CONSISTENT_FPCSR
1065+
#ifdef __aarch64__
1066+
__asm__ __volatile__ ("msr fpcr, %0" : : "r" (queue -> sse_mode));
1067+
#else
1068+
__asm__ __volatile__ ("ldmxcsr %0" : : "m" (queue -> sse_mode));
1069+
__asm__ __volatile__ ("fldcw %0" : : "m" (queue -> x87_mode));
1070+
#endif
1071+
#endif
1072+
1073+
#ifdef MONITOR
1074+
main_status[cpu] = MAIN_RUNNING1;
1075+
#endif
1076+
1077+
if (sa == NULL) sa = (void *)((BLASLONG)buffer + GEMM_OFFSET_A);
1078+
1079+
if (sb == NULL) {
1080+
if (!(queue -> mode & BLAS_COMPLEX)){
1081+
#ifdef EXPRECISION
1082+
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
1083+
sb = (void *)(((BLASLONG)sa + ((QGEMM_P * QGEMM_Q * sizeof(xdouble)
1084+
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
1085+
} else
1086+
#endif
1087+
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE) {
1088+
#ifdef BUILD_DOUBLE
1089+
sb = (void *)(((BLASLONG)sa + ((DGEMM_P * DGEMM_Q * sizeof(double)
1090+
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
1091+
#endif
1092+
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
1093+
#ifdef BUILD_SINGLE
1094+
sb = (void *)(((BLASLONG)sa + ((SGEMM_P * SGEMM_Q * sizeof(float)
1095+
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
1096+
#endif
1097+
} else {
1098+
/* Other types in future */
1099+
}
1100+
} else {
1101+
#ifdef EXPRECISION
1102+
if ((queue -> mode & BLAS_PREC) == BLAS_XDOUBLE){
1103+
sb = (void *)(((BLASLONG)sa + ((XGEMM_P * XGEMM_Q * 2 * sizeof(xdouble)
1104+
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
1105+
} else
1106+
#endif
1107+
if ((queue -> mode & BLAS_PREC) == BLAS_DOUBLE){
1108+
#ifdef BUILD_COMPLEX16
1109+
sb = (void *)(((BLASLONG)sa + ((ZGEMM_P * ZGEMM_Q * 2 * sizeof(double)
1110+
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
1111+
#endif
1112+
} else if ((queue -> mode & BLAS_PREC) == BLAS_SINGLE) {
1113+
#ifdef BUILD_COMPLEX
1114+
sb = (void *)(((BLASLONG)sa + ((CGEMM_P * CGEMM_Q * 2 * sizeof(float)
1115+
+ GEMM_ALIGN) & ~GEMM_ALIGN)) + GEMM_OFFSET_B);
1116+
#endif
1117+
} else {
1118+
/* Other types in future */
1119+
}
1120+
}
1121+
queue->sb=sb;
1122+
}
1123+
1124+
#ifdef MONITOR
1125+
main_status[cpu] = MAIN_RUNNING2;
1126+
#endif
1127+
1128+
if (queue -> mode & BLAS_LEGACY) {
1129+
legacy_exec(routine, queue -> mode, queue -> args, sb);
1130+
} else
1131+
if (queue -> mode & BLAS_PTHREAD) {
1132+
void (*pthreadcompat)(void *) = (void(*)(void*))queue -> routine;
1133+
(pthreadcompat)(queue -> args);
1134+
} else
1135+
(routine)(queue -> args, queue -> range_m, queue -> range_n, sa, sb, queue -> position);
1136+
1137+
#ifdef SMP_DEBUG
1138+
fprintf(STDERR, "Server[%2ld] Calculation finished!\n", cpu);
1139+
#endif
1140+
1141+
#ifdef MONITOR
1142+
main_status[cpu] = MAIN_FINISH;
1143+
#endif
1144+
1145+
// arm: make sure all results are written out _before_
1146+
// thread is marked as done and other threads use them
1147+
MB;
1148+
atomic_store_queue(&thread_status[cpu].queue, (blas_queue_t *)0);
1149+
1150+
1151+
}
1152+
1153+
}
11061154

1155+
#endif

0 commit comments

Comments
 (0)