Skip to content

Commit 519234a

Browse files
committed
coll/accelerator: allow to select functions to register
This PR introduces the ability to register the component only to the select functions specified by an MCA parameter string. The idea and the code is based on the UCC component, and some of the bits might be moved later to coll/base to make the mechanism more gnerally available to other components as well. Note, that the PR introduces the define statments for all MPI collective operations, not just the ones support by the component at the moment, since it is a bitmask based operation, and we anticipate to add support for more collective operations into coll/accelerator shortly Signed-off-by: Edgar Gabriel <[email protected]>
1 parent 60807d7 commit 519234a

File tree

3 files changed

+123
-12
lines changed

3 files changed

+123
-12
lines changed

ompi/mca/coll/accelerator/coll_accelerator.h

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,36 @@
3232

3333
BEGIN_C_DECLS
3434

35+
#define COLL_ACC_ALLGATHER 0x00000001
36+
#define COLL_ACC_ALLGATHERV 0x00000002
37+
#define COLL_ACC_ALLREDUCE 0x00000004
38+
#define COLL_ACC_ALLTOALL 0x00000008
39+
#define COLL_ACC_ALLTOALLV 0x00000010
40+
#define COLL_ACC_ALLTOALLW 0x00000020
41+
#define COLL_ACC_BARRIER 0x00000040
42+
#define COLL_ACC_BCAST 0x00000080
43+
#define COLL_ACC_EXSCAN 0x00000100
44+
#define COLL_ACC_GATHER 0x00000200
45+
#define COLL_ACC_GATHERV 0x00000400
46+
#define COLL_ACC_REDUCE 0x00000800
47+
#define COLL_ACC_REDUCE_SCATTER 0x00001000
48+
#define COLL_ACC_REDUCE_SCATTER_BLOCK 0x00002000
49+
#define COLL_ACC_REDUCE_LOCAL 0x00004000
50+
#define COLL_ACC_SCAN 0x00008000
51+
#define COLL_ACC_SCATTER 0x00010000
52+
#define COLL_ACC_SCATTERV 0x00020000
53+
#define COLL_ACC_NEIGHBOR_ALLGATHER 0x00040000
54+
#define COLL_ACC_NEIGHBOR_ALLGATHERV 0x00080000
55+
#define COLL_ACC_NEIGHBOR_ALLTOALL 0x00100000
56+
#define COLL_ACC_NEIGHBOR_ALLTTOALLV 0x00200000
57+
#define COLL_ACC_NEIGHBOR_ALLTTOALLW 0x00400000
58+
#define COLL_ACC_LASTCOLL 0x00800000
59+
60+
#define COLL_ACCELERATOR_CTS_STR "allreduce,reduce_scatter_block,reduce_local,reduce,scan,exscan"
61+
#define COLL_ACCELERATOR_CTS COLL_ACC_ALLREDUCE | COLL_ACC_REDUCE | \
62+
COLL_ACC_REDUCE_SCATTER_BLOCK | COLL_ACC_REDUCE_LOCAL | \
63+
COLL_ACC_EXSCAN | COLL_ACC_SCAN
64+
3565
/* API functions */
3666

3767
int mca_coll_accelerator_init_query(bool enable_progress_threads,
@@ -134,6 +164,8 @@ typedef struct mca_coll_accelerator_component_t {
134164

135165
int priority; /* Priority of this component */
136166
int disable_accelerator_coll; /* Force disable of the accelerator collective component */
167+
char *cts; /* String of collective operations which the component shall register itself */
168+
uint64_t cts_requested;
137169
} mca_coll_accelerator_component_t;
138170

139171
/* Globally exported variables */

ompi/mca/coll/accelerator/coll_accelerator_component.c

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
* Copyright (c) 2015 Los Alamos National Security, LLC. All rights
88
* reserved.
99
* Copyright (c) 2024 Triad National Security, LLC. All rights reserved.
10+
* Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights reserved.
1011
* $COPYRIGHT$
1112
*
1213
* Additional copyrights may follow
@@ -21,6 +22,7 @@
2122
#include "mpi.h"
2223
#include "ompi/constants.h"
2324
#include "coll_accelerator.h"
25+
#include "opal/util/argv.h"
2426

2527
/*
2628
* Public string showing the coll ompi_accelerator component version number
@@ -31,6 +33,7 @@ const char *mca_coll_accelerator_component_version_string =
3133
/*
3234
* Local function
3335
*/
36+
static int accelerator_open(void);
3437
static int accelerator_register(void);
3538

3639
/*
@@ -52,6 +55,7 @@ mca_coll_accelerator_component_t mca_coll_accelerator_component = {
5255
OMPI_RELEASE_VERSION),
5356

5457
/* Component open and close functions */
58+
.mca_open_component = accelerator_open,
5559
.mca_register_component_params = accelerator_register,
5660
},
5761
.collm_data = {
@@ -75,7 +79,8 @@ mca_coll_accelerator_component_t mca_coll_accelerator_component = {
7579
static int accelerator_register(void)
7680
{
7781
(void) mca_base_component_var_register(&mca_coll_accelerator_component.super.collm_version,
78-
"priority", "Priority of the accelerator coll component; only relevant if barrier_before or barrier_after is > 0",
82+
"priority", "Priority of the accelerator coll component; only relevant if barrier_before "
83+
"or barrier_after is > 0",
7984
MCA_BASE_VAR_TYPE_INT, NULL, 0, 0,
8085
OPAL_INFO_LVL_6,
8186
MCA_BASE_VAR_SCOPE_READONLY,
@@ -88,5 +93,76 @@ static int accelerator_register(void)
8893
MCA_BASE_VAR_SCOPE_READONLY,
8994
&mca_coll_accelerator_component.disable_accelerator_coll);
9095

96+
mca_coll_accelerator_component.cts = COLL_ACCELERATOR_CTS_STR;
97+
(void)mca_base_component_var_register(&mca_coll_accelerator_component.super.collm_version,
98+
"cts", "Comma separated list of collectives to be enabled",
99+
MCA_BASE_VAR_TYPE_STRING, NULL, 0, MCA_BASE_VAR_FLAG_SETTABLE,
100+
OPAL_INFO_LVL_6, MCA_BASE_VAR_SCOPE_ALL, &mca_coll_accelerator_component.cts);
101+
102+
return OMPI_SUCCESS;
103+
}
104+
105+
106+
/* The string parsing is based on the code available in the coll/ucc component */
107+
static uint64_t mca_coll_accelerator_str_to_type(const char *str)
108+
{
109+
if (0 == strcasecmp(str, "allreduce")) {
110+
return COLL_ACC_ALLREDUCE;
111+
} else if (0 == strcasecmp(str, "reduce_scatter_block")) {
112+
return COLL_ACC_REDUCE_SCATTER_BLOCK;
113+
} else if (0 == strcasecmp(str, "reduce_local")) {
114+
return COLL_ACC_REDUCE_LOCAL;
115+
} else if (0 == strcasecmp(str, "reduce")) {
116+
return COLL_ACC_REDUCE;
117+
} else if (0 == strcasecmp(str, "exscan")) {
118+
return COLL_ACC_EXSCAN;
119+
} else if (0 == strcasecmp(str, "scan")) {
120+
return COLL_ACC_SCAN;
121+
}
122+
opal_output(0, "incorrect value for cts: %s, allowed: %s",
123+
str, COLL_ACCELERATOR_CTS_STR);
124+
return COLL_ACC_LASTCOLL;
125+
}
126+
127+
static void accelerator_init_default_cts(void)
128+
{
129+
mca_coll_accelerator_component_t *cm = &mca_coll_accelerator_component;
130+
bool disable;
131+
char** cts;
132+
int n_cts, i;
133+
char* str;
134+
uint64_t *ct, c;
135+
136+
disable = (cm->cts[0] == '^') ? true : false;
137+
cts = opal_argv_split(disable ? (cm->cts + 1) : cm->cts, ',');
138+
n_cts = opal_argv_count(cts);
139+
cm->cts_requested = disable ? COLL_ACCELERATOR_CTS : 0;
140+
for (i = 0; i < n_cts; i++) {
141+
if (('i' == cts[i][0]) || ('I' == cts[i][0])) {
142+
/* non blocking collective setting */
143+
opal_output(0, "coll/accelerator component does not support non-blocking collectives at this time."
144+
" Ignoring collective: %s\n", cts[i]);
145+
continue;
146+
} else {
147+
str = cts[i];
148+
ct = &cm->cts_requested;
149+
}
150+
c = mca_coll_accelerator_str_to_type(str);
151+
if (COLL_ACC_LASTCOLL == c) {
152+
*ct = COLL_ACCELERATOR_CTS;
153+
break;
154+
}
155+
if (disable) {
156+
(*ct) &= ~c;
157+
} else {
158+
(*ct) |= c;
159+
}
160+
}
161+
opal_argv_free(cts);
162+
}
163+
164+
static int accelerator_open(void)
165+
{
166+
accelerator_init_default_cts();
91167
return OMPI_SUCCESS;
92168
}

ompi/mca/coll/accelerator/coll_accelerator_module.c

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* Copyright (c) 2014-2024 NVIDIA Corporation. All rights reserved.
77
* Copyright (c) 2019 Research Organization for Information Science
88
* and Technology (RIST). All rights reserved.
9-
* Copyright (c) 2023 Advanced Micro Devices, Inc. All rights reserved.
9+
* Copyright (c) 2023-2024 Advanced Micro Devices, Inc. All rights reserved.
1010
* Copyright (c) 2024 Triad National Security, LLC. All rights reserved.
1111
* $COPYRIGHT$
1212
*
@@ -106,18 +106,21 @@ mca_coll_accelerator_comm_query(struct ompi_communicator_t *comm,
106106
}
107107

108108

109-
#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api) \
109+
#define ACCELERATOR_INSTALL_COLL_API(__comm, __module, __api, __API) \
110110
do \
111111
{ \
112112
if ((__comm)->c_coll->coll_##__api) \
113113
{ \
114-
MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
115-
MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \
114+
if (mca_coll_accelerator_component.cts_requested & COLL_ACC_##__API) \
115+
{ \
116+
MCA_COLL_SAVE_API(__comm, __api, (__module)->c_coll.coll_##__api, (__module)->c_coll.coll_##__api##_module, "accelerator"); \
117+
MCA_COLL_INSTALL_API(__comm, __api, mca_coll_accelerator_##__api, &__module->super, "accelerator"); \
118+
} \
116119
} \
117120
else \
118121
{ \
119122
opal_show_help("help-mca-coll-base.txt", "comm-select:missing collective", true, \
120-
"cuda", #__api, ompi_process_info.nodename, \
123+
"accelerator", #__api, ompi_process_info.nodename, \
121124
mca_coll_accelerator_component.priority); \
122125
} \
123126
} while (0)
@@ -141,14 +144,14 @@ mca_coll_accelerator_module_enable(mca_coll_base_module_t *module,
141144
{
142145
mca_coll_accelerator_module_t *s = (mca_coll_accelerator_module_t*) module;
143146

144-
ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce);
145-
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce);
146-
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_local);
147-
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block);
147+
ACCELERATOR_INSTALL_COLL_API(comm, s, allreduce, ALLREDUCE);
148+
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce, REDUCE);
149+
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_local, REDUCE_LOCAL);
150+
ACCELERATOR_INSTALL_COLL_API(comm, s, reduce_scatter_block, REDUCE_SCATTER_BLOCK);
148151
if (!OMPI_COMM_IS_INTER(comm)) {
149152
/* MPI does not define scan/exscan on intercommunicators */
150-
ACCELERATOR_INSTALL_COLL_API(comm, s, exscan);
151-
ACCELERATOR_INSTALL_COLL_API(comm, s, scan);
153+
ACCELERATOR_INSTALL_COLL_API(comm, s, exscan, EXSCAN);
154+
ACCELERATOR_INSTALL_COLL_API(comm, s, scan, SCAN);
152155
}
153156

154157
return OMPI_SUCCESS;

0 commit comments

Comments
 (0)