Skip to content

coll/acoll: Remove use of cid as array index #12783

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions ompi/mca/coll/acoll/coll_acoll.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ BEGIN_C_DECLS
/* Globally exported variables */
OMPI_DECLSPEC extern const mca_coll_base_component_3_0_0_t mca_coll_acoll_component;
extern int mca_coll_acoll_priority;
extern int mca_coll_acoll_max_comms;
extern int mca_coll_acoll_sg_size;
extern int mca_coll_acoll_sg_scale;
extern int mca_coll_acoll_node_size;
Expand Down Expand Up @@ -75,7 +76,6 @@ int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base

END_C_DECLS

#define MCA_COLL_ACOLL_MAX_CID 100
#define MCA_COLL_ACOLL_ROOT_CHANGE_THRESH 10

typedef enum MCA_COLL_ACOLL_SG_SIZES {
Expand Down Expand Up @@ -208,8 +208,10 @@ struct mca_coll_acoll_module_t {
int mnode_log2_sg_size;
int allg_lin;
int allg_ring;
coll_acoll_subcomms_t subc[MCA_COLL_ACOLL_MAX_CID];
int max_comms;
coll_acoll_subcomms_t **subc;
coll_acoll_reserve_mem_t reserve_mem_s;
int num_subc;
};

#ifdef HAVE_XPMEM_H
Expand Down
12 changes: 6 additions & 6 deletions ompi/mca/coll/acoll/coll_acoll_allgather.c
Original file line number Diff line number Diff line change
Expand Up @@ -481,21 +481,21 @@ int mca_coll_acoll_allgather(const void *sbuf, size_t scount, struct ompi_dataty
int brank, last_brank;
int use_rd_base;
mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module;
coll_acoll_subcomms_t *subc;
int cid = ompi_comm_get_local_cid(comm);
coll_acoll_subcomms_t *subc = NULL;
char *local_rbuf;
ompi_communicator_t *intra_comm;

/* Fallback to ring if cid is beyond supported limit */
if (cid >= MCA_COLL_ACOLL_MAX_CID) {
/* Obtain the subcomms structure */
err = check_and_create_subc(comm, acoll_module, &subc);
/* Fallback to ring if subc is not obtained */
if (NULL == subc) {
return ompi_coll_base_allgather_intra_ring(sbuf, scount, sdtype, rbuf, rcount, rdtype, comm,
module);
}

subc = &acoll_module->subc[cid];
size = ompi_comm_size(comm);
if (!subc->initialized && size > 2) {
err = mca_coll_acoll_comm_split_init(comm, acoll_module, 0);
err = mca_coll_acoll_comm_split_init(comm, acoll_module, subc, 0);
if (MPI_SUCCESS != err) {
return err;
}
Expand Down
60 changes: 25 additions & 35 deletions ompi/mca/coll/acoll/coll_acoll_allreduce.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,8 @@ void mca_coll_acoll_sync(coll_acoll_data_t *data, int offset, int *group, int gp
int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t count,
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module, int intra);
mca_coll_base_module_t *module,
coll_acoll_subcomms_t *subc, int intra);


static inline int coll_allreduce_decision_fixed(int comm_size, size_t msg_size)
Expand All @@ -52,16 +53,13 @@ static inline int coll_allreduce_decision_fixed(int comm_size, size_t msg_size)
static inline int mca_coll_acoll_reduce_xpmem_h(const void *sbuf, void *rbuf, size_t count,
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
mca_coll_base_module_t *module,
coll_acoll_subcomms_t *subc)
{
int size;
size_t total_dsize, dsize;
mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module;

coll_acoll_subcomms_t *subc;
int cid = ompi_comm_get_local_cid(comm);
subc = &acoll_module->subc[cid];
coll_acoll_init(module, comm, subc->data);
coll_acoll_init(module, comm, subc->data, subc);
coll_acoll_data_t *data = subc->data;
if (NULL == data) {
return -1;
Expand Down Expand Up @@ -188,16 +186,13 @@ static inline int mca_coll_acoll_allreduce_xpmem_f(const void *sbuf, void *rbuf,
struct ompi_datatype_t *dtype,
struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module)
mca_coll_base_module_t *module,
coll_acoll_subcomms_t *subc)
{
int size;
size_t total_dsize, dsize;
mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module;

coll_acoll_subcomms_t *subc;
int cid = ompi_comm_get_local_cid(comm);
subc = &acoll_module->subc[cid];
coll_acoll_init(module, comm, subc->data);
coll_acoll_init(module, comm, subc->data, subc);
coll_acoll_data_t *data = subc->data;
if (NULL == data) {
return -1;
Expand Down Expand Up @@ -361,15 +356,13 @@ void mca_coll_acoll_sync(coll_acoll_data_t *data, int offset, int *group, int gp
int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t count,
struct ompi_datatype_t *dtype, struct ompi_op_t *op,
struct ompi_communicator_t *comm,
mca_coll_base_module_t *module, int intra)
mca_coll_base_module_t *module,
coll_acoll_subcomms_t *subc, int intra)
{
size_t dsize;
int err = MPI_SUCCESS;
mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module;
coll_acoll_subcomms_t *subc;
int cid = ompi_comm_get_local_cid(comm);
subc = &acoll_module->subc[cid];
coll_acoll_init(module, comm, subc->data);

coll_acoll_init(module, comm, subc->data, subc);
coll_acoll_data_t *data = subc->data;
if (NULL == data) {
return -1;
Expand All @@ -385,7 +378,6 @@ int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t c

int l1_local_rank = data->l1_local_rank;
int l2_local_rank = data->l2_local_rank;
int comm_id = ompi_comm_get_local_cid(comm);

int offset1 = data->offset[0];
int offset2 = data->offset[1];
Expand Down Expand Up @@ -441,8 +433,8 @@ int mca_coll_acoll_allreduce_small_msgs_h(const void *sbuf, void *rbuf, size_t c
}
}

if (intra && (ompi_comm_size(acoll_module->subc[comm_id].numa_comm) > 1)) {
err = mca_coll_acoll_bcast(rbuf, count, dtype, 0, acoll_module->subc[comm_id].numa_comm, module);
if (intra && (ompi_comm_size(subc->numa_comm) > 1)) {
err = mca_coll_acoll_bcast(rbuf, count, dtype, 0, subc->numa_comm, module);
}
return err;
}
Expand All @@ -466,25 +458,23 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
return MPI_SUCCESS;
}

coll_acoll_subcomms_t *subc;
int cid = ompi_comm_get_local_cid(comm);
subc = &acoll_module->subc[cid];

/* Falling back to recursivedoubling for non-commutative operators to be safe */
if (!ompi_op_is_commute(op)) {
return ompi_coll_base_allreduce_intra_recursivedoubling(sbuf, rbuf, count, dtype, op, comm,
module);
}

/* Fallback to knomial if cid is beyond supported limit */
if (cid >= MCA_COLL_ACOLL_MAX_CID) {
/* Obtain the subcomms structure */
coll_acoll_subcomms_t *subc = NULL;
err = check_and_create_subc(comm, acoll_module, &subc);

/* Fallback to knomial if subc is not obtained */
if (NULL == subc) {
return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype, op, comm,
module);
}

subc = &acoll_module->subc[cid];
if (!subc->initialized) {
err = mca_coll_acoll_comm_split_init(comm, acoll_module, 0);
err = mca_coll_acoll_comm_split_init(comm, acoll_module, subc, 0);
if (MPI_SUCCESS != err)
return err;
}
Expand All @@ -499,7 +489,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
comm, module);
} else if (total_dsize < 512) {
return mca_coll_acoll_allreduce_small_msgs_h(sbuf, rbuf, count, dtype, op, comm, module,
1);
subc, 1);
} else if (total_dsize <= 2048) {
return ompi_coll_base_allreduce_intra_recursivedoubling(sbuf, rbuf, count, dtype, op,
comm, module);
Expand All @@ -517,7 +507,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
} else if (total_dsize < 4194304) {
#ifdef HAVE_XPMEM_H
if (((subc->xpmem_use_sr_buf != 0) || (subc->xpmem_buf_size > 2 * total_dsize)) && (subc->without_xpmem != 1)) {
return mca_coll_acoll_allreduce_xpmem_f(sbuf, rbuf, count, dtype, op, comm, module);
return mca_coll_acoll_allreduce_xpmem_f(sbuf, rbuf, count, dtype, op, comm, module, subc);
} else {
return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype,
op, comm, module);
Expand All @@ -529,7 +519,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
} else if (total_dsize <= 16777216) {
#ifdef HAVE_XPMEM_H
if (((subc->xpmem_use_sr_buf != 0) || (subc->xpmem_buf_size > 2 * total_dsize)) && (subc->without_xpmem != 1)) {
mca_coll_acoll_reduce_xpmem_h(sbuf, rbuf, count, dtype, op, comm, module);
mca_coll_acoll_reduce_xpmem_h(sbuf, rbuf, count, dtype, op, comm, module, subc);
return mca_coll_acoll_bcast(rbuf, count, dtype, 0, comm, module);
} else {
return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype,
Expand All @@ -542,7 +532,7 @@ int mca_coll_acoll_allreduce_intra(const void *sbuf, void *rbuf, size_t count,
} else {
#ifdef HAVE_XPMEM_H
if (((subc->xpmem_use_sr_buf != 0) || (subc->xpmem_buf_size > 2 * total_dsize)) && (subc->without_xpmem != 1)) {
return mca_coll_acoll_allreduce_xpmem_f(sbuf, rbuf, count, dtype, op, comm, module);
return mca_coll_acoll_allreduce_xpmem_f(sbuf, rbuf, count, dtype, op, comm, module, subc);
} else {
return ompi_coll_base_allreduce_intra_redscat_allgather(sbuf, rbuf, count, dtype,
op, comm, module);
Expand Down
13 changes: 7 additions & 6 deletions ompi/mca/coll/acoll/coll_acoll_barrier.c
Original file line number Diff line number Diff line change
Expand Up @@ -130,21 +130,22 @@ int mca_coll_acoll_barrier_intra(struct ompi_communicator_t *comm, mca_coll_base
ompi_request_t **reqs;
int num_nodes;
mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module;
coll_acoll_subcomms_t *subc;
int cid = ompi_comm_get_local_cid(comm);
coll_acoll_subcomms_t *subc = NULL;

/* Fallback to linear if cid is beyond supported limit */
if (cid >= MCA_COLL_ACOLL_MAX_CID) {
/* Obtain the subcomms structure */
err = check_and_create_subc(comm, acoll_module, &subc);

/* Fallback to linear if subcomms structure is not obtained */
if (NULL == subc) {
return ompi_coll_base_barrier_intra_basic_linear(comm, module);
}

subc = &acoll_module->subc[cid];
size = ompi_comm_size(comm);
if (size == 1) {
return err;
}
if (!subc->initialized && size > 1) {
err = mca_coll_acoll_comm_split_init(comm, acoll_module, 0);
err = mca_coll_acoll_comm_split_init(comm, acoll_module, subc, 0);
if (MPI_SUCCESS != err) {
return err;
}
Expand Down
15 changes: 8 additions & 7 deletions ompi/mca/coll/acoll/coll_acoll_bcast.c
Original file line number Diff line number Diff line change
Expand Up @@ -444,24 +444,25 @@ int mca_coll_acoll_bcast(void *buff, size_t count, struct ompi_datatype_t *datat
size_t total_dsize, dsize;
mca_coll_acoll_module_t *acoll_module = (mca_coll_acoll_module_t *) module;
bcast_subc_func bcast_func[2] = {&bcast_binomial, &bcast_flat_tree};
coll_acoll_subcomms_t *subc;
coll_acoll_subcomms_t *subc = NULL;
struct ompi_communicator_t *subcomms[MCA_COLL_ACOLL_NUM_SC] = {NULL};
int subc_roots[MCA_COLL_ACOLL_NUM_SC] = {-1};
int cid = ompi_comm_get_local_cid(comm);

/* Fallback to knomial if cid is beyond supported limit */
if (cid >= MCA_COLL_ACOLL_MAX_CID) {
/* Obtain the subcomms structure */
err = check_and_create_subc(comm, acoll_module, &subc);
/* Fallback to knomial if subcomms is not obtained */
if (NULL == subc) {
return ompi_coll_base_bcast_intra_knomial(buff, count, datatype, root, comm, module, 0, 4);
}

subc = &acoll_module->subc[cid];
/* Fallback to knomial if no. of root changes is beyond a threshold */
if (subc->num_root_change > MCA_COLL_ACOLL_ROOT_CHANGE_THRESH) {
if ((subc->num_root_change > MCA_COLL_ACOLL_ROOT_CHANGE_THRESH)
&& (root != subc->prev_init_root)) {
return ompi_coll_base_bcast_intra_knomial(buff, count, datatype, root, comm, module, 0, 4);
}
size = ompi_comm_size(comm);
if ((!subc->initialized || (root != subc->prev_init_root)) && size > 2) {
err = mca_coll_acoll_comm_split_init(comm, acoll_module, root);
err = mca_coll_acoll_comm_split_init(comm, acoll_module, subc, root);
if (MPI_SUCCESS != err) {
return err;
}
Expand Down
60 changes: 17 additions & 43 deletions ompi/mca/coll/acoll/coll_acoll_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ const char *mca_coll_acoll_component_version_string
* Global variables
*/
int mca_coll_acoll_priority = 0;
int mca_coll_acoll_max_comms = 10;
int mca_coll_acoll_sg_size = 8;
int mca_coll_acoll_sg_scale = 1;
int mca_coll_acoll_node_size = 128;
Expand Down Expand Up @@ -91,6 +92,11 @@ static int acoll_register(void)
MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_priority);

/* Defaults on topology */
(void)
mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "max_comms",
"Maximum no. of communicators using subgroup based algorithms",
MCA_BASE_VAR_TYPE_INT, NULL, 0, 0, OPAL_INFO_LVL_9,
MCA_BASE_VAR_SCOPE_READONLY, &mca_coll_acoll_max_comms);
(void)
mca_base_component_var_register(&mca_coll_acoll_component.collm_version, "sg_size",
"Size of subgroup to be used for subgroup based algorithms",
Expand Down Expand Up @@ -186,47 +192,10 @@ static int acoll_register(void)
*/
static void mca_coll_acoll_module_construct(mca_coll_acoll_module_t *module)
{
for (int i = 0; i < MCA_COLL_ACOLL_MAX_CID; i++) {
coll_acoll_subcomms_t *subc = &module->subc[i];
subc->initialized = 0;
subc->is_root_node = 0;
subc->is_root_sg = 0;
subc->is_root_numa = 0;
subc->outer_grp_root = -1;
subc->subgrp_root = 0;
subc->num_nodes = 1;
subc->prev_init_root = -1;
subc->num_root_change = 0;
subc->numa_root = 0;
subc->socket_ldr_root = -1;
subc->local_comm = NULL;
subc->local_r_comm = NULL;
subc->leader_comm = NULL;
subc->subgrp_comm = NULL;
subc->socket_comm = NULL;
subc->socket_ldr_comm = NULL;
for (int j = 0; j < MCA_COLL_ACOLL_NUM_LAYERS; j++) {
for (int k = 0; k < MCA_COLL_ACOLL_NUM_BASE_LYRS; k++) {
subc->base_comm[k][j] = NULL;
subc->base_root[k][j] = -1;
}
subc->local_root[j] = 0;
}

subc->numa_comm = NULL;
subc->numa_comm_ldrs = NULL;
subc->node_comm = NULL;
subc->inter_comm = NULL;
subc->cid = -1;
subc->initialized_data = false;
subc->initialized_shm_data = false;
subc->data = NULL;
#ifdef HAVE_XPMEM_H
subc->xpmem_buf_size = mca_coll_acoll_xpmem_buffer_size;
subc->without_xpmem = mca_coll_acoll_without_xpmem;
subc->xpmem_use_sr_buf = mca_coll_acoll_xpmem_use_sr_buf;
#endif
}
/* Set number of subcomms to 0 */
module->num_subc = 0;
module->subc = NULL;

/* Reserve memory init. Lazy allocation of memory when needed. */
(module->reserve_mem_s).reserve_mem = NULL;
Expand All @@ -246,9 +215,8 @@ static void mca_coll_acoll_module_construct(mca_coll_acoll_module_t *module)
*/
static void mca_coll_acoll_module_destruct(mca_coll_acoll_module_t *module)
{

for (int i = 0; i < MCA_COLL_ACOLL_MAX_CID; i++) {
coll_acoll_subcomms_t *subc = &module->subc[i];
for (int i = 0; i < module->num_subc; i++) {
coll_acoll_subcomms_t *subc = module->subc[i];
if (subc->initialized_data) {
if (subc->initialized_shm_data) {
if (subc->orig_comm != NULL) {
Expand Down Expand Up @@ -334,8 +302,14 @@ static void mca_coll_acoll_module_destruct(mca_coll_acoll_module_t *module)
}
}
subc->initialized = 0;
free(subc);
module->subc[i] = NULL;
}

module->num_subc = 0;
free(module->subc);
module->subc = NULL;

if ((true == (module->reserve_mem_s).reserve_mem_allocate)
&& (NULL != (module->reserve_mem_s).reserve_mem)) {
free((module->reserve_mem_s).reserve_mem);
Expand Down
1 change: 1 addition & 0 deletions ompi/mca/coll/acoll/coll_acoll_module.c
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ mca_coll_base_module_t *mca_coll_acoll_comm_query(struct ompi_communicator_t *co
*priority = mca_coll_acoll_priority;

/* Set topology params */
acoll_module->max_comms = mca_coll_acoll_max_comms;
acoll_module->sg_scale = mca_coll_acoll_sg_scale;
acoll_module->sg_size = mca_coll_acoll_sg_size;
acoll_module->sg_cnt = mca_coll_acoll_sg_size / mca_coll_acoll_sg_scale;
Expand Down
Loading
Loading