@@ -28,7 +28,8 @@ void mca_coll_acoll_sync(coll_acoll_data_t *data, int offset, int *group, int gp
28
28
int mca_coll_acoll_allreduce_small_msgs_h (const void * sbuf , void * rbuf , size_t count ,
29
29
struct ompi_datatype_t * dtype , struct ompi_op_t * op ,
30
30
struct ompi_communicator_t * comm ,
31
- mca_coll_base_module_t * module , int intra );
31
+ mca_coll_base_module_t * module ,
32
+ coll_acoll_subcomms_t * subc , int intra );
32
33
33
34
34
35
static inline int coll_allreduce_decision_fixed (int comm_size , size_t msg_size )
@@ -52,16 +53,13 @@ static inline int coll_allreduce_decision_fixed(int comm_size, size_t msg_size)
52
53
static inline int mca_coll_acoll_reduce_xpmem_h (const void * sbuf , void * rbuf , size_t count ,
53
54
struct ompi_datatype_t * dtype , struct ompi_op_t * op ,
54
55
struct ompi_communicator_t * comm ,
55
- mca_coll_base_module_t * module )
56
+ mca_coll_base_module_t * module ,
57
+ coll_acoll_subcomms_t * subc )
56
58
{
57
59
int size ;
58
60
size_t total_dsize , dsize ;
59
- mca_coll_acoll_module_t * acoll_module = (mca_coll_acoll_module_t * ) module ;
60
61
61
- coll_acoll_subcomms_t * subc ;
62
- int cid = ompi_comm_get_local_cid (comm );
63
- subc = & acoll_module -> subc [cid ];
64
- coll_acoll_init (module , comm , subc -> data );
62
+ coll_acoll_init (module , comm , subc -> data , subc );
65
63
coll_acoll_data_t * data = subc -> data ;
66
64
if (NULL == data ) {
67
65
return -1 ;
@@ -188,16 +186,13 @@ static inline int mca_coll_acoll_allreduce_xpmem_f(const void *sbuf, void *rbuf,
188
186
struct ompi_datatype_t * dtype ,
189
187
struct ompi_op_t * op ,
190
188
struct ompi_communicator_t * comm ,
191
- mca_coll_base_module_t * module )
189
+ mca_coll_base_module_t * module ,
190
+ coll_acoll_subcomms_t * subc )
192
191
{
193
192
int size ;
194
193
size_t total_dsize , dsize ;
195
- mca_coll_acoll_module_t * acoll_module = (mca_coll_acoll_module_t * ) module ;
196
194
197
- coll_acoll_subcomms_t * subc ;
198
- int cid = ompi_comm_get_local_cid (comm );
199
- subc = & acoll_module -> subc [cid ];
200
- coll_acoll_init (module , comm , subc -> data );
195
+ coll_acoll_init (module , comm , subc -> data , subc );
201
196
coll_acoll_data_t * data = subc -> data ;
202
197
if (NULL == data ) {
203
198
return -1 ;
@@ -361,15 +356,13 @@ void mca_coll_acoll_sync(coll_acoll_data_t *data, int offset, int *group, int gp
361
356
int mca_coll_acoll_allreduce_small_msgs_h (const void * sbuf , void * rbuf , size_t count ,
362
357
struct ompi_datatype_t * dtype , struct ompi_op_t * op ,
363
358
struct ompi_communicator_t * comm ,
364
- mca_coll_base_module_t * module , int intra )
359
+ mca_coll_base_module_t * module ,
360
+ coll_acoll_subcomms_t * subc , int intra )
365
361
{
366
362
size_t dsize ;
367
363
int err = MPI_SUCCESS ;
368
- mca_coll_acoll_module_t * acoll_module = (mca_coll_acoll_module_t * ) module ;
369
- coll_acoll_subcomms_t * subc ;
370
- int cid = ompi_comm_get_local_cid (comm );
371
- subc = & acoll_module -> subc [cid ];
372
- coll_acoll_init (module , comm , subc -> data );
364
+
365
+ coll_acoll_init (module , comm , subc -> data , subc );
373
366
coll_acoll_data_t * data = subc -> data ;
374
367
if (NULL == data ) {
375
368
return -1 ;
@@ -385,7 +378,6 @@ int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t c
385
378
386
379
int l1_local_rank = data -> l1_local_rank ;
387
380
int l2_local_rank = data -> l2_local_rank ;
388
- int comm_id = ompi_comm_get_local_cid (comm );
389
381
390
382
int offset1 = data -> offset [0 ];
391
383
int offset2 = data -> offset [1 ];
@@ -441,8 +433,8 @@ int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t c
441
433
}
442
434
}
443
435
444
- if (intra && (ompi_comm_size (acoll_module -> subc [ comm_id ]. numa_comm ) > 1 )) {
445
- err = mca_coll_acoll_bcast (rbuf , count , dtype , 0 , acoll_module -> subc [ comm_id ]. numa_comm , module );
436
+ if (intra && (ompi_comm_size (subc -> numa_comm ) > 1 )) {
437
+ err = mca_coll_acoll_bcast (rbuf , count , dtype , 0 , subc -> numa_comm , module );
446
438
}
447
439
return err ;
448
440
}
@@ -466,25 +458,23 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
466
458
return MPI_SUCCESS ;
467
459
}
468
460
469
- coll_acoll_subcomms_t * subc ;
470
- int cid = ompi_comm_get_local_cid (comm );
471
- subc = & acoll_module -> subc [cid ];
472
-
473
461
/* Falling back to recursivedoubling for non-commutative operators to be safe */
474
462
if (!ompi_op_is_commute (op )) {
475
463
return ompi_coll_base_allreduce_intra_recursivedoubling (sbuf , rbuf , count , dtype , op , comm ,
476
464
module );
477
465
}
478
466
479
- /* Fallback to knomial if cid is beyond supported limit */
480
- if (cid >= MCA_COLL_ACOLL_MAX_CID ) {
467
+ /* Obtain the subcomms structure */
468
+ coll_acoll_subcomms_t * subc = NULL ;
469
+ err = check_and_create_subc (comm , acoll_module , & subc );
470
+
471
+ /* Fallback to knomial if subc is not obtained */
472
+ if (subc == NULL ) {
481
473
return ompi_coll_base_allreduce_intra_redscat_allgather (sbuf , rbuf , count , dtype , op , comm ,
482
474
module );
483
475
}
484
-
485
- subc = & acoll_module -> subc [cid ];
486
476
if (!subc -> initialized ) {
487
- err = mca_coll_acoll_comm_split_init (comm , acoll_module , 0 );
477
+ err = mca_coll_acoll_comm_split_init (comm , acoll_module , subc , 0 );
488
478
if (MPI_SUCCESS != err )
489
479
return err ;
490
480
}
@@ -499,7 +489,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
499
489
comm , module );
500
490
} else if (total_dsize < 512 ) {
501
491
return mca_coll_acoll_allreduce_small_msgs_h (sbuf , rbuf , count , dtype , op , comm , module ,
502
- 1 );
492
+ subc , 1 );
503
493
} else if (total_dsize <= 2048 ) {
504
494
return ompi_coll_base_allreduce_intra_recursivedoubling (sbuf , rbuf , count , dtype , op ,
505
495
comm , module );
@@ -517,7 +507,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
517
507
} else if (total_dsize < 4194304 ) {
518
508
#ifdef HAVE_XPMEM_H
519
509
if (((subc -> xpmem_use_sr_buf != 0 ) || (subc -> xpmem_buf_size > 2 * total_dsize )) && (subc -> without_xpmem != 1 )) {
520
- return mca_coll_acoll_allreduce_xpmem_f (sbuf , rbuf , count , dtype , op , comm , module );
510
+ return mca_coll_acoll_allreduce_xpmem_f (sbuf , rbuf , count , dtype , op , comm , module , subc );
521
511
} else {
522
512
return ompi_coll_base_allreduce_intra_redscat_allgather (sbuf , rbuf , count , dtype ,
523
513
op , comm , module );
@@ -529,7 +519,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
529
519
} else if (total_dsize <= 16777216 ) {
530
520
#ifdef HAVE_XPMEM_H
531
521
if (((subc -> xpmem_use_sr_buf != 0 ) || (subc -> xpmem_buf_size > 2 * total_dsize )) && (subc -> without_xpmem != 1 )) {
532
- mca_coll_acoll_reduce_xpmem_h (sbuf , rbuf , count , dtype , op , comm , module );
522
+ mca_coll_acoll_reduce_xpmem_h (sbuf , rbuf , count , dtype , op , comm , module , subc );
533
523
return mca_coll_acoll_bcast (rbuf , count , dtype , 0 , comm , module );
534
524
} else {
535
525
return ompi_coll_base_allreduce_intra_redscat_allgather (sbuf , rbuf , count , dtype ,
@@ -542,7 +532,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
542
532
} else {
543
533
#ifdef HAVE_XPMEM_H
544
534
if (((subc -> xpmem_use_sr_buf != 0 ) || (subc -> xpmem_buf_size > 2 * total_dsize )) && (subc -> without_xpmem != 1 )) {
545
- return mca_coll_acoll_allreduce_xpmem_f (sbuf , rbuf , count , dtype , op , comm , module );
535
+ return mca_coll_acoll_allreduce_xpmem_f (sbuf , rbuf , count , dtype , op , comm , module , subc );
546
536
} else {
547
537
return ompi_coll_base_allreduce_intra_redscat_allgather (sbuf , rbuf , count , dtype ,
548
538
op , comm , module );
0 commit comments