Skip to content

Commit 8e5a108

Browse files
committed
Refs #532. Improve gemv paralel with small m and large n case.
Splite the matrix and reduction.
1 parent 6743beb commit 8e5a108

File tree

1 file changed

+83
-1
lines changed

1 file changed

+83
-1
lines changed

driver/level2/gemv_thread.c

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,11 @@
6262
#endif
6363
#endif
6464

65+
#ifndef TRANSA
66+
#define Y_DUMMY_NUM 1024
67+
static FLOAT y_dummy[Y_DUMMY_NUM];
68+
#endif
69+
6570
static int gemv_kernel(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, FLOAT *dummy1, FLOAT *buffer, BLASLONG pos){
6671

6772
FLOAT *a, *x, *y;
@@ -99,10 +104,15 @@ static int gemv_kernel(blas_arg_t *args, BLASLONG *range_m, BLASLONG *range_n, F
99104
a += n_from * lda * COMPSIZE;
100105
#ifdef TRANSA
101106
y += n_from * incy * COMPSIZE;
107+
#else
108+
//for split matrix row (n) direction and vector x of gemv_n
109+
x += n_from * incx * COMPSIZE;
110+
//store partial result for every thread
111+
y += (m_to - m_from) * 1 * COMPSIZE * pos;
102112
#endif
103113
}
104114

105-
// fprintf(stderr, "M_From = %d M_To = %d N_From = %d N_To = %d\n", m_from, m_to, n_from, n_to);
115+
//fprintf(stderr, "M_From = %d M_To = %d N_From = %d N_To = %d POS=%d\n", m_from, m_to, n_from, n_to, pos);
106116

107117
GEMV(m_to - m_from, n_to - n_from, 0,
108118
*((FLOAT *)args -> alpha + 0),
@@ -126,6 +136,10 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *alpha, FLOAT *a, BLASLONG lda, FLOAT *x
126136

127137
BLASLONG width, i, num_cpu;
128138

139+
#ifndef TRANSA
140+
int split_x=0;
141+
#endif
142+
129143
#ifdef SMP
130144
#ifndef COMPLEX
131145
#ifdef XDOUBLE
@@ -198,6 +212,58 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *alpha, FLOAT *a, BLASLONG lda, FLOAT *x
198212
i -= width;
199213
}
200214

215+
#ifndef TRANSA
216+
//try to split matrix on row direction and x.
217+
//Then, reduction.
218+
if (num_cpu < nthreads) {
219+
220+
//too small to split or bigger than the y_dummy buffer.
221+
double MN = (double) m * (double) n;
222+
if ( MN <= (24.0 * 24.0 * (double) (GEMM_MULTITHREAD_THRESHOLD*GEMM_MULTITHREAD_THRESHOLD))
223+
|| m*COMPSIZE*nthreads > Y_DUMMY_NUM)
224+
goto Outer;
225+
226+
num_cpu = 0;
227+
range[0] = 0;
228+
229+
memset(y_dummy, 0, sizeof(FLOAT) * m * COMPSIZE * nthreads);
230+
231+
args.ldc = 1;
232+
args.c = (void *)y_dummy;
233+
234+
//split on row (n) and x
235+
i=n;
236+
split_x=1;
237+
while (i > 0){
238+
239+
width = blas_quickdivide(i + nthreads - num_cpu - 1, nthreads - num_cpu);
240+
if (width < 4) width = 4;
241+
if (i < width) width = i;
242+
243+
range[num_cpu + 1] = range[num_cpu] + width;
244+
245+
queue[num_cpu].mode = mode;
246+
queue[num_cpu].routine = gemv_kernel;
247+
queue[num_cpu].args = &args;
248+
249+
queue[num_cpu].position = num_cpu;
250+
251+
queue[num_cpu].range_m = NULL;
252+
queue[num_cpu].range_n = &range[num_cpu];
253+
254+
queue[num_cpu].sa = NULL;
255+
queue[num_cpu].sb = NULL;
256+
queue[num_cpu].next = &queue[num_cpu + 1];
257+
258+
num_cpu ++;
259+
i -= width;
260+
}
261+
262+
}
263+
264+
Outer:
265+
#endif
266+
201267
if (num_cpu) {
202268
queue[0].sa = NULL;
203269
queue[0].sb = buffer;
@@ -206,5 +272,21 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT *alpha, FLOAT *a, BLASLONG lda, FLOAT *x
206272
exec_blas(num_cpu, queue);
207273
}
208274

275+
#ifndef TRANSA
276+
if(split_x==1){
277+
//reduction
278+
for(i=0; i<num_cpu; i++){
279+
280+
int j;
281+
for(j=0; j<m; j++){
282+
y[j*incy*COMPSIZE] +=y_dummy[i*m*COMPSIZE + j*COMPSIZE];
283+
#ifdef COMPLEX
284+
y[j*incy*COMPSIZE+1] +=y_dummy[i*m*COMPSIZE + j*COMPSIZE+1];
285+
#endif
286+
}
287+
}
288+
}
289+
#endif
290+
209291
return 0;
210292
}

0 commit comments

Comments
 (0)