Skip to content

Commit 2aea543

Browse files
committed
Add multithreading support in PML UCX framework.
Signed-off-by: Xin Zhao <[email protected]>
1 parent 330b11c commit 2aea543

File tree

2 files changed

+68
-45
lines changed

2 files changed

+68
-45
lines changed

ompi/mca/pml/ucx/pml_ucx.c

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,12 +135,15 @@ int mca_pml_ucx_open(void)
135135
UCP_PARAM_FIELD_REQUEST_SIZE |
136136
UCP_PARAM_FIELD_REQUEST_INIT |
137137
UCP_PARAM_FIELD_REQUEST_CLEANUP |
138-
UCP_PARAM_FIELD_TAG_SENDER_MASK;
138+
UCP_PARAM_FIELD_TAG_SENDER_MASK |
139+
UCP_PARAM_FIELD_MT_WORKERS_SHARED;
139140
params.features = UCP_FEATURE_TAG;
140141
params.request_size = sizeof(ompi_request_t);
141142
params.request_init = mca_pml_ucx_request_init;
142143
params.request_cleanup = mca_pml_ucx_request_cleanup;
143144
params.tag_sender_mask = PML_UCX_SPECIFIC_SOURCE_MASK;
145+
params.mt_workers_shared = 0; /* we do not need mt support for context
146+
since it will be protected by worker */
144147

145148
status = ucp_init(&params, config, &ompi_pml_ucx.ucp_context);
146149
ucp_config_release(config);
@@ -178,20 +181,35 @@ int mca_pml_ucx_init(void)
178181
{
179182
ucp_worker_params_t params;
180183
ucs_status_t status;
184+
ucp_worker_attr_t attr;
181185
int rc;
182186

183187
PML_UCX_VERBOSE(1, "mca_pml_ucx_init");
184188

185189
/* TODO check MPI thread mode */
186190
params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
187-
params.thread_mode = UCS_THREAD_MODE_SINGLE;
191+
if (ompi_mpi_thread_multiple) {
192+
params.thread_mode = UCS_THREAD_MODE_MULTI;
193+
} else {
194+
params.thread_mode = UCS_THREAD_MODE_SINGLE;
195+
}
188196

189197
status = ucp_worker_create(ompi_pml_ucx.ucp_context, &params,
190198
&ompi_pml_ucx.ucp_worker);
191199
if (UCS_OK != status) {
192200
return OMPI_ERROR;
193201
}
194202

203+
attr.field_mask = UCP_WORKER_ATTR_FIELD_THREAD_MODE;
204+
status = ucp_worker_query(ompi_pml_ucx.ucp_worker, &attr);
205+
if (UCS_OK != status) {
206+
return OMPI_ERROR;
207+
}
208+
209+
ompi_mpi_thread_multiple = (attr.thread_mode == UCS_THREAD_MODE_MULTI);
210+
ompi_mpi_thread_provided = (ompi_mpi_thread_multiple == true ?
211+
MPI_THREAD_MULTIPLE : MPI_THREAD_SINGLE);
212+
195213
rc = mca_pml_ucx_send_worker_address();
196214
if (rc < 0) {
197215
return rc;

ompi/runtime/ompi_mpi_init.c

Lines changed: 48 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -540,19 +540,6 @@ int ompi_mpi_init(int argc, char **argv, int requested, int *provided)
540540
goto error;
541541
}
542542

543-
544-
/* determine the bitflag belonging to the threadlevel_support provided */
545-
memset ( &threadlevel_bf, 0, sizeof(uint8_t));
546-
OMPI_THREADLEVEL_SET_BITFLAG ( ompi_mpi_thread_provided, threadlevel_bf );
547-
548-
/* add this bitflag to the modex */
549-
OPAL_MODEX_SEND_STRING(ret, OPAL_PMIX_GLOBAL,
550-
"MPI_THREAD_LEVEL", &threadlevel_bf, sizeof(uint8_t));
551-
if (OPAL_SUCCESS != ret) {
552-
error = "ompi_mpi_init: modex send thread level";
553-
goto error;
554-
}
555-
556543
/* initialize datatypes. This step should be done early as it will
557544
* create the local convertor and local arch used in the proc
558545
* init.
@@ -568,25 +555,6 @@ int ompi_mpi_init(int argc, char **argv, int requested, int *provided)
568555
goto error;
569556
}
570557

571-
/* Initialize the op framework. This has to be done *after*
572-
ddt_init, but befor mca_coll_base_open, since some collective
573-
modules (e.g., the hierarchical coll component) may need ops in
574-
their query function. */
575-
if (OMPI_SUCCESS != (ret = mca_base_framework_open(&ompi_op_base_framework, 0))) {
576-
error = "ompi_op_base_open() failed";
577-
goto error;
578-
}
579-
if (OMPI_SUCCESS !=
580-
(ret = ompi_op_base_find_available(OPAL_ENABLE_PROGRESS_THREADS,
581-
ompi_mpi_thread_multiple))) {
582-
error = "ompi_op_base_find_available() failed";
583-
goto error;
584-
}
585-
if (OMPI_SUCCESS != (ret = ompi_op_init())) {
586-
error = "ompi_op_init() failed";
587-
goto error;
588-
}
589-
590558
/* Open up MPI-related MCA components */
591559

592560
if (OMPI_SUCCESS != (ret = mca_base_framework_open(&opal_allocator_base_framework, 0))) {
@@ -601,6 +569,24 @@ int ompi_mpi_init(int argc, char **argv, int requested, int *provided)
601569
error = "mca_mpool_base_open() failed";
602570
goto error;
603571
}
572+
573+
/* We need to initialize PML before mca_bml_base_open() and
574+
mca_op_base_find_available(), since this may modify ompi_mpi_thread_multiple,
575+
which are used in mca_bml_base_open() and mca_op_base_find_available(). */
576+
577+
if (OMPI_SUCCESS != (ret = mca_base_framework_open(&ompi_pml_base_framework, 0))) {
578+
error = "mca_pml_base_open() failed";
579+
goto error;
580+
}
581+
582+
/* Select which MPI components to use */
583+
if (OMPI_SUCCESS !=
584+
(ret = mca_pml_base_select(OPAL_ENABLE_PROGRESS_THREADS,
585+
ompi_mpi_thread_multiple))) {
586+
error = "mca_pml_base_select() failed";
587+
goto error;
588+
}
589+
604590
if (OMPI_SUCCESS != (ret = mca_base_framework_open(&ompi_bml_base_framework, 0))) {
605591
error = "mca_bml_base_open() failed";
606592
goto error;
@@ -609,10 +595,26 @@ int ompi_mpi_init(int argc, char **argv, int requested, int *provided)
609595
error = "mca_bml_base_init() failed";
610596
goto error;
611597
}
612-
if (OMPI_SUCCESS != (ret = mca_base_framework_open(&ompi_pml_base_framework, 0))) {
613-
error = "mca_pml_base_open() failed";
598+
599+
/* Initialize the op framework. This has to be done *after*
600+
ddt_init, but befor mca_coll_base_open, since some collective
601+
modules (e.g., the hierarchical coll component) may need ops in
602+
their query function. */
603+
if (OMPI_SUCCESS != (ret = mca_base_framework_open(&ompi_op_base_framework, 0))) {
604+
error = "ompi_op_base_open() failed";
605+
goto error;
606+
}
607+
if (OMPI_SUCCESS !=
608+
(ret = ompi_op_base_find_available(OPAL_ENABLE_PROGRESS_THREADS,
609+
ompi_mpi_thread_multiple))) {
610+
error = "ompi_op_base_find_available() failed";
611+
goto error;
612+
}
613+
if (OMPI_SUCCESS != (ret = ompi_op_init())) {
614+
error = "ompi_op_init() failed";
614615
goto error;
615616
}
617+
616618
if (OMPI_SUCCESS != (ret = mca_base_framework_open(&ompi_coll_base_framework, 0))) {
617619
error = "mca_coll_base_open() failed";
618620
goto error;
@@ -630,21 +632,24 @@ int ompi_mpi_init(int argc, char **argv, int requested, int *provided)
630632
}
631633
#endif
632634

635+
/* determine the bitflag belonging to the threadlevel_support provided */
636+
memset ( &threadlevel_bf, 0, sizeof(uint8_t));
637+
OMPI_THREADLEVEL_SET_BITFLAG ( ompi_mpi_thread_provided, threadlevel_bf );
638+
639+
/* add this bitflag to the modex */
640+
OPAL_MODEX_SEND_STRING(ret, OPAL_PMIX_GLOBAL,
641+
"MPI_THREAD_LEVEL", &threadlevel_bf, sizeof(uint8_t));
642+
if (OPAL_SUCCESS != ret) {
643+
error = "ompi_mpi_init: modex send thread level";
644+
goto error;
645+
}
646+
633647
/* In order to reduce the common case for MPI apps (where they
634648
don't use MPI-2 IO or MPI-1 topology functions), the io and
635649
topo frameworks are initialized lazily, at the first use of
636650
relevant functions (e.g., MPI_FILE_*, MPI_CART_*, MPI_GRAPH_*),
637651
so they are not opened here. */
638652

639-
/* Select which MPI components to use */
640-
641-
if (OMPI_SUCCESS !=
642-
(ret = mca_pml_base_select(OPAL_ENABLE_PROGRESS_THREADS,
643-
ompi_mpi_thread_multiple))) {
644-
error = "mca_pml_base_select() failed";
645-
goto error;
646-
}
647-
648653
/* check for timing request - get stop time and report elapsed time if so */
649654
OPAL_TIMING_MNEXT((&tm,"time to execute modex"));
650655

0 commit comments

Comments
 (0)