Skip to content

PML UCX: handle a synchronous send. #3347

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 26, 2017
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 40 additions & 45 deletions ompi/mca/pml/ucx/pml_ucx.c
Original file line number Diff line number Diff line change
Expand Up @@ -601,7 +601,7 @@ int mca_pml_ucx_isend_init(const void *buf, size_t count, ompi_datatype_t *datat
return OMPI_SUCCESS;
}

static int
static ucs_status_ptr_t
mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count,
ompi_datatype_t *datatype, uint64_t pml_tag)
{
Expand All @@ -623,21 +623,21 @@ mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count,
if (OPAL_UNLIKELY(NULL == packed_data)) {
OBJ_DESTRUCT(&opal_conv);
PML_UCX_ERROR("bsend: failed to allocate buffer");
return OMPI_ERR_OUT_OF_RESOURCE;
return UCS_STATUS_PTR(OMPI_ERROR);
}

iov_count = 1;
iov.iov_base = packed_data;
iov.iov_len = packed_length;

PML_UCX_VERBOSE(8, "bsend of packed buffer %p len %d", packed_data, packed_length);
PML_UCX_VERBOSE(8, "bsend of packed buffer %p len %zu", packed_data, packed_length);
offset = 0;
opal_convertor_set_position(&opal_conv, &offset);
if (0 > opal_convertor_pack(&opal_conv, &iov, &iov_count, &packed_length)) {
mca_pml_base_bsend_request_free(packed_data);
OBJ_DESTRUCT(&opal_conv);
PML_UCX_ERROR("bsend: failed to pack user datatype");
return OMPI_ERROR;
return UCS_STATUS_PTR(OMPI_ERROR);
}

OBJ_DESTRUCT(&opal_conv);
Expand All @@ -648,17 +648,34 @@ mca_pml_ucx_bsend(ucp_ep_h ep, const void *buf, size_t count,
if (NULL == req) {
/* request was completed in place */
mca_pml_base_bsend_request_free(packed_data);
return OMPI_SUCCESS;
return NULL;
}

if (OPAL_UNLIKELY(UCS_PTR_IS_ERR(req))) {
mca_pml_base_bsend_request_free(packed_data);
PML_UCX_ERROR("ucx bsend failed: %s", ucs_status_string(UCS_PTR_STATUS(req)));
return OMPI_ERROR;
return UCS_STATUS_PTR(OMPI_ERROR);
}

req->req_complete_cb_data = packed_data;
return OMPI_SUCCESS;
return NULL;
}

static inline ucs_status_ptr_t mca_pml_ucx_common_send(ucp_ep_h ep, const void *buf,
size_t count,
ompi_datatype_t *datatype,
ucp_datatype_t ucx_datatype,
ucp_tag_t tag,
mca_pml_base_send_mode_t mode,
ucp_send_callback_t cb)
{
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) {
return mca_pml_ucx_bsend(ep, buf, count, datatype, tag);
} else if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_SYNCHRONOUS == mode)) {
return ucp_tag_send_sync_nb(ep, buf, count, ucx_datatype, tag, cb);
} else {
return ucp_tag_send_nb(ep, buf, count, ucx_datatype, tag, cb);
}
}

int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
Expand All @@ -674,25 +691,17 @@ int mca_pml_ucx_isend(const void *buf, size_t count, ompi_datatype_t *datatype,
mode == MCA_PML_BASE_SEND_BUFFERED ? "b" : "",
(void*)request)

/* TODO special care to sync/buffered send */

ep = mca_pml_ucx_get_ep(comm, dst);
if (OPAL_UNLIKELY(NULL == ep)) {
PML_UCX_ERROR("Failed to get ep for rank %d", dst);
return OMPI_ERROR;
}

/* Special care to sync/buffered send */
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) {
*request = &ompi_pml_ucx.completed_send_req;
return mca_pml_ucx_bsend(ep, buf, count, datatype,
PML_UCX_MAKE_SEND_TAG(tag, comm));
}
req = (ompi_request_t*)mca_pml_ucx_common_send(ep, buf, count, datatype,
mca_pml_ucx_get_datatype(datatype),
PML_UCX_MAKE_SEND_TAG(tag, comm), mode,
mca_pml_ucx_send_completion);

req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count,
mca_pml_ucx_get_datatype(datatype),
PML_UCX_MAKE_SEND_TAG(tag, comm),
mca_pml_ucx_send_completion);
if (req == NULL) {
PML_UCX_VERBOSE(8, "returning completed request");
*request = &ompi_pml_ucx.completed_send_req;
Expand Down Expand Up @@ -723,16 +732,11 @@ int mca_pml_ucx_send(const void *buf, size_t count, ompi_datatype_t *datatype, i
return OMPI_ERROR;
}

/* Special care to sync/buffered send */
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == mode)) {
return mca_pml_ucx_bsend(ep, buf, count, datatype,
PML_UCX_MAKE_SEND_TAG(tag, comm));
}
req = (ompi_request_t*)mca_pml_ucx_common_send(ep, buf, count, datatype,
mca_pml_ucx_get_datatype(datatype),
PML_UCX_MAKE_SEND_TAG(tag, comm),
mode, mca_pml_ucx_send_completion);

req = (ompi_request_t*)ucp_tag_send_nb(ep, buf, count,
mca_pml_ucx_get_datatype(datatype),
PML_UCX_MAKE_SEND_TAG(tag, comm),
mca_pml_ucx_send_completion);
if (OPAL_LIKELY(req == NULL)) {
return OMPI_SUCCESS;
} else if (!UCS_PTR_IS_ERR(req)) {
Expand Down Expand Up @@ -891,7 +895,6 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
mca_pml_ucx_persistent_request_t *preq;
ompi_request_t *tmp_req;
size_t i;
int rc;

for (i = 0; i < count; ++i) {
preq = (mca_pml_ucx_persistent_request_t *)requests[i];
Expand All @@ -906,22 +909,14 @@ int mca_pml_ucx_start(size_t count, ompi_request_t** requests)
mca_pml_ucx_request_reset(&preq->ompi);

if (preq->flags & MCA_PML_UCX_REQUEST_FLAG_SEND) {
if (OPAL_UNLIKELY(MCA_PML_BASE_SEND_BUFFERED == preq->send.mode)) {
PML_UCX_VERBOSE(8, "start bsend request %p", (void*)preq);
rc = mca_pml_ucx_bsend(preq->send.ep, preq->buffer, preq->count,
preq->ompi_datatype, preq->tag);
if (OMPI_SUCCESS != rc) {
return rc;
}
/* pretend that we got immediate completion */
tmp_req = NULL;
} else {
PML_UCX_VERBOSE(8, "start send request %p", (void*)preq);
tmp_req = (ompi_request_t*)ucp_tag_send_nb(preq->send.ep, preq->buffer,
preq->count, preq->datatype,
preq->tag,
mca_pml_ucx_psend_completion);
}
tmp_req = (ompi_request_t*)mca_pml_ucx_common_send(preq->send.ep,
preq->buffer,
preq->count,
preq->ompi_datatype,
preq->datatype,
preq->tag,
preq->send.mode,
mca_pml_ucx_psend_completion);
} else {
PML_UCX_VERBOSE(8, "start recv request %p", (void*)preq);
tmp_req = (ompi_request_t*)ucp_tag_recv_nb(ompi_pml_ucx.ucp_worker,
Expand Down