Skip to content

coll/han: call fallback functin when HAN module is disabled #11454

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ompi/mca/coll/cuda/coll_cuda_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ mca_coll_cuda_component_t mca_coll_cuda_component = {
/* cuda-specific component information */

/* Priority: make it above all point to point collectives including self */
78,
.priority = 78,
};


Expand Down
4 changes: 2 additions & 2 deletions ompi/mca/coll/han/coll_han.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,9 +377,9 @@ OBJ_CLASS_DECLARATION(mca_coll_han_module_t);
#define HAN_LOAD_FALLBACK_COLLECTIVE(HANM, COMM, COLL) \
do { \
if ( ((COMM)->c_coll->coll_ ## COLL ## _module) == (mca_coll_base_module_t*)(HANM) ) { \
(COMM)->c_coll->coll_ ## COLL = (HANM)->fallback.COLL.module_fn.COLL; \
(COMM)->c_coll->coll_ ## COLL = (HANM)->previous_## COLL; \
mca_coll_base_module_t *coll_module = (COMM)->c_coll->coll_ ## COLL ## _module; \
(COMM)->c_coll->coll_ ## COLL ## _module = (HANM)->fallback.COLL.module; \
(COMM)->c_coll->coll_ ## COLL ## _module = (HANM)->previous_ ## COLL ## _module; \
OBJ_RETAIN((COMM)->c_coll->coll_ ## COLL ## _module); \
OBJ_RELEASE(coll_module); \
} \
Expand Down
18 changes: 9 additions & 9 deletions ompi/mca/coll/han/coll_han_allgather.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020 The University of Tennessee and The University
* Copyright (c) 2018-2023 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
* Copyright (c) 2020 Bull S.A.S. All rights reserved.
Expand Down Expand Up @@ -83,8 +83,8 @@ mca_coll_han_allgather_intra(const void *sbuf, int scount,
"han cannot handle allgather within this communicator. Fall back on another component\n"));
/* HAN cannot work with this communicator so fallback on all collectives */
HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm);
return comm->c_coll->coll_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, comm->c_coll->coll_allgather_module);
return han_module->previous_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, han_module->previous_allgather_module);
}
ompi_communicator_t *low_comm = han_module->sub_comm[INTRA_NODE];
ompi_communicator_t *up_comm = han_module->sub_comm[INTER_NODE];
Expand All @@ -98,8 +98,8 @@ mca_coll_han_allgather_intra(const void *sbuf, int scount,
OPAL_OUTPUT_VERBOSE((30, mca_coll_han_component.han_output,
"han cannot handle allgather with this communicator (imbalance). Fall back on another component\n"));
HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, allgather);
return comm->c_coll->coll_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, comm->c_coll->coll_allgather_module);
return han_module->previous_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, han_module->previous_allgather_module);
}

ompi_request_t *temp_request;
Expand Down Expand Up @@ -307,8 +307,8 @@ mca_coll_han_allgather_intra_simple(const void *sbuf, int scount,
"han cannot handle allgather within this communicator. Fall back on another component\n"));
/* HAN cannot work with this communicator so fallback on all collectives */
HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm);
return comm->c_coll->coll_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, comm->c_coll->coll_allgather_module);
return han_module->previous_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, han_module->previous_allgather_module);
}
/* discovery topology */
int *topo = mca_coll_han_topo_init(comm, han_module, 2);
Expand All @@ -321,8 +321,8 @@ mca_coll_han_allgather_intra_simple(const void *sbuf, int scount,
* future calls will then be automatically redirected.
*/
HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, allgather);
return comm->c_coll->coll_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, comm->c_coll->coll_allgather_module);
return han_module->previous_allgather(sbuf, scount, sdtype, rbuf, rcount, rdtype,
comm, han_module->previous_allgather_module);
}

ompi_communicator_t *low_comm = han_module->sub_comm[INTRA_NODE];
Expand Down
10 changes: 5 additions & 5 deletions ompi/mca/coll/han/coll_han_allreduce.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020 The University of Tennessee and The University
* Copyright (c) 2018-2023 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
* Copyright (c) 2020 Bull S.A.S. All rights reserved.
Expand Down Expand Up @@ -110,8 +110,8 @@ mca_coll_han_allreduce_intra(const void *sbuf,
"han cannot handle allreduce with this communicator. Drop HAN support in this communicator and fall back on another component\n"));
/* HAN cannot work with this communicator so fallback on all collectives */
HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm);
return comm->c_coll->coll_allreduce(sbuf, rbuf, count, dtype, op,
comm, comm->c_coll->coll_allreduce_module);
return han_module->previous_allreduce(sbuf, rbuf, count, dtype, op,
comm, han_module->previous_allreduce_module);
}

ptrdiff_t extent, lb;
Expand Down Expand Up @@ -450,8 +450,8 @@ mca_coll_han_allreduce_intra_simple(const void *sbuf,
"han cannot handle allreduce with this communicator. Drop HAN support in this communicator and fall back on another component\n"));
/* HAN cannot work with this communicator so fallback on all collectives */
HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm);
return comm->c_coll->coll_allreduce(sbuf, rbuf, count, dtype, op,
comm, comm->c_coll->coll_allreduce_module);
return han_module->previous_allreduce(sbuf, rbuf, count, dtype, op,
comm, han_module->previous_allreduce_module);
}

low_comm = han_module->sub_comm[INTRA_NODE];
Expand Down
4 changes: 2 additions & 2 deletions ompi/mca/coll/han/coll_han_barrier.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020 The University of Tennessee and The University
* Copyright (c) 2018-2023 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
* Copyright (c) 2020 Bull S.A.S. All rights reserved.
Expand Down Expand Up @@ -40,7 +40,7 @@ mca_coll_han_barrier_intra_simple(struct ompi_communicator_t *comm,
* future calls will then be automatically redirected.
*/
HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm);
return comm->c_coll->coll_barrier(comm, comm->c_coll->coll_bcast_module);
return han_module->previous_barrier(comm, han_module->previous_barrier_module);
}

low_comm = han_module->sub_comm[INTRA_NODE];
Expand Down
30 changes: 15 additions & 15 deletions ompi/mca/coll/han/coll_han_bcast.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020 The University of Tennessee and The University
* Copyright (c) 2018-2023 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
* Copyright (c) 2020 Bull S.A.S. All rights reserved.
Expand Down Expand Up @@ -63,7 +63,7 @@ mca_coll_han_set_bcast_args(mca_coll_han_bcast_args_t * args, mca_coll_task_t *
* iter 4 | | | | lb | task: t1, contains lb
*/
int
mca_coll_han_bcast_intra(void *buff,
mca_coll_han_bcast_intra(void *buf,
int count,
struct ompi_datatype_t *dtype,
int root,
Expand All @@ -84,8 +84,8 @@ mca_coll_han_bcast_intra(void *buff,
* future calls will then be automatically redirected.
*/
HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm);
return comm->c_coll->coll_bcast(buff, count, dtype, root,
comm, comm->c_coll->coll_bcast_module);
return han_module->previous_bcast(buf, count, dtype, root,
comm, han_module->previous_bcast_module);
}
/* Topo must be initialized to know rank distribution which then is used to
* determine if han can be used */
Expand All @@ -97,8 +97,8 @@ mca_coll_han_bcast_intra(void *buff,
* future calls will then be automatically redirected.
*/
HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, bcast);
return comm->c_coll->coll_bcast(buff, count, dtype, root,
comm, comm->c_coll->coll_bcast_module);
return han_module->previous_bcast(buf, count, dtype, root,
comm, han_module->previous_bcast_module);
}

ompi_datatype_get_extent(dtype, &lb, &extent);
Expand Down Expand Up @@ -129,7 +129,7 @@ mca_coll_han_bcast_intra(void *buff,
mca_coll_task_t *t0 = OBJ_NEW(mca_coll_task_t);
/* Setup up t0 task arguments */
mca_coll_han_bcast_args_t *t = malloc(sizeof(mca_coll_han_bcast_args_t));
mca_coll_han_set_bcast_args(t, t0, (char *) buff, seg_count, dtype,
mca_coll_han_set_bcast_args(t, t0, (char *)buf, seg_count, dtype,
root_up_rank, root_low_rank, up_comm, low_comm,
num_segments, 0, w_rank, count - (num_segments - 1) * seg_count,
low_rank != root_low_rank);
Expand Down Expand Up @@ -222,7 +222,7 @@ int mca_coll_han_bcast_t1_task(void *task_args)
* communications without tasks.
*/
int
mca_coll_han_bcast_intra_simple(void *buff,
mca_coll_han_bcast_intra_simple(void *buf,
int count,
struct ompi_datatype_t *dtype,
int root,
Expand All @@ -246,8 +246,8 @@ mca_coll_han_bcast_intra_simple(void *buff,
* future calls will then be automatically redirected.
*/
HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm);
return comm->c_coll->coll_bcast(buff, count, dtype, root,
comm, comm->c_coll->coll_bcast_module);
return han_module->previous_bcast(buf, count, dtype, root,
comm, han_module->previous_bcast_module);
}
/* Topo must be initialized to know rank distribution which then is used to
* determine if han can be used */
Expand All @@ -259,8 +259,8 @@ mca_coll_han_bcast_intra_simple(void *buff,
* future calls will then be automatically redirected.
*/
HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, bcast);
return comm->c_coll->coll_bcast(buff, count, dtype, root,
comm, comm->c_coll->coll_bcast_module);
return han_module->previous_bcast(buf, count, dtype, root,
comm, han_module->previous_bcast_module);
}

low_comm = han_module->sub_comm[INTRA_NODE];
Expand All @@ -277,18 +277,18 @@ mca_coll_han_bcast_intra_simple(void *buff,
w_rank, root_low_rank, root_up_rank));

if (low_rank == root_low_rank) {
up_comm->c_coll->coll_bcast(buff, count, dtype, root_up_rank,
up_comm->c_coll->coll_bcast(buf, count, dtype, root_up_rank,
up_comm, up_comm->c_coll->coll_bcast_module);

/* To remove when han has better sub-module selection.
For now switching to ibcast enables to make runs with libnbc. */
//ompi_request_t req;
//up_comm->c_coll->coll_ibcast(buff, count, dtype, root_up_rank,
//up_comm->c_coll->coll_ibcast(buf, count, dtype, root_up_rank,
// up_comm, &req, up_comm->c_coll->coll_ibcast_module);
//ompi_request_wait(&req, MPI_STATUS_IGNORE);

}
low_comm->c_coll->coll_bcast(buff, count, dtype, root_low_rank,
low_comm->c_coll->coll_bcast(buf, count, dtype, root_low_rank,
low_comm, low_comm->c_coll->coll_bcast_module);

return OMPI_SUCCESS;
Expand Down
3 changes: 1 addition & 2 deletions ompi/mca/coll/han/coll_han_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ mca_coll_han_component_t mca_coll_han_component = {
/* han-component specific information */

/* (default) priority */
.han_priority = 20,
.han_priority = 35,
/* workaround for nvcc compiler */
.dynamic_rules_filename = NULL,
};
Expand Down Expand Up @@ -251,7 +251,6 @@ static int han_register(void)
TOPO_LVL_T topo_lvl;
COMPONENT_T component;

cs->han_priority = 35;
(void) mca_base_component_var_register(c, "priority", "Priority of the HAN coll component",
MCA_BASE_VAR_TYPE_INT, NULL, 0, 0,
OPAL_INFO_LVL_9,
Expand Down
28 changes: 28 additions & 0 deletions ompi/mca/coll/han/coll_han_dynamic.c
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,11 @@ mca_coll_han_allreduce_intra_dynamic(const void *sbuf,
size_t dtype_size;
int rank, verbosity = 0;

if (!han_module->enabled) {
return han_module->previous_allreduce(sbuf, rbuf, count, dtype, op, comm,
han_module->previous_allreduce_module);
}

/* Compute configuration information for dynamic rules */
ompi_datatype_type_size(dtype, &dtype_size);
dtype_size = dtype_size * count;
Expand Down Expand Up @@ -722,6 +727,9 @@ mca_coll_han_barrier_intra_dynamic(struct ompi_communicator_t *comm,
mca_coll_base_module_t *sub_module;
int rank, verbosity = 0;

if (!han_module->enabled) {
return han_module->previous_barrier(comm, han_module->previous_barrier_module);
}

/* Compute configuration information for dynamic rules */
sub_module = get_module(BARRIER,
Expand Down Expand Up @@ -821,6 +829,11 @@ mca_coll_han_bcast_intra_dynamic(void *buff,
size_t dtype_size;
int rank, verbosity = 0;

if (!han_module->enabled) {
return han_module->previous_bcast(buff, count, dtype, root, comm,
han_module->previous_bcast_module);
}

/* Compute configuration information for dynamic rules */
ompi_datatype_type_size(dtype, &dtype_size);
dtype_size = dtype_size * count;
Expand Down Expand Up @@ -932,6 +945,11 @@ mca_coll_han_gather_intra_dynamic(const void *sbuf, int scount,
size_t dtype_size;
int rank, verbosity = 0;

if (!han_module->enabled) {
return han_module->previous_gather(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm,
han_module->previous_gather_module);
}

/* Compute configuration information for dynamic rules */
if( MPI_IN_PLACE != sbuf ) {
ompi_datatype_type_size(sdtype, &dtype_size);
Expand Down Expand Up @@ -1051,6 +1069,11 @@ mca_coll_han_reduce_intra_dynamic(const void *sbuf,
size_t dtype_size;
int rank, verbosity = 0;

if (!han_module->enabled) {
return han_module->previous_reduce(sbuf, rbuf, count, dtype, op, root, comm,
han_module->previous_reduce_module);
}

/* Compute configuration information for dynamic rules */
ompi_datatype_type_size(dtype, &dtype_size);
dtype_size = dtype_size * count;
Expand Down Expand Up @@ -1167,6 +1190,11 @@ mca_coll_han_scatter_intra_dynamic(const void *sbuf, int scount,
size_t dtype_size;
int rank, verbosity = 0;

if (!han_module->enabled) {
return han_module->previous_scatter(sbuf, scount, sdtype, rbuf, rcount, rdtype, root, comm,
han_module->previous_scatter_module);
}

/* Compute configuration information for dynamic rules */
if( MPI_IN_PLACE != rbuf ) {
ompi_datatype_type_size(rdtype, &dtype_size);
Expand Down
22 changes: 9 additions & 13 deletions ompi/mca/coll/han/coll_han_gather.c
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2018-2020 The University of Tennessee and The University
* Copyright (c) 2018-2023 The University of Tennessee and The University
* of Tennessee Research Foundation. All rights
* reserved.
* Copyright (c) 2020 Bull S.A.S. All rights reserved.
Expand Down Expand Up @@ -93,9 +93,8 @@ mca_coll_han_gather_intra(const void *sbuf, int scount,
"han cannot handle gather with this communicator. Fall back on another component\n"));
/* HAN cannot work with this communicator so fallback on all collectives */
HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm);
return comm->c_coll->coll_gather(sbuf, scount, sdtype, rbuf,
rcount, rdtype, root,
comm, comm->c_coll->coll_gather_module);
return han_module->previous_gather(sbuf, scount, sdtype, rbuf, rcount, rdtype, root,
comm, han_module->previous_gather_module);
}

/* Topo must be initialized to know rank distribution which then is used to
Expand All @@ -108,9 +107,8 @@ mca_coll_han_gather_intra(const void *sbuf, int scount,
* future calls will then be automatically redirected.
*/
HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, gather);
return comm->c_coll->coll_gather(sbuf, scount, sdtype, rbuf,
rcount, rdtype, root,
comm, comm->c_coll->coll_gather_module);
return han_module->previous_gather(sbuf, scount, sdtype, rbuf, rcount, rdtype, root,
comm, han_module->previous_gather_module);
}

w_rank = ompi_comm_rank(comm);
Expand Down Expand Up @@ -359,9 +357,8 @@ mca_coll_han_gather_intra_simple(const void *sbuf, int scount,
"han cannot handle gather with this communicator. Fall back on another component\n"));
/* HAN cannot work with this communicator so fallback on all collectives */
HAN_LOAD_FALLBACK_COLLECTIVES(han_module, comm);
return comm->c_coll->coll_gather(sbuf, scount, sdtype, rbuf,
rcount, rdtype, root,
comm, comm->c_coll->coll_gather_module);
return han_module->previous_gather(sbuf, scount, sdtype, rbuf, rcount, rdtype, root,
comm, han_module->previous_gather_module);
}

/* Topo must be initialized to know rank distribution which then is used to
Expand All @@ -374,9 +371,8 @@ mca_coll_han_gather_intra_simple(const void *sbuf, int scount,
* future calls will then be automatically redirected.
*/
HAN_LOAD_FALLBACK_COLLECTIVE(han_module, comm, gather);
return comm->c_coll->coll_gather(sbuf, scount, sdtype, rbuf,
rcount, rdtype, root,
comm, comm->c_coll->coll_gather_module);
return han_module->previous_gather(sbuf, scount, sdtype, rbuf, rcount, rdtype, root,
comm, han_module->previous_gather_module);
}

ompi_communicator_t *low_comm = han_module->sub_comm[INTRA_NODE];
Expand Down
Loading