Skip to content

Commit 2f4ef91

Browse files
authored
Merge pull request #12613 from tvegas1/ucx_thread_mode_v4.1.x
pml/ucx: Propagate MPI serialized thread mode
2 parents c89eaac + 8cd79c4 commit 2f4ef91

File tree

6 files changed

+34
-5
lines changed

6 files changed

+34
-5
lines changed

ompi/mca/osc/ucx/osc_ucx_component.c

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,8 +415,12 @@ static int component_select(struct ompi_win_t *win, void **base, size_t size, in
415415
assert(mca_osc_ucx_component.ucp_worker == NULL);
416416
memset(&worker_params, 0, sizeof(worker_params));
417417
worker_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
418-
worker_params.thread_mode = (mca_osc_ucx_component.enable_mpi_threads == true)
419-
? UCS_THREAD_MODE_MULTI : UCS_THREAD_MODE_SINGLE;
418+
if (mca_osc_ucx_component.enable_mpi_threads) {
419+
worker_params.thread_mode = UCS_THREAD_MODE_MULTI;
420+
} else {
421+
worker_params.thread_mode =
422+
opal_common_ucx_thread_mode(ompi_mpi_thread_provided);
423+
}
420424
status = ucp_worker_create(mca_osc_ucx_component.ucp_context, &worker_params,
421425
&(mca_osc_ucx_component.ucp_worker));
422426
if (UCS_OK != status) {

ompi/mca/pml/ucx/pml_ucx.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -288,12 +288,12 @@ int mca_pml_ucx_init(int enable_mpi_threads)
288288

289289
PML_UCX_VERBOSE(1, "mca_pml_ucx_init");
290290

291-
/* TODO check MPI thread mode */
292291
params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
293292
if (enable_mpi_threads) {
294293
params.thread_mode = UCS_THREAD_MODE_MULTI;
295294
} else {
296-
params.thread_mode = UCS_THREAD_MODE_SINGLE;
295+
params.thread_mode =
296+
opal_common_ucx_thread_mode(ompi_mpi_thread_provided);
297297
}
298298

299299
#if HAVE_DECL_UCP_WORKER_FLAG_IGNORE_REQUEST_LEAK

opal/mca/common/ucx/common_ucx.c

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include "opal/util/argv.h"
2626
#include "opal/util/printf.h"
2727

28+
#include "mpi.h"
29+
2830
#include <ucm/api/ucm.h>
2931
#include <fnmatch.h>
3032
#include <stdio.h>
@@ -49,6 +51,23 @@ static void opal_common_ucx_mem_release_cb(void *buf, size_t length,
4951
ucm_vm_munmap(buf, length);
5052
}
5153

54+
ucs_thread_mode_t opal_common_ucx_thread_mode(int ompi_mode)
55+
{
56+
switch (ompi_mode) {
57+
case MPI_THREAD_MULTIPLE:
58+
return UCS_THREAD_MODE_MULTI;
59+
case MPI_THREAD_SERIALIZED:
60+
return UCS_THREAD_MODE_SERIALIZED;
61+
case MPI_THREAD_FUNNELED:
62+
case MPI_THREAD_SINGLE:
63+
return UCS_THREAD_MODE_SINGLE;
64+
default:
65+
MCA_COMMON_UCX_WARN("Unknown MPI thread mode %d, using multithread",
66+
ompi_mode);
67+
return UCS_THREAD_MODE_MULTI;
68+
}
69+
}
70+
5271
OPAL_DECLSPEC void opal_common_ucx_mca_var_register(const mca_base_component_t *component)
5372
{
5473
char *default_tls = "rc_verbs,ud_verbs,rc_mlx5,dc_mlx5,ud_mlx5,cuda_ipc,rocm_ipc";

opal/mca/common/ucx/common_ucx.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ OPAL_DECLSPEC int opal_common_ucx_del_procs(opal_common_ucx_del_proc_t *procs, s
124124
OPAL_DECLSPEC int opal_common_ucx_del_procs_nofence(opal_common_ucx_del_proc_t *procs, size_t count,
125125
size_t my_rank, size_t max_disconnect, ucp_worker_h worker);
126126
OPAL_DECLSPEC void opal_common_ucx_mca_var_register(const mca_base_component_t *component);
127+
OPAL_DECLSPEC ucs_thread_mode_t opal_common_ucx_thread_mode(int ompi_mode);
127128

128129
static inline
129130
ucs_status_t opal_common_ucx_request_status(ucs_status_ptr_t request)

oshmem/mca/spml/ucx/spml_ucx.c

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1016,8 +1016,11 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx
10161016
ucx_ctx->strong_sync = mca_spml_ucx_ctx_default.strong_sync;
10171017

10181018
params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
1019-
if (oshmem_mpi_thread_provided == SHMEM_THREAD_SINGLE || options & SHMEM_CTX_PRIVATE || options & SHMEM_CTX_SERIALIZED) {
1019+
if (oshmem_mpi_thread_provided == SHMEM_THREAD_SINGLE ||
1020+
oshmem_mpi_thread_provided == SHMEM_THREAD_FUNNELED || options & SHMEM_CTX_PRIVATE) {
10201021
params.thread_mode = UCS_THREAD_MODE_SINGLE;
1022+
} else if (oshmem_mpi_thread_provided == SHMEM_THREAD_SERIALIZED || options & SHMEM_CTX_SERIALIZED) {
1023+
params.thread_mode = UCS_THREAD_MODE_SERIALIZED;
10211024
} else {
10221025
params.thread_mode = UCS_THREAD_MODE_MULTI;
10231026
}

oshmem/mca/spml/ucx/spml_ucx_component.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,8 @@ static int spml_ucx_init(void)
322322
wkr_params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE;
323323
if (oshmem_mpi_thread_requested == SHMEM_THREAD_MULTIPLE) {
324324
wkr_params.thread_mode = UCS_THREAD_MODE_MULTI;
325+
} else if (oshmem_mpi_thread_requested == SHMEM_THREAD_SERIALIZED) {
326+
wkr_params.thread_mode = UCS_THREAD_MODE_SERIALIZED;
325327
} else {
326328
wkr_params.thread_mode = UCS_THREAD_MODE_SINGLE;
327329
}

0 commit comments

Comments
 (0)