19
19
#include "nbc_internal.h"
20
20
21
21
static inline int red_sched_binomial (int rank , int p , int root , const void * sendbuf , void * redbuf , int count , MPI_Datatype datatype ,
22
- MPI_Op op , NBC_Schedule * schedule , NBC_Handle * handle );
22
+ MPI_Op op , char inplace , NBC_Schedule * schedule , NBC_Handle * handle );
23
23
static inline int red_sched_chain (int rank , int p , int root , const void * sendbuf , void * recvbuf , int count , MPI_Datatype datatype ,
24
24
MPI_Op op , int ext , size_t size , NBC_Schedule * schedule , NBC_Handle * handle , int fragsize );
25
25
@@ -58,6 +58,7 @@ int ompi_coll_libnbc_ireduce(const void* sendbuf, void* recvbuf, int count, MPI_
58
58
enum { NBC_RED_BINOMIAL , NBC_RED_CHAIN } alg ;
59
59
NBC_Handle * handle ;
60
60
ompi_coll_libnbc_module_t * libnbc_module = (ompi_coll_libnbc_module_t * ) module ;
61
+ ptrdiff_t span , gap ;
61
62
62
63
NBC_IN_PLACE (sendbuf , recvbuf , inplace );
63
64
@@ -92,20 +93,22 @@ int ompi_coll_libnbc_ireduce(const void* sendbuf, void* recvbuf, int count, MPI_
92
93
return res ;
93
94
}
94
95
96
+ span = opal_datatype_span (& datatype -> super , count , & gap );
97
+
95
98
/* algorithm selection */
96
99
if (p > 4 || size * count < 65536 || !ompi_op_is_commute (op )) {
97
100
alg = NBC_RED_BINOMIAL ;
98
101
if (rank == root ) {
99
102
/* root reduces in receivebuffer */
100
- handle -> tmpbuf = malloc (ext * count );
103
+ handle -> tmpbuf = malloc (span );
101
104
redbuf = recvbuf ;
102
105
} else {
103
106
/* recvbuf may not be valid on non-root nodes */
104
- handle -> tmpbuf = malloc (ext * count * 2 );
105
- redbuf = (char * ) handle -> tmpbuf + ext * count ;
107
+ handle -> tmpbuf = malloc (2 * span );
108
+ redbuf = (char * ) handle -> tmpbuf + span - gap ;
106
109
}
107
110
} else {
108
- handle -> tmpbuf = malloc (ext * count );
111
+ handle -> tmpbuf = malloc (span );
109
112
alg = NBC_RED_CHAIN ;
110
113
segsize = 16384 /2 ;
111
114
}
@@ -139,7 +142,7 @@ int ompi_coll_libnbc_ireduce(const void* sendbuf, void* recvbuf, int count, MPI_
139
142
140
143
switch (alg ) {
141
144
case NBC_RED_BINOMIAL :
142
- res = red_sched_binomial (rank , p , root , sendbuf , redbuf , count , datatype , op , schedule , handle );
145
+ res = red_sched_binomial (rank , p , root , sendbuf , redbuf , count , datatype , op , inplace , schedule , handle );
143
146
break ;
144
147
case NBC_RED_CHAIN :
145
148
res = red_sched_chain (rank , p , root , sendbuf , recvbuf , count , datatype , op , ext , size , schedule , handle , segsize );
@@ -292,10 +295,12 @@ int ompi_coll_libnbc_ireduce_inter(const void* sendbuf, void* recvbuf, int count
292
295
if (vrank == root) rank = 0; \
293
296
}
294
297
static inline int red_sched_binomial (int rank , int p , int root , const void * sendbuf , void * redbuf , int count , MPI_Datatype datatype ,
295
- MPI_Op op , NBC_Schedule * schedule , NBC_Handle * handle ) {
298
+ MPI_Op op , char inplace , NBC_Schedule * schedule , NBC_Handle * handle ) {
296
299
int vroot , vrank , vpeer , peer , res , maxr ;
297
300
char * rbuf , * lbuf , * buf ;
298
301
int tmprbuf , tmplbuf ;
302
+ ptrdiff_t gap ;
303
+ (void )opal_datatype_span (& datatype -> super , count , & gap );
299
304
300
305
if (ompi_op_is_commute (op )) {
301
306
vroot = root ;
@@ -307,15 +312,21 @@ static inline int red_sched_binomial (int rank, int p, int root, const void *sen
307
312
308
313
/* ensure the result ends up in redbuf on vrank 0 */
309
314
if (0 == (maxr %2 )) {
310
- rbuf = 0 ;
315
+ rbuf = ( void * )( - gap ) ;
311
316
tmprbuf = true;
312
317
lbuf = redbuf ;
313
318
tmplbuf = false;
314
319
} else {
315
- lbuf = 0 ;
320
+ lbuf = ( void * )( - gap ) ;
316
321
tmplbuf = true;
317
322
rbuf = redbuf ;
318
323
tmprbuf = false;
324
+ if (inplace ) {
325
+ res = NBC_Copy (rbuf , count , datatype , ((char * )handle -> tmpbuf )- gap , count , datatype , MPI_COMM_SELF );
326
+ if (OPAL_UNLIKELY (OMPI_SUCCESS != res )) {
327
+ return res ;
328
+ }
329
+ }
319
330
}
320
331
321
332
for (int r = 1 , firstred = 1 ; r <= maxr ; ++ r ) {
@@ -332,7 +343,7 @@ static inline int red_sched_binomial (int rank, int p, int root, const void *sen
332
343
333
344
/* perform the reduce in my local buffer */
334
345
/* this cannot be done until handle->tmpbuf is unused :-( so barrier after the op */
335
- if (firstred && MPI_IN_PLACE != sendbuf ) {
346
+ if (firstred && ! inplace ) {
336
347
/* perform the reduce with the senbuf */
337
348
res = NBC_Sched_op2 (sendbuf , false, rbuf , tmprbuf , count , datatype , op , schedule , true);
338
349
firstred = 0 ;
@@ -352,7 +363,7 @@ static inline int red_sched_binomial (int rank, int p, int root, const void *sen
352
363
/* we have to send this round */
353
364
vpeer = vrank - (1 << (r - 1 ));
354
365
VRANK2RANK (peer , vpeer , vroot )
355
- if (firstred && MPI_IN_PLACE != sendbuf ) {
366
+ if (firstred && ! inplace ) {
356
367
/* we have to use the sendbuf in the first round .. */
357
368
res = NBC_Sched_send (sendbuf , false, count , datatype , peer , schedule , false);
358
369
} else {
0 commit comments