7
7
* Copyright (c) 2015 Los Alamos National Security, LLC. All rights
8
8
* reserved.
9
9
* Copyright (c) 2024 Triad National Security, LLC. All rights reserved.
10
+ * Copyright (c) 2024 Advanced Micro Devices, Inc. All Rights reserved.
10
11
* $COPYRIGHT$
11
12
*
12
13
* Additional copyrights may follow
21
22
#include "mpi.h"
22
23
#include "ompi/constants.h"
23
24
#include "coll_accelerator.h"
25
+ #include "opal/util/argv.h"
24
26
25
27
/*
26
28
* Public string showing the coll ompi_accelerator component version number
@@ -31,6 +33,7 @@ const char *mca_coll_accelerator_component_version_string =
31
33
/*
32
34
* Local function
33
35
*/
36
+ static int accelerator_open (void );
34
37
static int accelerator_register (void );
35
38
36
39
/*
@@ -52,6 +55,7 @@ mca_coll_accelerator_component_t mca_coll_accelerator_component = {
52
55
OMPI_RELEASE_VERSION ),
53
56
54
57
/* Component open and close functions */
58
+ .mca_open_component = accelerator_open ,
55
59
.mca_register_component_params = accelerator_register ,
56
60
},
57
61
.collm_data = {
@@ -75,7 +79,8 @@ mca_coll_accelerator_component_t mca_coll_accelerator_component = {
75
79
static int accelerator_register (void )
76
80
{
77
81
(void ) mca_base_component_var_register (& mca_coll_accelerator_component .super .collm_version ,
78
- "priority" , "Priority of the accelerator coll component; only relevant if barrier_before or barrier_after is > 0" ,
82
+ "priority" , "Priority of the accelerator coll component; only relevant if barrier_before "
83
+ "or barrier_after is > 0" ,
79
84
MCA_BASE_VAR_TYPE_INT , NULL , 0 , 0 ,
80
85
OPAL_INFO_LVL_6 ,
81
86
MCA_BASE_VAR_SCOPE_READONLY ,
@@ -88,5 +93,76 @@ static int accelerator_register(void)
88
93
MCA_BASE_VAR_SCOPE_READONLY ,
89
94
& mca_coll_accelerator_component .disable_accelerator_coll );
90
95
96
+ mca_coll_accelerator_component .cts = COLL_ACCELERATOR_CTS_STR ;
97
+ (void )mca_base_component_var_register (& mca_coll_accelerator_component .super .collm_version ,
98
+ "cts" , "Comma separated list of collectives to be enabled" ,
99
+ MCA_BASE_VAR_TYPE_STRING , NULL , 0 , MCA_BASE_VAR_FLAG_SETTABLE ,
100
+ OPAL_INFO_LVL_6 , MCA_BASE_VAR_SCOPE_ALL , & mca_coll_accelerator_component .cts );
101
+
102
+ return OMPI_SUCCESS ;
103
+ }
104
+
105
+
106
+ /* The string parsing is based on the code available in the coll/ucc component */
107
+ static uint64_t mca_coll_accelerator_str_to_type (const char * str )
108
+ {
109
+ if (0 == strcasecmp (str , "allreduce" )) {
110
+ return COLL_ACC_ALLREDUCE ;
111
+ } else if (0 == strcasecmp (str , "reduce_scatter_block" )) {
112
+ return COLL_ACC_REDUCE_SCATTER_BLOCK ;
113
+ } else if (0 == strcasecmp (str , "reduce_local" )) {
114
+ return COLL_ACC_REDUCE_LOCAL ;
115
+ } else if (0 == strcasecmp (str , "reduce" )) {
116
+ return COLL_ACC_REDUCE ;
117
+ } else if (0 == strcasecmp (str , "exscan" )) {
118
+ return COLL_ACC_EXSCAN ;
119
+ } else if (0 == strcasecmp (str , "scan" )) {
120
+ return COLL_ACC_SCAN ;
121
+ }
122
+ opal_output (0 , "incorrect value for cts: %s, allowed: %s" ,
123
+ str , COLL_ACCELERATOR_CTS_STR );
124
+ return COLL_ACC_LASTCOLL ;
125
+ }
126
+
127
+ static void accelerator_init_default_cts (void )
128
+ {
129
+ mca_coll_accelerator_component_t * cm = & mca_coll_accelerator_component ;
130
+ bool disable ;
131
+ char * * cts ;
132
+ int n_cts , i ;
133
+ char * str ;
134
+ uint64_t * ct , c ;
135
+
136
+ disable = (cm -> cts [0 ] == '^' ) ? true : false;
137
+ cts = opal_argv_split (disable ? (cm -> cts + 1 ) : cm -> cts , ',' );
138
+ n_cts = opal_argv_count (cts );
139
+ cm -> cts_requested = disable ? COLL_ACCELERATOR_CTS : 0 ;
140
+ for (i = 0 ; i < n_cts ; i ++ ) {
141
+ if (('i' == cts [i ][0 ]) || ('I' == cts [i ][0 ])) {
142
+ /* non blocking collective setting */
143
+ opal_output (0 , "coll/accelerator component does not support non-blocking collectives at this time."
144
+ " Ignoring collective: %s\n" , cts [i ]);
145
+ continue ;
146
+ } else {
147
+ str = cts [i ];
148
+ ct = & cm -> cts_requested ;
149
+ }
150
+ c = mca_coll_accelerator_str_to_type (str );
151
+ if (COLL_ACC_LASTCOLL == c ) {
152
+ * ct = COLL_ACCELERATOR_CTS ;
153
+ break ;
154
+ }
155
+ if (disable ) {
156
+ (* ct ) &= ~c ;
157
+ } else {
158
+ (* ct ) |= c ;
159
+ }
160
+ }
161
+ opal_argv_free (cts );
162
+ }
163
+
164
+ static int accelerator_open (void )
165
+ {
166
+ accelerator_init_default_cts ();
91
167
return OMPI_SUCCESS ;
92
168
}
0 commit comments